查看: 1747|回复: 0

Pytorch 之squeeze和unsqueeze用法

[复制链接]

665

主题

1234

帖子

6695

积分

xdtech

Rank: 5Rank: 5

积分
6695
发表于 2019-9-6 21:45:54 | 显示全部楼层 |阅读模式

1. torch.squeeze(input, dim = None, out = None): 返回一个tensor,当dim不设值时,去掉输入的tensor的所有维度为1的维度; 当dim为某一整数(0<=dim<input.dim())时,判断dim维的维度是否为1,若是则去掉,否则不变。
另外,当input是一维的时候,squeeze不变



  • >>> x = torch.zeros(1,1,2,1,3)



  • >>> x.dim()



  • 5



  • >>> torch.squeeze(x).size()#去掉dim=1的维度



  • torch.Size([2, 3])



  • >>> torch.squeeze(x,0).size()  # dim=0表示第一维,且第一维的维度为1,所以去掉



  • torch.Size([1, 2, 1, 3])



  • >>> torch.squeeze(x,3).size()



  • torch.Size([1, 1, 2, 3])



  • >>> torch.squeeze(x,2).size()  # dim=2,第三维的维度为2!=1,所以不变



  • torch.Size([1, 1, 2, 1, 3])


2. torch.unqueeze(input, dim, out=None): 和squeeze作用相反,unsqueeze()在dim维插入一个维度为1的维,例如原来x是n×m维的,torch.unqueeze(x,0)这返回1×n×m的tensor







  • >>> x = torch.tensor([1,2,3])#dim=1,即(3)



  • >>> torch.unsqueeze(x,1)#变为(3,1)的矩阵



  • tensor([[ 1],



  •         [ 2],



回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

快速回复 返回顶部 返回列表