函数原型
tf.keras.layers.Embedding(input_dim,
    					  output_dim,
    					  embeddings_initializer='uniform',
    					  embeddings_regularizer=None,
    					  activity_regularizer=None,
    					  embeddings_constraint=None,
    					  mask_zero=False,
    					  input_length=None,
    					  **kwargs
)
函数说明

嵌入层主要负责将一个特征转换成一个向量。嵌入层一般放在第一层,常常用于对自然语言序列的处理。
在这里插入图片描述
如上图所示,每一个单词对应一个标签,比如“late”对应3、“yeah”对应8,这样可以将单词序列转化成一个向量,便于数据的处理。该Embedding层的作用就是把向量中每一个标签值映射为一个3维向量,这样就可以用一个三维向量来表示一个单词。

Embedding函数实现了嵌入层的功能。参数input_dim表示词汇量的大小,比如需要处理的单词序列共有100行,每一行有50个单词,那么总共有5000个单词,假设这5000个单词中不相同的单词有2000个,那么此时输入数据的词汇量就为2000。

参数output_dim表示每一个单词映射的向量维数,如果需要用20维向量表示一个单词,那么output_dim就为20。还有一个常用的参数input_length,这个参数用来规定输入的单词序列的长度,如果单词序列长度为30个,那么这个参数的值就应该设置为30。如果没有设置参数input_length,那么输入序列的长度可以改变。

注意,Embedding层输入是一个二维张量,形状为(batch_size, input_length),输出形状为(batch_size, input_length, output_dim),是一个三维张量。

函数用法
model = tf.keras.models.Sequential([
    # 嵌入层,词汇量为256
    tf.keras.layers.Embedding(256, 125),
    # LSTM层
    tf.keras.layers.LSTM(125, return_sequences=True),
    # LSTM层
    tf.keras.layers.LSTM(125, return_sequences=True),
    # 在时间维度上全连接的Dense层
    tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(256, activation="softmax"))
])
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 embedding (Embedding)       (None, None, 125)         32000     
                                                                 
 lstm (LSTM)                 (None, None, 125)         125500    
                                                                 
 lstm_1 (LSTM)               (None, None, 125)         125500    
                                                                 
 time_distributed (TimeDistr  (None, None, 256)        32256     
 ibuted)                                                         
                                                                 
=================================================================
Total params: 315,256
Trainable params: 315,256
Non-trainable params: 0
_________________________________________________________________
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐