natanielruiz
2017-09-23 31fc66b795c0a57b8009d7b03f49f6cd099ceb29
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
 
 
def one_hot(index, classes):
    size = index.size() + (classes,)
    view = index.size() + (1,)
 
    mask = torch.Tensor(*size).fill_(0)
    index = index.view(*view)
    ones = 1.
 
    if isinstance(index, Variable):
        ones = Variable(torch.Tensor(index.size()).fill_(1))
        mask = Variable(mask, volatile=index.volatile)
 
    return mask.scatter_(1, index, ones)
 
 
class FocalLoss(nn.Module):
 
    def __init__(self, gamma=0, eps=1e-7):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps
 
    def forward(self, input, target):
        y = one_hot(target, input.size(-1))
        logit = F.softmax(input)
        logit = logit.clamp(self.eps, 1. - self.eps)
 
        loss = -1 * y * torch.log(logit) # cross entropy
        loss = loss * (1 - logit) ** self.gamma # focal loss
 
        return loss.sum()