一元线性回归

发布于:2025-04-18 ⋅ 阅读:(70) ⋅ 点赞:(0)

一、数据

 二、原理(以下为大致原理,详细原理关注接下来的文章)

三、代码实现

import matplotlib.pylab as plt
import numpy as np

y=np.array([200,400,500,300,800,150,360,50,750,780,110,160,520,560,145,350,80,600,367,157])
x=np.array([5000,6500,7000,6000,10000,3800,6000,1500,9000,9400,3000,4500,7000,7500,4500,6000,2000,7800,6500,4500])

dot_y=y
dot_x=x#存原数据,用于画点状图;

x_range=np.array([2000,10000])
y_mean=np.mean(y)
y_std=np.std(y)
x_mean=np.mean(x)
x_std=np.std(x)
y=(y-y_mean)/y_std
x=(x-x_mean)/x_std#数据标准化;

max_iter=10000#限制下降次数;
volum=20#数据量
tolerance=1e-9#拟合阈值;

alpha=0.01#初始化步长;
w=0.0#初始化
b=0.0#初始化

for _ in range(max_iter):
    temp_w=0.0
    temp_b=0.0
    for i in range(0,volum):
        error=w*x[i]+b-y[i]
        temp_w+=error*x[i]
        temp_b+=error
    temp_w*=(alpha/volum)
    temp_b*=(alpha/volum)
    w-=temp_w
    b-=temp_b

    if abs(temp_w)<tolerance and abs(temp_b)<tolerance:
        break
##梯度下降

b=b*y_std+y_mean-w*y_std*x_mean/x_std
w*=(y_std/x_std)
##w与b还原

plt.scatter(dot_x,dot_y,color='blue',marker='o')
plt.title("散点图")
plt.xlabel("x轴")
plt.ylabel("y轴")
##绘制散点图

fx=w*x_range+b
plt.plot(x_range,fx,color='red',marker='o')
plt.xlim(0,10000)
plt.ylim(0,800)
plt.legend()
##绘制拟合直线

plt.grid(True)#显示辅助网格

plt.show()#绘图

 四、拟合结果


网站公告

今日签到

点亮在社区的每一天
去签到