nn.Embedding 实质上是矩阵运算,实现维度转换

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None,
 max_norm=None,  norm_type=2.0,   scale_grad_by_freq=False, 
 sparse=False,  _weight=None)

num_embeddings : 输入数据的类别数
embedding_dim : 数据的编码维度

from torch import nn
z = nn.Embedding(3,2)
z.weight

输出

Parameter containing:
tensor([[ 0.5813, -0.4503],
        [-1.8539,  0.6905],
        [ 0.3107, -0.7194]], requires_grad=True)

y = z(torch.tensor([1,2,0]))

输出

y
tensor([[-1.8539,  0.6905],
        [ 0.3107, -0.7194],
        [ 0.5813, -0.4503]], grad_fn=<EmbeddingBackward>)

由上面测试可见,该函数先将输入数据由索引编号转换为onehot编码的矩阵,然后右乘一个权重矩阵完成 输入的embedding,在训练过程中梯度更新权重,寻找让loss减小的embedding方式

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐