tf.stack()是一个矩阵拼接函数,会根据函数中对应的参数调整拼接的维度。 axis=0,表示在第一个维度及逆行数据的拼接,如1x3和1x3的数据拼接会形成一个形状为2x3的数据。axis=1表示在第二维的数据进行拼接。

import tensorflow as tf
import numpy as np
a=tf.constant([[1,2,3],[4,5,6]])
aa1=tf.constant([1,2,3])
aa2=tf.constant([4,5,6])#],[[7,8,9],[10,11,12]])
b=tf.constant([[7,8,9],[10,11,12]])
c=tf.stack([a,b],axis=0)
d=tf.stack([a,b],axis=1)
e=tf.stack(a,axis=0)
ff=tf.stack([aa1,aa2],axis=0)
f=tf.stack(a,axis=1)#会报错,stack中的数据必须包含两个属性
with tf.Session() as sess:
    result0=sess.run(c)
    result1=sess.run(d)
    result2=sess.run(e)
    result3=sess.run(f)
    result4=sess.run(ff)
    print('shape of result0',np.shape(result0),'result0:\n',result0)
    print('shape of result1',np.shape(result1),'result1:\n',result1)
    print('shape of result2(axis=0)',np.shape(result2),'result2:\n',result2)
    #print('shape of result3(axis=1)',np.reshape(result3),'result3:\n',result3)
    print('shape of result4(axis=1)',np.shape(result4),'result4:\n',result4)

结果如下:
在这里插入图片描述

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐