|
RuntimeError: there are no graph nodes that require computing gradients
我把阈值>0.5去掉就没有问题
代码如下:
- class DiceLoss(nn.Module):
- def __init__(self):
- super(DiceLoss, self).__init__()
- self.sigmoid = nn.Sigmoid()
- def forward(self, output, labels):
- batch_size = labels.size(0)
- Sigmoidout = self.sigmoid(output) > 0.5
- loss = 0
- for i in range(batch_size):
- sampleloss = -(2.0*torch.sum(torch.mul(Sigmoidout[i,:],labels[i,:])) + 1e-5)/(torch.sum(labels[i,:]) + torch.sum(Sigmoidout[i,:]) + 1e-5)
- loss += sampleloss
- meanloss = loss/batch_size
- return meanloss
复制代码
|
|