梯度下降算法的python实现
前言梯度下降算法 Gradient Descent GD是沿梯度下降的方向连续迭代逼近求最小值的过程,本文将实现以下梯度下降算法的python实现。简单梯度下降算法批量梯度下降算法随机梯度下降算法简单梯度下降算法简单梯度下降算法的核心就是先求出目标函数的导数gkg_kgk,然后利用简单随机梯度西江算法公式迭代求最小值。xk+1=xk−gk∗rkx_{k+1}=x_k-g_k*r_kxk+1=x
前言
梯度下降算法 Gradient Descent GD是沿梯度下降的方向连续迭代逼近求最小值的过程,本文将实现以下梯度下降算法的python实现。
- 简单梯度下降算法
- 批量梯度下降算法
- 随机梯度下降算法
简单梯度下降算法
简单梯度下降算法的核心就是先求出目标函数的导数
g
k
g_k
gk,然后利用简单随机梯度西江算法公式迭代求最小值。
x
k
+
1
=
x
k
−
g
k
∗
r
k
x_{k+1}=x_k-g_k*r_k
xk+1=xk−gk∗rk$$
- x k + 1 x_{k+1} xk+1 下一步位置
- x k x_{k} xk 当前位置
- g k g_k gk 为梯度
- r k r_k rk 学习率,步长
有一个目标函数,
f
(
x
)
=
x
2
f(x)=x^2
f(x)=x2,一元函数的导数就是该曲线某一点切线的斜率,导数越大,该点的斜率越大,下降越快
F
′
=
2
x
F^{'}=2x
F′=2x这就是梯度,
r
k
r_k
rk 控制着梯度下降前进距离,太大太小都不行。
代码
import numpy as np
import matplotlib.pyplot as plt
# 定义目标函数 f(x)=x**2+1
def f(x):
return np.array(x)**2 + 1
# 对目标函数求导 d(x)=x*2
def d1(x):
return x * 2
def Gradient_Descent_d1(current_x = 0.1,learn_rate = 0.01,e = 0.001,count = 50000):
# current_x initial x value
# learn_rate 学习率
# e error
# count number of iterations
for i in range(count):
grad = d1(current_x) # 求当前梯度
if abs(grad) < e: # 梯度收敛到控制误差内
break # 跳出循环
current_x = current_x - grad * learn_rate # 一维梯度的迭代公式
print("第{}次迭代逼近值为{}".format(i+1,current_x))
print("最小值为:",current_x)
print("最小值保存小数点后6位:%.6f"%(current_x))
return current_x
# 显示目标函数曲线及梯度下降最小值毕竟情况
def ShowLine(min_X,max_Y):
x = [x for x in range(10)] + [x * (-1) for x in range(1,10)]
x.sort()
print(x)
plt.plot(x,f(x))
plt.plot(min_X,max_Y,'ro')
plt.show()
minValue = Gradient_Descent_d1(current_x = 0.1,learn_rate = 0.01,e = 0.001,count = 50000)
minY = f(minValue)
print('目标函数最小值约为:',minY)
ShowLine(minValue,minY)
输出
第1次迭代逼近值为0.098
第2次迭代逼近值为0.09604
第3次迭代逼近值为0.0941192
第4次迭代逼近值为0.092236816
第5次迭代逼近值为0.09039207968
第6次迭代逼近值为0.08858423808640001
第7次迭代逼近值为0.08681255332467201
第8次迭代逼近值为0.08507630225817857
第9次迭代逼近值为0.08337477621301499
第10次迭代逼近值为0.08170728068875469
第11次迭代逼近值为0.0800731350749796
第12次迭代逼近值为0.07847167237348
第13次迭代逼近值为0.0769022389260104
第14次迭代逼近值为0.0753641941474902
第15次迭代逼近值为0.07385691026454039
第16次迭代逼近值为0.07237977205924959
第17次迭代逼近值为0.0709321766180646
第18次迭代逼近值为0.06951353308570331
第19次迭代逼近值为0.06812326242398924
第20次迭代逼近值为0.06676079717550945
第21次迭代逼近值为0.06542558123199926
第22次迭代逼近值为0.06411706960735927
第23次迭代逼近值为0.06283472821521209
第24次迭代逼近值为0.061578033650907846
第25次迭代逼近值为0.06034647297788969
第26次迭代逼近值为0.059139543518331894
第27次迭代逼近值为0.05795675264796526
第28次迭代逼近值为0.056797617595005956
第29次迭代逼近值为0.05566166524310584
第30次迭代逼近值为0.05454843193824372
第31次迭代逼近值为0.05345746329947885
第32次迭代逼近值为0.05238831403348927
第33次迭代逼近值为0.05134054775281949
第34次迭代逼近值为0.050313736797763096
第35次迭代逼近值为0.04930746206180783
第36次迭代逼近值为0.04832131282057168
第37次迭代逼近值为0.047354886564160245
第38次迭代逼近值为0.046407788832877044
第39次迭代逼近值为0.0454796330562195
第40次迭代逼近值为0.04457004039509511
第41次迭代逼近值为0.0436786395871932
第42次迭代逼近值为0.04280506679544934
第43次迭代逼近值为0.041948965459540355
第44次迭代逼近值为0.041109986150349546
第45次迭代逼近值为0.040287786427342556
第46次迭代逼近值为0.0394820306987957
第47次迭代逼近值为0.03869239008481979
第48次迭代逼近值为0.03791854228312339
第49次迭代逼近值为0.03716017143746093
第50次迭代逼近值为0.036416968008711706
第51次迭代逼近值为0.035688628648537474
第52次迭代逼近值为0.034974856075566725
第53次迭代逼近值为0.03427535895405539
第54次迭代逼近值为0.03358985177497428
第55次迭代逼近值为0.032918054739474796
第56次迭代逼近值为0.0322596936446853
第57次迭代逼近值为0.03161449977179159
第58次迭代逼近值为0.03098220977635576
第59次迭代逼近值为0.030362565580828647
第60次迭代逼近值为0.029755314269212074
第61次迭代逼近值为0.029160207983827832
第62次迭代逼近值为0.028577003824151275
第63次迭代逼近值为0.028005463747668248
第64次迭代逼近值为0.027445354472714883
第65次迭代逼近值为0.026896447383260587
第66次迭代逼近值为0.026358518435595377
第67次迭代逼近值为0.025831348066883468
第68次迭代逼近值为0.0253147211055458
第69次迭代逼近值为0.024808426683434883
第70次迭代逼近值为0.024312258149766185
第71次迭代逼近值为0.02382601298677086
第72次迭代逼近值为0.023349492727035442
第73次迭代逼近值为0.02288250287249473
第74次迭代逼近值为0.022424852815044836
第75次迭代逼近值为0.02197635575874394
第76次迭代逼近值为0.02153682864356906
第77次迭代逼近值为0.02110609207069768
第78次迭代逼近值为0.020683970229283727
第79次迭代逼近值为0.020270290824698053
第80次迭代逼近值为0.019864885008204092
第81次迭代逼近值为0.01946758730804001
第82次迭代逼近值为0.019078235561879212
第83次迭代逼近值为0.01869667085064163
第84次迭代逼近值为0.018322737433628795
第85次迭代逼近值为0.017956282684956217
第86次迭代逼近值为0.017597157031257093
第87次迭代逼近值为0.017245213890631952
第88次迭代逼近值为0.016900309612819315
第89次迭代逼近值为0.01656230342056293
第90次迭代逼近值为0.01623105735215167
第91次迭代逼近值为0.015906436205108634
第92次迭代逼近值为0.015588307481006461
第93次迭代逼近值为0.015276541331386333
第94次迭代逼近值为0.014971010504758606
第95次迭代逼近值为0.014671590294663434
第96次迭代逼近值为0.014378158488770165
第97次迭代逼近值为0.014090595318994762
第98次迭代逼近值为0.013808783412614867
第99次迭代逼近值为0.01353260774436257
第100次迭代逼近值为0.013261955589475318
第101次迭代逼近值为0.012996716477685811
第102次迭代逼近值为0.012736782148132095
第103次迭代逼近值为0.012482046505169453
第104次迭代逼近值为0.012232405575066064
第105次迭代逼近值为0.011987757463564744
第106次迭代逼近值为0.01174800231429345
第107次迭代逼近值为0.01151304226800758
第108次迭代逼近值为0.01128278142264743
第109次迭代逼近值为0.01105712579419448
第110次迭代逼近值为0.01083598327831059
第111次迭代逼近值为0.010619263612744378
第112次迭代逼近值为0.01040687834048949
第113次迭代逼近值为0.010198740773679701
第114次迭代逼近值为0.009994765958206107
第115次迭代逼近值为0.009794870639041985
第116次迭代逼近值为0.009598973226261145
第117次迭代逼近值为0.009406993761735923
第118次迭代逼近值为0.009218853886501205
第119次迭代逼近值为0.009034476808771182
第120次迭代逼近值为0.008853787272595759
第121次迭代逼近值为0.008676711527143843
第122次迭代逼近值为0.008503177296600965
第123次迭代逼近值为0.008333113750668945
第124次迭代逼近值为0.008166451475655567
第125次迭代逼近值为0.008003122446142456
第126次迭代逼近值为0.007843059997219607
第127次迭代逼近值为0.007686198797275215
第128次迭代逼近值为0.00753247482132971
第129次迭代逼近值为0.0073818253249031155
第130次迭代逼近值为0.007234188818405053
第131次迭代逼近值为0.0070895050420369515
第132次迭代逼近值为0.006947714941196213
第133次迭代逼近值为0.006808760642372289
第134次迭代逼近值为0.006672585429524843
第135次迭代逼近值为0.006539133720934346
第136次迭代逼近值为0.006408351046515659
第137次迭代逼近值为0.006280184025585346
第138次迭代逼近值为0.006154580345073639
第139次迭代逼近值为0.0060314887381721655
第140次迭代逼近值为0.005910858963408722
第141次迭代逼近值为0.005792641784140548
第142次迭代逼近值为0.005676788948457737
第143次迭代逼近值为0.005563253169488583
第144次迭代逼近值为0.005451988106098811
第145次迭代逼近值为0.005342948343976835
第146次迭代逼近值为0.005236089377097298
第147次迭代逼近值为0.005131367589555352
第148次迭代逼近值为0.005028740237764245
第149次迭代逼近值为0.00492816543300896
第150次迭代逼近值为0.004829602124348781
第151次迭代逼近值为0.004733010081861806
第152次迭代逼近值为0.004638349880224569
第153次迭代逼近值为0.004545582882620078
第154次迭代逼近值为0.004454671224967677
第155次迭代逼近值为0.004365577800468323
第156次迭代逼近值为0.004278266244458957
第157次迭代逼近值为0.004192700919569778
第158次迭代逼近值为0.004108846901178382
第159次迭代逼近值为0.004026669963154815
第160次迭代逼近值为0.003946136563891718
第161次迭代逼近值为0.003867213832613884
第162次迭代逼近值为0.0037898695559616066
第163次迭代逼近值为0.0037140721648423742
第164次迭代逼近值为0.003639790721545527
第165次迭代逼近值为0.0035669949071146165
第166次迭代逼近值为0.003495655008972324
第167次迭代逼近值为0.0034257419087928777
第168次迭代逼近值为0.00335722707061702
第169次迭代逼近值为0.00329008252920468
第170次迭代逼近值为0.0032242808786205864
第171次迭代逼近值为0.003159795261048175
第172次迭代逼近值为0.0030965993558272112
第173次迭代逼近值为0.003034667368710667
第174次迭代逼近值为0.0029739740213364538
第175次迭代逼近值为0.0029144945409097247
第176次迭代逼近值为0.0028562046500915303
第177次迭代逼近值为0.0027990805570896997
第178次迭代逼近值为0.0027430989459479057
第179次迭代逼近值为0.0026882369670289475
第180次迭代逼近值为0.0026344722276883687
第181次迭代逼近值为0.0025817827831346014
第182次迭代逼近值为0.0025301471274719093
第183次迭代逼近值为0.002479544184922471
第184次迭代逼近值为0.0024299533012240217
第185次迭代逼近值为0.0023813542351995413
第186次迭代逼近值为0.0023337271504955503
第187次迭代逼近值为0.0022870526074856394
第188次迭代逼近值为0.0022413115553359267
第189次迭代逼近值为0.0021964853242292083
第190次迭代逼近值为0.002152555617744624
第191次迭代逼近值为0.0021095045053897317
第192次迭代逼近值为0.002067314415281937
第193次迭代逼近值为0.0020259681269762984
第194次迭代逼近值为0.0019854487644367725
第195次迭代逼近值为0.001945739789148037
第196次迭代逼近值为0.0019068249933650763
第197次迭代逼近值为0.0018686884934977748
第198次迭代逼近值为0.0018313147236278192
第199次迭代逼近值为0.0017946884291552628
第200次迭代逼近值为0.0017587946605721575
第201次迭代逼近值为0.0017236187673607144
第202次迭代逼近值为0.0016891463920135
第203次迭代逼近值为0.00165536346417323
第204次迭代逼近值为0.0016222561948897654
第205次迭代逼近值为0.0015898110709919701
第206次迭代逼近值为0.0015580148495721307
第207次迭代逼近值为0.001526854552580688
第208次迭代逼近值为0.0014963174615290743
第209次迭代逼近值为0.0014663911122984928
第210次迭代逼近值为0.001437063290052523
第211次迭代逼近值为0.0014083220242514724
第212次迭代逼近值为0.001380155583766443
第213次迭代逼近值为0.0013525524720911142
第214次迭代逼近值为0.0013255014226492918
第215次迭代逼近值为0.001298991394196306
第216次迭代逼近值为0.00127301156631238
第217次迭代逼近值为0.0012475513349861323
第218次迭代逼近值为0.0012226003082864098
第219次迭代逼近值为0.0011981483021206816
第220次迭代逼近值为0.001174185336078268
第221次迭代逼近值为0.0011507016293567027
第222次迭代逼近值为0.0011276875967695687
第223次迭代逼近值为0.0011051338448341773
第224次迭代逼近值为0.0010830311679374937
第225次迭代逼近值为0.001061370544578744
第226次迭代逼近值为0.001040143133687169
第227次迭代逼近值为0.0010193402710134258
第228次迭代逼近值为0.0009989534655931573
第229次迭代逼近值为0.000978974396281294
第230次迭代逼近值为0.0009593949083556683
第231次迭代逼近值为0.000940207010188555
第232次迭代逼近值为0.0009214028699847839
第233次迭代逼近值为0.0009029748125850882
第234次迭代逼近值为0.0008849153163333864
第235次迭代逼近值为0.0008672170100067186
第236次迭代逼近值为0.0008498726698065842
第237次迭代逼近值为0.0008328752164104526
第238次迭代逼近值为0.0008162177120822435
第239次迭代逼近值为0.0007998933578405986
第240次迭代逼近值为0.0007838954906837866
第241次迭代逼近值为0.0007682175808701109
第242次迭代逼近值为0.0007528532292527087
第243次迭代逼近值为0.0007377961646676545
第244次迭代逼近值为0.0007230402413743014
第245次迭代逼近值为0.0007085794365468154
第246次迭代逼近值为0.0006944078478158791
第247次迭代逼近值为0.0006805196908595615
第248次迭代逼近值为0.0006669092970423702
第249次迭代逼近值为0.0006535711111015229
第250次迭代逼近值为0.0006404996888794924
第251次迭代逼近值为0.0006276896951019025
第252次迭代逼近值为0.0006151359011998645
第253次迭代逼近值为0.0006028331831758672
第254次迭代逼近值为0.0005907765195123498
第255次迭代逼近值为0.0005789609891221028
第256次迭代逼近值为0.0005673817693396607
第257次迭代逼近值为0.0005560341339528675
第258次迭代逼近值为0.0005449134512738102
第259次迭代逼近值为0.0005340151822483339
第260次迭代逼近值为0.0005233348786033672
第261次迭代逼近值为0.0005128681810312999
第262次迭代逼近值为0.0005026108174106739
第263次迭代逼近值为0.0004925586010624604
最小值为: 0.0004925586010624604
最小值保存小数点后6位:0.000493
目标函数最小值约为: 1.0000002426139756
[-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
批量梯度下降算法
批量梯度下降,Batch Gradient Descent BGD 算法是对所有的样本数据进行梯度迭代计算,这里“所有”考虑了非凹凸函数(存在多个局部极大值或极小值的情况)
损失函数
J
(
θ
)
=
1
2
n
Σ
i
=
1
n
(
h
θ
i
−
y
i
)
2
J(\theta)=\frac{1}{2n}\Sigma_{i=1}^{n}(h_{\theta}^{i}-y_i)^2
J(θ)=2n1Σi=1n(hθi−yi)2
- n n n是样本个数
- 1/2求偏导时可以相互抵消
-
x
i
,
y
i
x^i,y^i
xi,yi是第i个样本的
(
x
,
y
)
(x,y)
(x,y)坐标值
假设函数的公式为:
h θ ( x i ) = θ 0 + θ 1 x 1 i + θ 2 x 2 i + ⋅ ⋅ ⋅ + θ n x n i h_{\theta}(x^i)=\theta_0+\theta_1x_1^{i}+\theta_2x_2^{i}+···+\theta_nx_n^{i} hθ(xi)=θ0+θ1x1i+θ2x2i+⋅⋅⋅+θnxni
批量梯度是值在对全样本数据(任意维度)计算梯度时,通过计算损失函数求偏导得到梯度计算公式
$ ∇ θ J ( θ ) = ∂ J ( θ ) ∂ θ j = 1 n Σ i = 1 n ( h θ i − y i ) 2 x j i \nabla_\theta J(\theta)=\frac{\partial J(\theta)}{\partial \theta_j}=\frac{1}{n} \Sigma_{i=1}^{n}(h_{\theta}^{i}-y_i)^2x_j^{i} ∇θJ(θ)=∂θj∂J(θ)=n1Σi=1n(hθi−yi)2xji
i = 1 , 2 , ⋅ ⋅ ⋅ n i=1,2,···n i=1,2,⋅⋅⋅n表示样本数, j = 0 , 1 表 示 特 征 数 j=0,1表示特征数 j=0,1表示特征数
批量迭代公式:
θ = θ − μ ⋅ ∇ θ J ( θ ) \theta=\theta-\mu· \nabla_\theta J(\theta) θ=θ−μ⋅∇θJ(θ)
随机梯度下降算法
随机梯度下降(Stochastic Gradient Descent SGD)通过每次梯度迭代随机采用一个样本数量,最后逼近极值得出近似预测结果。损失函数计算公式:
J
(
θ
)
=
1
2
(
h
θ
(
x
i
)
−
y
i
)
2
J(\theta)=\frac{1}{2}(h_\theta(x^i)-y^i)^2
J(θ)=21(hθ(xi)−yi)2
特别指出:随机梯度下降算法每次迭代只对一个样本数据进行计算,与批量梯度下降算法相比,损失函数迭代一次无须求所有样本的值,因此不需要求均值。
$
∇
θ
J
(
θ
)
=
∂
J
(
θ
)
∂
θ
j
=
(
h
θ
i
−
y
i
)
2
x
j
i
\nabla_\theta J(\theta)=\frac{\partial J(\theta)}{\partial \theta_j}=(h_{\theta}^{i}-y_i)^2x_j^{i}
∇θJ(θ)=∂θj∂J(θ)=(hθi−yi)2xji
迭代公式:
θ
=
θ
−
μ
⋅
∇
θ
J
(
θ
)
\theta=\theta-\mu·\nabla_\theta J(\theta)
θ=θ−μ⋅∇θJ(θ)
更多推荐
所有评论(0)