Scheaven
2021-09-18 291deeb1fcf45dbf39a24aa72a213ff3fd6b3405
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# encoding: utf-8
"""
@author:  xingyu liao
@contact: liaoxingyu5@jd.com
"""
 
import math
 
import torch
import torch.nn as nn
import torch.nn.functional as F
 
__all__ = [
    'Mish',
    'Swish',
    'MemoryEfficientSwish',
    'GELU']
 
 
class Mish(nn.Module):
    def __init__(self):
        super().__init__()
 
    def forward(self, x):
        # inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)
        return x * (torch.tanh(F.softplus(x)))
 
 
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)
 
 
class SwishImplementation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.sigmoid(i)
        ctx.save_for_backward(i)
        return result
 
    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_variables[0]
        sigmoid_i = torch.sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
 
 
class MemoryEfficientSwish(nn.Module):
    def forward(self, x):
        return SwishImplementation.apply(x)
 
 
class GELU(nn.Module):
    """
    Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
    """
 
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))