打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
Python:机器学习造轮子之线性回归

最近看了线性回归,复习了一下微积分和线性代数,想着学以致用,能不能自己动手实现一把呢。于是就动手了。

线性回归是比较基础的算法,是后面逻辑回归的基础。主要是通过一条直线来拟合样本。通常来说只有教学意义。

来说说约定的符号,线性回归参数主要由斜率和截距组成,这里用W表示斜率,b表示截距。大写的W表示这是一个向量。一般来说是n_feauter_num数量,就是有多少个特征,W的shape就是(n_feauter_num,1),截距b是一个常数,通过公式Y=W*X+b计算出目标Y值,一般来说,在机器学习中约定原始值为Y,预测值为Y_hat。下面来谈谈具体实现步骤

  • 构造数据

  • 构造loss function(coss function)

  • 分别对W和b计算梯度(也是对cost function分别对W和b求导)

  • 计算Y_hat

  • 多次迭代计算梯度,直接收敛或者迭代结束

下面给出具体python代码实现,本代码是通用代码,可以任意扩展W,代码中计算loss和梯度的地方采用的向量实现,因此增加W的维度不用修改代码

import matplotlib.pyplot as pltimport numpy as npdef f(X):
w = np.array([1, 3, 2])
b = 10
return np.dot(X, w.T) + bdef cost(X, Y, w, b):
m = X.shape[0]
Z = np.dot(X, w) + b
Y_hat = Z.reshape(m, 1)
cost = np.sum(np.square(Y_hat - Y)) / (2 * m) return costdef gradient_descent(X, Y, W, b, learning_rate):
m = X.shape[0]
W = W - learning_rate * (1 / m) * X.T.dot((np.dot(X, W) + b - Y))
b = b - learning_rate * (1 / m) * np.sum(np.dot(X, W) + b - Y) return W, bdef main():
# sample number
m = 5
# feature number
n = 3
total = m * n # construct data
X = np.random.rand(total).reshape(m, n)
Y = f(X).reshape(m, 1)# iris = datasets.load_iris()# X, Y = iris.data, iris.target.reshape(150, 1)# X = X[Y[:, 0] < 2]# Y = Y[Y[:, 0] < 2]# m = X.shape[0]# n = X.shape[1]
# define parameter
W = np.ones((n, 1), dtype=float).reshape(n, 1)
b = 0.0
# def forward pass++
learning_rate = 0.1
iter_num = 10000
i = 0
J = [] while i < iter_num:
i = i + 1
W, b = gradient_descent(X, Y, W, b, learning_rate)
j = cost(X, Y, W, b)
J.append(j)
print(W, b)
print(j)
plt.plot(J)
plt.show()if __name__ == '__main__':
main()

可以看到,结果输出很接近预设参数[1,3,2]和10

是不是感觉so easy.

step: 4998 loss: 3.46349593719e-07[[ 1.00286704]
[ 3.00463459]
[ 2.00173473]] 9.99528287088step: 4999 loss: 3.45443124835e-07[[ 1.00286329]
[ 3.00462853]
[ 2.00173246]] 9.99528904819step: 5000 loss: 3.44539028368e-07
本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
【热】打开小程序,算一算2024你的财运
感知器及其在Python中的实现
详解 | 如何用Python实现机器学习算法
神经网络
Python 学习之 Numpy!最神奇的模块!了解一下?
神经网络-全连接层(1)
入门numpy(上)【解读numpy官方文档】
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服