矩阵连乘(动态规划)
矩阵连乘、动态规划。
问题描述
给定N 个矩阵 { A 1 , A 2 , A 3 … … A n } ,其中 A i 和 A i + 1 可以相乘的,即: A i
的列数等于 A i + 1行数。如何确定矩阵连乘的计算次序,使得按照此次序计算该矩阵连乘所需要的数乘次数最少?
两个矩阵相乘,它的数乘次数等于前一个矩阵的行数乘以后一个矩阵的列数,当这些矩阵相乘的次序不同时,它的数乘次数的有很大差异的。
输入A = [30, 35, 15, 5, 10, 20, 25],表示6个矩阵
输出最优的数乘次数与最佳分隔位置:最优的数乘次数15125 ; 最佳分隔位置 [3, 1, 5] (分隔点之间的顺序任意)。
如下例题:
例:现有四个矩阵 A , B , C , D 他们的维数分别是:A = 50 ∗ 10 , B = 10 ∗ 40 , C = 40 ∗
30 , D = 30 ∗ 5 ,则连乘 A , B , C , D 共有五种相乘的次序,如:( A ( ( B C ) D ) ) ,数乘次数为 16000
( ( A B ) ( C D ) ),数乘次数为 36000
( ( A ( B C ) ) D ) ,数乘次数为 34500
( A ( B ( C D ) ) ) ,数乘次数为 10500
( ( ( A B ) C ) D ) ,数乘次数为 87500
我们可以看到,连乘的次序不同,它的数乘次数大区别也是很大的,最少的 10500
次,最多的,87500 次,相差 8 倍之多。所以,寻找一个矩阵连乘的最优分隔方式就显得尤为重要。
使用动态规划解决该问题的思路:
设
f
[
i
,
j
]
f[i,j]
f[i,j]为矩阵从Ai连乘到Aj的最优数乘值,因此可以写出如下递推公式:
设定一个二维数组m[N][N]表示当前矩阵的连乘次数,比如m[i][j]表示矩阵从Ai连乘到Aj的最优数乘值f[i,j]。
再设定一个二维数组s[N][N]表示当前矩阵连乘时最优分隔位置,比如s[i][j]表示矩阵从Ai连乘到Aj的数乘值最小时的分隔位置,也就是加括号的位置。
接下来画表:
连乘矩阵 :{ A 1 , A 2 , A 3 , A 4 , A 5 , A 6 }
m数组(状态转移矩阵):
- 在计算该列表中的数据时,我首先可以明确对角线的值即f[i,i]都为0,因为这表示只有一个矩阵,数乘为0
- 然后我们首先比较容易计算的就是两个矩阵相乘的最优数乘次数,因为两个矩阵相乘,分隔点只有一个,此时对应于主对角线的相邻对角线的值,即m[1,2], m[2,3], m[3,4], m[4,5], m[5,6]。当这几个值计算得到后,我们就可以接着计算三个矩阵相乘的最优数乘值,也就是m[1,3], m[2, 4], m[3, 5], m[4, 6],例如m[1,3]的分隔点有两个,分别是1和2,当分隔点k为1时:f[1,1]+f[2,3]+a0 * a[1] * a[3],当分隔点k为2时:f[1][2]+f[2][3]+a0 * a2 * a3,而m[1,3]取二者的最小值作为最优数乘值。依次类推,因此状态转移是从主对角线依次计算对角线,最终m[1,6]即为所求值。
- 需要注意,m数组是索引为1的行和列来记录状态的,索引为0的行和列不用,这样做主要为了便去实现上面的推理。
s数组:
对于输出最优分隔位置,根据上图可以看出s[1,6]=3,首先在3处分隔,此时变为s[1,3]与s[4,6],s[1,3]在1处分隔,变为s[1,1]与s[2,3],s[4,6]在5处分隔,变为s[4,5]与s[6,6]。所以在1,3,5处分隔。因此可以利用递归完成。
完整代码:
def getIndex(i, j, s):
global index
if j - i >= 2:
index.append(s[i][j])
else:
return
getIndex(i, s[i][j], s)
getIndex(s[i][j] + 1, j, s)
def matrixChain(a):
n = len(a)
m = [[0 for _ in range(n)] for _ in range(n)]
s = [[0 for _ in range(n)] for _ in range(n)]
for r in range(2, n): # r个矩阵连乘
for i in range(1, n - r + 1): # 填补表格的斜线部分的元素对应i标
j = i + r - 1 # 填补表格的斜线部分的元素对应的j标
m[i][j] = m[i][i] + m[i + 1][j] + a[i - 1] * a[i] * a[j] # 首先计算第一个分割点,例如[1,3]的第一个分割点为1
s[i][j] = i # 然后将分割点暂时记录
for k in range(i + 1, j): # 计算其他的分隔点的数乘次数
temp = m[i][k] + m[k + 1][j] + a[i - 1] * a[k] * a[j]
if temp < m[i][j]: # 如果当前分隔点的数乘次数更优,则更新
s[i][j] = k
m[i][j] = temp
getIndex(1, n-1, s) # 计算分隔点的位置
return m[1][-1], index
if __name__ == "__main__":
A = [30, 35, 15, 5, 10, 20, 25]
index = []
result = matrixChain(A)
print(result)
参考:
https://blog.csdn.net/weixin_44952817/article/details/110124596
https://blog.csdn.net/qq_44755403/article/details/105015330
更多推荐
所有评论(0)