#!/usr/bin/env python
|
# -*- coding: utf-8 -*-
|
# @Time : 2020/10/26 15:50
|
# @Author : Scheaven
|
# @File : predictor.py
|
# @description:
|
import cv2, io
|
import torch
|
import torch.nn.functional as F
|
|
from modeling.meta_arch.build import build_model
|
from utils.checkpoint import Checkpointer
|
|
import onnx
|
import torch
|
from onnxsim import simplify
|
from torch.onnx import OperatorExportTypes
|
|
class ReID_Model(object):
|
def __init__(self, cfg):
|
self.cfg = cfg
|
|
self.predictor = DefaultPredictor(cfg)
|
|
def run_on_image(self, original_image):
|
|
predictions = self.predictor(original_image)
|
return predictions
|
|
def torch2onnx(self):
|
predictions = self.predictor.to_onnx()
|
|
class DefaultPredictor:
|
|
def __init__(self, cfg):
|
self.cfg = cfg.clone()
|
self.cfg.defrost()
|
self.cfg.MODEL.BACKBONE.PRETRAIN = False
|
self.model = build_model(self.cfg)
|
for param in self.model.parameters():
|
param.requires_grad = False
|
self.model.cuda()
|
self.model.eval()
|
|
Checkpointer(self.model).load(cfg.MODEL.WEIGHTS)
|
|
def __call__(self, image):
|
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
|
images = image.cuda()
|
self.model.eval()
|
predictions = self.model(images)
|
torch.set_printoptions(edgeitems=2048)
|
print("------------\n", predictions)
|
pred_feat = F.normalize(predictions)
|
pred_feat = pred_feat.cpu().data
|
return pred_feat
|
|
def to_onnx(self):
|
inputs = torch.randn(1, 3, self.cfg.INPUT.SIZE_TEST[0], self.cfg.INPUT.SIZE_TEST[1]).cuda()
|
onnx_model = self.export_onnx_model(self.model, inputs)
|
|
model_simp, check = simplify(onnx_model)
|
|
model_simp = self.remove_initializer_from_input(model_simp)
|
|
assert check, "Simplified ONNX model could not be validated"
|
|
onnx.save_model(model_simp, f"fastreid.onnx")
|
|
|
def export_onnx_model(self, model, inputs):
|
assert isinstance(model, torch.nn.Module)
|
def _check_eval(module):
|
assert not module.training
|
|
model.apply(_check_eval)
|
|
with torch.no_grad():
|
with io.BytesIO() as f:
|
torch.onnx.export(
|
model,
|
inputs,
|
f,
|
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
|
)
|
onnx_model = onnx.load_from_string(f.getvalue())
|
|
# Apply ONNX's Optimization
|
all_passes = onnx.optimizer.get_available_passes()
|
passes = ["extract_constant_to_initializer", "eliminate_unused_initializer", "fuse_bn_into_conv"]
|
assert all(p in all_passes for p in passes)
|
onnx_model = onnx.optimizer.optimize(onnx_model, passes)
|
return onnx_model
|
|
|
def remove_initializer_from_input(self, model):
|
if model.ir_version < 4:
|
print(
|
'Model with ir_version below 4 requires to include initilizer in graph input'
|
)
|
return
|
|
inputs = model.graph.input
|
name_to_input = {}
|
for input in inputs:
|
name_to_input[input.name] = input
|
|
for initializer in model.graph.initializer:
|
if initializer.name in name_to_input:
|
inputs.remove(name_to_input[initializer.name])
|
|
return model
|
|
|