tf.stack的用法
tf.stack()是一个矩阵拼接函数,会根据函数中对应的参数调整拼接的维度。 axis=0,表示在第一个维度及逆行数据的拼接,如1x3和1x3的数据拼接会形成一个形状为2x3的数据。axis=1表示在第二维的数据进行拼接。import tensorflow as tfimport numpy as npa=tf.constant([[1,2,3],[4,5,6]])aa1=tf.constant
·
tf.stack()是一个矩阵拼接函数,会根据函数中对应的参数调整拼接的维度。 axis=0,表示在第一个维度及逆行数据的拼接,如1x3和1x3的数据拼接会形成一个形状为2x3的数据。axis=1表示在第二维的数据进行拼接。
import tensorflow as tf
import numpy as np
a=tf.constant([[1,2,3],[4,5,6]])
aa1=tf.constant([1,2,3])
aa2=tf.constant([4,5,6])#],[[7,8,9],[10,11,12]])
b=tf.constant([[7,8,9],[10,11,12]])
c=tf.stack([a,b],axis=0)
d=tf.stack([a,b],axis=1)
e=tf.stack(a,axis=0)
ff=tf.stack([aa1,aa2],axis=0)
f=tf.stack(a,axis=1)#会报错,stack中的数据必须包含两个属性
with tf.Session() as sess:
result0=sess.run(c)
result1=sess.run(d)
result2=sess.run(e)
result3=sess.run(f)
result4=sess.run(ff)
print('shape of result0',np.shape(result0),'result0:\n',result0)
print('shape of result1',np.shape(result1),'result1:\n',result1)
print('shape of result2(axis=0)',np.shape(result2),'result2:\n',result2)
#print('shape of result3(axis=1)',np.reshape(result3),'result3:\n',result3)
print('shape of result4(axis=1)',np.shape(result4),'result4:\n',result4)
结果如下:
更多推荐
活动日历
查看更多
直播时间 2025-02-26 16:00:00


直播时间 2025-01-08 16:30:00


直播时间 2024-12-11 16:30:00


直播时间 2024-11-27 16:30:00


直播时间 2024-11-21 16:30:00


所有评论(0)