Scheaven
2021-09-18 291deeb1fcf45dbf39a24aa72a213ff3fd6b3405
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2020/10/26 15:50
# @Author  : Scheaven
# @File    :  predictor.py
# @description:
import cv2, io
import torch
import torch.nn.functional as F
 
from modeling.meta_arch.build import build_model
from utils.checkpoint import Checkpointer
 
import onnx
import torch
from onnxsim import simplify
from torch.onnx import OperatorExportTypes
 
class ReID_Model(object):
    def __init__(self, cfg):
        self.cfg = cfg
 
        self.predictor = DefaultPredictor(cfg)
 
    def run_on_image(self, original_image):
 
        predictions = self.predictor(original_image)
        return predictions
 
    def torch2onnx(self):
        predictions = self.predictor.to_onnx()
 
class DefaultPredictor:
 
    def __init__(self, cfg):
        self.cfg = cfg.clone()
        self.cfg.defrost()
        self.cfg.MODEL.BACKBONE.PRETRAIN = False
        self.model = build_model(self.cfg)
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.cuda()
        self.model.eval()
 
        Checkpointer(self.model).load(cfg.MODEL.WEIGHTS)
 
    def __call__(self, image):
        with torch.no_grad():  # https://github.com/sphinx-doc/sphinx/issues/4258
            images = image.cuda()
            self.model.eval()
            predictions = self.model(images)
            torch.set_printoptions(edgeitems=2048)
            print("------------\n", predictions)
            pred_feat = F.normalize(predictions)
            pred_feat = pred_feat.cpu().data
            return pred_feat
 
    def to_onnx(self):
        inputs = torch.randn(1, 3, self.cfg.INPUT.SIZE_TEST[0], self.cfg.INPUT.SIZE_TEST[1]).cuda()
        onnx_model = self.export_onnx_model(self.model, inputs)
 
        model_simp, check = simplify(onnx_model)
 
        model_simp = self.remove_initializer_from_input(model_simp)
 
        assert check, "Simplified ONNX model could not be validated"
 
        onnx.save_model(model_simp, f"fastreid.onnx")
 
 
    def export_onnx_model(self, model, inputs):
        assert isinstance(model, torch.nn.Module)
        def _check_eval(module):
            assert not module.training
 
        model.apply(_check_eval)
 
        with torch.no_grad():
            with io.BytesIO() as f:
                torch.onnx.export(
                    model,
                    inputs,
                    f,
                    operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
                )
                onnx_model = onnx.load_from_string(f.getvalue())
 
        # Apply ONNX's Optimization
        all_passes = onnx.optimizer.get_available_passes()
        passes = ["extract_constant_to_initializer", "eliminate_unused_initializer", "fuse_bn_into_conv"]
        assert all(p in all_passes for p in passes)
        onnx_model = onnx.optimizer.optimize(onnx_model, passes)
        return onnx_model
 
 
    def remove_initializer_from_input(self, model):
        if model.ir_version < 4:
            print(
                'Model with ir_version below 4 requires to include initilizer in graph input'
            )
            return
 
        inputs = model.graph.input
        name_to_input = {}
        for input in inputs:
            name_to_input[input.name] = input
 
        for initializer in model.graph.initializer:
            if initializer.name in name_to_input:
                inputs.remove(name_to_input[initializer.name])
 
        return model