支持向量机(SVM)的python简单实现

1

由于开题方向的改变,最近在学习入门深度学习的东西,主要看的是邱锡鹏大佬的b站网课,但有时候网课看多了感觉学到的都是空空的理论,终究还是应该自己动手试验一下。

自己动手写了一些关于SVM(支持向量机)的代码
环境为python3

import random
from random import sample
import matplotlib.pyplot as plt

先说好实现的都是很简单的东西,没有用梯度下降法之类的,用的知识简单穷举。只是从原理上实现了一下,没什么干货,代码如下。

import random
from random import sample
import matplotlib.pyplot as plt

def point_margin(x1,x2,w,b):#计算间隔margin
    margin = (w*x1-x2+b)*(w*x1-x2+b)/(w*w+1)
    return margin

def loss_function(real_data_x1,raw_data_x2,tag_y,w,b):#设置损失函数,这里用的是平方损失函数
    judge = raw_data_x2 -(w*real_data_x1 + b)
    if tag_y * judge > 0:
        loss = 0
    else:
        loss = (tag_y * judge) * (tag_y * judge)
    return loss

w = 2
pointnum = 1000
x1_real_data = []
x2_real_data = []
x2_raw_data = []
y_data = []
for i in range(0,pointnum):   #制作真值数据,用的是x2 = w*x1
    x1_real_data.append(i)
    x2_real_data.append(w * i)

for i in range(0,pointnum):   #增加噪声
    rawdata = x2_real_data[i] + random.gauss(0,10) + sample([40.0, -40.0], 1)[0] #增加了高斯噪声以及一个偏移量,用于支持向量机挑选最佳分类超平面
    x2_raw_data.append(rawdata)
    if rawdata > i+i: #制作标签
        y_data.append(1)
    else:
        y_data.append(-1)

# for i in range(0,100):
#     print(x1_real_data[i],x2_raw_data[i],y_data[i])

# losssum = 0
# for j in range(0,100):
#     losssum = losssum + loss_function(x1_real_data[j],y_data[j],w,1)
#     print(losssum)
bmin = -30
bmax = 31

good_b = []
for i in range(bmin,bmax):
    print('b = ',i)
    losssum = 0
    for j in range(0,pointnum):
        losssum = losssum + loss_function(x1_real_data[j],x2_raw_data[j],y_data[j],w,i)
    if losssum ==0:
        good_b.append(i)  #记录损失函数为0的b值
    print('loss = ',losssum)

print(good_b)

plt.xlim(xmax=225,xmin=-50)
plt.ylim(ymax=250,ymin=-50)
plt.plot(x1_real_data,x2_raw_data,'ro')
for i in range(0,pointnum):
    plt.annotate(y_data[i], xy = (x1_real_data[i], x2_raw_data[i]), xytext = (x1_real_data[i]+0.1, x2_raw_data[i]+0.1)) # 将标签显示在散点上



for i in range(bmin,bmax,10):
    x1 = [-100,250]
    x2 = [x1[0] + x1[0] + i,x1[1] + x1[1] + i]
    plt.plot(x1,x2)   #显示一些辅助线便于看出SVM的意义
plt.show()

min_margin = [] #记录每一个b值对应的最小距离
for i in range(0,len(good_b)):
    margin = 1000
    b = good_b[i]
    for j in range(0,pointnum):
        this_margin = point_margin(x1_real_data[j],x2_raw_data[j],w,b)
        if this_margin < margin:
            margin = this_margin
    min_margin.append(margin)

print(min_margin)

plt.plot(good_b,min_margin,'ro') #画出b值与最小距离对应的散点图
plt.show()
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐