|
写一个自己遇到的pytorch数据类型报错的例子之前在运行代码时遇到一个数据类型错误:其报错如下:
- RuntimeError: Expected object of type Variable[torch.LongTensor] but found type Variable[torch.cuda.ByteTensor] for argument #1 ‘argument1’
复制代码 这个解决办法为:
pytorch框架在存储labels时,采用LongTensor来存储,所以在一开始dataset返回label时,就要返回与LongTensor对应的数据类型,即numpy.int64
这个希望对大家有帮助
|
|