pytorch中 nn.Embedding的原理
nn.Embedding 实质上是矩阵运算,实现维度转换torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None,max_norm=None,norm_type=2.0,scale_grad_by_freq=False,sparse=False,_weight=None)num_embeddings : 输入数据的类别数e
·
nn.Embedding 实质上是矩阵运算,实现维度转换
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
sparse=False, _weight=None)
num_embeddings : 输入数据的类别数
embedding_dim : 数据的编码维度
from torch import nn
z = nn.Embedding(3,2)
z.weight
输出
Parameter containing:
tensor([[ 0.5813, -0.4503],
[-1.8539, 0.6905],
[ 0.3107, -0.7194]], requires_grad=True)
y = z(torch.tensor([1,2,0]))
输出
y
tensor([[-1.8539, 0.6905],
[ 0.3107, -0.7194],
[ 0.5813, -0.4503]], grad_fn=<EmbeddingBackward>)
由上面测试可见,该函数先将输入数据由索引编号转换为onehot编码的矩阵,然后右乘一个权重矩阵完成 输入的embedding,在训练过程中梯度更新权重,寻找让loss减小的embedding方式
更多推荐
已为社区贡献1条内容
所有评论(0)