从Superpoint到SuperGlue再到其它基于深度学习的图像匹配算法,几乎都用到了Sinkhorn,到底什么是Sinkhorn,参考了一篇外文,写的很清晰,翻译了一部分,供大家参考。

(注意,本文不分析Sinkhorn到底如何用于图/图像匹配,但如果看懂本文,我想这一问题就不在话下了)

Sinkhorn解决的是最优传输问题,简单讲就是把一个概率分布以最小代价转换成另外一个分布(此非人话,不理解也无妨,我也很讨厌这种过于学术的表述,不利于知识传播,但是确实没想出来到底怎么表述更合理更易于理解)

用例子解释一下问题

现有5种小吃,merveilleux, eclair, chocolate mousse, bavarois, carrot cake,每种小吃数量如下图所示,

                                            

实验室开party,要将这些小吃分享给8位同事,Bernard, Jan, Willem, Hilde, Steffie, Marlies, Tim, Wouter, 每个人对小吃的需求量如下图所示。

                                     

每个人对各小吃的喜爱程度如下表所示,其中-2分表示非常不喜欢,2分表示非常喜欢。

                                          

我们的任务(问题)是将这些小吃分发给这8个同事,同时使得大家满意度最高。

问题数学表述

  • 令r记录每个人所获得的小吃的数量,\mathbf{r}=[3,3,3,4,2,2,2,1]^T,这里r的维度为8,为使得后续分析更具普遍性,此处记为n.
  • 令c记录每种小吃有多少份,\mathbf{c}=[4,2,6,4,4]^T,这里c的维度为5,记为m.

一般地,r,c 表示边缘分布(就是从一个概率分布转为另一个概率分布的两个分布),因此rc中元素之和需为1,这个在后续程序里会处理一下.

  • 定义

                                                                           U(\mathbf{r},\mathbf{c})=\{\mathbf{P}\in \mathbb{R_+}^{n*m}|\mathbf{P1}_m=\mathbf{r}, \mathbf{P1}_n=\mathbf{c}\}

U(\mathbf{r},\mathbf{c})包含所有可能的小吃分配方案。注意,这里每份小吃是可以随意切割进行分配的.

  • 每个用户对各小吃的喜爱程度存储在矩阵\mathbf{M}\in\mathbb{R}^{n*m}(矩阵\mathbf{M}在其它资料中也被成为代价矩阵,本例中将矩阵\mathbf{M}中的元素取负即可得到其对应的代价矩阵)

最终问题可表示为,

                                                                        d_\mathbf{M}(\mathbf{r},\mathbf{c})=\min_{\mathbf{P}\in U(\mathbf{r},\mathbf{c}))}\sum {P_{ij}M_{ij}}

其中d_\mathbf{M}(\mathbf{r},\mathbf{c})又被称为 Wasserstein 距离.

此外,上述问题还可添加正则项,使得问题的描述更加合理,即,

                                                                      {d^\lambda}_\mathbf{M}(\mathbf{r},\mathbf{c})=\min_{\mathbf{P}\in U(\mathbf{r},\mathbf{c}))}\sum {P_{ij}M_{ij}}+\frac{1}{\lambda}(-P_{ij}logP_{ij})

d^{\lambda}_\mathbf{M}(\mathbf{r,c})又被成为Sinkhorn距离。在本例中,这样做一方面可以尽可能给每个人分配他们最喜欢的食物,但同时又尽量满足均匀分布…(不清楚具体含义,原文为trying to give every person only their favorites or encouraging equal distributions.)

解决方法:Sinkhorn

上述问题最优解可表示为,

                                                                        P_{ij}={\alpha}_i{\beta}_j\exp^{-{\lambda}M_{ij}}

其中{\alpha}_i{\beta }_j为要求解的常数,具体的,Sinkhorn伪代码为

                                                      

具体实现(python)

def compute_optimal_transport(M, r, c, lam, epsilon=1e-8):
    """
    Computes the optimal transport matrix and Slinkhorn distance using the
    Sinkhorn-Knopp algorithm

    Inputs:
        - M : cost matrix (n x m)
        - r : vector of marginals (n, )
        - c : vector of marginals (m, )
        - lam : strength of the entropic regularization
        - epsilon : convergence parameter

    Outputs:
        - P : optimal transport matrix (n x m)
        - dist : Sinkhorn distance
    """
    n, m = M.shape
    P = np.exp(- lam * M)
    P /= P.sum()
    u = np.zeros(n)
    # normalize this matrix
    while np.max(np.abs(u - P.sum(1))) > epsilon:
        u = P.sum(1)
        P *= (r / u).reshape((-1, 1))#行归r化,注意python中*号含义
        P *= (c / P.sum(0)).reshape((1, -1))#列归c化
    return P, np.sum(P * M)

 

假设输入为:

M(这里没有照搬上述例子中的代价矩阵,而是随机生成的,懒得一个一个输入啊~~~):

[[6.74620535e-01 6.51487856e-01 7.63909999e-01 1.22160802e-02  9.84285854e-01]

 [5.21836427e-02 6.98448351e-01 4.21872002e-04 5.77616315e-01  9.98398433e-01]

 [4.81595322e-01 8.59043865e-01 8.91100944e-01 1.27449590e-01  7.85357602e-01]

 [1.40637778e-01 5.98949422e-02 5.23676192e-02 1.44150411e-02  4.74618963e-01]

 [7.16849610e-01 2.82412228e-01 8.81465978e-01 2.55082618e-01  5.39586731e-01]

 [4.49385127e-01 7.78590147e-01 1.31048710e-03 8.68770877e-02  6.10843349e-01]

 [1.78421067e-02 7.53684632e-01 4.42902867e-01 7.38736941e-01  9.92555963e-01]

 [8.16664868e-01 3.12881863e-01 5.54218820e-01 6.13135979e-01  8.86964971e-01]]

c:

[0.2 0.1 0.3 0.2 0.2]#[4,2,6,4,4]/sum([4,2,6,4,4])

r:

[0.15 0.15 0.15 0.2  0.1  0.1  0.1  0.05]#[3,3,3,4,2,2,2,1]/sum([3,3,3,4,2,2,2,1])

lam:5

则经运算输出为:

P:

[[0.01079512 0.01032116 0.01212201 0.09937232 0.01738938]

 [0.04414505 0.00148527 0.10035015 0.00107044 0.0029491 ]

 [0.03008735 0.0038818  0.00681382 0.05929886 0.04991816]

 [0.02832576 0.03612279 0.07728489 0.0178614  0.04040515]

 [0.00322536 0.02411107 0.00248517 0.01088774 0.05929066]

 [0.00433108 0.00071122 0.07141918 0.00890056 0.01463795]

 [0.07703699 0.00165618 0.01614092 0.00070298 0.00446294]

 [0.00205329 0.02171051 0.01338386 0.00190569 0.01094666]]

P col sum:

[0.2 0.1 0.3 0.2 0.2]

P row sum:

[0.14999999 0.15000001 0.15       0.2        0.1        0.1     0.1        0.05      ]

distance:

0.24870527608715712

 

--------------------------------------------------------------------------

翻译自:https://michielstock.github.io/OptimalTransport/     原作者:Michiel Stock

 

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐