机器学习十大经典算法:另辟蹊径EM算法+高斯混合模型图像像素分割实战——Nemo鱼图像分割(python代码+详细注释)_意疏的学习笔记-CSDN博客_em算法实现图像分割

GMM与聚类一样属于无监督学习统计模型,用以拟合数据的分布特征。本文是根据下面的博客进行改编,输入图片路径便可以直接进行分割(包括普通光学图像、微波图像、SAR图像、遥感图像等)。本例属于二分分割,后面有空再呈上多分分割例子。运行该代码需要修改两个地方:

修改1:src_image = Image.open('face.bmp')  #需修改这里图片的路径

 修改2:图像类型选择
gray_status = False  # 灰度图像分割,打开这个开关
#gray_status = True  #彩色图像分割,打开这个开关

具体思路和原理,请参考上面的博客,直接上代码:

import os
from scipy import io
from scipy.stats import norm
import numpy as np
import PIL.Image as Image
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal

plot_dir = 'EM_out'
if os.path.exists(plot_dir) == 0:
    os.mkdir(plot_dir)
# 数据加载与解析

src_image = Image.open('face.bmp') # 修改1:需修改这里图片的路径
RGB_img = np.array(src_image)
Gray_img = np.array(src_image.convert('L'))
sample=np.reshape(Gray_img,(-1,1))/256
Gray_ROI=Gray_img/255
RGB_sample=np.reshape(RGB_img,(-1,3))/256
RGB_ROI=RGB_img/255
# # 通过mask,获取ROI区域
# Gray_ROI = (Gray_img * mask)/256
# RGB_mask = np.array([mask, mask, mask]).transpose(1, 2, 0)
# RGB_ROI = (RGB_img * RGB_mask)/255

# 假设两类数据初始占比相同,即先验概率相同
P_pre1 = 0.5
P_pre2 = 0.5

# 假设每个数据来自两类的初始概率相同,即软标签相同
soft_guess1 = 0.5
soft_guess2 = 0.5

# 修改2:图像类型选择
gray_status = False  # 灰度图像分割,打开这个开关
#gray_status = True  #彩色图像分割,打开这个开关

# 一维时的EM
# ----------------------------------------------------------------------------------------------------#
if gray_status:
    # 观察图像,肉眼估计初始值
    gray1_m = 0.5
    gray1_s = 0.1
    gray2_m = 0.8
    gray2_s = 0.3

    # 绘制假定的PDF
    x = np.arange(0, 1, 1/1000)
    gray1_pdf = norm.pdf(x, gray1_m, gray1_s)
    gray2_pdf = norm.pdf(x, gray2_m, gray2_s)
    plt.figure(0)
    ax = plt.subplot(1, 1, 1)
    ax.plot(x, gray1_pdf, 'r', x, gray2_pdf, 'b')
    ax.set_title('supposed PDF')
    plt.figure(1)
    ax1 = plt.subplot(1, 1, 1)
    ax1.imshow(Gray_img, cmap='gray')
    ax1.set_title('gray ROI')
    plt.show()

    gray = np.zeros((len(sample), 5))
    gray_s_old = gray1_s + gray2_s

    # 迭代更新参数
    for epoch in range(10):

        for i in range(len(sample)):

            # 贝叶斯计算每个数据的后验,即得到软标签(E过程)
            soft_guess1 = (P_pre1*norm.pdf(sample[i], gray1_m, gray1_s))/(P_pre1*norm.pdf(sample[i], gray1_m, gray1_s) +
                                                                             P_pre2*norm.pdf(sample[i], gray2_m, gray2_s))
            soft_guess2 = 1 - soft_guess1
            gray[i][0] = sample[i]
            gray[i][1] = soft_guess1*1                         # 当前一个数据中类别1占的个数,1*后验,显然是小数
            gray[i][2] = soft_guess2*1
            gray[i][3] = soft_guess1*sample[i]              # 对当前数据中属于类别1的部分,当前数据*后验
            gray[i][4] = soft_guess2*sample[i]

            # 根据软标签,再借助最大似然估计出类条件概率PDF参数——均值,标准差(M过程)

        gray1_num = sum(gray)[1]                                # 对每一个数据中类别1占的个数求和,就得到数据中类别1的总数
        gray2_num = sum(gray)[2]
        gray1_m = sum(gray)[3]/gray1_num                        # 对每一个数据中属于类别1的那部分求和,就得到类别1的x的和,用其除以类别1的个数就得到其均值
        gray2_m = sum(gray)[4]/gray2_num

        sum_s1 = 0.0
        sum_s2 = 0.0

        for i in range(len(gray)):
            sum_s1 = sum_s1 + gray[i][1]*(gray[i][0] - gray1_m)*(gray[i][0] - gray1_m)     # 每个数据的波动中,属于类别1的部分
            sum_s2 = sum_s2 + gray[i][2]*(gray[i][0] - gray2_m)*(gray[i][0] - gray2_m)
        gray1_s = pow(sum_s1/gray1_num, 0.5)                                               # 标准差
        gray2_s = pow(sum_s2/gray2_num, 0.5)

        # print(gray1_m, gray2_m, gray1_s, gray2_s)
        P_pre1 = gray1_num/(gray1_num + gray2_num)                                         # 更新先验概率
        P_pre2 = 1 - P_pre1

        gray1_pdf = norm.pdf(x, gray1_m, gray1_s)
        gray2_pdf = norm.pdf(x, gray2_m, gray2_s)
        gray_s_d = abs(gray_s_old - gray2_s - gray1_s)
        gray_s_old = gray2_s + gray1_s
        # if gray_s_d < 0.0001:                                                               # 迭代停止条件,如果两次方差变化较小则停止迭代
        #     break

        # 绘制更新参数后的pdf
        plt.figure(2)
        ax2 = plt.subplot(1, 1, 1)
        ax2.plot(x, gray1_pdf, 'r', x, gray2_pdf, 'b')
        ax2.set_title('epoch' + str(epoch + 1) + ' PDF')
        plt.savefig(plot_dir + '//' + 'PDF_' + str(epoch + 1) + '.jpg', dpi=100)
        plt.close()
        # plt.show()

        if epoch % 1 == 0:                                # 迭代2次进行一次分割测试

            gray_out = np.zeros_like(Gray_img)
            for i in range(len(Gray_ROI)):
                for j in range(len(Gray_ROI[0])):
                    if Gray_ROI[i][j] == 0:
                        continue
                    # 贝叶斯公式分子比较,等价于最大后验
                    elif P_pre1 * norm.pdf(Gray_ROI[i][j], gray1_m, gray1_s) > P_pre2 * norm.pdf(Gray_ROI[i][j],
                                                                                                 gray2_m, gray2_s):
                        gray_out[i][j] = 100
                    else:
                        gray_out[i][j] = 255
            # 显示分割结果
            plt.figure(3)
            ax3 = plt.subplot(1, 1, 1)
            ax3.imshow(gray_out, cmap='gray')
            ax3.set_title('epoch' + str(epoch + 1) + 'gray segment')
            plt.savefig(plot_dir + '//' + 'Gray_segment_' + str(epoch + 1) + '.jpg', dpi=100)
            plt.close()
            plt.show()

# 三维时的EM
# -------------------------------------------------------------------------------------------------------#
else:
    # 观察图像,肉眼估计初始值
    RGB1_m = np.array([0.5, 0.5, 0.5])
    RGB2_m = np.array([0.8, 0.8, 0.8])
    RGB1_cov = np.array([[0.1, 0.05, 0.04],
                        [0.05, 0.1, 0.02],
                        [0.04, 0.02, 0.1]])
    RGB2_cov = np.array([[0.1, 0.05, 0.04],
                        [0.05, 0.1, 0.02],
                        [0.04, 0.02, 0.1]])

    RGB = np.zeros((len(RGB_sample), 11))

    # 显示彩色ROI
    plt.figure(3)
    cx = plt.subplot(1, 1, 1)
    cx.set_title('RGB ROI')
    cx.imshow(RGB_img)
    plt.show()
    # 迭代更新参数
    for epoch in range(20):
        for i in range(len(RGB_sample)):

            # 贝叶斯计算每个数据的后验,即得到软标签
            soft_guess1 = P_pre1*multivariate_normal.pdf(RGB_sample[i], RGB1_m, RGB1_cov)/(P_pre1*multivariate_normal.pdf(RGB_sample[i], RGB1_m, RGB1_cov) + P_pre2*multivariate_normal.pdf(RGB_sample[i], RGB2_m, RGB2_cov))
            soft_guess2 = 1 - soft_guess1
            RGB[i][0:3] = RGB_sample[i]
            RGB[i][3] = soft_guess1*1
            RGB[i][4] = soft_guess2*1
            RGB[i][5:8] = soft_guess1*RGB_sample[i]
            RGB[i][8:11] = soft_guess2*RGB_sample[i]
        # print(RGB[0])

        # 根据软标签,再借助最大似然估计出类条件概率PDF参数——均值,标准差
        RGB1_num = sum(RGB)[3]
        RGB2_num = sum(RGB)[4]
        RGB1_m = sum(RGB)[5:8]/RGB1_num
        RGB2_m = sum(RGB)[8:11]/RGB2_num

        # print(RGB1_num+RGB2_num, RGB1_m, RGB2_m)
        cov_sum1 = np.zeros((3, 3))
        cov_sum2 = np.zeros((3, 3))

        for i in range(len(RGB)):
            # print(np.dot((RGB[i][0:3]-RGB1_m).reshape(3, 1), (RGB[i][0:3]-RGB1_m).reshape(1, 3)))
            cov_sum1 = cov_sum1 + RGB[i][3]*np.dot((RGB[i][0:3]-RGB1_m).reshape(3, 1), (RGB[i][0:3]-RGB1_m).reshape(1, 3))
            cov_sum2 = cov_sum2 + RGB[i][4]*np.dot((RGB[i][0:3]-RGB2_m).reshape(3, 1), (RGB[i][0:3]-RGB2_m).reshape(1, 3))
        RGB1_cov = cov_sum1/(RGB1_num-1)                                                    # 无偏估计除以N-1
        RGB2_cov = cov_sum2/(RGB2_num-1)

        P_pre1 = RGB1_num/(RGB1_num + RGB2_num)
        P_pre2 = 1 - P_pre1

        print(RGB1_cov, P_pre1)

        # 用贝叶斯对彩色图像进行分割

        RGB_out = np.zeros_like(RGB_ROI)

        for i in range(len(RGB_ROI)):
            for j in range(len(RGB_ROI[0])):
                if np.sum(RGB_ROI[i][j]) == 0:
                    continue
                # 贝叶斯公式分子比较
                elif P_pre1 * multivariate_normal.pdf(RGB_ROI[i][j], RGB1_m, RGB1_cov) > P_pre2 * multivariate_normal.pdf(
                        RGB_ROI[i][j], RGB2_m, RGB2_cov):
                    RGB_out[i][j] = [255, 0, 0]
                else:
                    RGB_out[i][j] = [0, 255, 0]
        # print(RGB_ROI.shape)

        # 显示彩色分割结果
        plt.figure(4)
        ax3 = plt.subplot(1, 1, 1)
        ax3.imshow(RGB_out)
        ax3.set_title('epoch' + str(epoch + 1) + ' RGB segment')
        plt.savefig(plot_dir + '//' + 'RGB_segment_' + str(epoch + 1) + '.jpg', dpi=100)
        plt.close()

输入图像 ,输出图像 

 输入图像,输出图像

 

 

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐