查看: 1311|回复: 0

tf.concat()函数解析

[复制链接]

665

主题

1234

帖子

6568

积分

xdtech

Rank: 5Rank: 5

积分
6568
发表于 2020-9-16 17:53:32 | 显示全部楼层 |阅读模式
tf.concat()函数用于数组或者矩阵拼接。

tf.concat的官方解释

tf.concat(    values,    axis,    name='concat')
  • 1
  • 2
  • 3
  • 4
  • 5

其中:
values应该是一个tensor的list或者tuple,里面是准备连接的矩阵或者数组。

axis则是我们准备连接的矩阵或者数组的维度。

  • axis=0代表在第0个维度拼接
  • axis=1代表在第1个维度拼接
  • axis=-1表示在倒数第1个维度拼接

负数在数组索引里面表示倒数,也就算是倒着数,-1是最后一个,-2是倒数第二个,对于二维矩阵拼接来说,axis=-1等价于axis=1。一般在维度非常高的情况下,if 我们想在最’高’的维度进行拼接,一般就直接用倒数机制,直接axis=-1就搞定了。

1. values:import tensorflow as tft1=tf.constant([1,2,3)t2=tf.constant([4,5,6)print(t1)print(t2)concated = tf.concat([t1,t2, 1)> Tensor("Const_20:0", shape=(3,), dtype=int32)> Tensor("Const_21:0", shape=(3,), dtype=int32)> ValueError: Shapes (2, 3) and () are incompatible
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

因为它们对应的shape只有一个维度,当然不能在第二维上拼接了,虽然实际中两个向量可以在行上(axis = 1)拼接,但是放在程序里是会报错的

import tensorflow as tft1=tf.expand_dims(tf.constant([1,2,3),1)t2=tf.expand_dims(tf.constant([4,5,6),1)print(t1)print(t2)concated = tf.concat([t1,t2, 1)> Tensor("ExpandDims_26:0", shape=(3, 1), dtype=int32)> Tensor("ExpandDims_27:0", shape=(3, 1), dtype=int32)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

如果想要拼接,必须要调用tf.expand_dims()来扩维:

2. axis:

第0个维度代表最外面的括号所在的维度,第1个维度代表最外层括号里面的那层括号所在的维度,以此类推。

import tensorflow as tfwith tf.Session() as sess:        t1 = [[1, 2, 3,  [4, 5, 6        t2 = [[7, 8, 9,  [10, 11, 12        print(sess.run(tf.concat([t1, t2, 0)))  >  [[ 1  2  3    [ 4  5  6   [ 7  8  9    [10 11 12        print(sess.run(tf.concat([t1, t2, 1)))  >  [[ 1  2  3  7  8  9    [ 4  5  6 10 11 12        print(sess.run(tf.concat([t1, t2, -1)))  >  [[ 1  2  3  7  8  9    [ 4  5  6 10 11 12
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

快速回复 返回顶部 返回列表