shaoheshaohe 发表于 2020-5-14 08:54:11

torch.tensor的类型转换以及和numpy的转换

PyTorch中的常用的tensor类型      PyTorch中的常用的tensor类型包括:    32位浮点型torch.FloatTensor,    64位浮点型torch.DoubleTensor,    16位整型torch.ShortTensor,    32位整型torch.IntTensor,    64位整型torch.LongTensor。类型之间的转换      一般只要在tensor后加long(), int(), double(),float(),byte()等函数就能将tensor进行类型转换  https://img2020.cnblogs.com/blog/1800705/202003/1800705-20200326115051106-1596085299.png   https://img2020.cnblogs.com/blog/1800705/202003/1800705-20200326115419894-1768732826.png      此外,还可以使用type()函数,data为Tensor数据类型,data.type()为给出data的类型,如果使用data.type(torch.FloatTensor)则强制转换为torch.FloatTensor类型张量。      a1.type_as(a2)可将a1转换为a2同类型。tensor和numpy.array转换  tensor -> numpy.array: data.numpy(),如:  https://img2020.cnblogs.com/blog/1800705/202003/1800705-20200326115737236-335410822.png  numpy.array -> tensor: torch.from_numpy(data),如:  https://img2020.cnblogs.com/blog/1800705/202003/1800705-20200326120104768-1818464682.pngCPU张量和GPU张量之间的转换  CPU -> GPU: data.cuda()  GPU -> CPU: data.cpu()


页: [1]
查看完整版本: torch.tensor的类型转换以及和numpy的转换