在学习强化学习的过程中,有时需要将动作与价值对应起来,具体可参见莫凡pytorch的DQN程序。

# sample batch transitions
sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
b_memory = self.memory[sample_index, :]
b_s = torch.FloatTensor(b_memory[:, :N_STATES])
b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int))
b_r = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2])
 b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])

# q_eval w.r.t the action in experience
q_eval = self.eval_net(b_s).gather(1, b_a)  # shape (batch, 1)
q_next = self.target_net(b_s_).detach()     # detach from graph, don't backpropagate
q_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1)   # shape (batch, 1)
loss = self.loss_func(q_eval, q_target)

其中有一个gather函数,这里介绍其用法。
内容参考下文
理解pytorch几个高级选择函数(如gather)

gather函数

pytorch和numpy中许多函数都涉及维度运算,gather也不例外,但是它相对于其他函数更难理解。依然先来看一个例子

import torch
a = torch.arange(1, 16).reshape(5, 3)
"""
result:
a = [[1, 2, 3],
      [4, 5, 6],
      [7, 8, 9],
      [10, 11, 12],
      [13, 14, 15]]
"""


# 定义两个index
b = torch.tensor([[0, 1, 2], [2, 3, 4], [0, 2, 4]])
c = torch.tensor([[1, 2, 0, 2, 1], [1, 2, 1, 0, 0]])

# axis=0
output1 = a.gather(0, b)
"""
result:
[[1, 5, 9],
[7, 11, 15],
[1, 8, 15]]
"""

# axis=1
output2 = a.gather(1, c)
"""
result:
[[2, 3, 1, 3, 2],
[5, 6, 5, 4, 4]]
"""

上面的例子看起来可能有点复杂,我们来一步步的分析它,先从gather维度为0开始讲起。

1、a.gather(0, b)分为3个部分,a是需要被提取元素的矩阵,0代表的是提取的维度为0,b是提取元素的索引。

  • 其中规定b和a是同维张量,即a是2维张量,b也必须是2维张量

2、0除了代表往维度0的方向提取元素外,还有一个特权—提取结果output可以在这个维度上的长度与a不同。打个比方,a现在的shape为(5, 3),那么提取结果output1的shape可以是(1,3),(2, 3),甚至(n, 3)。具体维度0的长度到底为多少由b来决定。

3、根据0的特权,导致了给定的b张量除了维度0外,其他的维度大小必须和a一样。其中张量b实际上包含以下两个信息

  • b可以利用除用于gather的维度(此处为维度0)外的维度来定位出唯一一个向量,也就是a[:, ?](三维度也是同理的,有a[:, ?1,?2]),?的取值范围为a同维度的index。
  • 对于上述定位出的向量,通过b中的元素来定位提取向量中的哪一个元素。
  • 上面说得可能有点抽象,实际上b中的每个元素都能在a中提取出一个元素。举个具体点的例子,按照上面所说的,b[0, 0]可以提取a中的一个元素。对于b[0,0],除了维度0外,可以通过维度1来定位出唯一一个向量a[:, 0]。因为b[0,0]的元素为0,即提取的是a[:, 0]的第0个元素—1,并将其作为output1[0, 0]的提取结果。

下图给出了维度0和维度1,gather运算的图示
gather二维

对于3维或者更高维度的张量gather的原理也是一样的
gather三维

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐