#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2020/10/26 15:50 # @Author : Scheaven # @File : predictor.py # @description: import cv2, io import torch import torch.nn.functional as F from modeling.meta_arch.build import build_model from utils.checkpoint import Checkpointer import onnx import torch from onnxsim import simplify from torch.onnx import OperatorExportTypes class ReID_Model(object): def __init__(self, cfg): self.cfg = cfg self.predictor = DefaultPredictor(cfg) def run_on_image(self, original_image): predictions = self.predictor(original_image) return predictions def torch2onnx(self): predictions = self.predictor.to_onnx() class DefaultPredictor: def __init__(self, cfg): self.cfg = cfg.clone() self.cfg.defrost() self.cfg.MODEL.BACKBONE.PRETRAIN = False self.model = build_model(self.cfg) for param in self.model.parameters(): param.requires_grad = False self.model.cuda() self.model.eval() Checkpointer(self.model).load(cfg.MODEL.WEIGHTS) def __call__(self, image): with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258 images = image.cuda() self.model.eval() predictions = self.model(images) torch.set_printoptions(edgeitems=2048) print("------------\n", predictions) pred_feat = F.normalize(predictions) pred_feat = pred_feat.cpu().data return pred_feat def to_onnx(self): inputs = torch.randn(1, 3, self.cfg.INPUT.SIZE_TEST[0], self.cfg.INPUT.SIZE_TEST[1]).cuda() onnx_model = self.export_onnx_model(self.model, inputs) model_simp, check = simplify(onnx_model) model_simp = self.remove_initializer_from_input(model_simp) assert check, "Simplified ONNX model could not be validated" onnx.save_model(model_simp, f"fastreid.onnx") def export_onnx_model(self, model, inputs): assert isinstance(model, torch.nn.Module) def _check_eval(module): assert not module.training model.apply(_check_eval) with torch.no_grad(): with io.BytesIO() as f: torch.onnx.export( model, inputs, f, operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, ) onnx_model = onnx.load_from_string(f.getvalue()) # Apply ONNX's Optimization all_passes = onnx.optimizer.get_available_passes() passes = ["extract_constant_to_initializer", "eliminate_unused_initializer", "fuse_bn_into_conv"] assert all(p in all_passes for p in passes) onnx_model = onnx.optimizer.optimize(onnx_model, passes) return onnx_model def remove_initializer_from_input(self, model): if model.ir_version < 4: print( 'Model with ir_version below 4 requires to include initilizer in graph input' ) return inputs = model.graph.input name_to_input = {} for input in inputs: name_to_input[input.name] = input for initializer in model.graph.initializer: if initializer.name in name_to_input: inputs.remove(name_to_input[initializer.name]) return model