1.算法描述

RANSAC的基本假设是:
(1)数据由“局内点”组成,例如:数据的分布可以用一些模型参数来解释;
(2)“局外点”是不能适应该模型的数据;
(3)除此之外的数据属于噪声。
    局外点产生的原因有:噪声的极值;错误的测量方法;对数据的错误假设。
    RANSAC也做了以下假设:给定一组(通常很小的)局内点,存在一个可以估计模型参数的过程;而该模型能够解释或者适用于局内点。

通俗来说就是用RANSAC算法处理一批含有噪音的点,如一批数据有M个点,其中Q个点是有噪音的,即局外点。每一个点我们都可以用(x,y)来表示。

1.首先我们会根据数据提出一个模型来拟合局内点。如f(x) = 0或f(x) = y,这个假设模型的参数需要我们后续求取。由于Q个点是局外点,并不影响我们的模型参数,因此只要剩下M-Q个点满足f(x) < ε,就认为假设的模型是合理的,其中ε是我们设定的阈值,可以理解为误差。

2.此处假设f(x)为一个线性模型,即y = ax + b。由于线性模型是一条直线,只需根据两点即可确定。我们在M个点中随机取出两个点,求出a和b,即可得到一个假设线性模型。

3.对于剩下M-2个点,依次代入我们假设的模型中,计算假设模型的y值与真实y值之间的距离,如果距离小于ε,则我们认为该点为内点,即|ax +b - y| < ε。若内点很少,则返回第二步,重新采点,生成新的假设模型进行计算。此处有两种方案,第一种是迭代足够多的次数,记录内点的数量,迭代完成后内点最多的模型即为最终我们想要的模型。第二种是设定一个阈值,如我们有500个点,如果内点个数为400,就认为当前模型可以拟合足够多的内点,跳出循环。

 

说句实话,清爽的八得了。

 

2.代码实现

下面就用实现一个线性回归的RANSAC算法代码。

import numpy as np
from matplotlib import pyplot as plt
import random

#生成数据集
SIZE = 50
OUT = 20
X = np.linspace(0, 10, 50)
Y = [3 * i + 10 + 2 * random.random() for i in X[:-OUT]] + [random.randint(0, int(i)) for i in X[-OUT:]]
X_data = np.array(X)
Y_data = np.array(Y)

用matplotlib可视化出来,数据看起来是这样的。

很明显,左上角为内点,右下角为外点,即噪音。

# 选点、评估

iters = 1000
epsilon = 5
threshold = 0.8
best_a, best_b = 0, 0
pre_total = 0
for i in range(iters):
    sample_index = random.sample(range(SIZE), 2)
    x_1 = X_data[sample_index[0]]
    x_2 = X_data[sample_index[1]]
    y_1 = Y_data[sample_index[0]]
    y_2 = Y_data[sample_index[1]]

    a = (y_2 - y_1) / (x_2 - x_1)
    b = y_1 - a * x_1

    total_in = 0  #内点计数器
    for index in range(SIZE):
        y_estimate = a * X_data[index] + b
        if abs(y_estimate - Y_data[index]) < epsilon: #符合内点条件
            total_in += 1

        if total_in > pre_total:  #记录最大内点数与对应的参数
            pre_total = total_in
            best_a = a
            best_b = b

        if total_in > SIZE * threshold: #内点数大于设定的阈值,跳出循环
            break

    print("迭代{}次,a = {}, b = {}".format(i, best_a, best_b))

也可以通过plt可视化一下

    x_line = np.linspace(0, 10, 1000)
    y_line = best_a * x_line + b
    plt.plot(x_line, y_line, c = 'r')
    plt.scatter(X_data, Y_data)
    plt.show()

最终拟合结果如下:

 

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐