线性回归拟合并绘图

本题出自蓝桥云题库:https://www.lanqiao.cn/problems/78/learning/

介绍

线性回归是机器学习中最基础、最重要的方法之一。接下来,你需要根据题目提供的数据点,完成线性拟合,并绘制出图像。

目标

题目给出一个二维数组如下,共计 20 个数据样本。

data = [[5.06, 5.79], [4.92, 6.61], [4.67, 5.48], [4.54, 6.11], [4.26, 6.39],
        [4.07, 4.81], [4.01, 4.16], [4.01, 5.55], [3.66, 5.05], [3.43, 4.34],
        [3.12, 3.24], [3.02, 4.80], [2.87, 4.01], [2.64, 3.17], [2.48, 1.61],
        [2.48, 2.62], [2.02, 2.50], [1.95, 3.59], [1.79, 1.49], [1.54, 2.10], ]

你需要根据这 20 个样本,使用线性回归拟合,得到自变量系数及截距项。

y = w x + b \displaystyle y = wx + b y=wx+b
其中, w w w 即为自变量系数, b b b 则为常数项。

最后,需要使用 Matplotlib 将数据样本绘制成散点图,并将拟合直线一并绘出。

提示

你可以使用自行实现的最小二乘法函数计算 ww 和 bb 的值,也可以使用 scikit-learn 提供的线性回归类完成。提示代码如下:

def linear_plot():
    """
    参数:无
    
    返回:
    w -- 自变量系数, 保留两位小数
    b -- 截距项, 保留两位小数
    fig -- matplotlib 绘图对象
    """

    data = [[5.06, 5.79], [4.92, 6.61], [4.67, 5.48], [4.54, 6.11], [4.26, 6.39],
            [4.07, 4.81], [4.01, 4.16], [4.01, 5.55], [3.66, 5.05], [3.43, 4.34],
            [3.12, 3.24], [3.02, 4.80], [2.87, 4.01], [2.64, 3.17], [2.48, 1.61],
            [2.48, 2.62], [2.02, 2.50], [1.95, 3.59], [1.79, 1.49], [1.54, 2.10], ]
    
    ### TODO: 线性拟合计算参数 ###
    
    w = None
    b = None
    
    fig = plt.figure() # 务必保留此行,设置绘图对象
    
    ### TODO: 按题目要求绘图 ### 
    
    return w, b, fig # 务必按此顺序返回

图像示例:

在这里插入图片描述

实现思路

这里使用最小二乘法解决这一问题,先来看一下最小二乘法的数学表达式。

在这里插入图片描述
根据上面的公式写出代码,就可以直接求待拟合直线的斜率和截距了。

参考代码

def linear_plot():
    import numpy as np
    import matplotlib.pyplot as plt
    """
    参数:无
    
    返回:
    w -- 自变量系数, 保留两位小数
    b -- 截距项, 保留两位小数
    fig -- matplotlib 绘图对象
    """

    data = [[5.06, 5.79], [4.92, 6.61], [4.67, 5.48], [4.54, 6.11], [4.26, 6.39],
            [4.07, 4.81], [4.01, 4.16], [4.01, 5.55], [3.66, 5.05], [3.43, 4.34],
            [3.12, 3.24], [3.02, 4.80], [2.87, 4.01], [2.64, 3.17], [2.48, 1.61],
            [2.48, 2.62], [2.02, 2.50], [1.95, 3.59], [1.79, 1.49], [1.54, 2.10], ]
    
    ### TODO: 线性拟合计算参数 ###
    SumXiYi = 0
    SumXi = 0
    SumYi = 0
    SumXi2 = 0
    PointX = []
    PointY = []
    for item in range(len(data)):
        XiYi = data[item][0] * data[item][1]
        SumXiYi += XiYi
        SumXi += data[item][0]
        SumYi += data[item][1]
        SumXi2 += data[item][0] * data[item][0]
        PointX.append(data[item][0])
        PointY.append(data[item][1])
    
    
    w = (len(data) * SumXiYi - SumXi * SumYi) / (len(data) * SumXi2 - SumXi * SumXi)
    b = (SumXi2 * SumYi - SumXiYi * SumXi) / (len(data) * SumXi2 - SumXi * SumXi)
    w = round(w,2)
    b = round(b,2)
    
    X = np.arange(0.5, 6, 0.01)
    Y = w * X + b

    plt.plot(X, Y, color='red')
    plt.scatter(PointX, PointY,  color='blue') # 散点图

    fig = plt.figure() # 务必保留此行,设置绘图对象
    
    ### TODO: 按题目要求绘图 ### 
    fig.show()
    
    return w, b, fig # 务必按此顺序返回

if __name__=="__main__":
    linear_plot()

绘制出的图像如下图所示:

在这里插入图片描述

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐