import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable def one_hot(index, classes): size = index.size() + (classes,) view = index.size() + (1,) mask = torch.Tensor(*size).fill_(0) index = index.view(*view) ones = 1. if isinstance(index, Variable): ones = Variable(torch.Tensor(index.size()).fill_(1)) mask = Variable(mask, volatile=index.volatile) return mask.scatter_(1, index, ones) class FocalLoss(nn.Module): def __init__(self, gamma=0, eps=1e-7): super(FocalLoss, self).__init__() self.gamma = gamma self.eps = eps def forward(self, input, target): y = one_hot(target, input.size(-1)) logit = F.softmax(input) logit = logit.clamp(self.eps, 1. - self.eps) loss = -1 * y * torch.log(logit) # cross entropy loss = loss * (1 - logit) ** self.gamma # focal loss return loss.sum()