nn.Embedding 实质上是矩阵运算,实现维度转换

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None,
 max_norm=None,  norm_type=2.0,   scale_grad_by_freq=False, 
 sparse=False,  _weight=None)

num_embeddings : 输入数据的类别数
embedding_dim : 数据的编码维度

from torch import nn
z = nn.Embedding(3,2)
z.weight

输出

Parameter containing:
tensor([[ 0.5813, -0.4503],
        [-1.8539,  0.6905],
        [ 0.3107, -0.7194]], requires_grad=True)

y = z(torch.tensor([1,2,0]))

输出

y
tensor([[-1.8539,  0.6905],
        [ 0.3107, -0.7194],
        [ 0.5813, -0.4503]], grad_fn=<EmbeddingBackward>)

由上面测试可见,该函数先将输入数据由索引编号转换为onehot编码的矩阵,然后右乘一个权重矩阵完成 输入的embedding,在训练过程中梯度更新权重,寻找让loss减小的embedding方式

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐