From 291deeb1fcf45dbf39a24aa72a213ff3fd6b3405 Mon Sep 17 00:00:00 2001
From: Scheaven <xuepengqiang>
Date: 星期六, 18 九月 2021 14:22:50 +0800
Subject: [PATCH] update

---
 utils/checkpoint.py                                          |  403 ++
 modeling/meta_arch/__pycache__/build.cpython-37.pyc          |    0 
 modeling/meta_arch/build.py                                  |   21 
 tools/__pycache__/predictor.cpython-38.pyc                   |    0 
 layers/activation.py                                         |   59 
 modeling/meta_arch/__pycache__/__init__.cpython-37.pyc       |    0 
 modeling/losses/__pycache__/__init__.cpython-37.pyc          |    0 
 config/__pycache__/__init__.cpython-37.pyc                   |    0 
 modeling/heads/__pycache__/build.cpython-37.pyc              |    0 
 modeling/backbones/resnet.py                                 |  359 ++
 modeling/heads/embedding_head.py                             |   97 
 engine/defaults.py                                           |  115 
 layers/__pycache__/batch_drop.cpython-38.pyc                 |    0 
 modeling/heads/__pycache__/__init__.cpython-38.pyc           |    0 
 tools/03_check_onnx.py                                       |  168 +
 utils/env.py                                                 |  119 
 utils/logger.py                                              |  209 +
 data/__pycache__/data_utils.cpython-37.pyc                   |    0 
 engine/__init__.py                                           |    6 
 modeling/backbones/__pycache__/__init__.cpython-37.pyc       |    0 
 layers/arcface.py                                            |   54 
 layers/__pycache__/batch_norm.cpython-37.pyc                 |    0 
 modeling/losses/__pycache__/focal_loss.cpython-38.pyc        |    0 
 utils/__pycache__/file_io.cpython-38.pyc                     |    0 
 data/transforms/__pycache__/build.cpython-38.pyc             |    0 
 modeling/heads/bnneck_head.py                                |   61 
 modeling/backbones/resnest.py                                |  411 ++
 config/__pycache__/defaults.cpython-38.pyc                   |    0 
 layers/context_block.py                                      |  113 
 data/transforms/autoaugment.py                               |  812 +++++
 modeling/heads/build.py                                      |   25 
 utils/comm.py                                                |  255 +
 modeling/meta_arch/__pycache__/baseline.cpython-37.pyc       |    0 
 modeling/backbones/__pycache__/osnet.cpython-37.pyc          |    0 
 modeling/losses/__pycache__/cross_entroy_loss.cpython-38.pyc |    0 
 data/transforms/transforms.py                                |  204 +
 layers/__pycache__/non_local.cpython-37.pyc                  |    0 
 modeling/heads/__pycache__/embedding_head.cpython-38.pyc     |    0 
 utils/__pycache__/__init__.cpython-37.pyc                    |    0 
 utils/file_io.py                                             |  520 +++
 modeling/heads/__pycache__/attr_head.cpython-37.pyc          |    0 
 modeling/backbones/__pycache__/resnet.cpython-38.pyc         |    0 
 layers/__pycache__/frn.cpython-37.pyc                        |    0 
 modeling/backbones/__pycache__/resnest.cpython-38.pyc        |    0 
 data/transforms/__pycache__/autoaugment.cpython-38.pyc       |    0 
 layers/splat.py                                              |   97 
 modeling/heads/linear_head.py                                |   50 
 utils/__pycache__/registry.cpython-37.pyc                    |    0 
 modeling/meta_arch/__pycache__/mgn.cpython-38.pyc            |    0 
 utils/__pycache__/__init__.cpython-38.pyc                    |    0 
 tools/predictor.py                                           |  114 
 layers/__pycache__/context_block.cpython-38.pyc              |    0 
 modeling/backbones/__pycache__/resnet.cpython-37.pyc         |    0 
 utils/__pycache__/checkpoint.cpython-37.pyc                  |    0 
 tools/03_py2onnx.py                                          |   40 
 layers/__pycache__/splat.cpython-37.pyc                      |    0 
 modeling/heads/reduction_head.py                             |   73 
 layers/__pycache__/arcface.cpython-37.pyc                    |    0 
 data/transforms/build.py                                     |   73 
 modeling/backbones/__pycache__/resnext.cpython-38.pyc        |    0 
 utils/__pycache__/history_buffer.cpython-37.pyc              |    0 
 modeling/__pycache__/__init__.cpython-37.pyc                 |    0 
 utils/weight_init.py                                         |   37 
 layers/__pycache__/se_layer.cpython-38.pyc                   |    0 
 modeling/meta_arch/__pycache__/baseline.cpython-38.pyc       |    0 
 utils/__pycache__/events.cpython-38.pyc                      |    0 
 modeling/__init__.py                                         |    6 
 utils/__pycache__/comm.cpython-38.pyc                        |    0 
 data/transforms/__pycache__/functional.cpython-38.pyc        |    0 
 layers/__pycache__/pooling.cpython-37.pyc                    |    0 
 modeling/losses/__pycache__/metric_loss.cpython-37.pyc       |    0 
 layers/__pycache__/activation.cpython-38.pyc                 |    0 
 layers/non_local.py                                          |   54 
 tools/inference_net.py                                       |   60 
 layers/__pycache__/batch_drop.cpython-37.pyc                 |    0 
 config/__pycache__/config.cpython-37.pyc                     |    0 
 modeling/meta_arch/__pycache__/__init__.cpython-38.pyc       |    0 
 layers/__pycache__/__init__.cpython-37.pyc                   |    0 
 data/__pycache__/__init__.cpython-38.pyc                     |    0 
 modeling/meta_arch/__init__.py                               |   12 
 modeling/meta_arch/mgn.py                                    |  280 +
 utils/__pycache__/weight_init.cpython-37.pyc                 |    0 
 layers/__pycache__/circle.cpython-37.pyc                     |    0 
 utils/collect_env.py                                         |  158 +
 data/transforms/__pycache__/transforms.cpython-38.pyc        |    0 
 modeling/heads/attr_head.py                                  |   77 
 data/__pycache__/data_utils.cpython-38.pyc                   |    0 
 data/data_utils.py                                           |   45 
 modeling/losses/__init__.py                                  |    9 
 config/__pycache__/__init__.cpython-38.pyc                   |    0 
 layers/__pycache__/frn.cpython-38.pyc                        |    0 
 modeling/heads/__pycache__/attr_head.cpython-38.pyc          |    0 
 modeling/losses/focal_loss.py                                |  110 
 tools/__pycache__/predictor.cpython-37.pyc                   |    0 
 modeling/meta_arch/baseline.py                               |  119 
 modeling/meta_arch/__pycache__/build.cpython-38.pyc          |    0 
 modeling/losses/__pycache__/metric_loss.cpython-38.pyc       |    0 
 modeling/backbones/__pycache__/__init__.cpython-38.pyc       |    0 
 data/__init__.py                                             |    7 
 modeling/heads/__pycache__/build.cpython-38.pyc              |    0 
 layers/__pycache__/non_local.cpython-38.pyc                  |    0 
 config/defaults.py                                           |  273 +
 data/transforms/__pycache__/__init__.cpython-38.pyc          |    0 
 modeling/heads/__pycache__/__init__.cpython-37.pyc           |    0 
 utils/__pycache__/collect_env.cpython-38.pyc                 |    0 
 modeling/backbones/__pycache__/resnest.cpython-37.pyc        |    0 
 modeling/backbones/resnext.py                                |  198 +
 utils/registry.py                                            |   66 
 config/__init__.py                                           |    9 
 tools/__init__.py                                            |    6 
 utils/__pycache__/registry.cpython-38.pyc                    |    0 
 README.md                                                    |   48 
 modeling/losses/metric_loss.py                               |  215 +
 modeling/backbones/__pycache__/build.cpython-38.pyc          |    0 
 layers/circle.py                                             |   42 
 modeling/losses/__pycache__/__init__.cpython-38.pyc          |    0 
 utils/events.py                                              |  445 +++
 modeling/heads/__pycache__/embedding_head.cpython-37.pyc     |    0 
 modeling/losses/__pycache__/cross_entroy_loss.cpython-37.pyc |    0 
 modeling/backbones/__pycache__/osnet.cpython-38.pyc          |    0 
 layers/__pycache__/se_layer.cpython-37.pyc                   |    0 
 modeling/heads/__init__.py                                   |   12 
 utils/history_buffer.py                                      |   71 
 utils/__pycache__/logger.cpython-38.pyc                      |    0 
 modeling/backbones/__pycache__/build.cpython-37.pyc          |    0 
 utils/__pycache__/weight_init.cpython-38.pyc                 |    0 
 data/transforms/__init__.py                                  |   10 
 modeling/backbones/build.py                                  |   28 
 tools/04_trt_inference.py                                    |   96 
 modeling/backbones/osnet.py                                  |  487 +++
 config/__pycache__/config.cpython-38.pyc                     |    0 
 layers/batch_norm.py                                         |  208 +
 engine/__pycache__/__init__.cpython-38.pyc                   |    0 
 layers/frn.py                                                |  199 +
 layers/__pycache__/arcface.cpython-38.pyc                    |    0 
 layers/__pycache__/context_block.cpython-37.pyc              |    0 
 layers/__pycache__/splat.cpython-38.pyc                      |    0 
 modeling/losses/cross_entroy_loss.py                         |   62 
 config/config.py                                             |  161 +
 layers/pooling.py                                            |   79 
 utils/__pycache__/history_buffer.cpython-38.pyc              |    0 
 modeling/__pycache__/__init__.cpython-38.pyc                 |    0 
 utils/__pycache__/env.cpython-38.pyc                         |    0 
 utils/__pycache__/file_io.cpython-37.pyc                     |    0 
 layers/__pycache__/__init__.cpython-38.pyc                   |    0 
 layers/__pycache__/activation.cpython-37.pyc                 |    0 
 data/__pycache__/__init__.cpython-37.pyc                     |    0 
 layers/__pycache__/batch_norm.cpython-38.pyc                 |    0 
 modeling/losses/__pycache__/focal_loss.cpython-37.pyc        |    0 
 utils/__pycache__/comm.cpython-37.pyc                        |    0 
 layers/__pycache__/pooling.cpython-38.pyc                    |    0 
 /dev/null                                                    |   86 
 layers/__pycache__/circle.cpython-38.pyc                     |    0 
 utils/__init__.py                                            |    6 
 utils/__pycache__/checkpoint.cpython-38.pyc                  |    0 
 modeling/meta_arch/__pycache__/mgn.cpython-37.pyc            |    0 
 utils/__pycache__/events.cpython-37.pyc                      |    0 
 modeling/backbones/__pycache__/resnext.cpython-37.pyc        |    0 
 engine/__pycache__/defaults.cpython-38.pyc                   |    0 
 layers/batch_drop.py                                         |   32 
 data/transforms/functional.py                                |  190 +
 layers/__init__.py                                           |   17 
 layers/se_layer.py                                           |   25 
 modeling/backbones/__init__.py                               |   12 
 164 files changed, 8,427 insertions(+), 88 deletions(-)

diff --git a/Makefile b/Makefile
deleted file mode 100644
index 4eec39b..0000000
--- a/Makefile
+++ /dev/null
@@ -1,93 +0,0 @@
-GPU=1
-CUDNN=1
-DEBUG=1
-
-ARCH= -gencode arch=compute_30,code=sm_30 \
-      -gencode arch=compute_35,code=sm_35 \
-      -gencode arch=compute_50,code=[sm_50,compute_50] \
-      -gencode arch=compute_52,code=[sm_52,compute_52]
-#      -gencode arch=compute_20,code=[sm_20,sm_21] \ This one is deprecated?
-
-# This is what I use, uncomment if you know your arch and want to specify
-# ARCH= -gencode arch=compute_52,code=compute_52
-
-VPATH=.
-EXEC=reid
-OBJDIR=./obj/
-
-CC=gcc
-CPP=g++
-NVCC=nvcc 
-AR=ar
-ARFLAGS=rcs
-OPTS=-Ofast
-LDFLAGS= -lm -pthread 
-CFLAGS=-Wall -Wno-unused-result -Wno-unknown-pragmas -Wfatal-errors -fPIC
-
-CFLAGS+= -I/home/disk1/s_opt/opencv/include/opencv4
-LDFLAGS+= -L/home/disk1/s_opt/opencv/lib -lopencv_core -lopencv_imgcodecs -lopencv_highgui -lopencv_imgproc
-
-
-COMMON = -std=c++11
-#flag = -Wl,-rpath=/home/disk1/s_opt/libtorch/lib
-#COMMON += -D_GLIBCXX_USE_CXX11_ABI=0 -I/home/disk1/s_opt/libtorch/include -I/home/disk1/s_opt/libtorch/include/torch/csrc/api/include
-#LDFLAGS+= -L/home/disk1/s_opt/libtorch/lib -ltorch -lc10 -lc10_cuda -lcudart -lgomp -lnvToolsExt
-
-flag = -Wl,-rpath=/home/disk1/s_opt/libtorch_CPP11/libtorch/lib
-COMMON += -I/home/disk1/s_opt/libtorch_CPP11/libtorch/include -I/home/disk1/s_opt/libtorch_CPP11/libtorch/include/torch/csrc/api/include
-LDFLAGS+= -L/home/disk1/s_opt/libtorch_CPP11/libtorch/lib -ltorch -lc10 -lc10_cuda -lcudart -lgomp -lnvToolsExt
-
-#COMMON = -std=c++11
-#COMMON= -I/home/disk1/data/s_software/CPP11_torch/libtorch/include  -std=c++11
-#COMMON= -D_GLIBCXX_USE_CXX11_ABI=0 -I/home/disk1/data/s_software/CPP11_torch/libtorch/include  -std=c++11
-#LDFLAGS+= -L/home/disk1/data/s_software/CPP11_torch/libtorch/lib -ltorch -lc10 -lc10_cuda -lcudart -lgomp -lnvToolsExt
-
-
-ifeq ($(DEBUG), 1) 
-OPTS=-O0 -g
-endif
-
-CFLAGS+=$(OPTS)
-
-
-COMMON+= -DGPU -I/usr/local/cuda/include/
-
-ifeq ($(CUDNN), 1) 
-COMMON+= -DCUDNN 
-CFLAGS+= -DCUDNN
-LDFLAGS+= -lcudnn
-endif
-
-OBJ=test.o reid_feature.o
-
-ifeq ($(GPU), 1) 
-LDFLAGS+= -lstdc++ 
-OBJ+=
-endif
-
-OBJS = $(addprefix $(OBJDIR), $(OBJ))
-DEPS = $(wildcard */*.h) Makefile 
-
-all: obj $(EXEC)
-#all: obj  results $(SLIB) $(ALIB) $(EXEC)
-
-
-$(EXEC): $(OBJS) 
-	$(CC) $(COMMON) $(CFLAGS) $^ -o $@ $(LDFLAGS) $(ALIB)  $(flag)
-
-$(OBJDIR)%.o: %.cpp $(DEPS)
-	$(CPP) $(COMMON) $(CFLAGS) -c $< -o $@
-
-$(OBJDIR)%.o: %.c $(DEPS)
-	$(CC) $(COMMON) $(CFLAGS) -c $< -o $@
-
-$(OBJDIR)%.o: %.cu $(DEPS)
-	$(NVCC) $(ARCH) $(COMMON) --compiler-options "$(CFLAGS)" -c $< -o $@
-
-obj:
-	mkdir -p obj
-
-.PHONY: clean
-
-clean:
-	rm -rf $(OBJS) $(EXEC) $(EXECOBJ) $(OBJDIR)/*
diff --git a/README.md b/README.md
index f4c8be8..7ea9697 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,48 @@
-## reID
+鏁版嵁璇存槑璺緞锛�
+https://github.com/JDAI-CV/fast-reid/tree/master/datasets
+澶ф濡備笅锛�
 
-浜轰綋閲嶈瘑鍒�
+璁粌
+
+./tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml MODEL.DEVICE "cuda:0"
+
+璇勪及锛�
+python tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --eval-only \
+MODEL.WEIGHTS /path/to/checkpoint_file MODEL.DEVICE "cuda:0"
+
+娴嬭瘯锛�
+python tools/test_net.py --img_a1 a_1.jpg --img_a2 a_2.jpg --img_b1 b_1.jpg --img_b2 b_2.jpg --config-file ./configs/Market1501/bagtricks_R101-ibn.yml MODEL.WEIGHTS ../market_bot_R101-ibn.pth MODEL.DEVICE "cuda:6"
+
+鎺ㄧ悊锛�
+python tools/inference_net.py --img_a1 /data/disk1/workspace/05_DarknetSort/02_humanID/1.jpg --img_a2 /data/disk1/workspace/05_DarknetSort/02_humanID/21.jpg --img_b1 b_1.jpg --img_b2 b_2.jpg --config-file ../01_fast-reid/fast-reid/configs/Market1501/bagtricks_R101-ibn.yml MODEL.WEIGHTS ../market_bot_R101-ibn.pth MODEL.DEVICE "cuda:6"
+
+
+杞琽nnx
+python tools/03_py2onnx.py --config-file ../fast-reid-master/configs/Market1501/bagtricks_R101-ibn.yml MODEL.WEIGHTS /data/disk1/project/model_dump/01_reid/market_bot_R101-ibn.pth MODEL.DEVICE "cuda:0"
+
+python tools/03_py2onnx.py --config-file ./configs/Market1501/bagtricks_R101-ibn.yml MODEL.WEIGHTS ../market_bot_R101-ibn.pth MODEL.DEVICE "cuda:0"
+
+杞瑃rt
+python tools/03_check_onnx.py
+
+./trtexec --onnx=/data/disk1/project/data/01_reid/02_changchuang/human_direction2.onnx --shapes=the_input:1x224x224x3 --workspace=4096 --saveEngine=/data/disk1/project/data/01_reid/02_changchuang/person_count_engine.trt
+
+
+杩愯
+python tools/04_trt_inference.py --model_path 4batch_fp16_True.trt
+
+
+瀹夎瑕佹眰锛氫互鍙婃墍闇�姹傜殑鍖�
+
+Linux or macOS with python 鈮� 3.6
+PyTorch 鈮� 1.6
+torchvision that matches the Pytorch installation. You can install them together at pytorch.org to make sure of this.
+yacs
+Cython (optional to compile evaluation code)
+tensorboard (needed for visualization): pip install tensorboard
+gdown (for automatically downloading pre-train model)
+sklearn
+termcolor
+tabulate
+faiss pip install faiss-cpu
 
diff --git a/config/__init__.py b/config/__init__.py
new file mode 100644
index 0000000..37e5cb6
--- /dev/null
+++ b/config/__init__.py
@@ -0,0 +1,9 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time    : 2020/10/26 17:40
+# @Author  : Scheaven
+# @File    :  __init__.py.py
+# @description:
+
+from .config import CfgNode, get_cfg
+from .defaults import _C as cfg
diff --git a/config/__pycache__/__init__.cpython-37.pyc b/config/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..c77849f
--- /dev/null
+++ b/config/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/config/__pycache__/__init__.cpython-38.pyc b/config/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..f8290d7
--- /dev/null
+++ b/config/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/config/__pycache__/config.cpython-37.pyc b/config/__pycache__/config.cpython-37.pyc
new file mode 100644
index 0000000..1d6dbd8
--- /dev/null
+++ b/config/__pycache__/config.cpython-37.pyc
Binary files differ
diff --git a/config/__pycache__/config.cpython-38.pyc b/config/__pycache__/config.cpython-38.pyc
new file mode 100644
index 0000000..4f03814
--- /dev/null
+++ b/config/__pycache__/config.cpython-38.pyc
Binary files differ
diff --git a/config/__pycache__/defaults.cpython-38.pyc b/config/__pycache__/defaults.cpython-38.pyc
new file mode 100644
index 0000000..8b85b95
--- /dev/null
+++ b/config/__pycache__/defaults.cpython-38.pyc
Binary files differ
diff --git a/config/config.py b/config/config.py
new file mode 100644
index 0000000..85831f5
--- /dev/null
+++ b/config/config.py
@@ -0,0 +1,161 @@
+# encoding: utf-8
+"""
+@author:  l1aoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import logging
+import os
+from typing import Any
+
+import yaml
+from yacs.config import CfgNode as _CfgNode
+
+from utils.file_io import PathManager
+
+BASE_KEY = "_BASE_"
+
+
+class CfgNode(_CfgNode):
+    """
+    Our own extended version of :class:`yacs.config.CfgNode`.
+    It contains the following extra features:
+    1. The :meth:`merge_from_file` method supports the "_BASE_" key,
+       which allows the new CfgNode to inherit all the attributes from the
+       base configuration file.
+    2. Keys that start with "COMPUTED_" are treated as insertion-only
+       "computed" attributes. They can be inserted regardless of whether
+       the CfgNode is frozen or not.
+    3. With "allow_unsafe=True", it supports pyyaml tags that evaluate
+       expressions in config. See examples in
+       https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types
+       Note that this may lead to arbitrary code execution: you must not
+       load a config file from untrusted sources before manually inspecting
+       the content of the file.
+    """
+
+    @staticmethod
+    def load_yaml_with_base(filename: str, allow_unsafe: bool = False):
+        """
+        Just like `yaml.load(open(filename))`, but inherit attributes from its
+            `_BASE_`.
+        Args:
+            filename (str): the file name of the current config. Will be used to
+                find the base config file.
+            allow_unsafe (bool): whether to allow loading the config file with
+                `yaml.unsafe_load`.
+        Returns:
+            (dict): the loaded yaml
+        """
+        with PathManager.open(filename, "r") as f:
+            try:
+                cfg = yaml.safe_load(f)
+            except yaml.constructor.ConstructorError:
+                if not allow_unsafe:
+                    raise
+                logger = logging.getLogger(__name__)
+                logger.warning(
+                    "Loading config {} with yaml.unsafe_load. Your machine may "
+                    "be at risk if the file contains malicious content.".format(
+                        filename
+                    )
+                )
+                f.close()
+                with open(filename, "r") as f:
+                    cfg = yaml.unsafe_load(f)
+
+        def merge_a_into_b(a, b):
+            # merge dict a into dict b. values in a will overwrite b.
+            for k, v in a.items():
+                if isinstance(v, dict) and k in b:
+                    assert isinstance(
+                        b[k], dict
+                    ), "Cannot inherit key '{}' from base!".format(k)
+                    merge_a_into_b(v, b[k])
+                else:
+                    b[k] = v
+
+        if BASE_KEY in cfg:
+            base_cfg_file = cfg[BASE_KEY]
+            if base_cfg_file.startswith("~"):
+                base_cfg_file = os.path.expanduser(base_cfg_file)
+            if not any(
+                    map(base_cfg_file.startswith, ["/", "https://", "http://"])
+            ):
+                # the path to base cfg is relative to the config file itself.
+                base_cfg_file = os.path.join(
+                    os.path.dirname(filename), base_cfg_file
+                )
+            base_cfg = CfgNode.load_yaml_with_base(
+                base_cfg_file, allow_unsafe=allow_unsafe
+            )
+            del cfg[BASE_KEY]
+
+            merge_a_into_b(cfg, base_cfg)
+            return base_cfg
+        return cfg
+
+    def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = False):
+        """
+        Merge configs from a given yaml file.
+        Args:
+            cfg_filename: the file name of the yaml config.
+            allow_unsafe: whether to allow loading the config file with
+                `yaml.unsafe_load`.
+        """
+        loaded_cfg = CfgNode.load_yaml_with_base(
+            cfg_filename, allow_unsafe=allow_unsafe
+        )
+        loaded_cfg = type(self)(loaded_cfg)
+        # self.merge_from_other_cfg(loaded_cfg)
+
+    # Forward the following calls to base, but with a check on the BASE_KEY.
+    def merge_from_other_cfg(self, cfg_other):
+        """
+        Args:
+            cfg_other (CfgNode): configs to merge from.
+        """
+        assert (
+                BASE_KEY not in cfg_other
+        ), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
+
+
+        return super().merge_from_other_cfg(cfg_other)
+
+    def merge_from_list(self, cfg_list: list):
+        """
+        Args:
+            cfg_list (list): list of configs to merge from.
+        """
+        keys = set(cfg_list[0::2])
+        assert (
+                BASE_KEY not in keys
+        ), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
+        return super().merge_from_list(cfg_list)
+
+    def __setattr__(self, name: str, val: Any):
+        if name.startswith("COMPUTED_"):
+            if name in self:
+                old_val = self[name]
+                if old_val == val:
+                    return
+                raise KeyError(
+                    "Computed attributed '{}' already exists "
+                    "with a different value! old={}, new={}.".format(
+                        name, old_val, val
+                    )
+                )
+            self[name] = val
+        else:
+            super().__setattr__(name, val)
+
+
+def get_cfg() -> CfgNode:
+    """
+    Get a copy of the default config.
+    Returns:
+        a fastreid CfgNode instance.
+    """
+    from .defaults import _C
+
+    return _C.clone()
diff --git a/config/defaults.py b/config/defaults.py
new file mode 100644
index 0000000..61651a5
--- /dev/null
+++ b/config/defaults.py
@@ -0,0 +1,273 @@
+from .config import CfgNode as CN
+
+# -----------------------------------------------------------------------------
+# Convention about Training / Test specific parameters
+# -----------------------------------------------------------------------------
+# Whenever an argument can be either used for training or for testing, the
+# corresponding name will be post-fixed by a _TRAIN for a training parameter,
+# or _TEST for a test-specific parameter.
+# For example, the number of images during training will be
+# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
+# IMAGES_PER_BATCH_TEST
+
+# -----------------------------------------------------------------------------
+# Config definition
+# -----------------------------------------------------------------------------
+
+_C = CN()
+
+# -----------------------------------------------------------------------------
+# MODEL
+# -----------------------------------------------------------------------------
+_C.MODEL = CN()
+_C.MODEL.DEVICE = "cuda"
+_C.MODEL.META_ARCHITECTURE = 'Baseline'
+_C.MODEL.FREEZE_LAYERS = ['']
+
+# ---------------------------------------------------------------------------- #
+# Backbone options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.BACKBONE = CN()
+
+_C.MODEL.BACKBONE.NAME = "build_resnet_backbone"
+_C.MODEL.BACKBONE.DEPTH = "50x"
+_C.MODEL.BACKBONE.LAST_STRIDE = 1
+# Backbone feature dimension
+_C.MODEL.BACKBONE.FEAT_DIM = 2048
+# Normalization method for the convolution layers.
+_C.MODEL.BACKBONE.NORM = "BN"
+# If use IBN block in backbone
+_C.MODEL.BACKBONE.WITH_IBN = False
+# If use SE block in backbone
+_C.MODEL.BACKBONE.WITH_SE = False
+# If use Non-local block in backbone
+_C.MODEL.BACKBONE.WITH_NL = False
+# If use ImageNet pretrain model
+_C.MODEL.BACKBONE.PRETRAIN = True
+# Pretrain model path
+_C.MODEL.BACKBONE.PRETRAIN_PATH = ''
+
+# ---------------------------------------------------------------------------- #
+# REID HEADS options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.HEADS = CN()
+_C.MODEL.HEADS.NAME = "EmbeddingHead"
+# Normalization method for the convolution layers.
+_C.MODEL.HEADS.NORM = "BN"
+# Number of identity
+_C.MODEL.HEADS.NUM_CLASSES = 0
+# Embedding dimension in head
+_C.MODEL.HEADS.EMBEDDING_DIM = 0
+# If use BNneck in embedding
+_C.MODEL.HEADS.WITH_BNNECK = True
+# Triplet feature using feature before(after) bnneck
+_C.MODEL.HEADS.NECK_FEAT = "before"  # options: before, after
+# Pooling layer type
+_C.MODEL.HEADS.POOL_LAYER = "avgpool"
+
+# Classification layer type
+_C.MODEL.HEADS.CLS_LAYER = "linear"  # "arcSoftmax" or "circleSoftmax"
+
+# Margin and Scale for margin-based classification layer
+_C.MODEL.HEADS.MARGIN = 0.15
+_C.MODEL.HEADS.SCALE = 128
+
+# ---------------------------------------------------------------------------- #
+# REID LOSSES options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.LOSSES = CN()
+_C.MODEL.LOSSES.NAME = ("CrossEntropyLoss",)
+
+# Cross Entropy Loss options
+_C.MODEL.LOSSES.CE = CN()
+# if epsilon == 0, it means no label smooth regularization,
+# if epsilon == -1, it means adaptive label smooth regularization
+_C.MODEL.LOSSES.CE.EPSILON = 0.0
+_C.MODEL.LOSSES.CE.ALPHA = 0.2
+_C.MODEL.LOSSES.CE.SCALE = 1.0
+
+# Triplet Loss options
+_C.MODEL.LOSSES.TRI = CN()
+_C.MODEL.LOSSES.TRI.MARGIN = 0.3
+_C.MODEL.LOSSES.TRI.NORM_FEAT = False
+_C.MODEL.LOSSES.TRI.HARD_MINING = True
+_C.MODEL.LOSSES.TRI.SCALE = 1.0
+
+# Circle Loss options
+_C.MODEL.LOSSES.CIRCLE = CN()
+_C.MODEL.LOSSES.CIRCLE.MARGIN = 0.25
+_C.MODEL.LOSSES.CIRCLE.ALPHA = 128
+_C.MODEL.LOSSES.CIRCLE.SCALE = 1.0
+
+# Focal Loss options
+_C.MODEL.LOSSES.FL = CN()
+_C.MODEL.LOSSES.FL.ALPHA = 0.25
+_C.MODEL.LOSSES.FL.GAMMA = 2
+_C.MODEL.LOSSES.FL.SCALE = 1.0
+
+# Path to a checkpoint file to be loaded to the model. You can find available models in the model zoo.
+_C.MODEL.WEIGHTS = ""
+
+# Values to be used for image normalization
+_C.MODEL.PIXEL_MEAN = [0.485*255, 0.456*255, 0.406*255]
+# Values to be used for image normalization
+_C.MODEL.PIXEL_STD = [0.229*255, 0.224*255, 0.225*255]
+
+
+# -----------------------------------------------------------------------------
+# INPUT
+# -----------------------------------------------------------------------------
+_C.INPUT = CN()
+# Size of the image during training
+_C.INPUT.SIZE_TRAIN = [256, 128]
+# Size of the image during test
+_C.INPUT.SIZE_TEST = [256, 128]  #鍙傛暟鏄紙h,w锛�
+
+# Random probability for image horizontal flip
+_C.INPUT.DO_FLIP = True
+_C.INPUT.FLIP_PROB = 0.5
+
+# Value of padding size
+_C.INPUT.DO_PAD = True
+_C.INPUT.PADDING_MODE = 'constant'
+_C.INPUT.PADDING = 10
+# Random color jitter
+_C.INPUT.CJ = CN()
+_C.INPUT.CJ.ENABLED = False
+_C.INPUT.CJ.PROB = 0.8
+_C.INPUT.CJ.BRIGHTNESS = 0.15
+_C.INPUT.CJ.CONTRAST = 0.15
+_C.INPUT.CJ.SATURATION = 0.1
+_C.INPUT.CJ.HUE = 0.1
+# Auto augmentation
+_C.INPUT.DO_AUTOAUG = False
+# Augmix augmentation
+_C.INPUT.DO_AUGMIX = False
+# Random Erasing
+_C.INPUT.REA = CN()
+_C.INPUT.REA.ENABLED = False
+_C.INPUT.REA.PROB = 0.5
+_C.INPUT.REA.MEAN = [0.596*255, 0.558*255, 0.497*255]  # [0.485*255, 0.456*255, 0.406*255]
+# Random Patch
+_C.INPUT.RPT = CN()
+_C.INPUT.RPT.ENABLED = False
+_C.INPUT.RPT.PROB = 0.5
+
+# -----------------------------------------------------------------------------
+# Dataset
+# -----------------------------------------------------------------------------
+_C.DATASETS = CN()
+# List of the dataset names for training
+_C.DATASETS.NAMES = ("Market1501",)
+# List of the dataset names for testing
+_C.DATASETS.TESTS = ("Market1501",)
+# Combine trainset and testset joint training
+_C.DATASETS.COMBINEALL = False
+
+# -----------------------------------------------------------------------------
+# DataLoader
+# -----------------------------------------------------------------------------
+_C.DATALOADER = CN()
+# P/K Sampler for data loading
+_C.DATALOADER.PK_SAMPLER = True
+# Naive sampler which don't consider balanced identity sampling
+_C.DATALOADER.NAIVE_WAY = False
+# Number of instance for each person
+_C.DATALOADER.NUM_INSTANCE = 4
+_C.DATALOADER.NUM_WORKERS = 8
+
+# ---------------------------------------------------------------------------- #
+# Solver
+# ---------------------------------------------------------------------------- #
+_C.SOLVER = CN()
+
+# AUTOMATIC MIXED PRECISION
+_C.SOLVER.AMP_ENABLED = False
+
+# Optimizer
+_C.SOLVER.OPT = "Adam"
+
+_C.SOLVER.MAX_ITER = 120
+
+_C.SOLVER.BASE_LR = 3e-4
+_C.SOLVER.BIAS_LR_FACTOR = 1.
+_C.SOLVER.HEADS_LR_FACTOR = 1.
+
+_C.SOLVER.MOMENTUM = 0.9
+
+_C.SOLVER.WEIGHT_DECAY = 0.0005
+_C.SOLVER.WEIGHT_DECAY_BIAS = 0.
+
+# Multi-step learning rate options
+_C.SOLVER.SCHED = "WarmupMultiStepLR"
+_C.SOLVER.GAMMA = 0.1
+_C.SOLVER.STEPS = [30, 55]
+
+# Cosine annealing learning rate options
+_C.SOLVER.DELAY_ITERS = 0
+_C.SOLVER.ETA_MIN_LR = 3e-7
+
+# Warmup options
+_C.SOLVER.WARMUP_FACTOR = 0.1
+_C.SOLVER.WARMUP_ITERS = 10
+_C.SOLVER.WARMUP_METHOD = "linear"
+
+_C.SOLVER.FREEZE_ITERS = 0
+
+# SWA options
+_C.SOLVER.SWA = CN()
+_C.SOLVER.SWA.ENABLED = False
+_C.SOLVER.SWA.ITER = 10
+_C.SOLVER.SWA.PERIOD = 2
+_C.SOLVER.SWA.LR_FACTOR = 10.
+_C.SOLVER.SWA.ETA_MIN_LR = 3.5e-6
+_C.SOLVER.SWA.LR_SCHED = False
+
+_C.SOLVER.CHECKPOINT_PERIOD = 20
+
+# Number of images per batch across all machines.
+# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
+# see 2 images per batch
+_C.SOLVER.IMS_PER_BATCH = 64
+
+# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
+# see 2 images per batch
+_C.TEST = CN()
+
+_C.TEST.EVAL_PERIOD = 20
+
+# Number of images per batch in one process.
+_C.TEST.IMS_PER_BATCH = 64
+_C.TEST.METRIC = "cosine"
+_C.TEST.ROC_ENABLED = False
+
+# Average query expansion
+_C.TEST.AQE = CN()
+_C.TEST.AQE.ENABLED = False
+_C.TEST.AQE.ALPHA = 3.0
+_C.TEST.AQE.QE_TIME = 1
+_C.TEST.AQE.QE_K = 5
+
+# Re-rank
+_C.TEST.RERANK = CN()
+_C.TEST.RERANK.ENABLED = False
+_C.TEST.RERANK.K1 = 20
+_C.TEST.RERANK.K2 = 6
+_C.TEST.RERANK.LAMBDA = 0.3
+
+# Precise batchnorm
+_C.TEST.PRECISE_BN = CN()
+_C.TEST.PRECISE_BN.ENABLED = False
+_C.TEST.PRECISE_BN.DATASET = 'Market1501'
+_C.TEST.PRECISE_BN.NUM_ITER = 300
+
+# ---------------------------------------------------------------------------- #
+# Misc options
+# ---------------------------------------------------------------------------- #
+_C.OUTPUT_DIR = "logs/"
+
+# Benchmark different cudnn algorithms.
+# If input images have very different sizes, this option will have large overhead
+# for about 10k iterations. It usually hurts total time, but can benefit for certain models.
+# If input images have the same or similar sizes, benchmark is often helpful.
+_C.CUDNN_BENCHMARK = False
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000..dbf5adc
--- /dev/null
+++ b/data/__init__.py
@@ -0,0 +1,7 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time    : 2020/10/26 16:24
+# @Author  : Scheaven
+# @File    :  __init__.py.py
+# @description:
+
diff --git a/data/__pycache__/__init__.cpython-37.pyc b/data/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..36ec137
--- /dev/null
+++ b/data/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/data/__pycache__/__init__.cpython-38.pyc b/data/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..dda59a4
--- /dev/null
+++ b/data/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/data/__pycache__/data_utils.cpython-37.pyc b/data/__pycache__/data_utils.cpython-37.pyc
new file mode 100644
index 0000000..332f74a
--- /dev/null
+++ b/data/__pycache__/data_utils.cpython-37.pyc
Binary files differ
diff --git a/data/__pycache__/data_utils.cpython-38.pyc b/data/__pycache__/data_utils.cpython-38.pyc
new file mode 100644
index 0000000..da72b54
--- /dev/null
+++ b/data/__pycache__/data_utils.cpython-38.pyc
Binary files differ
diff --git a/data/data_utils.py b/data/data_utils.py
new file mode 100644
index 0000000..3b65111
--- /dev/null
+++ b/data/data_utils.py
@@ -0,0 +1,45 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+import numpy as np
+from PIL import Image, ImageOps
+
+from utils.file_io import PathManager
+
+
+def read_image(file_name, format=None):
+    """
+    Read an image into the given format.
+    Will apply rotation and flipping if the image has such exif information.
+    Args:
+        file_name (str): image file path
+        format (str): one of the supported image modes in PIL, or "BGR"
+    Returns:
+        image (np.ndarray): an HWC image
+    """
+    with PathManager.open(file_name, "rb") as f:
+        image = Image.open(f)
+
+        # capture and ignore this bug: https://github.com/python-pillow/Pillow/issues/3973
+        try:
+            image = ImageOps.exif_transpose(image)
+        except Exception:
+            pass
+
+        if format is not None:
+            # PIL only supports RGB, so convert to RGB and flip channels over below
+            conversion_format = format
+            if format == "BGR":
+                conversion_format = "RGB"
+            image = image.convert(conversion_format)
+        image = np.asarray(image)
+        if format == "BGR":
+            # flip channels if needed
+            image = image[:, :, ::-1]
+        # PIL squeezes out the channel dimension for "L", so make it HWC
+        if format == "L":
+            image = np.expand_dims(image, -1)
+        image = Image.fromarray(image)
+        return image
diff --git a/data/transforms/__init__.py b/data/transforms/__init__.py
new file mode 100644
index 0000000..45c7d13
--- /dev/null
+++ b/data/transforms/__init__.py
@@ -0,0 +1,10 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time    : 2020/10/26 16:24
+# @Author  : Scheaven
+# @File    :  __init__.py.py
+# @description:
+
+from .build import build_transforms
+from .transforms import *
+from .autoaugment import *
\ No newline at end of file
diff --git a/data/transforms/__pycache__/__init__.cpython-38.pyc b/data/transforms/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..ae2715d
--- /dev/null
+++ b/data/transforms/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/data/transforms/__pycache__/autoaugment.cpython-38.pyc b/data/transforms/__pycache__/autoaugment.cpython-38.pyc
new file mode 100644
index 0000000..78532f6
--- /dev/null
+++ b/data/transforms/__pycache__/autoaugment.cpython-38.pyc
Binary files differ
diff --git a/data/transforms/__pycache__/build.cpython-38.pyc b/data/transforms/__pycache__/build.cpython-38.pyc
new file mode 100644
index 0000000..642959c
--- /dev/null
+++ b/data/transforms/__pycache__/build.cpython-38.pyc
Binary files differ
diff --git a/data/transforms/__pycache__/functional.cpython-38.pyc b/data/transforms/__pycache__/functional.cpython-38.pyc
new file mode 100644
index 0000000..d7e6a43
--- /dev/null
+++ b/data/transforms/__pycache__/functional.cpython-38.pyc
Binary files differ
diff --git a/data/transforms/__pycache__/transforms.cpython-38.pyc b/data/transforms/__pycache__/transforms.cpython-38.pyc
new file mode 100644
index 0000000..9a7430f
--- /dev/null
+++ b/data/transforms/__pycache__/transforms.cpython-38.pyc
Binary files differ
diff --git a/data/transforms/autoaugment.py b/data/transforms/autoaugment.py
new file mode 100644
index 0000000..487e70d
--- /dev/null
+++ b/data/transforms/autoaugment.py
@@ -0,0 +1,812 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+""" AutoAugment, RandAugment, and AugMix for PyTorch
+This code implements the searched ImageNet policies with various tweaks and improvements and
+does not include any of the search code.
+AA and RA Implementation adapted from:
+    https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
+AugMix adapted from:
+    https://github.com/google-research/augmix
+Papers:
+    AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
+    Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
+    RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
+    AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
+Hacked together by Ross Wightman
+"""
+import math
+import random
+import re
+
+import PIL
+import numpy as np
+from PIL import Image, ImageOps, ImageEnhance
+
+_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
+
+_FILL = (128, 128, 128)
+
+# This signifies the max integer that the controller RNN could predict for the
+# augmentation scheme.
+_MAX_LEVEL = 10.
+
+_HPARAMS_DEFAULT = dict(
+    translate_const=57,
+    img_mean=_FILL,
+)
+
+_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
+
+
+def _interpolation(kwargs):
+    interpolation = kwargs.pop('resample', Image.BILINEAR)
+    if isinstance(interpolation, (list, tuple)):
+        return random.choice(interpolation)
+    else:
+        return interpolation
+
+
+def _check_args_tf(kwargs):
+    if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
+        kwargs.pop('fillcolor')
+    kwargs['resample'] = _interpolation(kwargs)
+
+
+def shear_x(img, factor, **kwargs):
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
+
+
+def shear_y(img, factor, **kwargs):
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
+
+
+def translate_x_rel(img, pct, **kwargs):
+    pixels = pct * img.size[0]
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
+
+
+def translate_y_rel(img, pct, **kwargs):
+    pixels = pct * img.size[1]
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
+
+
+def translate_x_abs(img, pixels, **kwargs):
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
+
+
+def translate_y_abs(img, pixels, **kwargs):
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
+
+
+def rotate(img, degrees, **kwargs):
+    _check_args_tf(kwargs)
+    if _PIL_VER >= (5, 2):
+        return img.rotate(degrees, **kwargs)
+    elif _PIL_VER >= (5, 0):
+        w, h = img.size
+        post_trans = (0, 0)
+        rotn_center = (w / 2.0, h / 2.0)
+        angle = -math.radians(degrees)
+        matrix = [
+            round(math.cos(angle), 15),
+            round(math.sin(angle), 15),
+            0.0,
+            round(-math.sin(angle), 15),
+            round(math.cos(angle), 15),
+            0.0,
+        ]
+
+        def transform(x, y, matrix):
+            (a, b, c, d, e, f) = matrix
+            return a * x + b * y + c, d * x + e * y + f
+
+        matrix[2], matrix[5] = transform(
+            -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
+        )
+        matrix[2] += rotn_center[0]
+        matrix[5] += rotn_center[1]
+        return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
+    else:
+        return img.rotate(degrees, resample=kwargs['resample'])
+
+
+def auto_contrast(img, **__):
+    return ImageOps.autocontrast(img)
+
+
+def invert(img, **__):
+    return ImageOps.invert(img)
+
+
+def equalize(img, **__):
+    return ImageOps.equalize(img)
+
+
+def solarize(img, thresh, **__):
+    return ImageOps.solarize(img, thresh)
+
+
+def solarize_add(img, add, thresh=128, **__):
+    lut = []
+    for i in range(256):
+        if i < thresh:
+            lut.append(min(255, i + add))
+        else:
+            lut.append(i)
+    if img.mode in ("L", "RGB"):
+        if img.mode == "RGB" and len(lut) == 256:
+            lut = lut + lut + lut
+        return img.point(lut)
+    else:
+        return img
+
+
+def posterize(img, bits_to_keep, **__):
+    if bits_to_keep >= 8:
+        return img
+    return ImageOps.posterize(img, bits_to_keep)
+
+
+def contrast(img, factor, **__):
+    return ImageEnhance.Contrast(img).enhance(factor)
+
+
+def color(img, factor, **__):
+    return ImageEnhance.Color(img).enhance(factor)
+
+
+def brightness(img, factor, **__):
+    return ImageEnhance.Brightness(img).enhance(factor)
+
+
+def sharpness(img, factor, **__):
+    return ImageEnhance.Sharpness(img).enhance(factor)
+
+
+def _randomly_negate(v):
+    """With 50% prob, negate the value"""
+    return -v if random.random() > 0.5 else v
+
+
+def _rotate_level_to_arg(level, _hparams):
+    # range [-30, 30]
+    level = (level / _MAX_LEVEL) * 30.
+    level = _randomly_negate(level)
+    return level,
+
+
+def _enhance_level_to_arg(level, _hparams):
+    # range [0.1, 1.9]
+    return (level / _MAX_LEVEL) * 1.8 + 0.1,
+
+
+def _enhance_increasing_level_to_arg(level, _hparams):
+    # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
+    # range [0.1, 1.9]
+    level = (level / _MAX_LEVEL) * .9
+    level = 1.0 + _randomly_negate(level)
+    return level,
+
+
+def _shear_level_to_arg(level, _hparams):
+    # range [-0.3, 0.3]
+    level = (level / _MAX_LEVEL) * 0.3
+    level = _randomly_negate(level)
+    return level,
+
+
+def _translate_abs_level_to_arg(level, hparams):
+    translate_const = hparams['translate_const']
+    level = (level / _MAX_LEVEL) * float(translate_const)
+    level = _randomly_negate(level)
+    return level,
+
+
+def _translate_rel_level_to_arg(level, hparams):
+    # default range [-0.45, 0.45]
+    translate_pct = hparams.get('translate_pct', 0.45)
+    level = (level / _MAX_LEVEL) * translate_pct
+    level = _randomly_negate(level)
+    return level,
+
+
+def _posterize_level_to_arg(level, _hparams):
+    # As per Tensorflow TPU EfficientNet impl
+    # range [0, 4], 'keep 0 up to 4 MSB of original image'
+    # intensity/severity of augmentation decreases with level
+    return int((level / _MAX_LEVEL) * 4),
+
+
+def _posterize_increasing_level_to_arg(level, hparams):
+    # As per Tensorflow models research and UDA impl
+    # range [4, 0], 'keep 4 down to 0 MSB of original image',
+    # intensity/severity of augmentation increases with level
+    return 4 - _posterize_level_to_arg(level, hparams)[0],
+
+
+def _posterize_original_level_to_arg(level, _hparams):
+    # As per original AutoAugment paper description
+    # range [4, 8], 'keep 4 up to 8 MSB of image'
+    # intensity/severity of augmentation decreases with level
+    return int((level / _MAX_LEVEL) * 4) + 4,
+
+
+def _solarize_level_to_arg(level, _hparams):
+    # range [0, 256]
+    # intensity/severity of augmentation decreases with level
+    return int((level / _MAX_LEVEL) * 256),
+
+
+def _solarize_increasing_level_to_arg(level, _hparams):
+    # range [0, 256]
+    # intensity/severity of augmentation increases with level
+    return 256 - _solarize_level_to_arg(level, _hparams)[0],
+
+
+def _solarize_add_level_to_arg(level, _hparams):
+    # range [0, 110]
+    return int((level / _MAX_LEVEL) * 110),
+
+
+LEVEL_TO_ARG = {
+    'AutoContrast': None,
+    'Equalize': None,
+    'Invert': None,
+    'Rotate': _rotate_level_to_arg,
+    # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
+    'Posterize': _posterize_level_to_arg,
+    'PosterizeIncreasing': _posterize_increasing_level_to_arg,
+    'PosterizeOriginal': _posterize_original_level_to_arg,
+    'Solarize': _solarize_level_to_arg,
+    'SolarizeIncreasing': _solarize_increasing_level_to_arg,
+    'SolarizeAdd': _solarize_add_level_to_arg,
+    'Color': _enhance_level_to_arg,
+    'ColorIncreasing': _enhance_increasing_level_to_arg,
+    'Contrast': _enhance_level_to_arg,
+    'ContrastIncreasing': _enhance_increasing_level_to_arg,
+    'Brightness': _enhance_level_to_arg,
+    'BrightnessIncreasing': _enhance_increasing_level_to_arg,
+    'Sharpness': _enhance_level_to_arg,
+    'SharpnessIncreasing': _enhance_increasing_level_to_arg,
+    'ShearX': _shear_level_to_arg,
+    'ShearY': _shear_level_to_arg,
+    'TranslateX': _translate_abs_level_to_arg,
+    'TranslateY': _translate_abs_level_to_arg,
+    'TranslateXRel': _translate_rel_level_to_arg,
+    'TranslateYRel': _translate_rel_level_to_arg,
+}
+
+NAME_TO_OP = {
+    'AutoContrast': auto_contrast,
+    'Equalize': equalize,
+    'Invert': invert,
+    'Rotate': rotate,
+    'Posterize': posterize,
+    'PosterizeIncreasing': posterize,
+    'PosterizeOriginal': posterize,
+    'Solarize': solarize,
+    'SolarizeIncreasing': solarize,
+    'SolarizeAdd': solarize_add,
+    'Color': color,
+    'ColorIncreasing': color,
+    'Contrast': contrast,
+    'ContrastIncreasing': contrast,
+    'Brightness': brightness,
+    'BrightnessIncreasing': brightness,
+    'Sharpness': sharpness,
+    'SharpnessIncreasing': sharpness,
+    'ShearX': shear_x,
+    'ShearY': shear_y,
+    'TranslateX': translate_x_abs,
+    'TranslateY': translate_y_abs,
+    'TranslateXRel': translate_x_rel,
+    'TranslateYRel': translate_y_rel,
+}
+
+
+class AugmentOp:
+
+    def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
+        hparams = hparams or _HPARAMS_DEFAULT
+        self.aug_fn = NAME_TO_OP[name]
+        self.level_fn = LEVEL_TO_ARG[name]
+        self.prob = prob
+        self.magnitude = magnitude
+        self.hparams = hparams.copy()
+        self.kwargs = dict(
+            fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
+            resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
+        )
+
+        # If magnitude_std is > 0, we introduce some randomness
+        # in the usually fixed policy and sample magnitude from a normal distribution
+        # with mean `magnitude` and std-dev of `magnitude_std`.
+        # NOTE This is my own hack, being tested, not in papers or reference impls.
+        self.magnitude_std = self.hparams.get('magnitude_std', 0)
+
+    def __call__(self, img):
+        if self.prob < 1.0 and random.random() > self.prob:
+            return img
+        magnitude = self.magnitude
+        if self.magnitude_std and self.magnitude_std > 0:
+            magnitude = random.gauss(magnitude, self.magnitude_std)
+        magnitude = min(_MAX_LEVEL, max(0, magnitude))  # clip to valid range
+        level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
+        return self.aug_fn(img, *level_args, **self.kwargs)
+
+
+def auto_augment_policy_v0(hparams):
+    # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
+    policy = [
+        [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
+        [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
+        [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
+        [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
+        [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
+        [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
+        [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
+        [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
+        [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
+        [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
+        [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
+        [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
+        [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
+        [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
+        [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
+        [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
+        [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
+        [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
+        [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
+        [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
+        [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
+        [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
+        [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],  # This results in black image with Tpu posterize
+        [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
+        [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
+    ]
+    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+    return pc
+
+
+def auto_augment_policy_v0r(hparams):
+    # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
+    # in Google research implementation (number of bits discarded increases with magnitude)
+    policy = [
+        [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
+        [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
+        [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
+        [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
+        [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
+        [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
+        [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
+        [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
+        [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
+        [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
+        [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
+        [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
+        [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
+        [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
+        [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
+        [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
+        [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
+        [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
+        [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
+        [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
+        [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
+        [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
+        [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
+        [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
+        [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
+    ]
+    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+    return pc
+
+
+def auto_augment_policy_original(hparams):
+    # ImageNet policy from https://arxiv.org/abs/1805.09501
+    policy = [
+        [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
+        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+        [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
+        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+        [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
+        [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
+        [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
+        [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
+        [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
+        [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
+        [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
+        [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
+        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+        [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
+        [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
+        [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
+        [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
+        [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
+        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+    ]
+    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+    return pc
+
+
+def auto_augment_policy_originalr(hparams):
+    # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
+    policy = [
+        [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
+        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+        [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
+        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+        [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
+        [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
+        [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
+        [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
+        [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
+        [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
+        [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
+        [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
+        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+        [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
+        [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
+        [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
+        [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
+        [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
+        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+    ]
+    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+    return pc
+
+
+def auto_augment_policy(name="original"):
+    hparams = _HPARAMS_DEFAULT
+    if name == 'original':
+        return auto_augment_policy_original(hparams)
+    elif name == 'originalr':
+        return auto_augment_policy_originalr(hparams)
+    elif name == 'v0':
+        return auto_augment_policy_v0(hparams)
+    elif name == 'v0r':
+        return auto_augment_policy_v0r(hparams)
+    else:
+        assert False, 'Unknown AA policy (%s)' % name
+
+
+class AutoAugment:
+
+    def __init__(self, total_iter):
+        self.total_iter = total_iter
+        self.gamma = 0
+        self.policy = auto_augment_policy()
+
+    def __call__(self, img):
+        if random.uniform(0, 1) > self.gamma:
+            sub_policy = random.choice(self.policy)
+            self.gamma = min(1.0, self.gamma + 1.0 / self.total_iter)
+            for op in sub_policy:
+                img = op(img)
+            return img
+        else:
+            return img
+
+
+def auto_augment_transform(config_str, hparams):
+    """
+    Create a AutoAugment transform
+    :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
+    dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
+    The remaining sections, not order sepecific determine
+        'mstd' -  float std deviation of magnitude noise applied
+    Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
+    :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
+    :return: A PyTorch compatible Transform
+    """
+    config = config_str.split('-')
+    policy_name = config[0]
+    config = config[1:]
+    for c in config:
+        cs = re.split(r'(\d.*)', c)
+        if len(cs) < 2:
+            continue
+        key, val = cs[:2]
+        if key == 'mstd':
+            # noise param injected via hparams for now
+            hparams.setdefault('magnitude_std', float(val))
+        else:
+            assert False, 'Unknown AutoAugment config section'
+    aa_policy = auto_augment_policy(policy_name)
+    return AutoAugment(aa_policy)
+
+
+_RAND_TRANSFORMS = [
+    'AutoContrast',
+    'Equalize',
+    'Invert',
+    'Rotate',
+    'Posterize',
+    'Solarize',
+    'SolarizeAdd',
+    'Color',
+    'Contrast',
+    'Brightness',
+    'Sharpness',
+    'ShearX',
+    'ShearY',
+    'TranslateXRel',
+    'TranslateYRel',
+    # 'Cutout'  # NOTE I've implement this as random erasing separately
+]
+
+_RAND_INCREASING_TRANSFORMS = [
+    'AutoContrast',
+    'Equalize',
+    'Invert',
+    'Rotate',
+    'PosterizeIncreasing',
+    'SolarizeIncreasing',
+    'SolarizeAdd',
+    'ColorIncreasing',
+    'ContrastIncreasing',
+    'BrightnessIncreasing',
+    'SharpnessIncreasing',
+    'ShearX',
+    'ShearY',
+    'TranslateXRel',
+    'TranslateYRel',
+    # 'Cutout'  # NOTE I've implement this as random erasing separately
+]
+
+# These experimental weights are based loosely on the relative improvements mentioned in paper.
+# They may not result in increased performance, but could likely be tuned to so.
+_RAND_CHOICE_WEIGHTS_0 = {
+    'Rotate': 0.3,
+    'ShearX': 0.2,
+    'ShearY': 0.2,
+    'TranslateXRel': 0.1,
+    'TranslateYRel': 0.1,
+    'Color': .025,
+    'Sharpness': 0.025,
+    'AutoContrast': 0.025,
+    'Solarize': .005,
+    'SolarizeAdd': .005,
+    'Contrast': .005,
+    'Brightness': .005,
+    'Equalize': .005,
+    'Posterize': 0,
+    'Invert': 0,
+}
+
+
+def _select_rand_weights(weight_idx=0, transforms=None):
+    transforms = transforms or _RAND_TRANSFORMS
+    assert weight_idx == 0  # only one set of weights currently
+    rand_weights = _RAND_CHOICE_WEIGHTS_0
+    probs = [rand_weights[k] for k in transforms]
+    probs /= np.sum(probs)
+    return probs
+
+
+def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
+    hparams = hparams or _HPARAMS_DEFAULT
+    transforms = transforms or _RAND_TRANSFORMS
+    return [AugmentOp(
+        name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
+
+
+class RandAugment:
+    def __init__(self, ops, num_layers=2, choice_weights=None):
+        self.ops = ops
+        self.num_layers = num_layers
+        self.choice_weights = choice_weights
+
+    def __call__(self, img):
+        # no replacement when using weighted choice
+        ops = np.random.choice(
+            self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
+        for op in ops:
+            img = op(img)
+        return img
+
+
+def rand_augment_transform(config_str, hparams):
+    """
+    Create a RandAugment transform
+    :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
+    dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
+    sections, not order sepecific determine
+        'm' - integer magnitude of rand augment
+        'n' - integer num layers (number of transform ops selected per image)
+        'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
+        'mstd' -  float std deviation of magnitude noise applied
+        'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
+    Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
+    'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
+    :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
+    :return: A PyTorch compatible Transform
+    """
+    magnitude = _MAX_LEVEL  # default to _MAX_LEVEL for magnitude (currently 10)
+    num_layers = 2  # default to 2 ops per image
+    weight_idx = None  # default to no probability weights for op choice
+    transforms = _RAND_TRANSFORMS
+    config = config_str.split('-')
+    assert config[0] == 'rand'
+    config = config[1:]
+    for c in config:
+        cs = re.split(r'(\d.*)', c)
+        if len(cs) < 2:
+            continue
+        key, val = cs[:2]
+        if key == 'mstd':
+            # noise param injected via hparams for now
+            hparams.setdefault('magnitude_std', float(val))
+        elif key == 'inc':
+            if bool(val):
+                transforms = _RAND_INCREASING_TRANSFORMS
+        elif key == 'm':
+            magnitude = int(val)
+        elif key == 'n':
+            num_layers = int(val)
+        elif key == 'w':
+            weight_idx = int(val)
+        else:
+            assert False, 'Unknown RandAugment config section'
+    ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
+    choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
+    return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
+
+
+_AUGMIX_TRANSFORMS = [
+    'AutoContrast',
+    'ColorIncreasing',  # not in paper
+    'ContrastIncreasing',  # not in paper
+    'BrightnessIncreasing',  # not in paper
+    'SharpnessIncreasing',  # not in paper
+    'Equalize',
+    'Rotate',
+    'PosterizeIncreasing',
+    'SolarizeIncreasing',
+    'ShearX',
+    'ShearY',
+    'TranslateXRel',
+    'TranslateYRel',
+]
+
+
+def augmix_ops(magnitude=10, hparams=None, transforms=None):
+    hparams = hparams or _HPARAMS_DEFAULT
+    transforms = transforms or _AUGMIX_TRANSFORMS
+    return [AugmentOp(
+        name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
+
+
+class AugMixAugment:
+    """ AugMix Transform
+    Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
+    From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
+    https://arxiv.org/abs/1912.02781
+    """
+
+    def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
+        self.ops = ops
+        self.alpha = alpha
+        self.width = width
+        self.depth = depth
+        self.blended = blended  # blended mode is faster but not well tested
+
+    def _calc_blended_weights(self, ws, m):
+        ws = ws * m
+        cump = 1.
+        rws = []
+        for w in ws[::-1]:
+            alpha = w / cump
+            cump *= (1 - alpha)
+            rws.append(alpha)
+        return np.array(rws[::-1], dtype=np.float32)
+
+    def _apply_blended(self, img, mixing_weights, m):
+        # This is my first crack and implementing a slightly faster mixed augmentation. Instead
+        # of accumulating the mix for each chain in a Numpy array and then blending with original,
+        # it recomputes the blending coefficients and applies one PIL image blend per chain.
+        # TODO the results appear in the right ballpark but they differ by more than rounding.
+        img_orig = img.copy()
+        ws = self._calc_blended_weights(mixing_weights, m)
+        for w in ws:
+            depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
+            ops = np.random.choice(self.ops, depth, replace=True)
+            img_aug = img_orig  # no ops are in-place, deep copy not necessary
+            for op in ops:
+                img_aug = op(img_aug)
+            img = Image.blend(img, img_aug, w)
+        return img
+
+    def _apply_basic(self, img, mixing_weights, m):
+        # This is a literal adaptation of the paper/official implementation without normalizations and
+        # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
+        # typical augmentation transforms, could use a GPU / Kornia implementation.
+        img_shape = img.size[0], img.size[1], len(img.getbands())
+        mixed = np.zeros(img_shape, dtype=np.float32)
+        for mw in mixing_weights:
+            depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
+            ops = np.random.choice(self.ops, depth, replace=True)
+            img_aug = img  # no ops are in-place, deep copy not necessary
+            for op in ops:
+                img_aug = op(img_aug)
+            mixed += mw * np.asarray(img_aug, dtype=np.float32)
+        np.clip(mixed, 0, 255., out=mixed)
+        mixed = Image.fromarray(mixed.astype(np.uint8))
+        return Image.blend(img, mixed, m)
+
+    def __call__(self, img):
+        mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
+        m = np.float32(np.random.beta(self.alpha, self.alpha))
+        if self.blended:
+            mixed = self._apply_blended(img, mixing_weights, m)
+        else:
+            mixed = self._apply_basic(img, mixing_weights, m)
+        return mixed
+
+
+def augment_and_mix_transform(config_str, hparams):
+    """ Create AugMix PyTorch transform
+    :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
+    dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
+    sections, not order sepecific determine
+        'm' - integer magnitude (severity) of augmentation mix (default: 3)
+        'w' - integer width of augmentation chain (default: 3)
+        'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
+        'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
+        'mstd' -  float std deviation of magnitude noise applied (default: 0)
+    Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
+    :param hparams: Other hparams (kwargs) for the Augmentation transforms
+    :return: A PyTorch compatible Transform
+    """
+    magnitude = 3
+    width = 3
+    depth = -1
+    alpha = 1.
+    blended = False
+    config = config_str.split('-')
+    assert config[0] == 'augmix'
+    config = config[1:]
+    for c in config:
+        cs = re.split(r'(\d.*)', c)
+        if len(cs) < 2:
+            continue
+        key, val = cs[:2]
+        if key == 'mstd':
+            # noise param injected via hparams for now
+            hparams.setdefault('magnitude_std', float(val))
+        elif key == 'm':
+            magnitude = int(val)
+        elif key == 'w':
+            width = int(val)
+        elif key == 'd':
+            depth = int(val)
+        elif key == 'a':
+            alpha = float(val)
+        elif key == 'b':
+            blended = bool(val)
+        else:
+            assert False, 'Unknown AugMix config section'
+    ops = augmix_ops(magnitude=magnitude, hparams=hparams)
+    return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)
diff --git a/data/transforms/build.py b/data/transforms/build.py
new file mode 100644
index 0000000..7351aed
--- /dev/null
+++ b/data/transforms/build.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time    : 2020/10/26 16:25
+# @Author  : Scheaven
+# @File    :  build.py
+# @description:
+
+import torchvision.transforms as T
+
+from .transforms import *
+from .autoaugment import AutoAugment
+
+
+def build_transforms(cfg, is_train=True):
+    res = []
+
+    if is_train:
+        size_train = cfg.INPUT.SIZE_TRAIN
+
+        # augmix augmentation
+        do_augmix = cfg.INPUT.DO_AUGMIX
+
+        # auto augmentation
+        do_autoaug = cfg.INPUT.DO_AUTOAUG
+        total_iter = cfg.SOLVER.MAX_ITER
+
+        # horizontal filp
+        do_flip = cfg.INPUT.DO_FLIP
+        flip_prob = cfg.INPUT.FLIP_PROB
+
+        # padding
+        do_pad = cfg.INPUT.DO_PAD
+        padding = cfg.INPUT.PADDING
+        padding_mode = cfg.INPUT.PADDING_MODE
+
+        # color jitter
+        do_cj = cfg.INPUT.CJ.ENABLED
+        cj_prob = cfg.INPUT.CJ.PROB
+        cj_brightness = cfg.INPUT.CJ.BRIGHTNESS
+        cj_contrast = cfg.INPUT.CJ.CONTRAST
+        cj_saturation = cfg.INPUT.CJ.SATURATION
+        cj_hue = cfg.INPUT.CJ.HUE
+
+        # random erasing
+        do_rea = cfg.INPUT.REA.ENABLED
+        rea_prob = cfg.INPUT.REA.PROB
+        rea_mean = cfg.INPUT.REA.MEAN
+        # random patch
+        do_rpt = cfg.INPUT.RPT.ENABLED
+        rpt_prob = cfg.INPUT.RPT.PROB
+
+        if do_autoaug:
+            res.append(AutoAugment(total_iter))
+        res.append(T.Resize(size_train, interpolation=3))
+        if do_flip:
+            res.append(T.RandomHorizontalFlip(p=flip_prob))
+        if do_pad:
+            res.extend([T.Pad(padding, padding_mode=padding_mode),
+                        T.RandomCrop(size_train)])
+        if do_cj:
+            res.append(T.RandomApply([T.ColorJitter(cj_brightness, cj_contrast, cj_saturation, cj_hue)], p=cj_prob))
+        if do_augmix:
+            res.append(AugMix())
+        if do_rea:
+            res.append(RandomErasing(probability=rea_prob, mean=rea_mean))
+        if do_rpt:
+            res.append(RandomPatch(prob_happen=rpt_prob))
+    else:
+        size_test = cfg.INPUT.SIZE_TEST
+        res.append(T.Resize(size_test, interpolation=3))
+    res.append(ToTensor())
+    res.append(T.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]))
+    return T.Compose(res)
diff --git a/data/transforms/functional.py b/data/transforms/functional.py
new file mode 100644
index 0000000..359bea8
--- /dev/null
+++ b/data/transforms/functional.py
@@ -0,0 +1,190 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import numpy as np
+import torch
+from PIL import Image, ImageOps, ImageEnhance
+
+
+def to_tensor(pic):
+    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
+
+    See ``ToTensor`` for more details.
+
+    Args:
+        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
+
+    Returns:
+        Tensor: Converted image.
+    """
+    if isinstance(pic, np.ndarray):
+        assert len(pic.shape) in (2, 3)
+        # handle numpy array
+        if pic.ndim == 2:
+            pic = pic[:, :, None]
+
+        img = torch.from_numpy(pic.transpose((2, 0, 1)))
+        # backward compatibility
+        if isinstance(img, torch.ByteTensor):
+            return img.float()
+        else:
+            return img
+
+    # handle PIL Image
+    if pic.mode == 'I':
+        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
+    elif pic.mode == 'I;16':
+        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
+    elif pic.mode == 'F':
+        img = torch.from_numpy(np.array(pic, np.float32, copy=False))
+    elif pic.mode == '1':
+        img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
+    else:
+        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
+    # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK
+    if pic.mode == 'YCbCr':
+        nchannel = 3
+    elif pic.mode == 'I;16':
+        nchannel = 1
+    else:
+        nchannel = len(pic.mode)
+    img = img.view(pic.size[1], pic.size[0], nchannel)
+    # put it from HWC to CHW format
+    # yikes, this transpose takes 80% of the loading time/CPU
+    img = img.transpose(0, 1).transpose(0, 2).contiguous()
+    if isinstance(img, torch.ByteTensor):
+        return img.float()
+    else:
+        return img
+
+
+def int_parameter(level, maxval):
+    """Helper function to scale `val` between 0 and maxval .
+    Args:
+      level: Level of the operation that will be between [0, `PARAMETER_MAX`].
+      maxval: Maximum value that the operation can have. This will be scaled to
+        level/PARAMETER_MAX.
+    Returns:
+      An int that results from scaling `maxval` according to `level`.
+    """
+    return int(level * maxval / 10)
+
+
+def float_parameter(level, maxval):
+    """Helper function to scale `val` between 0 and maxval.
+    Args:
+      level: Level of the operation that will be between [0, `PARAMETER_MAX`].
+      maxval: Maximum value that the operation can have. This will be scaled to
+        level/PARAMETER_MAX.
+    Returns:
+      A float that results from scaling `maxval` according to `level`.
+    """
+    return float(level) * maxval / 10.
+
+
+def sample_level(n):
+    return np.random.uniform(low=0.1, high=n)
+
+
+def autocontrast(pil_img, *args):
+    return ImageOps.autocontrast(pil_img)
+
+
+def equalize(pil_img, *args):
+    return ImageOps.equalize(pil_img)
+
+
+def posterize(pil_img, level, *args):
+    level = int_parameter(sample_level(level), 4)
+    return ImageOps.posterize(pil_img, 4 - level)
+
+
+def rotate(pil_img, level, *args):
+    degrees = int_parameter(sample_level(level), 30)
+    if np.random.uniform() > 0.5:
+        degrees = -degrees
+    return pil_img.rotate(degrees, resample=Image.BILINEAR)
+
+
+def solarize(pil_img, level, *args):
+    level = int_parameter(sample_level(level), 256)
+    return ImageOps.solarize(pil_img, 256 - level)
+
+
+def shear_x(pil_img, level, image_size):
+    level = float_parameter(sample_level(level), 0.3)
+    if np.random.uniform() > 0.5:
+        level = -level
+    return pil_img.transform(image_size,
+                             Image.AFFINE, (1, level, 0, 0, 1, 0),
+                             resample=Image.BILINEAR)
+
+
+def shear_y(pil_img, level, image_size):
+    level = float_parameter(sample_level(level), 0.3)
+    if np.random.uniform() > 0.5:
+        level = -level
+    return pil_img.transform(image_size,
+                             Image.AFFINE, (1, 0, 0, level, 1, 0),
+                             resample=Image.BILINEAR)
+
+
+def translate_x(pil_img, level, image_size):
+    level = int_parameter(sample_level(level), image_size[0] / 3)
+    if np.random.random() > 0.5:
+        level = -level
+    return pil_img.transform(image_size,
+                             Image.AFFINE, (1, 0, level, 0, 1, 0),
+                             resample=Image.BILINEAR)
+
+
+def translate_y(pil_img, level, image_size):
+    level = int_parameter(sample_level(level), image_size[1] / 3)
+    if np.random.random() > 0.5:
+        level = -level
+    return pil_img.transform(image_size,
+                             Image.AFFINE, (1, 0, 0, 0, 1, level),
+                             resample=Image.BILINEAR)
+
+
+# operation that overlaps with ImageNet-C's test set
+def color(pil_img, level, *args):
+    level = float_parameter(sample_level(level), 1.8) + 0.1
+    return ImageEnhance.Color(pil_img).enhance(level)
+
+
+# operation that overlaps with ImageNet-C's test set
+def contrast(pil_img, level, *args):
+    level = float_parameter(sample_level(level), 1.8) + 0.1
+    return ImageEnhance.Contrast(pil_img).enhance(level)
+
+
+# operation that overlaps with ImageNet-C's test set
+def brightness(pil_img, level, *args):
+    level = float_parameter(sample_level(level), 1.8) + 0.1
+    return ImageEnhance.Brightness(pil_img).enhance(level)
+
+
+# operation that overlaps with ImageNet-C's test set
+def sharpness(pil_img, level, *args):
+    level = float_parameter(sample_level(level), 1.8) + 0.1
+    return ImageEnhance.Sharpness(pil_img).enhance(level)
+
+
+augmentations_reid = [
+    autocontrast, equalize, posterize, shear_x, shear_y,
+    color, contrast, brightness, sharpness
+]
+
+augmentations = [
+    autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
+    translate_x, translate_y
+]
+
+augmentations_all = [
+    autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
+    translate_x, translate_y, color, contrast, brightness, sharpness
+]
diff --git a/data/transforms/transforms.py b/data/transforms/transforms.py
new file mode 100644
index 0000000..2e2def2
--- /dev/null
+++ b/data/transforms/transforms.py
@@ -0,0 +1,204 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+__all__ = ['ToTensor', 'RandomErasing', 'RandomPatch', 'AugMix',]
+
+import math
+import random
+from collections import deque
+
+import numpy as np
+from PIL import Image
+
+from .functional import to_tensor, augmentations_reid
+
+
+class ToTensor(object):
+    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
+
+    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
+    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 255.0]
+    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
+    or if the numpy.ndarray has dtype = np.uint8
+
+    In the other cases, tensors are returned without scaling.
+    """
+
+    def __call__(self, pic):
+        """
+        Args:
+            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
+
+        Returns:
+            Tensor: Converted image.
+        """
+        return to_tensor(pic)
+
+    def __repr__(self):
+        return self.__class__.__name__ + '()'
+
+
+class RandomErasing(object):
+    """ Randomly selects a rectangle region in an image and erases its pixels.
+        'Random Erasing Data Augmentation' by Zhong et al.
+        See https://arxiv.org/pdf/1708.04896.pdf
+    Args:
+        probability: The probability that the Random Erasing operation will be performed.
+        sl: Minimum proportion of erased area against input image.
+        sh: Maximum proportion of erased area against input image.
+        r1: Minimum aspect ratio of erased area.
+        mean: Erasing value.
+    """
+
+    def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=255 * (0.49735, 0.4822, 0.4465)):
+        self.probability = probability
+        self.mean = mean
+        self.sl = sl
+        self.sh = sh
+        self.r1 = r1
+
+    def __call__(self, img):
+        img = np.asarray(img, dtype=np.float32).copy()
+        if random.uniform(0, 1) > self.probability:
+            return img
+
+        for attempt in range(100):
+            area = img.shape[0] * img.shape[1]
+            target_area = random.uniform(self.sl, self.sh) * area
+            aspect_ratio = random.uniform(self.r1, 1 / self.r1)
+
+            h = int(round(math.sqrt(target_area * aspect_ratio)))
+            w = int(round(math.sqrt(target_area / aspect_ratio)))
+
+            if w < img.shape[1] and h < img.shape[0]:
+                x1 = random.randint(0, img.shape[0] - h)
+                y1 = random.randint(0, img.shape[1] - w)
+                if img.shape[2] == 3:
+                    img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
+                    img[x1:x1 + h, y1:y1 + w, 1] = self.mean[1]
+                    img[x1:x1 + h, y1:y1 + w, 2] = self.mean[2]
+                else:
+                    img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
+                return img
+        return img
+
+
+class RandomPatch(object):
+    """Random patch data augmentation.
+    There is a patch pool that stores randomly extracted pathces from person images.
+    For each input image, RandomPatch
+        1) extracts a random patch and stores the patch in the patch pool;
+        2) randomly selects a patch from the patch pool and pastes it on the
+           input (at random position) to simulate occlusion.
+    Reference:
+        - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
+        - Zhou et al. Learning Generalisable Omni-Scale Representations
+          for Person Re-Identification. arXiv preprint, 2019.
+    """
+
+    def __init__(self, prob_happen=0.5, pool_capacity=50000, min_sample_size=100,
+                 patch_min_area=0.01, patch_max_area=0.5, patch_min_ratio=0.1,
+                 prob_rotate=0.5, prob_flip_leftright=0.5,
+                 ):
+        self.prob_happen = prob_happen
+
+        self.patch_min_area = patch_min_area
+        self.patch_max_area = patch_max_area
+        self.patch_min_ratio = patch_min_ratio
+
+        self.prob_rotate = prob_rotate
+        self.prob_flip_leftright = prob_flip_leftright
+
+        self.patchpool = deque(maxlen=pool_capacity)
+        self.min_sample_size = min_sample_size
+
+    def generate_wh(self, W, H):
+        area = W * H
+        for attempt in range(100):
+            target_area = random.uniform(self.patch_min_area, self.patch_max_area) * area
+            aspect_ratio = random.uniform(self.patch_min_ratio, 1. / self.patch_min_ratio)
+            h = int(round(math.sqrt(target_area * aspect_ratio)))
+            w = int(round(math.sqrt(target_area / aspect_ratio)))
+            if w < W and h < H:
+                return w, h
+        return None, None
+
+    def transform_patch(self, patch):
+        if random.uniform(0, 1) > self.prob_flip_leftright:
+            patch = patch.transpose(Image.FLIP_LEFT_RIGHT)
+        if random.uniform(0, 1) > self.prob_rotate:
+            patch = patch.rotate(random.randint(-10, 10))
+        return patch
+
+    def __call__(self, img):
+        if isinstance(img, np.ndarray):
+            img = Image.fromarray(img.astype(np.uint8))
+
+        W, H = img.size  # original image size
+
+        # collect new patch
+        w, h = self.generate_wh(W, H)
+        if w is not None and h is not None:
+            x1 = random.randint(0, W - w)
+            y1 = random.randint(0, H - h)
+            new_patch = img.crop((x1, y1, x1 + w, y1 + h))
+            self.patchpool.append(new_patch)
+
+        if len(self.patchpool) < self.min_sample_size:
+            return img
+
+        if random.uniform(0, 1) > self.prob_happen:
+            return img
+
+        # paste a randomly selected patch on a random position
+        patch = random.sample(self.patchpool, 1)[0]
+        patchW, patchH = patch.size
+        x1 = random.randint(0, W - patchW)
+        y1 = random.randint(0, H - patchH)
+        patch = self.transform_patch(patch)
+        img.paste(patch, (x1, y1))
+
+        return img
+
+
+class AugMix(object):
+    """ Perform AugMix augmentation and compute mixture.
+    Args:
+        aug_prob_coeff: Probability distribution coefficients.
+        mixture_width: Number of augmentation chains to mix per augmented example.
+        mixture_depth: Depth of augmentation chains. -1 denotes stochastic depth in [1, 3]'
+        severity: Severity of underlying augmentation operators (between 1 to 10).
+    """
+
+    def __init__(self, aug_prob_coeff=1, mixture_width=3, mixture_depth=-1, severity=1):
+        self.aug_prob_coeff = aug_prob_coeff
+        self.mixture_width = mixture_width
+        self.mixture_depth = mixture_depth
+        self.severity = severity
+        self.aug_list = augmentations_reid
+
+    def __call__(self, image):
+        """Perform AugMix augmentations and compute mixture.
+        Returns:
+          mixed: Augmented and mixed image.
+        """
+        ws = np.float32(
+            np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width))
+        m = np.float32(np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff))
+
+        image = np.asarray(image, dtype=np.float32).copy()
+        mix = np.zeros_like(image)
+        h, w = image.shape[0], image.shape[1]
+        for i in range(self.mixture_width):
+            image_aug = Image.fromarray(image.copy().astype(np.uint8))
+            depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4)
+            for _ in range(depth):
+                op = np.random.choice(self.aug_list)
+                image_aug = op(image_aug, self.severity, (w, h))
+            mix += ws[i] * np.asarray(image_aug, dtype=np.float32)
+
+        mixed = (1 - m) * image + m * mix
+        return mixed
diff --git a/engine/__init__.py b/engine/__init__.py
new file mode 100644
index 0000000..77c1ec2
--- /dev/null
+++ b/engine/__init__.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time    : 2020/10/26 13:32
+# @Author  : Scheaven
+# @File    :  __init__.py.py
+# @description: 
\ No newline at end of file
diff --git a/engine/__pycache__/__init__.cpython-38.pyc b/engine/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..58e7ffd
--- /dev/null
+++ b/engine/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/engine/__pycache__/defaults.cpython-38.pyc b/engine/__pycache__/defaults.cpython-38.pyc
new file mode 100644
index 0000000..f07dd58
--- /dev/null
+++ b/engine/__pycache__/defaults.cpython-38.pyc
Binary files differ
diff --git a/engine/defaults.py b/engine/defaults.py
new file mode 100644
index 0000000..9c1261e
--- /dev/null
+++ b/engine/defaults.py
@@ -0,0 +1,115 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time    : 2020/10/26 13:38
+# @Author  : Scheaven
+# @File    :  defaults.py
+# @description:
+
+"""
+    This file contains components with some default boilerplate logic user may need
+    in training / testing. They will not work for everyone, but many users may find them useful.
+    The behavior of functions/classes in this file is subject to change,
+    since they are meant to represent the "common default behavior" people need in their projects.
+"""
+import argparse
+import logging
+import os
+import sys
+from utils import comm
+from utils.env import seed_all_rng
+from utils.file_io import PathManager
+from utils.logger import setup_logger
+from utils.collect_env import collect_env_info
+from collections import OrderedDict
+
+import torch
+# import torch.nn.functional as F
+# from torch.nn.parallel import DistributedDataParallel
+
+def default_argument_parser():
+    """
+    Create a parser with some common arguments used by fastreid users.
+    Returns:
+        argparse.ArgumentParser:
+    """
+    parser = argparse.ArgumentParser(description="fastreid Training")
+    parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
+    parser.add_argument(
+        "--finetune",
+        action="store_true",
+        help="whether to attempt to finetune from the trained model",
+    )
+    parser.add_argument(
+        "--resume",
+        action="store_true",
+        help="whether to attempt to resume from the checkpoint directory",
+    )
+    parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
+    parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
+    parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
+    parser.add_argument("--img_a1", default="1.jpg", help="input image")
+    parser.add_argument("--img_a2", default="2.jpg", help="input image2")
+    parser.add_argument("--img_b1", default="1.jpg", help="input image")
+    parser.add_argument("--img_b2", default="2.jpg", help="input image2")
+    parser.add_argument(
+        "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
+    )
+
+    # PyTorch still may leave orphan processes in multi-gpu training.
+    # Therefore we use a deterministic way to obtain port,
+    # so that users are aware of orphan processes by seeing the port occupied.
+    port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
+    parser.add_argument("--dist-url", default="tcp://127.0.0.1:{}".format(port))
+    parser.add_argument(
+        "opts",
+        help="Modify config options using the command-line",
+        default=None,
+        nargs=argparse.REMAINDER,
+    )
+    return parser
+
+def default_setup(cfg, args):
+    """
+    Perform some basic common setups at the beginning of a job, including:
+    1. Set up the detectron2 logger
+    2. Log basic information about environment, cmdline arguments, and config
+    3. Backup the config to the output directory
+    Args:
+        cfg (CfgNode): the full config to be used
+        args (argparse.NameSpace): the command line arguments to be logged
+    """
+    output_dir = cfg.OUTPUT_DIR
+    if comm.is_main_process() and output_dir:
+        PathManager.mkdirs(output_dir)
+
+    rank = comm.get_rank()
+    setup_logger(output_dir, distributed_rank=rank, name="fvcore")
+    logger = setup_logger(output_dir, distributed_rank=rank)
+
+    logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
+    logger.info("Environment info:\n" + collect_env_info())
+
+    logger.info("Command line arguments: " + str(args))
+    if hasattr(args, "config_file") and args.config_file != "":
+        logger.info(
+            "Contents of args.config_file={}:\n{}".format(
+                args.config_file, PathManager.open(args.config_file, "r").read()
+            )
+        )
+
+    logger.info("Running with full config:\n{}".format(cfg))
+    if comm.is_main_process() and output_dir:
+        # Note: some of our scripts may expect the existence of
+        # config.yaml in output directory
+        path = os.path.join(output_dir, "config.yaml")
+        with PathManager.open(path, "w") as f:
+            f.write(cfg.dump())
+        logger.info("Full config saved to {}".format(os.path.abspath(path)))
+
+    # make sure each worker has a different, yet deterministic seed if specified
+    seed_all_rng()
+
+    # cudnn benchmark has large overhead. It shouldn't be used considering the small size of
+    # typical validation set.
+    if not (hasattr(args, "eval_only") and args.eval_only):
+        torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
diff --git a/layers/__init__.py b/layers/__init__.py
new file mode 100644
index 0000000..2e67b13
--- /dev/null
+++ b/layers/__init__.py
@@ -0,0 +1,17 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from .activation import *
+from .arcface import Arcface
+from .batch_drop import BatchDrop
+from .batch_norm import *
+from .circle import Circle
+from .context_block import ContextBlock
+from .frn import FRN, TLU
+from .non_local import Non_local
+from .pooling import *
+from .se_layer import SELayer
+from .splat import SplAtConv2d
diff --git a/layers/__pycache__/__init__.cpython-37.pyc b/layers/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..502959c
--- /dev/null
+++ b/layers/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/__init__.cpython-38.pyc b/layers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..5537f1f
--- /dev/null
+++ b/layers/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/activation.cpython-37.pyc b/layers/__pycache__/activation.cpython-37.pyc
new file mode 100644
index 0000000..8490b01
--- /dev/null
+++ b/layers/__pycache__/activation.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/activation.cpython-38.pyc b/layers/__pycache__/activation.cpython-38.pyc
new file mode 100644
index 0000000..9d59e01
--- /dev/null
+++ b/layers/__pycache__/activation.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/arcface.cpython-37.pyc b/layers/__pycache__/arcface.cpython-37.pyc
new file mode 100644
index 0000000..a01ae38
--- /dev/null
+++ b/layers/__pycache__/arcface.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/arcface.cpython-38.pyc b/layers/__pycache__/arcface.cpython-38.pyc
new file mode 100644
index 0000000..f62c979
--- /dev/null
+++ b/layers/__pycache__/arcface.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/batch_drop.cpython-37.pyc b/layers/__pycache__/batch_drop.cpython-37.pyc
new file mode 100644
index 0000000..c8c49cf
--- /dev/null
+++ b/layers/__pycache__/batch_drop.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/batch_drop.cpython-38.pyc b/layers/__pycache__/batch_drop.cpython-38.pyc
new file mode 100644
index 0000000..e0b0d5e
--- /dev/null
+++ b/layers/__pycache__/batch_drop.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/batch_norm.cpython-37.pyc b/layers/__pycache__/batch_norm.cpython-37.pyc
new file mode 100644
index 0000000..944c044
--- /dev/null
+++ b/layers/__pycache__/batch_norm.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/batch_norm.cpython-38.pyc b/layers/__pycache__/batch_norm.cpython-38.pyc
new file mode 100644
index 0000000..4b45f1d
--- /dev/null
+++ b/layers/__pycache__/batch_norm.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/circle.cpython-37.pyc b/layers/__pycache__/circle.cpython-37.pyc
new file mode 100644
index 0000000..a4aa86b
--- /dev/null
+++ b/layers/__pycache__/circle.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/circle.cpython-38.pyc b/layers/__pycache__/circle.cpython-38.pyc
new file mode 100644
index 0000000..1b31d9e
--- /dev/null
+++ b/layers/__pycache__/circle.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/context_block.cpython-37.pyc b/layers/__pycache__/context_block.cpython-37.pyc
new file mode 100644
index 0000000..9782731
--- /dev/null
+++ b/layers/__pycache__/context_block.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/context_block.cpython-38.pyc b/layers/__pycache__/context_block.cpython-38.pyc
new file mode 100644
index 0000000..23f3cfa
--- /dev/null
+++ b/layers/__pycache__/context_block.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/frn.cpython-37.pyc b/layers/__pycache__/frn.cpython-37.pyc
new file mode 100644
index 0000000..677945a
--- /dev/null
+++ b/layers/__pycache__/frn.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/frn.cpython-38.pyc b/layers/__pycache__/frn.cpython-38.pyc
new file mode 100644
index 0000000..26d8109
--- /dev/null
+++ b/layers/__pycache__/frn.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/non_local.cpython-37.pyc b/layers/__pycache__/non_local.cpython-37.pyc
new file mode 100644
index 0000000..84b216d
--- /dev/null
+++ b/layers/__pycache__/non_local.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/non_local.cpython-38.pyc b/layers/__pycache__/non_local.cpython-38.pyc
new file mode 100644
index 0000000..70f1863
--- /dev/null
+++ b/layers/__pycache__/non_local.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/pooling.cpython-37.pyc b/layers/__pycache__/pooling.cpython-37.pyc
new file mode 100644
index 0000000..72e8113
--- /dev/null
+++ b/layers/__pycache__/pooling.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/pooling.cpython-38.pyc b/layers/__pycache__/pooling.cpython-38.pyc
new file mode 100644
index 0000000..db3c40b
--- /dev/null
+++ b/layers/__pycache__/pooling.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/se_layer.cpython-37.pyc b/layers/__pycache__/se_layer.cpython-37.pyc
new file mode 100644
index 0000000..a33eee6
--- /dev/null
+++ b/layers/__pycache__/se_layer.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/se_layer.cpython-38.pyc b/layers/__pycache__/se_layer.cpython-38.pyc
new file mode 100644
index 0000000..15871c0
--- /dev/null
+++ b/layers/__pycache__/se_layer.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/splat.cpython-37.pyc b/layers/__pycache__/splat.cpython-37.pyc
new file mode 100644
index 0000000..bbcd0ec
--- /dev/null
+++ b/layers/__pycache__/splat.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/splat.cpython-38.pyc b/layers/__pycache__/splat.cpython-38.pyc
new file mode 100644
index 0000000..0b7289f
--- /dev/null
+++ b/layers/__pycache__/splat.cpython-38.pyc
Binary files differ
diff --git a/layers/activation.py b/layers/activation.py
new file mode 100644
index 0000000..3e0aea6
--- /dev/null
+++ b/layers/activation.py
@@ -0,0 +1,59 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: liaoxingyu5@jd.com
+"""
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = [
+    'Mish',
+    'Swish',
+    'MemoryEfficientSwish',
+    'GELU']
+
+
+class Mish(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        # inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)
+        return x * (torch.tanh(F.softplus(x)))
+
+
+class Swish(nn.Module):
+    def forward(self, x):
+        return x * torch.sigmoid(x)
+
+
+class SwishImplementation(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, i):
+        result = i * torch.sigmoid(i)
+        ctx.save_for_backward(i)
+        return result
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        i = ctx.saved_variables[0]
+        sigmoid_i = torch.sigmoid(i)
+        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
+
+
+class MemoryEfficientSwish(nn.Module):
+    def forward(self, x):
+        return SwishImplementation.apply(x)
+
+
+class GELU(nn.Module):
+    """
+    Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
+    """
+
+    def forward(self, x):
+        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
diff --git a/layers/arcface.py b/layers/arcface.py
new file mode 100644
index 0000000..be7f9f6
--- /dev/null
+++ b/layers/arcface.py
@@ -0,0 +1,54 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Parameter
+
+
+class Arcface(nn.Module):
+    def __init__(self, cfg, in_feat, num_classes):
+        super().__init__()
+        self.in_feat = in_feat
+        self._num_classes = num_classes
+        self._s = cfg.MODEL.HEADS.SCALE
+        self._m = cfg.MODEL.HEADS.MARGIN
+
+        self.cos_m = math.cos(self._m)
+        self.sin_m = math.sin(self._m)
+        self.threshold = math.cos(math.pi - self._m)
+        self.mm = math.sin(math.pi - self._m) * self._m
+
+        self.weight = Parameter(torch.Tensor(num_classes, in_feat))
+        self.register_buffer('t', torch.zeros(1))
+
+    def forward(self, features, targets):
+        # get cos(theta)
+        cos_theta = F.linear(F.normalize(features), F.normalize(self.weight))
+        cos_theta = cos_theta.clamp(-1, 1)  # for numerical stability
+
+        target_logit = cos_theta[torch.arange(0, features.size(0)), targets].view(-1, 1)
+
+        sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
+        cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m  # cos(target+margin)
+        mask = cos_theta > cos_theta_m
+        final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm)
+
+        hard_example = cos_theta[mask]
+        with torch.no_grad():
+            self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
+        cos_theta[mask] = hard_example * (self.t + hard_example)
+        cos_theta.scatter_(1, targets.view(-1, 1).long(), final_target_logit)
+        pred_class_logits = cos_theta * self._s
+        return pred_class_logits
+
+    def extra_repr(self):
+        return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
+            self.in_feat, self._num_classes, self._s, self._m
+        )
diff --git a/layers/batch_drop.py b/layers/batch_drop.py
new file mode 100644
index 0000000..5c25697
--- /dev/null
+++ b/layers/batch_drop.py
@@ -0,0 +1,32 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import random
+
+from torch import nn
+
+
+class BatchDrop(nn.Module):
+    """ref: https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py
+    batch drop mask
+    """
+
+    def __init__(self, h_ratio, w_ratio):
+        super(BatchDrop, self).__init__()
+        self.h_ratio = h_ratio
+        self.w_ratio = w_ratio
+
+    def forward(self, x):
+        if self.training:
+            h, w = x.size()[-2:]
+            rh = round(self.h_ratio * h)
+            rw = round(self.w_ratio * w)
+            sx = random.randint(0, h - rh)
+            sy = random.randint(0, w - rw)
+            mask = x.new_ones(x.size())
+            mask[:, :, sx:sx + rh, sy:sy + rw] = 0
+            x = x * mask
+        return x
diff --git a/layers/batch_norm.py b/layers/batch_norm.py
new file mode 100644
index 0000000..e0e88e3
--- /dev/null
+++ b/layers/batch_norm.py
@@ -0,0 +1,208 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import logging
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+__all__ = [
+    "BatchNorm",
+    "IBN",
+    "GhostBatchNorm",
+    "FrozenBatchNorm",
+    "SyncBatchNorm",
+    "get_norm",
+]
+
+
+class BatchNorm(nn.BatchNorm2d):
+    def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
+                 bias_init=0.0, **kwargs):
+        super().__init__(num_features, eps=eps, momentum=momentum)
+        if weight_init is not None: nn.init.constant_(self.weight, weight_init)
+        if bias_init is not None: nn.init.constant_(self.bias, bias_init)
+        self.weight.requires_grad_(not weight_freeze)
+        self.bias.requires_grad_(not bias_freeze)
+
+
+class SyncBatchNorm(nn.SyncBatchNorm):
+    def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
+                 bias_init=0.0):
+        super().__init__(num_features, eps=eps, momentum=momentum)
+        if weight_init is not None: nn.init.constant_(self.weight, weight_init)
+        if bias_init is not None: nn.init.constant_(self.bias, bias_init)
+        self.weight.requires_grad_(not weight_freeze)
+        self.bias.requires_grad_(not bias_freeze)
+
+
+class IBN(nn.Module):
+    def __init__(self, planes, bn_norm, **kwargs):
+        super(IBN, self).__init__()
+        half1 = int(planes / 2)
+        self.half = half1
+        half2 = planes - half1
+        self.IN = nn.InstanceNorm2d(half1, affine=True)
+        self.BN = get_norm(bn_norm, half2, **kwargs)
+
+    def forward(self, x):
+        split = torch.split(x, self.half, 1)
+        out1 = self.IN(split[0].contiguous())
+        out2 = self.BN(split[1].contiguous())
+        out = torch.cat((out1, out2), 1)
+        return out
+
+
+class GhostBatchNorm(BatchNorm):
+    def __init__(self, num_features, num_splits=1, **kwargs):
+        super().__init__(num_features, **kwargs)
+        self.num_splits = num_splits
+        self.register_buffer('running_mean', torch.zeros(num_features))
+        self.register_buffer('running_var', torch.ones(num_features))
+
+    def forward(self, input):
+        N, C, H, W = input.shape
+        if self.training or not self.track_running_stats:
+            self.running_mean = self.running_mean.repeat(self.num_splits)
+            self.running_var = self.running_var.repeat(self.num_splits)
+            outputs = F.batch_norm(
+                input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var,
+                self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
+                True, self.momentum, self.eps).view(N, C, H, W)
+            self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0)
+            self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0)
+            return outputs
+        else:
+            return F.batch_norm(
+                input, self.running_mean, self.running_var,
+                self.weight, self.bias, False, self.momentum, self.eps)
+
+
+class FrozenBatchNorm(BatchNorm):
+    """
+    BatchNorm2d where the batch statistics and the affine parameters are fixed.
+    It contains non-trainable buffers called
+    "weight" and "bias", "running_mean", "running_var",
+    initialized to perform identity transformation.
+    The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
+    which are computed from the original four parameters of BN.
+    The affine transform `x * weight + bias` will perform the equivalent
+    computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
+    When loading a backbone model from Caffe2, "running_mean" and "running_var"
+    will be left unchanged as identity transformation.
+    Other pre-trained backbone models may contain all 4 parameters.
+    The forward is implemented by `F.batch_norm(..., training=False)`.
+    """
+
+    _version = 3
+
+    def __init__(self, num_features, eps=1e-5, **kwargs):
+        super().__init__(num_features, weight_freeze=True, bias_freeze=True, **kwargs)
+        self.num_features = num_features
+        self.eps = eps
+
+    def forward(self, x):
+        if x.requires_grad:
+            # When gradients are needed, F.batch_norm will use extra memory
+            # because its backward op computes gradients for weight/bias as well.
+            scale = self.weight * (self.running_var + self.eps).rsqrt()
+            bias = self.bias - self.running_mean * scale
+            scale = scale.reshape(1, -1, 1, 1)
+            bias = bias.reshape(1, -1, 1, 1)
+            return x * scale + bias
+        else:
+            # When gradients are not needed, F.batch_norm is a single fused op
+            # and provide more optimization opportunities.
+            return F.batch_norm(
+                x,
+                self.running_mean,
+                self.running_var,
+                self.weight,
+                self.bias,
+                training=False,
+                eps=self.eps,
+            )
+
+    def _load_from_state_dict(
+            self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+    ):
+        version = local_metadata.get("version", None)
+
+        if version is None or version < 2:
+            # No running_mean/var in early versions
+            # This will silent the warnings
+            if prefix + "running_mean" not in state_dict:
+                state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
+            if prefix + "running_var" not in state_dict:
+                state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
+
+        if version is not None and version < 3:
+            logger = logging.getLogger(__name__)
+            logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip(".")))
+            # In version < 3, running_var are used without +eps.
+            state_dict[prefix + "running_var"] -= self.eps
+
+        super()._load_from_state_dict(
+            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+        )
+
+    def __repr__(self):
+        return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
+
+    @classmethod
+    def convert_frozen_batchnorm(cls, module):
+        """
+        Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
+        Args:
+            module (torch.nn.Module):
+        Returns:
+            If module is BatchNorm/SyncBatchNorm, returns a new module.
+            Otherwise, in-place convert module and return it.
+        Similar to convert_sync_batchnorm in
+        https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
+        """
+        bn_module = nn.modules.batchnorm
+        bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
+        res = module
+        if isinstance(module, bn_module):
+            res = cls(module.num_features)
+            if module.affine:
+                res.weight.data = module.weight.data.clone().detach()
+                res.bias.data = module.bias.data.clone().detach()
+            res.running_mean.data = module.running_mean.data
+            res.running_var.data = module.running_var.data
+            res.eps = module.eps
+        else:
+            for name, child in module.named_children():
+                new_child = cls.convert_frozen_batchnorm(child)
+                if new_child is not child:
+                    res.add_module(name, new_child)
+        return res
+
+
+def get_norm(norm, out_channels, **kwargs):
+    """
+    Args:
+        norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN;
+            or a callable that thakes a channel number and returns
+            the normalization layer as a nn.Module
+        out_channels: number of channels for normalization layer
+
+    Returns:
+        nn.Module or None: the normalization layer
+    """
+    if isinstance(norm, str):
+        if len(norm) == 0:
+            return None
+        norm = {
+            "BN": BatchNorm,
+            "GhostBN": GhostBatchNorm,
+            "FrozenBN": FrozenBatchNorm,
+            "GN": lambda channels, **args: nn.GroupNorm(32, channels),
+            "syncBN": SyncBatchNorm,
+        }[norm]
+    return norm(out_channels, **kwargs)
diff --git a/layers/circle.py b/layers/circle.py
new file mode 100644
index 0000000..6182b60
--- /dev/null
+++ b/layers/circle.py
@@ -0,0 +1,42 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Parameter
+
+
+class Circle(nn.Module):
+    def __init__(self, cfg, in_feat, num_classes):
+        super().__init__()
+        self.in_feat = in_feat
+        self._num_classes = num_classes
+        self._s = cfg.MODEL.HEADS.SCALE
+        self._m = cfg.MODEL.HEADS.MARGIN
+
+        self.weight = Parameter(torch.Tensor(num_classes, in_feat))
+
+    def forward(self, features, targets):
+        sim_mat = F.linear(F.normalize(features), F.normalize(self.weight))
+        alpha_p = F.relu(-sim_mat.detach() + 1 + self._m)
+        alpha_n = F.relu(sim_mat.detach() + self._m)
+        delta_p = 1 - self._m
+        delta_n = self._m
+
+        s_p = self._s * alpha_p * (sim_mat - delta_p)
+        s_n = self._s * alpha_n * (sim_mat - delta_n)
+
+        targets = F.one_hot(targets, num_classes=self._num_classes)
+
+        pred_class_logits = targets * s_p + (1.0 - targets) * s_n
+
+        return pred_class_logits
+
+    def extra_repr(self):
+        return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
+            self.in_feat, self._num_classes, self._s, self._m
+        )
diff --git a/layers/context_block.py b/layers/context_block.py
new file mode 100644
index 0000000..7b1098a
--- /dev/null
+++ b/layers/context_block.py
@@ -0,0 +1,113 @@
+# copy from https://github.com/xvjiarui/GCNet/blob/master/mmdet/ops/gcb/context_block.py
+
+import torch
+from torch import nn
+
+__all__ = ['ContextBlock']
+
+
+def last_zero_init(m):
+    if isinstance(m, nn.Sequential):
+        nn.init.constant_(m[-1].weight, val=0)
+        if hasattr(m[-1], 'bias') and m[-1].bias is not None:
+            nn.init.constant_(m[-1].bias, 0)
+    else:
+        nn.init.constant_(m.weight, val=0)
+        if hasattr(m, 'bias') and m.bias is not None:
+            nn.init.constant_(m.bias, 0)
+
+
+class ContextBlock(nn.Module):
+
+    def __init__(self,
+                 inplanes,
+                 ratio,
+                 pooling_type='att',
+                 fusion_types=('channel_add',)):
+        super(ContextBlock, self).__init__()
+        assert pooling_type in ['avg', 'att']
+        assert isinstance(fusion_types, (list, tuple))
+        valid_fusion_types = ['channel_add', 'channel_mul']
+        assert all([f in valid_fusion_types for f in fusion_types])
+        assert len(fusion_types) > 0, 'at least one fusion should be used'
+        self.inplanes = inplanes
+        self.ratio = ratio
+        self.planes = int(inplanes * ratio)
+        self.pooling_type = pooling_type
+        self.fusion_types = fusion_types
+        if pooling_type == 'att':
+            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
+            self.softmax = nn.Softmax(dim=2)
+        else:
+            self.avg_pool = nn.AdaptiveAvgPool2d(1)
+        if 'channel_add' in fusion_types:
+            self.channel_add_conv = nn.Sequential(
+                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
+                nn.LayerNorm([self.planes, 1, 1]),
+                nn.ReLU(inplace=True),  # yapf: disable
+                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
+        else:
+            self.channel_add_conv = None
+        if 'channel_mul' in fusion_types:
+            self.channel_mul_conv = nn.Sequential(
+                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
+                nn.LayerNorm([self.planes, 1, 1]),
+                nn.ReLU(inplace=True),  # yapf: disable
+                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
+        else:
+            self.channel_mul_conv = None
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        if self.pooling_type == 'att':
+            nn.init.kaiming_normal_(self.conv_mask.weight, a=0, mode='fan_in', nonlinearity='relu')
+            if hasattr(self.conv_mask, 'bias') and self.conv_mask.bias is not None:
+                nn.init.constant_(self.conv_mask.bias, 0)
+            self.conv_mask.inited = True
+
+        if self.channel_add_conv is not None:
+            last_zero_init(self.channel_add_conv)
+        if self.channel_mul_conv is not None:
+            last_zero_init(self.channel_mul_conv)
+
+    def spatial_pool(self, x):
+        batch, channel, height, width = x.size()
+        if self.pooling_type == 'att':
+            input_x = x
+            # [N, C, H * W]
+            input_x = input_x.view(batch, channel, height * width)
+            # [N, 1, C, H * W]
+            input_x = input_x.unsqueeze(1)
+            # [N, 1, H, W]
+            context_mask = self.conv_mask(x)
+            # [N, 1, H * W]
+            context_mask = context_mask.view(batch, 1, height * width)
+            # [N, 1, H * W]
+            context_mask = self.softmax(context_mask)
+            # [N, 1, H * W, 1]
+            context_mask = context_mask.unsqueeze(-1)
+            # [N, 1, C, 1]
+            context = torch.matmul(input_x, context_mask)
+            # [N, C, 1, 1]
+            context = context.view(batch, channel, 1, 1)
+        else:
+            # [N, C, 1, 1]
+            context = self.avg_pool(x)
+
+        return context
+
+    def forward(self, x):
+        # [N, C, 1, 1]
+        context = self.spatial_pool(x)
+
+        out = x
+        if self.channel_mul_conv is not None:
+            # [N, C, 1, 1]
+            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
+            out = out * channel_mul_term
+        if self.channel_add_conv is not None:
+            # [N, C, 1, 1]
+            channel_add_term = self.channel_add_conv(context)
+            out = out + channel_add_term
+
+        return out
diff --git a/layers/frn.py b/layers/frn.py
new file mode 100644
index 0000000..f00a1e4
--- /dev/null
+++ b/layers/frn.py
@@ -0,0 +1,199 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+from torch import nn
+from torch.nn.modules.batchnorm import BatchNorm2d
+from torch.nn import ReLU, LeakyReLU
+from torch.nn.parameter import Parameter
+
+
+class TLU(nn.Module):
+    def __init__(self, num_features):
+        """max(y, tau) = max(y - tau, 0) + tau = ReLU(y - tau) + tau"""
+        super(TLU, self).__init__()
+        self.num_features = num_features
+        self.tau = Parameter(torch.Tensor(num_features))
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        nn.init.zeros_(self.tau)
+
+    def extra_repr(self):
+        return 'num_features={num_features}'.format(**self.__dict__)
+
+    def forward(self, x):
+        return torch.max(x, self.tau.view(1, self.num_features, 1, 1))
+
+
+class FRN(nn.Module):
+    def __init__(self, num_features, eps=1e-6, is_eps_leanable=False):
+        """
+        weight = gamma, bias = beta
+        beta, gamma:
+            Variables of shape [1, 1, 1, C]. if TensorFlow
+            Variables of shape [1, C, 1, 1]. if PyTorch
+        eps: A scalar constant or learnable variable.
+        """
+        super(FRN, self).__init__()
+
+        self.num_features = num_features
+        self.init_eps = eps
+        self.is_eps_leanable = is_eps_leanable
+
+        self.weight = Parameter(torch.Tensor(num_features))
+        self.bias = Parameter(torch.Tensor(num_features))
+        if is_eps_leanable:
+            self.eps = Parameter(torch.Tensor(1))
+        else:
+            self.register_buffer('eps', torch.Tensor([eps]))
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        nn.init.ones_(self.weight)
+        nn.init.zeros_(self.bias)
+        if self.is_eps_leanable:
+            nn.init.constant_(self.eps, self.init_eps)
+
+    def extra_repr(self):
+        return 'num_features={num_features}, eps={init_eps}'.format(**self.__dict__)
+
+    def forward(self, x):
+        """
+        0, 1, 2, 3 -> (B, H, W, C) in TensorFlow
+        0, 1, 2, 3 -> (B, C, H, W) in PyTorch
+        TensorFlow code
+            nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True)
+            x = x * tf.rsqrt(nu2 + tf.abs(eps))
+            # This Code include TLU function max(y, tau)
+            return tf.maximum(gamma * x + beta, tau)
+        """
+        # Compute the mean norm of activations per channel.
+        nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)
+
+        # Perform FRN.
+        x = x * torch.rsqrt(nu2 + self.eps.abs())
+
+        # Scale and Bias
+        x = self.weight.view(1, self.num_features, 1, 1) * x + self.bias.view(1, self.num_features, 1, 1)
+        # x = self.weight * x + self.bias
+        return x
+
+
+def bnrelu_to_frn(module):
+    """
+    Convert 'BatchNorm2d + ReLU' to 'FRN + TLU'
+    """
+    mod = module
+    before_name = None
+    before_child = None
+    is_before_bn = False
+
+    for name, child in module.named_children():
+        if is_before_bn and isinstance(child, (ReLU, LeakyReLU)):
+            # Convert BN to FRN
+            if isinstance(before_child, BatchNorm2d):
+                mod.add_module(
+                    before_name, FRN(num_features=before_child.num_features))
+            else:
+                raise NotImplementedError()
+
+            # Convert ReLU to TLU
+            mod.add_module(name, TLU(num_features=before_child.num_features))
+        else:
+            mod.add_module(name, bnrelu_to_frn(child))
+
+        before_name = name
+        before_child = child
+        is_before_bn = isinstance(child, BatchNorm2d)
+    return mod
+
+
+def convert(module, flag_name):
+    mod = module
+    before_ch = None
+    for name, child in module.named_children():
+        if hasattr(child, flag_name) and getattr(child, flag_name):
+            if isinstance(child, BatchNorm2d):
+                before_ch = child.num_features
+                mod.add_module(name, FRN(num_features=child.num_features))
+            # TODO bn is no good...
+            if isinstance(child, (ReLU, LeakyReLU)):
+                mod.add_module(name, TLU(num_features=before_ch))
+        else:
+            mod.add_module(name, convert(child, flag_name))
+    return mod
+
+
+def remove_flags(module, flag_name):
+    mod = module
+    for name, child in module.named_children():
+        if hasattr(child, 'is_convert_frn'):
+            delattr(child, flag_name)
+            mod.add_module(name, remove_flags(child, flag_name))
+        else:
+            mod.add_module(name, remove_flags(child, flag_name))
+    return mod
+
+
+def bnrelu_to_frn2(model, input_size=(3, 128, 128), batch_size=2, flag_name='is_convert_frn'):
+    forard_hooks = list()
+    backward_hooks = list()
+
+    is_before_bn = [False]
+
+    def register_forward_hook(module):
+        def hook(self, input, output):
+            if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
+                is_before_bn.append(False)
+                return
+
+            # input and output is required in hook def
+            is_converted = is_before_bn[-1] and isinstance(self, (ReLU, LeakyReLU))
+            if is_converted:
+                setattr(self, flag_name, True)
+            is_before_bn.append(isinstance(self, BatchNorm2d))
+
+        forard_hooks.append(module.register_forward_hook(hook))
+
+    is_before_relu = [False]
+
+    def register_backward_hook(module):
+        def hook(self, input, output):
+            if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
+                is_before_relu.append(False)
+                return
+            is_converted = is_before_relu[-1] and isinstance(self, BatchNorm2d)
+            if is_converted:
+                setattr(self, flag_name, True)
+            is_before_relu.append(isinstance(self, (ReLU, LeakyReLU)))
+
+        backward_hooks.append(module.register_backward_hook(hook))
+
+    # multiple inputs to the network
+    if isinstance(input_size, tuple):
+        input_size = [input_size]
+
+    # batch_size of 2 for batchnorm
+    x = [torch.rand(batch_size, *in_size) for in_size in input_size]
+
+    # register hook
+    model.apply(register_forward_hook)
+    model.apply(register_backward_hook)
+
+    # make a forward pass
+    output = model(*x)
+    output.sum().backward()  # Raw output is not enabled to use backward()
+
+    # remove these hooks
+    for h in forard_hooks:
+        h.remove()
+    for h in backward_hooks:
+        h.remove()
+
+    model = convert(model, flag_name=flag_name)
+    model = remove_flags(model, flag_name=flag_name)
+    return model
diff --git a/layers/non_local.py b/layers/non_local.py
new file mode 100644
index 0000000..876ec43
--- /dev/null
+++ b/layers/non_local.py
@@ -0,0 +1,54 @@
+# encoding: utf-8
+
+
+import torch
+from torch import nn
+from .batch_norm import get_norm
+
+
+class Non_local(nn.Module):
+    def __init__(self, in_channels, bn_norm, num_splits, reduc_ratio=2):
+        super(Non_local, self).__init__()
+
+        self.in_channels = in_channels
+        self.inter_channels = reduc_ratio // reduc_ratio
+
+        self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
+                           kernel_size=1, stride=1, padding=0)
+
+        self.W = nn.Sequential(
+            nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
+                      kernel_size=1, stride=1, padding=0),
+            get_norm(bn_norm, self.in_channels, num_splits),
+        )
+        nn.init.constant_(self.W[1].weight, 0.0)
+        nn.init.constant_(self.W[1].bias, 0.0)
+
+        self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
+                               kernel_size=1, stride=1, padding=0)
+
+        self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
+                             kernel_size=1, stride=1, padding=0)
+
+    def forward(self, x):
+        '''
+                :param x: (b, t, h, w)
+                :return x: (b, t, h, w)
+        '''
+        batch_size = x.size(0)
+        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
+        g_x = g_x.permute(0, 2, 1)
+
+        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
+        theta_x = theta_x.permute(0, 2, 1)
+        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
+        f = torch.matmul(theta_x, phi_x)
+        N = f.size(-1)
+        f_div_C = f / N
+
+        y = torch.matmul(f_div_C, g_x)
+        y = y.permute(0, 2, 1).contiguous()
+        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
+        W_y = self.W(y)
+        z = W_y + x
+        return z
diff --git a/layers/pooling.py b/layers/pooling.py
new file mode 100644
index 0000000..5aec39e
--- /dev/null
+++ b/layers/pooling.py
@@ -0,0 +1,79 @@
+# encoding: utf-8
+"""
+@author:  l1aoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Flatten(nn.Module):
+    def forward(self, input):
+        return input.view(input.size(0), -1)
+
+
+class GeneralizedMeanPooling(nn.Module):
+    r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes.
+    The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`
+        - At p = infinity, one gets Max Pooling
+        - At p = 1, one gets Average Pooling
+    The output is of size H x W, for any input size.
+    The number of output features is equal to the number of input planes.
+    Args:
+        output_size: the target output size of the image of the form H x W.
+                     Can be a tuple (H, W) or a single H for a square image H x H
+                     H and W can be either a ``int``, or ``None`` which means the size will
+                     be the same as that of the input.
+    """
+
+    def __init__(self, norm, output_size=1, eps=1e-6):
+        super(GeneralizedMeanPooling, self).__init__()
+        assert norm > 0
+        self.p = float(norm)
+        self.output_size = output_size
+        self.eps = eps
+
+    def forward(self, x):
+        x = x.clamp(min=self.eps).pow(self.p)
+        return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(' \
+               + str(self.p) + ', ' \
+               + 'output_size=' + str(self.output_size) + ')'
+
+
+class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
+    """ Same, but norm is trainable
+    """
+
+    def __init__(self, norm=3, output_size=1, eps=1e-6):
+        super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps)
+        self.p = nn.Parameter(torch.ones(1) * norm)
+
+
+class AdaptiveAvgMaxPool2d(nn.Module):
+    def __init__(self):
+        super(AdaptiveAvgMaxPool2d, self).__init__()
+        self.avgpool = FastGlobalAvgPool2d()
+
+    def forward(self, x):
+        x_avg = self.avgpool(x, self.output_size)
+        x_max = F.adaptive_max_pool2d(x, 1)
+        x = x_max + x_avg
+        return x
+
+
+class FastGlobalAvgPool2d(nn.Module):
+    def __init__(self, flatten=False):
+        super(FastGlobalAvgPool2d, self).__init__()
+        self.flatten = flatten
+
+    def forward(self, x):
+        if self.flatten:
+            in_size = x.size()
+            return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
+        else:
+            return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
diff --git a/layers/se_layer.py b/layers/se_layer.py
new file mode 100644
index 0000000..04e1dc6
--- /dev/null
+++ b/layers/se_layer.py
@@ -0,0 +1,25 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from torch import nn
+
+
+class SELayer(nn.Module):
+    def __init__(self, channel, reduction=16):
+        super(SELayer, self).__init__()
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+        self.fc = nn.Sequential(
+            nn.Linear(channel, int(channel / reduction), bias=False),
+            nn.ReLU(inplace=True),
+            nn.Linear(int(channel / reduction), channel, bias=False),
+            nn.Sigmoid()
+        )
+
+    def forward(self, x):
+        b, c, _, _ = x.size()
+        y = self.avg_pool(x).view(b, c)
+        y = self.fc(y).view(b, c, 1, 1)
+        return x * y.expand_as(x)
diff --git a/layers/splat.py b/layers/splat.py
new file mode 100644
index 0000000..8a8901d
--- /dev/null
+++ b/layers/splat.py
@@ -0,0 +1,97 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: liaoxingyu5@jd.com
+"""
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import Conv2d, ReLU
+from torch.nn.modules.utils import _pair
+from .batch_norm import get_norm
+
+
+class SplAtConv2d(nn.Module):
+    """Split-Attention Conv2d
+    """
+
+    def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0),
+                 dilation=(1, 1), groups=1, bias=True,
+                 radix=2, reduction_factor=4,
+                 rectify=False, rectify_avg=False, norm_layer=None, num_splits=1,
+                 dropblock_prob=0.0, **kwargs):
+        super(SplAtConv2d, self).__init__()
+        padding = _pair(padding)
+        self.rectify = rectify and (padding[0] > 0 or padding[1] > 0)
+        self.rectify_avg = rectify_avg
+        inter_channels = max(in_channels * radix // reduction_factor, 32)
+        self.radix = radix
+        self.cardinality = groups
+        self.channels = channels
+        self.dropblock_prob = dropblock_prob
+        if self.rectify:
+            from rfconv import RFConv2d
+            self.conv = RFConv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation,
+                                 groups=groups * radix, bias=bias, average_mode=rectify_avg, **kwargs)
+        else:
+            self.conv = Conv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation,
+                               groups=groups * radix, bias=bias, **kwargs)
+        self.use_bn = norm_layer is not None
+        if self.use_bn:
+            self.bn0 = get_norm(norm_layer, channels * radix, num_splits)
+        self.relu = ReLU(inplace=True)
+        self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality)
+        if self.use_bn:
+            self.bn1 = get_norm(norm_layer, inter_channels, num_splits)
+        self.fc2 = Conv2d(inter_channels, channels * radix, 1, groups=self.cardinality)
+
+        self.rsoftmax = rSoftMax(radix, groups)
+
+    def forward(self, x):
+        x = self.conv(x)
+        if self.use_bn:
+            x = self.bn0(x)
+        if self.dropblock_prob > 0.0:
+            x = self.dropblock(x)
+        x = self.relu(x)
+
+        batch, rchannel = x.shape[:2]
+        if self.radix > 1:
+            splited = torch.split(x, rchannel // self.radix, dim=1)
+            gap = sum(splited)
+        else:
+            gap = x
+        gap = F.adaptive_avg_pool2d(gap, 1)
+        gap = self.fc1(gap)
+
+        if self.use_bn:
+            gap = self.bn1(gap)
+        gap = self.relu(gap)
+
+        atten = self.fc2(gap)
+        atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
+
+        if self.radix > 1:
+            attens = torch.split(atten, rchannel // self.radix, dim=1)
+            out = sum([att * split for (att, split) in zip(attens, splited)])
+        else:
+            out = atten * x
+        return out.contiguous()
+
+
+class rSoftMax(nn.Module):
+    def __init__(self, radix, cardinality):
+        super().__init__()
+        self.radix = radix
+        self.cardinality = cardinality
+
+    def forward(self, x):
+        batch = x.size(0)
+        if self.radix > 1:
+            x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
+            x = F.softmax(x, dim=1)
+            x = x.reshape(batch, -1)
+        else:
+            x = torch.sigmoid(x)
+        return x
diff --git a/libtorch.tar.gz b/libtorch.tar.gz
deleted file mode 100644
index e3d2ead..0000000
--- a/libtorch.tar.gz
+++ /dev/null
Binary files differ
diff --git a/modeling/__init__.py b/modeling/__init__.py
new file mode 100644
index 0000000..0873432
--- /dev/null
+++ b/modeling/__init__.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time    : 2020/10/26 14:48
+# @Author  : Scheaven
+# @File    :  __init__.py.py
+# @description: 
\ No newline at end of file
diff --git a/modeling/__pycache__/__init__.cpython-37.pyc b/modeling/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..b0d4e9b
--- /dev/null
+++ b/modeling/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/modeling/__pycache__/__init__.cpython-38.pyc b/modeling/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..7293942
--- /dev/null
+++ b/modeling/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/modeling/backbones/__init__.py b/modeling/backbones/__init__.py
new file mode 100644
index 0000000..ec6f22d
--- /dev/null
+++ b/modeling/backbones/__init__.py
@@ -0,0 +1,12 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from .build import build_backbone, BACKBONE_REGISTRY
+
+from .resnet import build_resnet_backbone
+from .osnet import build_osnet_backbone
+from .resnest import build_resnest_backbone
+from .resnext import build_resnext_backbone
\ No newline at end of file
diff --git a/modeling/backbones/__pycache__/__init__.cpython-37.pyc b/modeling/backbones/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..adfac7c
--- /dev/null
+++ b/modeling/backbones/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/__init__.cpython-38.pyc b/modeling/backbones/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..2bb2c2f
--- /dev/null
+++ b/modeling/backbones/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/build.cpython-37.pyc b/modeling/backbones/__pycache__/build.cpython-37.pyc
new file mode 100644
index 0000000..ff5bf99
--- /dev/null
+++ b/modeling/backbones/__pycache__/build.cpython-37.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/build.cpython-38.pyc b/modeling/backbones/__pycache__/build.cpython-38.pyc
new file mode 100644
index 0000000..ab9d690
--- /dev/null
+++ b/modeling/backbones/__pycache__/build.cpython-38.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/osnet.cpython-37.pyc b/modeling/backbones/__pycache__/osnet.cpython-37.pyc
new file mode 100644
index 0000000..e567de5
--- /dev/null
+++ b/modeling/backbones/__pycache__/osnet.cpython-37.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/osnet.cpython-38.pyc b/modeling/backbones/__pycache__/osnet.cpython-38.pyc
new file mode 100644
index 0000000..24f2adf
--- /dev/null
+++ b/modeling/backbones/__pycache__/osnet.cpython-38.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/resnest.cpython-37.pyc b/modeling/backbones/__pycache__/resnest.cpython-37.pyc
new file mode 100644
index 0000000..34c7f8a
--- /dev/null
+++ b/modeling/backbones/__pycache__/resnest.cpython-37.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/resnest.cpython-38.pyc b/modeling/backbones/__pycache__/resnest.cpython-38.pyc
new file mode 100644
index 0000000..415db6c
--- /dev/null
+++ b/modeling/backbones/__pycache__/resnest.cpython-38.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/resnet.cpython-37.pyc b/modeling/backbones/__pycache__/resnet.cpython-37.pyc
new file mode 100644
index 0000000..d60009a
--- /dev/null
+++ b/modeling/backbones/__pycache__/resnet.cpython-37.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/resnet.cpython-38.pyc b/modeling/backbones/__pycache__/resnet.cpython-38.pyc
new file mode 100644
index 0000000..ffdab08
--- /dev/null
+++ b/modeling/backbones/__pycache__/resnet.cpython-38.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/resnext.cpython-37.pyc b/modeling/backbones/__pycache__/resnext.cpython-37.pyc
new file mode 100644
index 0000000..b1f70eb
--- /dev/null
+++ b/modeling/backbones/__pycache__/resnext.cpython-37.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/resnext.cpython-38.pyc b/modeling/backbones/__pycache__/resnext.cpython-38.pyc
new file mode 100644
index 0000000..72bc6b1
--- /dev/null
+++ b/modeling/backbones/__pycache__/resnext.cpython-38.pyc
Binary files differ
diff --git a/modeling/backbones/build.py b/modeling/backbones/build.py
new file mode 100644
index 0000000..3006786
--- /dev/null
+++ b/modeling/backbones/build.py
@@ -0,0 +1,28 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from utils.registry import Registry
+
+BACKBONE_REGISTRY = Registry("BACKBONE")
+BACKBONE_REGISTRY.__doc__ = """
+Registry for backbones, which extract feature maps from images
+The registered object must be a callable that accepts two arguments:
+1. A :class:`detectron2.config.CfgNode`
+2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification.
+It must returns an instance of :class:`Backbone`.
+"""
+
+
+def build_backbone(cfg):
+    """
+    Build a backbone from `cfg.MODEL.BACKBONE.NAME`.
+    Returns:
+        an instance of :class:`Backbone`
+    """
+
+    backbone_name = cfg.MODEL.BACKBONE.NAME
+    backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg)
+    return backbone
diff --git a/modeling/backbones/osnet.py b/modeling/backbones/osnet.py
new file mode 100644
index 0000000..d71615a
--- /dev/null
+++ b/modeling/backbones/osnet.py
@@ -0,0 +1,487 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: liaoxingyu5@jd.com
+"""
+
+# based on:
+# https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/models/osnet.py
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from .build import BACKBONE_REGISTRY
+
+model_urls = {
+    'osnet_x1_0':
+        'https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY',
+    'osnet_x0_75':
+        'https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq',
+    'osnet_x0_5':
+        'https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i',
+    'osnet_x0_25':
+        'https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs',
+    'osnet_ibn_x1_0':
+        'https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l'
+}
+
+
+##########
+# Basic layers
+##########
+class ConvLayer(nn.Module):
+    """Convolution layer (conv + bn + relu)."""
+
+    def __init__(
+            self,
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=1,
+            padding=0,
+            groups=1,
+            IN=False
+    ):
+        super(ConvLayer, self).__init__()
+        self.conv = nn.Conv2d(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=stride,
+            padding=padding,
+            bias=False,
+            groups=groups
+        )
+        if IN:
+            self.bn = nn.InstanceNorm2d(out_channels, affine=True)
+        else:
+            self.bn = nn.BatchNorm2d(out_channels)
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.relu(x)
+        return x
+
+
+class Conv1x1(nn.Module):
+    """1x1 convolution + bn + relu."""
+
+    def __init__(self, in_channels, out_channels, stride=1, groups=1):
+        super(Conv1x1, self).__init__()
+        self.conv = nn.Conv2d(
+            in_channels,
+            out_channels,
+            1,
+            stride=stride,
+            padding=0,
+            bias=False,
+            groups=groups
+        )
+        self.bn = nn.BatchNorm2d(out_channels)
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.relu(x)
+        return x
+
+
+class Conv1x1Linear(nn.Module):
+    """1x1 convolution + bn (w/o non-linearity)."""
+
+    def __init__(self, in_channels, out_channels, stride=1):
+        super(Conv1x1Linear, self).__init__()
+        self.conv = nn.Conv2d(
+            in_channels, out_channels, 1, stride=stride, padding=0, bias=False
+        )
+        self.bn = nn.BatchNorm2d(out_channels)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        return x
+
+
+class Conv3x3(nn.Module):
+    """3x3 convolution + bn + relu."""
+
+    def __init__(self, in_channels, out_channels, stride=1, groups=1):
+        super(Conv3x3, self).__init__()
+        self.conv = nn.Conv2d(
+            in_channels,
+            out_channels,
+            3,
+            stride=stride,
+            padding=1,
+            bias=False,
+            groups=groups
+        )
+        self.bn = nn.BatchNorm2d(out_channels)
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.relu(x)
+        return x
+
+
+class LightConv3x3(nn.Module):
+    """Lightweight 3x3 convolution.
+    1x1 (linear) + dw 3x3 (nonlinear).
+    """
+
+    def __init__(self, in_channels, out_channels):
+        super(LightConv3x3, self).__init__()
+        self.conv1 = nn.Conv2d(
+            in_channels, out_channels, 1, stride=1, padding=0, bias=False
+        )
+        self.conv2 = nn.Conv2d(
+            out_channels,
+            out_channels,
+            3,
+            stride=1,
+            padding=1,
+            bias=False,
+            groups=out_channels
+        )
+        self.bn = nn.BatchNorm2d(out_channels)
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.conv2(x)
+        x = self.bn(x)
+        x = self.relu(x)
+        return x
+
+
+##########
+# Building blocks for omni-scale feature learning
+##########
+class ChannelGate(nn.Module):
+    """A mini-network that generates channel-wise gates conditioned on input tensor."""
+
+    def __init__(
+            self,
+            in_channels,
+            num_gates=None,
+            return_gates=False,
+            gate_activation='sigmoid',
+            reduction=16,
+            layer_norm=False
+    ):
+        super(ChannelGate, self).__init__()
+        if num_gates is None: num_gates = in_channels
+        self.return_gates = return_gates
+
+        self.global_avgpool = nn.AdaptiveAvgPool2d(1)
+
+        self.fc1 = nn.Conv2d(
+            in_channels,
+            in_channels // reduction,
+            kernel_size=1,
+            bias=True,
+            padding=0
+        )
+        self.norm1 = None
+        if layer_norm: self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
+        self.relu = nn.ReLU(inplace=True)
+        self.fc2 = nn.Conv2d(
+            in_channels // reduction,
+            num_gates,
+            kernel_size=1,
+            bias=True,
+            padding=0
+        )
+        if gate_activation == 'sigmoid':  self.gate_activation = nn.Sigmoid()
+        elif gate_activation == 'relu':   self.gate_activation = nn.ReLU(inplace=True)
+        elif gate_activation == 'linear': self.gate_activation = nn.Identity()
+        else:
+            raise RuntimeError(
+                "Unknown gate activation: {}".format(gate_activation)
+            )
+
+    def forward(self, x):
+        input = x
+        x = self.global_avgpool(x)
+        x = self.fc1(x)
+        if self.norm1 is not None: x = self.norm1(x)
+        x = self.relu(x)
+        x = self.fc2(x)
+        x = self.gate_activation(x)
+        if self.return_gates: return x
+        return input * x
+
+
+class OSBlock(nn.Module):
+    """Omni-scale feature learning block."""
+
+    def __init__(
+            self,
+            in_channels,
+            out_channels,
+            IN=False,
+            bottleneck_reduction=4,
+            **kwargs
+    ):
+        super(OSBlock, self).__init__()
+        mid_channels = out_channels // bottleneck_reduction
+        self.conv1 = Conv1x1(in_channels, mid_channels)
+        self.conv2a = LightConv3x3(mid_channels, mid_channels)
+        self.conv2b = nn.Sequential(
+            LightConv3x3(mid_channels, mid_channels),
+            LightConv3x3(mid_channels, mid_channels),
+        )
+        self.conv2c = nn.Sequential(
+            LightConv3x3(mid_channels, mid_channels),
+            LightConv3x3(mid_channels, mid_channels),
+            LightConv3x3(mid_channels, mid_channels),
+        )
+        self.conv2d = nn.Sequential(
+            LightConv3x3(mid_channels, mid_channels),
+            LightConv3x3(mid_channels, mid_channels),
+            LightConv3x3(mid_channels, mid_channels),
+            LightConv3x3(mid_channels, mid_channels),
+        )
+        self.gate = ChannelGate(mid_channels)
+        self.conv3 = Conv1x1Linear(mid_channels, out_channels)
+        self.downsample = None
+        if in_channels != out_channels:
+            self.downsample = Conv1x1Linear(in_channels, out_channels)
+        self.IN = None
+        if IN: self.IN = nn.InstanceNorm2d(out_channels, affine=True)
+        self.relu = nn.ReLU(True)
+
+    def forward(self, x):
+        identity = x
+        x1 = self.conv1(x)
+        x2a = self.conv2a(x1)
+        x2b = self.conv2b(x1)
+        x2c = self.conv2c(x1)
+        x2d = self.conv2d(x1)
+        x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
+        x3 = self.conv3(x2)
+        if self.downsample is not None:
+            identity = self.downsample(identity)
+        out = x3 + identity
+        if self.IN is not None:
+            out = self.IN(out)
+        return self.relu(out)
+
+
+##########
+# Network architecture
+##########
+class OSNet(nn.Module):
+    """Omni-Scale Network.
+
+    Reference:
+        - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
+        - Zhou et al. Learning Generalisable Omni-Scale Representations
+          for Person Re-Identification. arXiv preprint, 2019.
+    """
+
+    def __init__(
+            self,
+            blocks,
+            layers,
+            channels,
+            IN=False,
+            **kwargs
+    ):
+        super(OSNet, self).__init__()
+        num_blocks = len(blocks)
+        assert num_blocks == len(layers)
+        assert num_blocks == len(channels) - 1
+
+        # convolutional backbone
+        self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
+        self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
+        self.conv2 = self._make_layer(
+            blocks[0],
+            layers[0],
+            channels[0],
+            channels[1],
+            reduce_spatial_size=True,
+            IN=IN
+        )
+        self.conv3 = self._make_layer(
+            blocks[1],
+            layers[1],
+            channels[1],
+            channels[2],
+            reduce_spatial_size=True
+        )
+        self.conv4 = self._make_layer(
+            blocks[2],
+            layers[2],
+            channels[2],
+            channels[3],
+            reduce_spatial_size=False
+        )
+        self.conv5 = Conv1x1(channels[3], channels[3])
+
+        self._init_params()
+
+    def _make_layer(
+            self,
+            block,
+            layer,
+            in_channels,
+            out_channels,
+            reduce_spatial_size,
+            IN=False
+    ):
+        layers = []
+
+        layers.append(block(in_channels, out_channels, IN=IN))
+        for i in range(1, layer):
+            layers.append(block(out_channels, out_channels, IN=IN))
+
+        if reduce_spatial_size:
+            layers.append(
+                nn.Sequential(
+                    Conv1x1(out_channels, out_channels),
+                    nn.AvgPool2d(2, stride=2),
+                )
+            )
+
+        return nn.Sequential(*layers)
+
+    def _init_params(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(
+                    m.weight, mode='fan_out', nonlinearity='relu'
+                )
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+            elif isinstance(m, nn.BatchNorm1d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, 0, 0.01)
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.maxpool(x)
+        x = self.conv2(x)
+        x = self.conv3(x)
+        x = self.conv4(x)
+        x = self.conv5(x)
+        return x
+
+
+def init_pretrained_weights(model, key=''):
+    """Initializes model with pretrained weights.
+
+    Layers that don't match with pretrained layers in name or size are kept unchanged.
+    """
+    import os
+    import errno
+    import gdown
+    from collections import OrderedDict
+    import warnings
+    import logging
+
+    logger = logging.getLogger(__name__)
+
+    def _get_torch_home():
+        ENV_TORCH_HOME = 'TORCH_HOME'
+        ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+        DEFAULT_CACHE_DIR = '~/.cache'
+        torch_home = os.path.expanduser(
+            os.getenv(
+                ENV_TORCH_HOME,
+                os.path.join(
+                    os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
+                )
+            )
+        )
+        return torch_home
+
+    torch_home = _get_torch_home()
+    model_dir = os.path.join(torch_home, 'checkpoints')
+    try:
+        os.makedirs(model_dir)
+    except OSError as e:
+        if e.errno == errno.EEXIST:
+            # Directory already exists, ignore.
+            pass
+        else:
+            # Unexpected OSError, re-raise.
+            raise
+    filename = key + '_imagenet.pth'
+    cached_file = os.path.join(model_dir, filename)
+
+    if not os.path.exists(cached_file):
+        gdown.download(model_urls[key], cached_file, quiet=False)
+
+    state_dict = torch.load(cached_file)
+    model_dict = model.state_dict()
+    new_state_dict = OrderedDict()
+    matched_layers, discarded_layers = [], []
+
+    for k, v in state_dict.items():
+        if k.startswith('module.'):
+            k = k[7:]  # discard module.
+
+        if k in model_dict and model_dict[k].size() == v.size():
+            new_state_dict[k] = v
+            matched_layers.append(k)
+        else:
+            discarded_layers.append(k)
+
+    model_dict.update(new_state_dict)
+    model.load_state_dict(model_dict)
+
+    if len(matched_layers) == 0:
+        warnings.warn(
+            'The pretrained weights from "{}" cannot be loaded, '
+            'please check the key names manually '
+            '(** ignored and continue **)'.format(cached_file)
+        )
+    else:
+        logger.info(
+            'Successfully loaded imagenet pretrained weights from "{}"'.
+                format(cached_file)
+        )
+        if len(discarded_layers) > 0:
+            logger.info(
+                '** The following layers are discarded '
+                'due to unmatched keys or layer size: {}'.
+                    format(discarded_layers)
+            )
+
+
+@BACKBONE_REGISTRY.register()
+def build_osnet_backbone(cfg):
+    """
+    Create a OSNet instance from config.
+    Returns:
+        OSNet: a :class:`OSNet` instance
+    """
+
+    # fmt: off
+    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
+    with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
+
+    num_blocks_per_stage = [2, 2, 2]
+    num_channels_per_stage = [64, 256, 384, 512]
+    model = OSNet([OSBlock, OSBlock, OSBlock], num_blocks_per_stage, num_channels_per_stage, with_ibn)
+    pretrain_key = 'osnet_ibn_x1_0' if with_ibn else 'osnet_x1_0'
+    if pretrain:
+        init_pretrained_weights(model, pretrain_key)
+    return model
diff --git a/modeling/backbones/resnest.py b/modeling/backbones/resnest.py
new file mode 100644
index 0000000..f5eaa9e
--- /dev/null
+++ b/modeling/backbones/resnest.py
@@ -0,0 +1,411 @@
+# encoding: utf-8
+# based on:
+# https://github.com/zhanghang1989/ResNeSt/blob/master/resnest/torch/resnest.py
+"""ResNeSt models"""
+
+import logging
+import math
+
+import torch
+from torch import nn
+
+from layers import (
+    IBN,
+    Non_local,
+    SplAtConv2d,
+    get_norm,
+)
+
+from utils.checkpoint import get_unexpected_parameters_message, get_missing_parameters_message
+
+from .build import BACKBONE_REGISTRY
+
+_url_format = 'https://hangzh.s3.amazonaws.com/encoding/models/{}-{}.pth'
+
+_model_sha256 = {name: checksum for checksum, name in [
+    ('528c19ca', 'resnest50'),
+    ('22405ba7', 'resnest101'),
+    ('75117900', 'resnest200'),
+    ('0cc87c48', 'resnest269'),
+]}
+
+
+def short_hash(name):
+    if name not in _model_sha256:
+        raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
+    return _model_sha256[name][:8]
+
+
+model_urls = {name: _url_format.format(name, short_hash(name)) for
+              name in _model_sha256.keys()
+              }
+
+
+class Bottleneck(nn.Module):
+    """ResNet Bottleneck
+    """
+    # pylint: disable=unused-argument
+    expansion = 4
+
+    def __init__(self, inplanes, planes, bn_norm, num_splits, with_ibn=False, stride=1, downsample=None,
+                 radix=1, cardinality=1, bottleneck_width=64,
+                 avd=False, avd_first=False, dilation=1, is_first=False,
+                 rectified_conv=False, rectify_avg=False,
+                 dropblock_prob=0.0, last_gamma=False):
+        super(Bottleneck, self).__init__()
+        group_width = int(planes * (bottleneck_width / 64.)) * cardinality
+        self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
+        if with_ibn:
+            self.bn1 = IBN(group_width, bn_norm, num_splits)
+        else:
+            self.bn1 = get_norm(bn_norm, group_width, num_splits)
+        self.dropblock_prob = dropblock_prob
+        self.radix = radix
+        self.avd = avd and (stride > 1 or is_first)
+        self.avd_first = avd_first
+
+        if self.avd:
+            self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
+            stride = 1
+
+        if radix > 1:
+            self.conv2 = SplAtConv2d(
+                group_width, group_width, kernel_size=3,
+                stride=stride, padding=dilation,
+                dilation=dilation, groups=cardinality, bias=False,
+                radix=radix, rectify=rectified_conv,
+                rectify_avg=rectify_avg,
+                norm_layer=bn_norm, num_splits=num_splits,
+                dropblock_prob=dropblock_prob)
+        elif rectified_conv:
+            from rfconv import RFConv2d
+            self.conv2 = RFConv2d(
+                group_width, group_width, kernel_size=3, stride=stride,
+                padding=dilation, dilation=dilation,
+                groups=cardinality, bias=False,
+                average_mode=rectify_avg)
+            self.bn2 = get_norm(bn_norm, group_width, num_splits)
+        else:
+            self.conv2 = nn.Conv2d(
+                group_width, group_width, kernel_size=3, stride=stride,
+                padding=dilation, dilation=dilation,
+                groups=cardinality, bias=False)
+            self.bn2 = get_norm(bn_norm, group_width, num_splits)
+
+        self.conv3 = nn.Conv2d(
+            group_width, planes * 4, kernel_size=1, bias=False)
+        self.bn3 = get_norm(bn_norm, planes * 4, num_splits)
+
+        if last_gamma:
+            from torch.nn.init import zeros_
+            zeros_(self.bn3.weight)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.dilation = dilation
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        if self.dropblock_prob > 0.0:
+            out = self.dropblock1(out)
+        out = self.relu(out)
+
+        if self.avd and self.avd_first:
+            out = self.avd_layer(out)
+
+        out = self.conv2(out)
+        if self.radix == 1:
+            out = self.bn2(out)
+            if self.dropblock_prob > 0.0:
+                out = self.dropblock2(out)
+            out = self.relu(out)
+
+        if self.avd and not self.avd_first:
+            out = self.avd_layer(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+        if self.dropblock_prob > 0.0:
+            out = self.dropblock3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class ResNest(nn.Module):
+    """ResNet Variants ResNest
+    Parameters
+    ----------
+    block : Block
+        Class for the residual block. Options are BasicBlockV1, BottleneckV1.
+    layers : list of int
+        Numbers of layers in each block
+    classes : int, default 1000
+        Number of classification classes.
+    dilated : bool, default False
+        Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
+        typically used in Semantic Segmentation.
+    norm_layer : object
+        Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
+        for Synchronized Cross-GPU BachNormalization).
+    Reference:
+        - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
+        - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
+    """
+
+    # pylint: disable=unused-variable
+    def __init__(self, last_stride, bn_norm, num_splits, with_ibn, with_nl, block, layers, non_layers, radix=1, groups=1,
+                 bottleneck_width=64,
+                 dilated=False, dilation=1,
+                 deep_stem=False, stem_width=64, avg_down=False,
+                 rectified_conv=False, rectify_avg=False,
+                 avd=False, avd_first=False,
+                 final_drop=0.0, dropblock_prob=0,
+                 last_gamma=False):
+        self.cardinality = groups
+        self.bottleneck_width = bottleneck_width
+        # ResNet-D params
+        self.inplanes = stem_width * 2 if deep_stem else 64
+        self.avg_down = avg_down
+        self.last_gamma = last_gamma
+        # ResNeSt params
+        self.radix = radix
+        self.avd = avd
+        self.avd_first = avd_first
+
+        super().__init__()
+        self.rectified_conv = rectified_conv
+        self.rectify_avg = rectify_avg
+        if rectified_conv:
+            from rfconv import RFConv2d
+            conv_layer = RFConv2d
+        else:
+            conv_layer = nn.Conv2d
+        conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {}
+        if deep_stem:
+            self.conv1 = nn.Sequential(
+                conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs),
+                get_norm(bn_norm, stem_width, num_splits),
+                nn.ReLU(inplace=True),
+                conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
+                get_norm(bn_norm, stem_width, num_splits),
+                nn.ReLU(inplace=True),
+                conv_layer(stem_width, stem_width * 2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
+            )
+        else:
+            self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
+                                    bias=False, **conv_kwargs)
+        self.bn1 = get_norm(bn_norm, self.inplanes, num_splits)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, num_splits, with_ibn=with_ibn, is_first=False)
+        self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, num_splits, with_ibn=with_ibn)
+        if dilated or dilation == 4:
+            self.layer3 = self._make_layer(block, 256, layers[2], 1, bn_norm, num_splits, with_ibn=with_ibn,
+                                           dilation=2, dropblock_prob=dropblock_prob)
+            self.layer4 = self._make_layer(block, 512, layers[3], 1, bn_norm, num_splits, with_ibn=with_ibn,
+                                           dilation=4, dropblock_prob=dropblock_prob)
+        elif dilation == 2:
+            self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn=with_ibn,
+                                           dilation=1, dropblock_prob=dropblock_prob)
+            self.layer4 = self._make_layer(block, 512, layers[3], 1, bn_norm, num_splits, with_ibn=with_ibn,
+                                           dilation=2, dropblock_prob=dropblock_prob)
+        else:
+            self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn=with_ibn,
+                                           dropblock_prob=dropblock_prob)
+            self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, num_splits, with_ibn=with_ibn,
+                                           dropblock_prob=dropblock_prob)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2. / n))
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+        if with_nl:
+            self._build_nonlocal(layers, non_layers, bn_norm, num_splits)
+        else:
+            self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
+
+    def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", num_splits=1, with_ibn=False,
+                    dilation=1, dropblock_prob=0.0, is_first=True):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            down_layers = []
+            if self.avg_down:
+                if dilation == 1:
+                    down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride,
+                                                    ceil_mode=True, count_include_pad=False))
+                else:
+                    down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1,
+                                                    ceil_mode=True, count_include_pad=False))
+                down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
+                                             kernel_size=1, stride=1, bias=False))
+            else:
+                down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
+                                             kernel_size=1, stride=stride, bias=False))
+            down_layers.append(get_norm(bn_norm, planes * block.expansion, num_splits))
+            downsample = nn.Sequential(*down_layers)
+
+        layers = []
+        if planes == 512:
+            with_ibn = False
+        if dilation == 1 or dilation == 2:
+            layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, stride, downsample=downsample,
+                                radix=self.radix, cardinality=self.cardinality,
+                                bottleneck_width=self.bottleneck_width,
+                                avd=self.avd, avd_first=self.avd_first,
+                                dilation=1, is_first=is_first, rectified_conv=self.rectified_conv,
+                                rectify_avg=self.rectify_avg,
+                                dropblock_prob=dropblock_prob,
+                                last_gamma=self.last_gamma))
+        elif dilation == 4:
+            layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, stride, downsample=downsample,
+                                radix=self.radix, cardinality=self.cardinality,
+                                bottleneck_width=self.bottleneck_width,
+                                avd=self.avd, avd_first=self.avd_first,
+                                dilation=2, is_first=is_first, rectified_conv=self.rectified_conv,
+                                rectify_avg=self.rectify_avg,
+                                dropblock_prob=dropblock_prob,
+                                last_gamma=self.last_gamma))
+        else:
+            raise RuntimeError("=> unknown dilation size: {}".format(dilation))
+
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn,
+                                radix=self.radix, cardinality=self.cardinality,
+                                bottleneck_width=self.bottleneck_width,
+                                avd=self.avd, avd_first=self.avd_first,
+                                dilation=dilation, rectified_conv=self.rectified_conv,
+                                rectify_avg=self.rectify_avg,
+                                dropblock_prob=dropblock_prob,
+                                last_gamma=self.last_gamma))
+
+        return nn.Sequential(*layers)
+
+    def _build_nonlocal(self, layers, non_layers, bn_norm, num_splits):
+        self.NL_1 = nn.ModuleList(
+            [Non_local(256, bn_norm, num_splits) for _ in range(non_layers[0])])
+        self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
+        self.NL_2 = nn.ModuleList(
+            [Non_local(512, bn_norm, num_splits) for _ in range(non_layers[1])])
+        self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
+        self.NL_3 = nn.ModuleList(
+            [Non_local(1024, bn_norm, num_splits) for _ in range(non_layers[2])])
+        self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
+        self.NL_4 = nn.ModuleList(
+            [Non_local(2048, bn_norm, num_splits) for _ in range(non_layers[3])])
+        self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        NL1_counter = 0
+        if len(self.NL_1_idx) == 0:
+            self.NL_1_idx = [-1]
+        for i in range(len(self.layer1)):
+            x = self.layer1[i](x)
+            if i == self.NL_1_idx[NL1_counter]:
+                _, C, H, W = x.shape
+                x = self.NL_1[NL1_counter](x)
+                NL1_counter += 1
+        # Layer 2
+        NL2_counter = 0
+        if len(self.NL_2_idx) == 0:
+            self.NL_2_idx = [-1]
+        for i in range(len(self.layer2)):
+            x = self.layer2[i](x)
+            if i == self.NL_2_idx[NL2_counter]:
+                _, C, H, W = x.shape
+                x = self.NL_2[NL2_counter](x)
+                NL2_counter += 1
+        # Layer 3
+        NL3_counter = 0
+        if len(self.NL_3_idx) == 0:
+            self.NL_3_idx = [-1]
+        for i in range(len(self.layer3)):
+            x = self.layer3[i](x)
+            if i == self.NL_3_idx[NL3_counter]:
+                _, C, H, W = x.shape
+                x = self.NL_3[NL3_counter](x)
+                NL3_counter += 1
+        # Layer 4
+        NL4_counter = 0
+        if len(self.NL_4_idx) == 0:
+            self.NL_4_idx = [-1]
+        for i in range(len(self.layer4)):
+            x = self.layer4[i](x)
+            if i == self.NL_4_idx[NL4_counter]:
+                _, C, H, W = x.shape
+                x = self.NL_4[NL4_counter](x)
+                NL4_counter += 1
+
+        return x
+
+
+@BACKBONE_REGISTRY.register()
+def build_resnest_backbone(cfg):
+    """
+    Create a ResNest instance from config.
+    Returns:
+        ResNet: a :class:`ResNet` instance.
+    """
+
+    # fmt: off
+    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
+    last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
+    bn_norm = cfg.MODEL.BACKBONE.NORM
+    num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
+    with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
+    with_se = cfg.MODEL.BACKBONE.WITH_SE
+    with_nl = cfg.MODEL.BACKBONE.WITH_NL
+    depth = cfg.MODEL.BACKBONE.DEPTH
+
+    num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 200: [3, 24, 36, 3], 269: [3, 30, 48, 8]}[depth]
+    nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 3, 0]}[depth]
+    stem_width = {50: 32, 101: 64, 200: 64, 269: 64}[depth]
+    model = ResNest(last_stride, bn_norm, num_splits, with_ibn, with_nl, Bottleneck, num_blocks_per_stage,
+                    nl_layers_per_stage, radix=2, groups=1, bottleneck_width=64,
+                    deep_stem=True, stem_width=stem_width, avg_down=True,
+                    avd=True, avd_first=False)
+    if pretrain:
+        # if not with_ibn:
+        # original resnet
+        state_dict = torch.hub.load_state_dict_from_url(
+            model_urls['resnest' + str(depth)], progress=True, check_hash=True)
+        # else:
+        #     raise KeyError('Not implementation ibn in resnest')
+        # # ibn resnet
+        # state_dict = torch.load(pretrain_path)['state_dict']
+        # # remove module in name
+        # new_state_dict = {}
+        # for k in state_dict:
+        #     new_k = '.'.join(k.split('.')[1:])
+        #     if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
+        #         new_state_dict[new_k] = state_dict[k]
+        # state_dict = new_state_dict
+        incompatible = model.load_state_dict(state_dict, strict=False)
+        logger = logging.getLogger(__name__)
+        if incompatible.missing_keys:
+            logger.info(
+                get_missing_parameters_message(incompatible.missing_keys)
+            )
+        if incompatible.unexpected_keys:
+            logger.info(
+                get_unexpected_parameters_message(incompatible.unexpected_keys)
+            )
+    return model
diff --git a/modeling/backbones/resnet.py b/modeling/backbones/resnet.py
new file mode 100644
index 0000000..5dc3f63
--- /dev/null
+++ b/modeling/backbones/resnet.py
@@ -0,0 +1,359 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import logging
+import math
+
+import torch
+from torch import nn
+
+from layers import (
+    IBN,
+    SELayer,
+    Non_local,
+    get_norm,
+)
+from utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
+from .build import BACKBONE_REGISTRY
+from utils import comm
+
+
+logger = logging.getLogger(__name__)
+model_urls = {
+    '18x': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+    '34x': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+    '50x': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+    '101x': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+    'ibn_18x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_a-2f571257.pth',
+    'ibn_34x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_a-94bc1577.pth',
+    'ibn_50x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth',
+    'ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth',
+    'se_ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/se_resnet101_ibn_a-fabed4e2.pth',
+}
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False,
+                 stride=1, downsample=None, reduction=16):
+        super(BasicBlock, self).__init__()
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+        if with_ibn:
+            self.bn1 = IBN(planes, bn_norm)
+        else:
+            self.bn1 = get_norm(bn_norm, planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn2 = get_norm(bn_norm, planes)
+        self.relu = nn.ReLU(inplace=True)
+        if with_se:
+            self.se = SELayer(planes, reduction)
+        else:
+            self.se = nn.Identity()
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False,
+                 stride=1, downsample=None, reduction=16):
+        super(Bottleneck, self).__init__()
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+        if with_ibn:
+            self.bn1 = IBN(planes, bn_norm)
+        else:
+            self.bn1 = get_norm(bn_norm, planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+                               padding=1, bias=False)
+        self.bn2 = get_norm(bn_norm, planes)
+        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
+        self.bn3 = get_norm(bn_norm, planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        if with_se:
+            self.se = SELayer(planes * self.expansion, reduction)
+        else:
+            self.se = nn.Identity()
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+        out = self.se(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class ResNet(nn.Module):
+    def __init__(self, last_stride, bn_norm, with_ibn, with_se, with_nl, block, layers, non_layers):
+        self.inplanes = 64
+        super().__init__()
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+                               bias=False)
+        self.bn1 = get_norm(bn_norm, 64)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
+        self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, with_ibn, with_se)
+        self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, with_ibn, with_se)
+        self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn, with_se)
+        self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, with_se=with_se)
+
+        self.random_init()
+
+        # fmt: off
+        if with_nl: self._build_nonlocal(layers, non_layers, bn_norm)
+        else:       self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
+        # fmt: on
+
+    def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", with_ibn=False, with_se=False):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(self.inplanes, planes * block.expansion,
+                          kernel_size=1, stride=stride, bias=False),
+                get_norm(bn_norm, planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, bn_norm, with_ibn, with_se, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes, bn_norm, with_ibn, with_se))
+
+        return nn.Sequential(*layers)
+
+    def _build_nonlocal(self, layers, non_layers, bn_norm):
+        self.NL_1 = nn.ModuleList(
+            [Non_local(256, bn_norm) for _ in range(non_layers[0])])
+        self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
+        self.NL_2 = nn.ModuleList(
+            [Non_local(512, bn_norm) for _ in range(non_layers[1])])
+        self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
+        self.NL_3 = nn.ModuleList(
+            [Non_local(1024, bn_norm) for _ in range(non_layers[2])])
+        self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
+        self.NL_4 = nn.ModuleList(
+            [Non_local(2048, bn_norm) for _ in range(non_layers[3])])
+        self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        NL1_counter = 0
+        if len(self.NL_1_idx) == 0:
+            self.NL_1_idx = [-1]
+        for i in range(len(self.layer1)):
+            x = self.layer1[i](x)
+            if i == self.NL_1_idx[NL1_counter]:
+                _, C, H, W = x.shape
+                x = self.NL_1[NL1_counter](x)
+                NL1_counter += 1
+        # Layer 2
+        NL2_counter = 0
+        if len(self.NL_2_idx) == 0:
+            self.NL_2_idx = [-1]
+        for i in range(len(self.layer2)):
+            x = self.layer2[i](x)
+            if i == self.NL_2_idx[NL2_counter]:
+                _, C, H, W = x.shape
+                x = self.NL_2[NL2_counter](x)
+                NL2_counter += 1
+        # Layer 3
+        NL3_counter = 0
+        if len(self.NL_3_idx) == 0:
+            self.NL_3_idx = [-1]
+        for i in range(len(self.layer3)):
+            x = self.layer3[i](x)
+            if i == self.NL_3_idx[NL3_counter]:
+                _, C, H, W = x.shape
+                x = self.NL_3[NL3_counter](x)
+                NL3_counter += 1
+        # Layer 4
+        NL4_counter = 0
+        if len(self.NL_4_idx) == 0:
+            self.NL_4_idx = [-1]
+        for i in range(len(self.layer4)):
+            x = self.layer4[i](x)
+            if i == self.NL_4_idx[NL4_counter]:
+                _, C, H, W = x.shape
+                x = self.NL_4[NL4_counter](x)
+                NL4_counter += 1
+
+        return x
+
+    def random_init(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                nn.init.normal_(m.weight, 0, math.sqrt(2. / n))
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+
+def init_pretrained_weights(key):
+    """Initializes model with pretrained weights.
+
+    Layers that don't match with pretrained layers in name or size are kept unchanged.
+    """
+    import os
+    import errno
+    import gdown
+
+    def _get_torch_home():
+        ENV_TORCH_HOME = 'TORCH_HOME'
+        ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+        DEFAULT_CACHE_DIR = '~/.cache'
+        torch_home = os.path.expanduser(
+            os.getenv(
+                ENV_TORCH_HOME,
+                os.path.join(
+                    os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
+                )
+            )
+        )
+        return torch_home
+
+    torch_home = _get_torch_home()
+    model_dir = os.path.join(torch_home, 'checkpoints')
+    try:
+        os.makedirs(model_dir)
+    except OSError as e:
+        if e.errno == errno.EEXIST:
+            # Directory already exists, ignore.
+            pass
+        else:
+            # Unexpected OSError, re-raise.
+            raise
+
+    filename = model_urls[key].split('/')[-1]
+
+    cached_file = os.path.join(model_dir, filename)
+
+    if not os.path.exists(cached_file):
+        if comm.is_main_process():
+            gdown.download(model_urls[key], cached_file, quiet=False)
+
+    comm.synchronize()
+
+    logger.info(f"Loading pretrained model from {cached_file}")
+    state_dict = torch.load(cached_file, map_location=torch.device('cpu'))
+
+    return state_dict
+
+
+@BACKBONE_REGISTRY.register()
+def build_resnet_backbone(cfg):
+    """
+    Create a ResNet instance from config.
+    Returns:
+        ResNet: a :class:`ResNet` instance.
+    """
+
+    # fmt: off
+    pretrain      = cfg.MODEL.BACKBONE.PRETRAIN
+    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
+    last_stride   = cfg.MODEL.BACKBONE.LAST_STRIDE
+    bn_norm       = cfg.MODEL.BACKBONE.NORM
+    with_ibn      = cfg.MODEL.BACKBONE.WITH_IBN
+    with_se       = cfg.MODEL.BACKBONE.WITH_SE
+    with_nl       = cfg.MODEL.BACKBONE.WITH_NL
+    depth         = cfg.MODEL.BACKBONE.DEPTH
+    # fmt: on
+
+    num_blocks_per_stage = {
+        '18x': [2, 2, 2, 2],
+        '34x': [3, 4, 6, 3],
+        '50x': [3, 4, 6, 3],
+        '101x': [3, 4, 23, 3],
+    }[depth]
+
+    nl_layers_per_stage = {
+        '18x': [0, 0, 0, 0],
+        '34x': [0, 0, 0, 0],
+        '50x': [0, 2, 3, 0],
+        '101x': [0, 2, 9, 0]
+    }[depth]
+
+    block = {
+        '18x': BasicBlock,
+        '34x': BasicBlock,
+        '50x': Bottleneck,
+        '101x': Bottleneck
+    }[depth]
+
+    model = ResNet(last_stride, bn_norm, with_ibn, with_se, with_nl, block,
+                   num_blocks_per_stage, nl_layers_per_stage)
+    if pretrain:
+        # Load pretrain path if specifically
+        if pretrain_path:
+            try:
+                state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
+                logger.info(f"Loading pretrained model from {pretrain_path}")
+            except FileNotFoundError as e:
+                logger.info(f'{pretrain_path} is not found! Please check this path.')
+                raise e
+            except KeyError as e:
+                logger.info("State dict keys error! Please check the state dict.")
+                raise e
+        else:
+            key = depth
+            if with_ibn: key = 'ibn_' + key
+            if with_se:  key = 'se_' + key
+
+            state_dict = init_pretrained_weights(key)
+
+        incompatible = model.load_state_dict(state_dict, strict=False)
+        if incompatible.missing_keys:
+            logger.info(
+                get_missing_parameters_message(incompatible.missing_keys)
+            )
+        if incompatible.unexpected_keys:
+            logger.info(
+                get_unexpected_parameters_message(incompatible.unexpected_keys)
+            )
+
+    return model
diff --git a/modeling/backbones/resnext.py b/modeling/backbones/resnext.py
new file mode 100644
index 0000000..b68253d
--- /dev/null
+++ b/modeling/backbones/resnext.py
@@ -0,0 +1,198 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: liaoxingyu5@jd.com
+"""
+
+# based on:
+# https://github.com/XingangPan/IBN-Net/blob/master/models/imagenet/resnext_ibn_a.py
+
+import math
+import logging
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import init
+import torch
+from layers import IBN
+from .build import BACKBONE_REGISTRY
+
+
+class Bottleneck(nn.Module):
+    """
+    RexNeXt bottleneck type C
+    """
+    expansion = 4
+
+    def __init__(self, inplanes, planes, with_ibn, baseWidth, cardinality, stride=1, downsample=None):
+        """ Constructor
+        Args:
+            inplanes: input channel dimensionality
+            planes: output channel dimensionality
+            baseWidth: base width.
+            cardinality: num of convolution groups.
+            stride: conv stride. Replaces pooling layer.
+        """
+        super(Bottleneck, self).__init__()
+
+        D = int(math.floor(planes * (baseWidth / 64)))
+        C = cardinality
+        self.conv1 = nn.Conv2d(inplanes, D * C, kernel_size=1, stride=1, padding=0, bias=False)
+        if with_ibn:
+            self.bn1 = IBN(D * C)
+        else:
+            self.bn1 = nn.BatchNorm2d(D * C)
+        self.conv2 = nn.Conv2d(D * C, D * C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False)
+        self.bn2 = nn.BatchNorm2d(D * C)
+        self.conv3 = nn.Conv2d(D * C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * 4)
+        self.relu = nn.ReLU(inplace=True)
+
+        self.downsample = downsample
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class ResNeXt(nn.Module):
+    """
+    ResNext optimized for the ImageNet dataset, as specified in
+    https://arxiv.org/pdf/1611.05431.pdf
+    """
+
+    def __init__(self, last_stride, with_ibn, block, layers, baseWidth=4, cardinality=32):
+        """ Constructor
+        Args:
+            baseWidth: baseWidth for ResNeXt.
+            cardinality: number of convolution groups.
+            layers: config of layers, e.g., [3, 4, 6, 3]
+            num_classes: number of classes
+        """
+        super(ResNeXt, self).__init__()
+
+        self.cardinality = cardinality
+        self.baseWidth = baseWidth
+        self.inplanes = 64
+        self.output_size = 64
+
+        self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0], with_ibn=with_ibn)
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, with_ibn=with_ibn)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, with_ibn=with_ibn)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride, with_ibn=with_ibn)
+
+        self.random_init()
+
+    def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False):
+        """ Stack n bottleneck modules where n is inferred from the depth of the network.
+        Args:
+            block: block type used to construct ResNext
+            planes: number of output channels (need to multiply by block.expansion)
+            blocks: number of blocks to be built
+            stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
+        Returns: a Module consisting of n sequential bottlenecks.
+        """
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(self.inplanes, planes * block.expansion,
+                          kernel_size=1, stride=stride, bias=False),
+                nn.BatchNorm2d(planes * block.expansion),
+            )
+
+        layers = []
+        if planes == 512:
+            with_ibn = False
+        layers.append(block(self.inplanes, planes, with_ibn, self.baseWidth, self.cardinality, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes, with_ibn, self.baseWidth, self.cardinality, 1, None))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool1(x)
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        return x
+
+    def random_init(self):
+        self.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2. / n))
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+            elif isinstance(m, nn.InstanceNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+
+@BACKBONE_REGISTRY.register()
+def build_resnext_backbone(cfg):
+    """
+    Create a ResNeXt instance from config.
+    Returns:
+        ResNeXt: a :class:`ResNeXt` instance.
+    """
+
+    # fmt: off
+    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
+    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
+    last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
+    with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
+    with_se = cfg.MODEL.BACKBONE.WITH_SE
+    with_nl = cfg.MODEL.BACKBONE.WITH_NL
+    depth = cfg.MODEL.BACKBONE.DEPTH
+
+    num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth]
+    nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 3, 0]}[depth]
+    model = ResNeXt(last_stride, with_ibn, Bottleneck, num_blocks_per_stage)
+    if pretrain:
+        # if not with_ibn:
+        # original resnet
+        # state_dict = model_zoo.load_url(model_urls[depth])
+        # else:
+        # ibn resnet
+        state_dict = torch.load(pretrain_path)['state_dict']
+        # remove module in name
+        new_state_dict = {}
+        for k in state_dict:
+            new_k = '.'.join(k.split('.')[1:])
+            if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
+                new_state_dict[new_k] = state_dict[k]
+        state_dict = new_state_dict
+        res = model.load_state_dict(state_dict, strict=False)
+        logger = logging.getLogger(__name__)
+        logger.info('missing keys is {}'.format(res.missing_keys))
+        logger.info('unexpected keys is {}'.format(res.unexpected_keys))
+    return model
diff --git a/modeling/heads/__init__.py b/modeling/heads/__init__.py
new file mode 100644
index 0000000..0413982
--- /dev/null
+++ b/modeling/heads/__init__.py
@@ -0,0 +1,12 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from .build import REID_HEADS_REGISTRY, build_heads
+
+# import all the meta_arch, so they will be registered
+from .embedding_head import EmbeddingHead
+from .attr_head import AttrHead
+
diff --git a/modeling/heads/__pycache__/__init__.cpython-37.pyc b/modeling/heads/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..16cd05e
--- /dev/null
+++ b/modeling/heads/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/modeling/heads/__pycache__/__init__.cpython-38.pyc b/modeling/heads/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..4d162e3
--- /dev/null
+++ b/modeling/heads/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/modeling/heads/__pycache__/attr_head.cpython-37.pyc b/modeling/heads/__pycache__/attr_head.cpython-37.pyc
new file mode 100644
index 0000000..8a0b0f1
--- /dev/null
+++ b/modeling/heads/__pycache__/attr_head.cpython-37.pyc
Binary files differ
diff --git a/modeling/heads/__pycache__/attr_head.cpython-38.pyc b/modeling/heads/__pycache__/attr_head.cpython-38.pyc
new file mode 100644
index 0000000..806b1d5
--- /dev/null
+++ b/modeling/heads/__pycache__/attr_head.cpython-38.pyc
Binary files differ
diff --git a/modeling/heads/__pycache__/build.cpython-37.pyc b/modeling/heads/__pycache__/build.cpython-37.pyc
new file mode 100644
index 0000000..2facead
--- /dev/null
+++ b/modeling/heads/__pycache__/build.cpython-37.pyc
Binary files differ
diff --git a/modeling/heads/__pycache__/build.cpython-38.pyc b/modeling/heads/__pycache__/build.cpython-38.pyc
new file mode 100644
index 0000000..f51780b
--- /dev/null
+++ b/modeling/heads/__pycache__/build.cpython-38.pyc
Binary files differ
diff --git a/modeling/heads/__pycache__/embedding_head.cpython-37.pyc b/modeling/heads/__pycache__/embedding_head.cpython-37.pyc
new file mode 100644
index 0000000..e1cd5a7
--- /dev/null
+++ b/modeling/heads/__pycache__/embedding_head.cpython-37.pyc
Binary files differ
diff --git a/modeling/heads/__pycache__/embedding_head.cpython-38.pyc b/modeling/heads/__pycache__/embedding_head.cpython-38.pyc
new file mode 100644
index 0000000..982513d
--- /dev/null
+++ b/modeling/heads/__pycache__/embedding_head.cpython-38.pyc
Binary files differ
diff --git a/modeling/heads/attr_head.py b/modeling/heads/attr_head.py
new file mode 100644
index 0000000..62bc941
--- /dev/null
+++ b/modeling/heads/attr_head.py
@@ -0,0 +1,77 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+from torch import nn
+
+from layers import *
+from utils.weight_init import weights_init_kaiming, weights_init_classifier
+from .build import REID_HEADS_REGISTRY
+
+
+@REID_HEADS_REGISTRY.register()
+class AttrHead(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        # fmt: off
+        feat_dim      = cfg.MODEL.BACKBONE.FEAT_DIM
+        num_classes   = cfg.MODEL.HEADS.NUM_CLASSES
+        pool_type     = cfg.MODEL.HEADS.POOL_LAYER
+        cls_type      = cfg.MODEL.HEADS.CLS_LAYER
+        with_bnneck   = cfg.MODEL.HEADS.WITH_BNNECK
+        norm_type     = cfg.MODEL.HEADS.NORM
+
+        if pool_type == 'fastavgpool':   self.pool_layer = FastGlobalAvgPool2d()
+        elif pool_type == 'avgpool':     self.pool_layer = nn.AdaptiveAvgPool2d(1)
+        elif pool_type == 'maxpool':     self.pool_layer = nn.AdaptiveMaxPool2d(1)
+        elif pool_type == 'gempoolP':    self.pool_layer = GeneralizedMeanPoolingP()
+        elif pool_type == 'gempool':     self.pool_layer = GeneralizedMeanPooling()
+        elif pool_type == "avgmaxpool":  self.pool_layer = AdaptiveAvgMaxPool2d()
+        elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d()
+        elif pool_type == "identity":    self.pool_layer = nn.Identity()
+        elif pool_type == "flatten":     self.pool_layer = Flatten()
+        else:                            raise KeyError(f"{pool_type} is not supported!")
+
+        # Classification layer
+        if cls_type == 'linear':          self.classifier = nn.Linear(feat_dim, num_classes, bias=False)
+        elif cls_type == 'arcSoftmax':    self.classifier = ArcSoftmax(cfg, feat_dim, num_classes)
+        elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, feat_dim, num_classes)
+        elif cls_type == 'amSoftmax':     self.classifier = AMSoftmax(cfg, feat_dim, num_classes)
+        else:                             raise KeyError(f"{cls_type} is not supported!")
+        # fmt: on
+
+        # bottleneck = []
+        # if with_bnneck:
+        #     bottleneck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
+        bottleneck = [nn.BatchNorm1d(num_classes)]
+
+        self.bottleneck = nn.Sequential(*bottleneck)
+
+        self.bottleneck.apply(weights_init_kaiming)
+        self.classifier.apply(weights_init_classifier)
+
+    def forward(self, features, targets=None):
+        """
+        See :class:`ReIDHeads.forward`.
+        """
+        global_feat = self.pool_layer(features)
+        global_feat = global_feat[..., 0, 0]
+
+        classifier_name = self.classifier.__class__.__name__
+        # fmt: off
+        if classifier_name == 'Linear': cls_outputs = self.classifier(global_feat)
+        else:                           cls_outputs = self.classifier(global_feat, targets)
+        # fmt: on
+
+        cls_outputs = self.bottleneck(cls_outputs)
+
+        if self.training:
+            return {
+                "cls_outputs": cls_outputs,
+            }
+        else:
+            cls_outputs = torch.sigmoid(cls_outputs)
+            return cls_outputs
diff --git a/modeling/heads/bnneck_head.py b/modeling/heads/bnneck_head.py
new file mode 100644
index 0000000..e798a81
--- /dev/null
+++ b/modeling/heads/bnneck_head.py
@@ -0,0 +1,61 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from layers import *
+from modeling.losses import *
+from utils.weight_init import weights_init_kaiming, weights_init_classifier
+from .build import REID_HEADS_REGISTRY
+
+
+@REID_HEADS_REGISTRY.register()
+class BNneckHead(nn.Module):
+    def __init__(self, cfg, in_feat, num_classes, pool_layer):
+        super().__init__()
+        self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
+        self.pool_layer = pool_layer
+
+        self.bnneck = get_norm(cfg.MODEL.HEADS.NORM, in_feat, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True)
+        self.bnneck.apply(weights_init_kaiming)
+
+        # identity classification layer
+        cls_type = cfg.MODEL.HEADS.CLS_LAYER
+        if cls_type == 'linear':    self.classifier = nn.Linear(in_feat, num_classes, bias=False)
+        elif cls_type == 'arcface': self.classifier = Arcface(cfg, in_feat, num_classes)
+        elif cls_type == 'circle':  self.classifier = Circle(cfg, in_feat, num_classes)
+        else:
+            raise KeyError(f"{cls_type} is invalid, please choose from "
+                           f"'linear', 'arcface' and 'circle'.")
+
+        self.classifier.apply(weights_init_classifier)
+
+    def forward(self, features, targets=None):
+        """
+        See :class:`ReIDHeads.forward`.
+        """
+        global_feat = self.pool_layer(features)
+        bn_feat = self.bnneck(global_feat)
+        bn_feat = bn_feat[..., 0, 0]
+
+        # Evaluation
+        if not self.training: return bn_feat
+
+        # Training
+        try:
+            cls_outputs = self.classifier(bn_feat)
+            pred_class_logits = cls_outputs.detach()
+        except TypeError:
+            cls_outputs = self.classifier(bn_feat, targets)
+            pred_class_logits = F.linear(F.normalize(bn_feat.detach()), F.normalize(self.classifier.weight.detach()))
+        # Log prediction accuracy
+        CrossEntropyLoss.log_accuracy(pred_class_logits, targets)
+
+        if self.neck_feat == "before":
+            feat = global_feat[..., 0, 0]
+        elif self.neck_feat == "after":
+            feat = bn_feat
+        else:
+            raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
+        return cls_outputs, feat
diff --git a/modeling/heads/build.py b/modeling/heads/build.py
new file mode 100644
index 0000000..b88bef5
--- /dev/null
+++ b/modeling/heads/build.py
@@ -0,0 +1,25 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+
+from utils.registry import Registry
+
+REID_HEADS_REGISTRY = Registry("HEADS")
+REID_HEADS_REGISTRY.__doc__ = """
+Registry for ROI heads in a generalized R-CNN model.
+ROIHeads take feature maps and region proposals, and
+perform per-region computation.
+The registered object will be called with `obj(cfg, input_shape)`.
+The call is expected to return an :class:`ROIHeads`.
+"""
+
+
+def build_heads(cfg):
+    """
+    Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`.
+    """
+    head = cfg.MODEL.HEADS.NAME
+    return REID_HEADS_REGISTRY.get(head)(cfg)
diff --git a/modeling/heads/embedding_head.py b/modeling/heads/embedding_head.py
new file mode 100644
index 0000000..1ead80b
--- /dev/null
+++ b/modeling/heads/embedding_head.py
@@ -0,0 +1,97 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch.nn.functional as F
+from torch import nn
+
+from layers import *
+from utils.weight_init import weights_init_kaiming, weights_init_classifier
+from .build import REID_HEADS_REGISTRY
+
+
+@REID_HEADS_REGISTRY.register()
+class EmbeddingHead(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        # fmt: off
+        feat_dim      = cfg.MODEL.BACKBONE.FEAT_DIM
+        embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM
+        num_classes   = cfg.MODEL.HEADS.NUM_CLASSES
+        neck_feat     = cfg.MODEL.HEADS.NECK_FEAT
+        pool_type     = cfg.MODEL.HEADS.POOL_LAYER
+        cls_type      = cfg.MODEL.HEADS.CLS_LAYER
+        with_bnneck   = cfg.MODEL.HEADS.WITH_BNNECK
+        norm_type     = cfg.MODEL.HEADS.NORM
+
+        if pool_type == 'fastavgpool':   self.pool_layer = FastGlobalAvgPool2d()
+        elif pool_type == 'avgpool':     self.pool_layer = nn.AdaptiveAvgPool2d(1)
+        elif pool_type == 'maxpool':     self.pool_layer = nn.AdaptiveMaxPool2d(1)
+        elif pool_type == 'gempoolP':    self.pool_layer = GeneralizedMeanPoolingP()
+        elif pool_type == 'gempool':     self.pool_layer = GeneralizedMeanPooling()
+        elif pool_type == "avgmaxpool":  self.pool_layer = AdaptiveAvgMaxPool2d()
+        elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d()
+        elif pool_type == "identity":    self.pool_layer = nn.Identity()
+        elif pool_type == "flatten":     self.pool_layer = Flatten()
+        else:                            raise KeyError(f"{pool_type} is not supported!")
+        # fmt: on
+
+        self.neck_feat = neck_feat
+
+        bottleneck = []
+        if embedding_dim > 0:
+            bottleneck.append(nn.Conv2d(feat_dim, embedding_dim, 1, 1, bias=False))
+            feat_dim = embedding_dim
+
+        if with_bnneck:
+            bottleneck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
+
+        self.bottleneck = nn.Sequential(*bottleneck)
+
+        # identity classification layer
+        # fmt: off
+        if cls_type == 'linear':          self.classifier = nn.Linear(feat_dim, num_classes, bias=False)
+        elif cls_type == 'arcSoftmax':    self.classifier = ArcSoftmax(cfg, feat_dim, num_classes)
+        elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, feat_dim, num_classes)
+        elif cls_type == 'amSoftmax':     self.classifier = AMSoftmax(cfg, feat_dim, num_classes)
+        else:                             raise KeyError(f"{cls_type} is not supported!")
+        # fmt: on
+
+        self.bottleneck.apply(weights_init_kaiming)
+        self.classifier.apply(weights_init_classifier)
+
+    def forward(self, features, targets=None):
+        """
+        See :class:`ReIDHeads.forward`.
+        """
+        global_feat = self.pool_layer(features)
+        bn_feat = self.bottleneck(global_feat)
+        bn_feat = bn_feat[..., 0, 0]
+
+        # Evaluation
+        # fmt: off
+        if not self.training: return bn_feat
+        # fmt: on
+
+        # Training
+        if self.classifier.__class__.__name__ == 'Linear':
+            cls_outputs = self.classifier(bn_feat)
+            pred_class_logits = F.linear(bn_feat, self.classifier.weight)
+        else:
+            cls_outputs = self.classifier(bn_feat, targets)
+            pred_class_logits = self.classifier.s * F.linear(F.normalize(bn_feat),
+                                                             F.normalize(self.classifier.weight))
+
+        # fmt: off
+        if self.neck_feat == "before":  feat = global_feat[..., 0, 0]
+        elif self.neck_feat == "after": feat = bn_feat
+        else:                           raise KeyError(f"{self.neck_feat} is invalid for MODEL.HEADS.NECK_FEAT")
+        # fmt: on
+
+        return {
+            "cls_outputs": cls_outputs,
+            "pred_class_logits": pred_class_logits,
+            "features": feat,
+        }
diff --git a/modeling/heads/linear_head.py b/modeling/heads/linear_head.py
new file mode 100644
index 0000000..9c34835
--- /dev/null
+++ b/modeling/heads/linear_head.py
@@ -0,0 +1,50 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from layers import *
+from modeling.losses import *
+from .build import REID_HEADS_REGISTRY
+from utils.weight_init import weights_init_classifier
+
+
+@REID_HEADS_REGISTRY.register()
+class LinearHead(nn.Module):
+    def __init__(self, cfg, in_feat, num_classes, pool_layer):
+        super().__init__()
+        self.pool_layer = pool_layer
+
+        # identity classification layer
+        cls_type = cfg.MODEL.HEADS.CLS_LAYER
+        if cls_type == 'linear':    self.classifier = nn.Linear(in_feat, num_classes, bias=False)
+        elif cls_type == 'arcface': self.classifier = Arcface(cfg, in_feat, num_classes)
+        elif cls_type == 'circle':  self.classifier = Circle(cfg, in_feat, num_classes)
+        else:
+            raise KeyError(f"{cls_type} is invalid, please choose from "
+                           f"'linear', 'arcface' and 'circle'.")
+
+        self.classifier.apply(weights_init_classifier)
+
+    def forward(self, features, targets=None):
+        """
+        See :class:`ReIDHeads.forward`.
+        """
+        global_feat = self.pool_layer(features)
+        global_feat = global_feat[..., 0, 0]
+
+        # Evaluation
+        if not self.training: return global_feat
+
+        # Training
+        try:
+            cls_outputs = self.classifier(global_feat)
+            pred_class_logits = cls_outputs.detach()
+        except TypeError:
+            cls_outputs = self.classifier(global_feat, targets)
+            pred_class_logits = F.linear(F.normalize(global_feat.detach()), F.normalize(self.classifier.weight.detach()))
+        # Log prediction accuracy
+        CrossEntropyLoss.log_accuracy(pred_class_logits, targets)
+
+        return cls_outputs, global_feat
diff --git a/modeling/heads/reduction_head.py b/modeling/heads/reduction_head.py
new file mode 100644
index 0000000..3e29c84
--- /dev/null
+++ b/modeling/heads/reduction_head.py
@@ -0,0 +1,73 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from layers import *
+from modeling.losses import *
+from utils.weight_init import weights_init_kaiming, weights_init_classifier
+from .build import REID_HEADS_REGISTRY
+
+
+@REID_HEADS_REGISTRY.register()
+class ReductionHead(nn.Module):
+    def __init__(self, cfg, in_feat, num_classes, pool_layer):
+        super().__init__()
+        self._cfg = cfg
+        reduction_dim = cfg.MODEL.HEADS.REDUCTION_DIM
+        self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
+
+        self.pool_layer = pool_layer
+
+        self.bottleneck = nn.Sequential(
+            nn.Conv2d(in_feat, reduction_dim, 1, 1, bias=False),
+            get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT),
+            nn.LeakyReLU(0.1, inplace=True),
+        )
+
+        self.bnneck = get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True)
+
+        self.bottleneck.apply(weights_init_kaiming)
+        self.bnneck.apply(weights_init_kaiming)
+
+        # identity classification layer
+        cls_type = cfg.MODEL.HEADS.CLS_LAYER
+        if cls_type == 'linear':    self.classifier = nn.Linear(reduction_dim, num_classes, bias=False)
+        elif cls_type == 'arcface': self.classifier = Arcface(cfg, reduction_dim, num_classes)
+        elif cls_type == 'circle':  self.classifier = Circle(cfg, reduction_dim, num_classes)
+        else:
+            raise KeyError(f"{cls_type} is invalid, please choose from "
+                           f"'linear', 'arcface' and 'circle'.")
+
+        self.classifier.apply(weights_init_classifier)
+
+    def forward(self, features, targets=None):
+        """
+        See :class:`ReIDHeads.forward`.
+        """
+        features = self.pool_layer(features)
+        global_feat = self.bottleneck(features)
+        bn_feat = self.bnneck(global_feat)
+        bn_feat = bn_feat[..., 0, 0]
+
+        # Evaluation
+        if not self.training: return bn_feat
+
+        # Training
+        try:
+            cls_outputs = self.classifier(bn_feat)
+            pred_class_logits = cls_outputs.detach()
+        except TypeError:
+            cls_outputs = self.classifier(bn_feat, targets)
+            pred_class_logits = F.linear(F.normalize(bn_feat.detach()), F.normalize(self.classifier.weight.detach()))
+        # Log prediction accuracy
+        CrossEntropyLoss.log_accuracy(pred_class_logits, targets)
+
+        if self.neck_feat == "before":  feat = global_feat[..., 0, 0]
+        elif self.neck_feat == "after": feat = bn_feat
+        else:
+            raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
+
+        return cls_outputs, feat
+
diff --git a/modeling/losses/__init__.py b/modeling/losses/__init__.py
new file mode 100644
index 0000000..977e5c3
--- /dev/null
+++ b/modeling/losses/__init__.py
@@ -0,0 +1,9 @@
+# encoding: utf-8
+"""
+@author:  l1aoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from .cross_entroy_loss import CrossEntropyLoss
+from .focal_loss import FocalLoss
+from .metric_loss import TripletLoss, CircleLoss
diff --git a/modeling/losses/__pycache__/__init__.cpython-37.pyc b/modeling/losses/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..9a28f08
--- /dev/null
+++ b/modeling/losses/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/modeling/losses/__pycache__/__init__.cpython-38.pyc b/modeling/losses/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..a6c71c7
--- /dev/null
+++ b/modeling/losses/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/modeling/losses/__pycache__/cross_entroy_loss.cpython-37.pyc b/modeling/losses/__pycache__/cross_entroy_loss.cpython-37.pyc
new file mode 100644
index 0000000..543c429
--- /dev/null
+++ b/modeling/losses/__pycache__/cross_entroy_loss.cpython-37.pyc
Binary files differ
diff --git a/modeling/losses/__pycache__/cross_entroy_loss.cpython-38.pyc b/modeling/losses/__pycache__/cross_entroy_loss.cpython-38.pyc
new file mode 100644
index 0000000..a43b6f3
--- /dev/null
+++ b/modeling/losses/__pycache__/cross_entroy_loss.cpython-38.pyc
Binary files differ
diff --git a/modeling/losses/__pycache__/focal_loss.cpython-37.pyc b/modeling/losses/__pycache__/focal_loss.cpython-37.pyc
new file mode 100644
index 0000000..3440861
--- /dev/null
+++ b/modeling/losses/__pycache__/focal_loss.cpython-37.pyc
Binary files differ
diff --git a/modeling/losses/__pycache__/focal_loss.cpython-38.pyc b/modeling/losses/__pycache__/focal_loss.cpython-38.pyc
new file mode 100644
index 0000000..bf84802
--- /dev/null
+++ b/modeling/losses/__pycache__/focal_loss.cpython-38.pyc
Binary files differ
diff --git a/modeling/losses/__pycache__/metric_loss.cpython-37.pyc b/modeling/losses/__pycache__/metric_loss.cpython-37.pyc
new file mode 100644
index 0000000..cad5fb8
--- /dev/null
+++ b/modeling/losses/__pycache__/metric_loss.cpython-37.pyc
Binary files differ
diff --git a/modeling/losses/__pycache__/metric_loss.cpython-38.pyc b/modeling/losses/__pycache__/metric_loss.cpython-38.pyc
new file mode 100644
index 0000000..a7b70f6
--- /dev/null
+++ b/modeling/losses/__pycache__/metric_loss.cpython-38.pyc
Binary files differ
diff --git a/modeling/losses/cross_entroy_loss.py b/modeling/losses/cross_entroy_loss.py
new file mode 100644
index 0000000..2108acd
--- /dev/null
+++ b/modeling/losses/cross_entroy_loss.py
@@ -0,0 +1,62 @@
+# encoding: utf-8
+"""
+@author:  l1aoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+import torch
+import torch.nn.functional as F
+
+from utils.events import get_event_storage
+
+
+class CrossEntropyLoss(object):
+    """
+    A class that stores information and compute losses about outputs of a Baseline head.
+    """
+
+    def __init__(self, cfg):
+        self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
+        self._eps = cfg.MODEL.LOSSES.CE.EPSILON
+        self._alpha = cfg.MODEL.LOSSES.CE.ALPHA
+        self._scale = cfg.MODEL.LOSSES.CE.SCALE
+
+    @staticmethod
+    def log_accuracy(pred_class_logits, gt_classes, topk=(1,)):
+        """
+        Log the accuracy metrics to EventStorage.
+        """
+        bsz = pred_class_logits.size(0)
+        maxk = max(topk)
+        _, pred_class = pred_class_logits.topk(maxk, 1, True, True)
+        pred_class = pred_class.t()
+        correct = pred_class.eq(gt_classes.view(1, -1).expand_as(pred_class))
+
+        ret = []
+        for k in topk:
+            correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
+            ret.append(correct_k.mul_(1. / bsz))
+
+        storage = get_event_storage()
+        storage.put_scalar("cls_accuracy", ret[0])
+
+    def __call__(self, pred_class_logits, gt_classes):
+        """
+        Compute the softmax cross entropy loss for box classification.
+        Returns:
+            scalar Tensor
+        """
+        if self._eps >= 0:
+            smooth_param = self._eps
+        else:
+            # adaptive lsr
+            soft_label = F.softmax(pred_class_logits, dim=1)
+            smooth_param = self._alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1)
+
+        log_probs = F.log_softmax(pred_class_logits, dim=1)
+        with torch.no_grad():
+            targets = torch.ones_like(log_probs)
+            targets *= smooth_param / (self._num_classes - 1)
+            targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param))
+
+        loss = (-targets * log_probs).mean(0).sum()
+        return loss * self._scale
diff --git a/modeling/losses/focal_loss.py b/modeling/losses/focal_loss.py
new file mode 100644
index 0000000..c520594
--- /dev/null
+++ b/modeling/losses/focal_loss.py
@@ -0,0 +1,110 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+import torch.nn.functional as F
+
+
+# based on:
+# https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py
+
+def focal_loss(
+        input: torch.Tensor,
+        target: torch.Tensor,
+        alpha: float,
+        gamma: float = 2.0,
+        reduction: str = 'mean', ) -> torch.Tensor:
+    r"""Function that computes Focal loss.
+    See :class:`fastreid.modeling.losses.FocalLoss` for details.
+    """
+    if not torch.is_tensor(input):
+        raise TypeError("Input type is not a torch.Tensor. Got {}"
+                        .format(type(input)))
+
+    if not len(input.shape) >= 2:
+        raise ValueError("Invalid input shape, we expect BxCx*. Got: {}"
+                         .format(input.shape))
+
+    if input.size(0) != target.size(0):
+        raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
+                         .format(input.size(0), target.size(0)))
+
+    n = input.size(0)
+    out_size = (n,) + input.size()[2:]
+    if target.size()[1:] != input.size()[2:]:
+        raise ValueError('Expected target size {}, got {}'.format(
+            out_size, target.size()))
+
+    if not input.device == target.device:
+        raise ValueError(
+            "input and target must be in the same device. Got: {}".format(
+                input.device, target.device))
+
+    # compute softmax over the classes axis
+    input_soft = F.softmax(input, dim=1)
+
+    # create the labels one hot tensor
+    target_one_hot = F.one_hot(target, num_classes=input.shape[1])
+
+    # compute the actual focal loss
+    weight = torch.pow(-input_soft + 1., gamma)
+
+    focal = -alpha * weight * torch.log(input_soft)
+    loss_tmp = torch.sum(target_one_hot * focal, dim=1)
+
+    if reduction == 'none':
+        loss = loss_tmp
+    elif reduction == 'mean':
+        loss = torch.mean(loss_tmp)
+    elif reduction == 'sum':
+        loss = torch.sum(loss_tmp)
+    else:
+        raise NotImplementedError("Invalid reduction mode: {}"
+                                  .format(reduction))
+    return loss
+
+
+class FocalLoss(object):
+    r"""Criterion that computes Focal loss.
+    According to [1], the Focal loss is computed as follows:
+    .. math::
+        \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
+    where:
+       - :math:`p_t` is the model's estimated probability for each class.
+    Arguments:
+        alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
+        gamma (float): Focusing parameter :math:`\gamma >= 0`.
+        reduction (str, optional): Specifies the reduction to apply to the
+         output: 鈥榥one鈥� | 鈥榤ean鈥� | 鈥榮um鈥�. 鈥榥one鈥�: no reduction will be applied,
+         鈥榤ean鈥�: the sum of the output will be divided by the number of elements
+         in the output, 鈥榮um鈥�: the output will be summed. Default: 鈥榥one鈥�.
+    Shape:
+        - Input: :math:`(N, C, *)` where C = number of classes.
+        - Target: :math:`(N, *)` where each value is
+          :math:`0 鈮� targets[i] 鈮� C鈭�1`.
+    Examples:
+        >>> N = 5  # num_classes
+        >>> loss = FocalLoss(cfg)
+        >>> input = torch.randn(1, N, 3, 5, requires_grad=True)
+        >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
+        >>> output = loss(input, target)
+        >>> output.backward()
+    References:
+        [1] https://arxiv.org/abs/1708.02002
+    """
+
+    # def __init__(self, alpha: float, gamma: float = 2.0,
+    #              reduction: str = 'none') -> None:
+    def __init__(self, cfg):
+        self._alpha: float = cfg.MODEL.LOSSES.FL.ALPHA
+        self._gamma: float = cfg.MODEL.LOSSES.FL.GAMMA
+        self._scale: float = cfg.MODEL.LOSSES.FL.SCALE
+
+    def __call__(self, pred_class_logits: torch.Tensor, _, gt_classes: torch.Tensor) -> dict:
+        loss = focal_loss(pred_class_logits, gt_classes, self._alpha, self._gamma)
+        return {
+            'loss_focal': loss * self._scale,
+        }
diff --git a/modeling/losses/metric_loss.py b/modeling/losses/metric_loss.py
new file mode 100644
index 0000000..edcf636
--- /dev/null
+++ b/modeling/losses/metric_loss.py
@@ -0,0 +1,215 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+import torch.nn.functional as F
+
+from utils import comm
+
+
+# utils
+@torch.no_grad()
+def concat_all_gather(tensor):
+    """
+    Performs all_gather operation on the provided tensors.
+    *** Warning ***: torch.distributed.all_gather has no gradient.
+    """
+    tensors_gather = [torch.ones_like(tensor)
+                      for _ in range(torch.distributed.get_world_size())]
+    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+    output = torch.cat(tensors_gather, dim=0)
+    return output
+
+
+def normalize(x, axis=-1):
+    """Normalizing to unit length along the specified dimension.
+    Args:
+      x: pytorch Variable
+    Returns:
+      x: pytorch Variable, same shape as input
+    """
+    x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
+    return x
+
+
+def euclidean_dist(x, y):
+    m, n = x.size(0), y.size(0)
+    xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
+    yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
+    dist = xx + yy
+    dist.addmm_(1, -2, x, y.t())
+    dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
+    return dist
+
+
+def cosine_dist(x, y):
+    bs1, bs2 = x.size(0), y.size(0)
+    frac_up = torch.matmul(x, y.transpose(0, 1))
+    frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \
+                (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1)
+    cosine = frac_up / frac_down
+    return 1 - cosine
+
+
+def softmax_weights(dist, mask):
+    max_v = torch.max(dist * mask, dim=1, keepdim=True)[0]
+    diff = dist - max_v
+    Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6  # avoid division by zero
+    W = torch.exp(diff) * mask / Z
+    return W
+
+
+def hard_example_mining(dist_mat, is_pos, is_neg):
+    """For each anchor, find the hardest positive and negative sample.
+    Args:
+      dist_mat: pair wise distance between samples, shape [N, M]
+      is_pos: positive index with shape [N, M]
+      is_neg: negative index with shape [N, M]
+    Returns:
+      dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
+      dist_an: pytorch Variable, distance(anchor, negative); shape [N]
+      p_inds: pytorch LongTensor, with shape [N];
+        indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
+      n_inds: pytorch LongTensor, with shape [N];
+        indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
+    NOTE: Only consider the case in which all labels have same num of samples,
+      thus we can cope with all anchors in parallel.
+    """
+
+    assert len(dist_mat.size()) == 2
+    N = dist_mat.size(0)
+
+    # `dist_ap` means distance(anchor, positive)
+    # both `dist_ap` and `relative_p_inds` with shape [N, 1]
+    # pos_dist = dist_mat[is_pos].contiguous().view(N, -1)
+    # ap_weight = F.softmax(pos_dist, dim=1)
+    # dist_ap = torch.sum(ap_weight * pos_dist, dim=1)
+    dist_ap, relative_p_inds = torch.max(
+        dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
+    # `dist_an` means distance(anchor, negative)
+    # both `dist_an` and `relative_n_inds` with shape [N, 1]
+    dist_an, relative_n_inds = torch.min(
+        dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
+    # neg_dist = dist_mat[is_neg].contiguous().view(N, -1)
+    # an_weight = F.softmax(-neg_dist, dim=1)
+    # dist_an = torch.sum(an_weight * neg_dist, dim=1)
+
+    # shape [N]
+    dist_ap = dist_ap.squeeze(1)
+    dist_an = dist_an.squeeze(1)
+
+    return dist_ap, dist_an
+
+
+def weighted_example_mining(dist_mat, is_pos, is_neg):
+    """For each anchor, find the weighted positive and negative sample.
+    Args:
+      dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
+      is_pos:
+      is_neg:
+    Returns:
+      dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
+      dist_an: pytorch Variable, distance(anchor, negative); shape [N]
+    """
+    assert len(dist_mat.size()) == 2
+
+    is_pos = is_pos.float()
+    is_neg = is_neg.float()
+    dist_ap = dist_mat * is_pos
+    dist_an = dist_mat * is_neg
+
+    weights_ap = softmax_weights(dist_ap, is_pos)
+    weights_an = softmax_weights(-dist_an, is_neg)
+
+    dist_ap = torch.sum(dist_ap * weights_ap, dim=1)
+    dist_an = torch.sum(dist_an * weights_an, dim=1)
+
+    return dist_ap, dist_an
+
+
+class TripletLoss(object):
+    """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
+    Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
+    Loss for Person Re-Identification'."""
+
+    def __init__(self, cfg):
+        self._margin = cfg.MODEL.LOSSES.TRI.MARGIN
+        self._normalize_feature = cfg.MODEL.LOSSES.TRI.NORM_FEAT
+        self._scale = cfg.MODEL.LOSSES.TRI.SCALE
+        self._hard_mining = cfg.MODEL.LOSSES.TRI.HARD_MINING
+
+    def __call__(self, embedding, targets):
+        if self._normalize_feature:
+            embedding = normalize(embedding, axis=-1)
+
+        # For distributed training, gather all features from different process.
+        if comm.get_world_size() > 1:
+            all_embedding = concat_all_gather(embedding)
+            all_targets = concat_all_gather(targets)
+        else:
+            all_embedding = embedding
+            all_targets = targets
+
+        dist_mat = euclidean_dist(embedding, all_embedding)
+
+        N, M = dist_mat.size()
+        is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
+        is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
+
+        if self._hard_mining:
+            dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
+        else:
+            dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)
+
+        y = dist_an.new().resize_as_(dist_an).fill_(1)
+
+        if self._margin > 0:
+            loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=self._margin)
+        else:
+            loss = F.soft_margin_loss(dist_an - dist_ap, y)
+            if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
+
+        return loss * self._scale
+
+
+class CircleLoss(object):
+    def __init__(self, cfg):
+        self._scale = cfg.MODEL.LOSSES.CIRCLE.SCALE
+
+        self.m = cfg.MODEL.LOSSES.CIRCLE.MARGIN
+        self.s = cfg.MODEL.LOSSES.CIRCLE.ALPHA
+
+    def __call__(self, embedding, targets):
+        embedding = F.normalize(embedding, dim=1)
+
+        if comm.get_world_size() > 1:
+            all_embedding = concat_all_gather(embedding)
+            all_targets = concat_all_gather(targets)
+        else:
+            all_embedding = embedding
+            all_targets = targets
+
+        dist_mat = torch.matmul(embedding, all_embedding.t())
+
+        N, M = dist_mat.size()
+        is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
+        is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
+
+        s_p = dist_mat[is_pos].contiguous().view(N, -1)
+        s_n = dist_mat[is_neg].contiguous().view(N, -1)
+
+        alpha_p = F.relu(-s_p.detach() + 1 + self.m)
+        alpha_n = F.relu(s_n.detach() + self.m)
+        delta_p = 1 - self.m
+        delta_n = self.m
+
+        logit_p = - self.s * alpha_p * (s_p - delta_p)
+        logit_n = self.s * alpha_n * (s_n - delta_n)
+
+        loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
+
+        return loss * self._scale
diff --git a/modeling/meta_arch/__init__.py b/modeling/meta_arch/__init__.py
new file mode 100644
index 0000000..84ae987
--- /dev/null
+++ b/modeling/meta_arch/__init__.py
@@ -0,0 +1,12 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from .build import META_ARCH_REGISTRY, build_model
+
+
+# import all the meta_arch, so they will be registered
+from .baseline import Baseline
+from .mgn import MGN
diff --git a/modeling/meta_arch/__pycache__/__init__.cpython-37.pyc b/modeling/meta_arch/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..9edd4bc
--- /dev/null
+++ b/modeling/meta_arch/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/modeling/meta_arch/__pycache__/__init__.cpython-38.pyc b/modeling/meta_arch/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..272307a
--- /dev/null
+++ b/modeling/meta_arch/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/modeling/meta_arch/__pycache__/baseline.cpython-37.pyc b/modeling/meta_arch/__pycache__/baseline.cpython-37.pyc
new file mode 100644
index 0000000..bc9d4b2
--- /dev/null
+++ b/modeling/meta_arch/__pycache__/baseline.cpython-37.pyc
Binary files differ
diff --git a/modeling/meta_arch/__pycache__/baseline.cpython-38.pyc b/modeling/meta_arch/__pycache__/baseline.cpython-38.pyc
new file mode 100644
index 0000000..382735e
--- /dev/null
+++ b/modeling/meta_arch/__pycache__/baseline.cpython-38.pyc
Binary files differ
diff --git a/modeling/meta_arch/__pycache__/build.cpython-37.pyc b/modeling/meta_arch/__pycache__/build.cpython-37.pyc
new file mode 100644
index 0000000..3b47628
--- /dev/null
+++ b/modeling/meta_arch/__pycache__/build.cpython-37.pyc
Binary files differ
diff --git a/modeling/meta_arch/__pycache__/build.cpython-38.pyc b/modeling/meta_arch/__pycache__/build.cpython-38.pyc
new file mode 100644
index 0000000..a27b826
--- /dev/null
+++ b/modeling/meta_arch/__pycache__/build.cpython-38.pyc
Binary files differ
diff --git a/modeling/meta_arch/__pycache__/mgn.cpython-37.pyc b/modeling/meta_arch/__pycache__/mgn.cpython-37.pyc
new file mode 100644
index 0000000..9e4a0a6
--- /dev/null
+++ b/modeling/meta_arch/__pycache__/mgn.cpython-37.pyc
Binary files differ
diff --git a/modeling/meta_arch/__pycache__/mgn.cpython-38.pyc b/modeling/meta_arch/__pycache__/mgn.cpython-38.pyc
new file mode 100644
index 0000000..ad31752
--- /dev/null
+++ b/modeling/meta_arch/__pycache__/mgn.cpython-38.pyc
Binary files differ
diff --git a/modeling/meta_arch/baseline.py b/modeling/meta_arch/baseline.py
new file mode 100644
index 0000000..4e4e518
--- /dev/null
+++ b/modeling/meta_arch/baseline.py
@@ -0,0 +1,119 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+from torch import nn
+
+from modeling.backbones import build_backbone
+from modeling.heads import build_heads
+from modeling.losses import *
+from .build import META_ARCH_REGISTRY
+
+
+@META_ARCH_REGISTRY.register()
+class Baseline(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self._cfg = cfg
+        assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
+        self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
+        self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
+
+        # backbone
+        self.backbone = build_backbone(cfg)
+
+        # head
+        self.heads = build_heads(cfg)
+
+    @property
+    def device(self):
+        return self.pixel_mean.device
+
+    def forward(self, batched_inputs):
+        images = self.preprocess_image(batched_inputs)
+        features = self.backbone(images)
+
+        if self.training:
+            assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
+            targets = batched_inputs["targets"].to(self.device)
+
+            # PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
+            # may be larger than that in the original dataset, so the circle/arcface will
+            # throw an error. We just set all the targets to 0 to avoid this problem.
+            if targets.sum() < 0: targets.zero_()
+
+            outputs = self.heads(features, targets)
+            return {
+                "outputs": outputs,
+                "targets": targets,
+            }
+        else:
+            outputs = self.heads(features)
+            return outputs
+
+    def preprocess_image(self, images):
+        r"""
+        Normalize and batch the input images.
+        """
+        # images = images.to(self.device)
+        # if isinstance(batched_inputs, dict):
+        #     images = batched_inputs["images"].to(self.device)
+        # elif isinstance(batched_inputs, torch.Tensor):
+        #     images = batched_inputs.to(self.device)
+        # else:
+        #     raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs)))
+
+        # print(images)
+        print("-----------------------------------------------------------------------------")
+        images.sub_(self.pixel_mean).div_(self.pixel_std)
+        return images
+
+    def losses(self, outs):
+        r"""
+        Compute loss from modeling's outputs, the loss function input arguments
+        must be the same as the outputs of the model forwarding.
+        """
+        # fmt: off
+        outputs           = outs["outputs"]
+        gt_labels         = outs["targets"]
+        # model predictions
+        pred_class_logits = outputs['pred_class_logits'].detach()
+        cls_outputs       = outputs['cls_outputs']
+        pred_features     = outputs['features']
+        # fmt: on
+
+        # Log prediction accuracy
+        log_accuracy(pred_class_logits, gt_labels)
+
+        loss_dict = {}
+        loss_names = self._cfg.MODEL.LOSSES.NAME
+
+        if "CrossEntropyLoss" in loss_names:
+            loss_dict['loss_cls'] = cross_entropy_loss(
+                cls_outputs,
+                gt_labels,
+                self._cfg.MODEL.LOSSES.CE.EPSILON,
+                self._cfg.MODEL.LOSSES.CE.ALPHA,
+            ) * self._cfg.MODEL.LOSSES.CE.SCALE
+
+        if "TripletLoss" in loss_names:
+            loss_dict['loss_triplet'] = triplet_loss(
+                pred_features,
+                gt_labels,
+                self._cfg.MODEL.LOSSES.TRI.MARGIN,
+                self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
+                self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
+            ) * self._cfg.MODEL.LOSSES.TRI.SCALE
+
+        if "CircleLoss" in loss_names:
+            loss_dict['loss_circle'] = circle_loss(
+                pred_features,
+                gt_labels,
+                self._cfg.MODEL.LOSSES.CIRCLE.MARGIN,
+                self._cfg.MODEL.LOSSES.CIRCLE.ALPHA,
+            ) * self._cfg.MODEL.LOSSES.CIRCLE.SCALE
+
+        return loss_dict
diff --git a/modeling/meta_arch/build.py b/modeling/meta_arch/build.py
new file mode 100644
index 0000000..395b5c7
--- /dev/null
+++ b/modeling/meta_arch/build.py
@@ -0,0 +1,21 @@
+# encoding: utf-8
+from utils.registry import Registry
+
+META_ARCH_REGISTRY = Registry("META_ARCH")  # noqa F401 isort:skip
+META_ARCH_REGISTRY.__doc__ = """
+Registry for meta-architectures, i.e. the whole model.
+The registered object will be called with `obj(cfg)`
+and expected to return a `nn.Module` object.
+"""
+
+
+def build_model(cfg):
+    """
+    Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
+    Note that it does not load any weights from ``cfg``.
+    """
+    meta_arch = cfg.MODEL.META_ARCHITECTURE
+    model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
+    return model
+
+
diff --git a/modeling/meta_arch/mgn.py b/modeling/meta_arch/mgn.py
new file mode 100644
index 0000000..6050b55
--- /dev/null
+++ b/modeling/meta_arch/mgn.py
@@ -0,0 +1,280 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+import copy
+
+import torch
+from torch import nn
+
+from layers import get_norm
+from modeling.backbones import build_backbone
+from modeling.backbones.resnet import Bottleneck
+from modeling.heads import build_heads
+from modeling.losses import *
+from .build import META_ARCH_REGISTRY
+
+
+@META_ARCH_REGISTRY.register()
+class MGN(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self._cfg = cfg
+        assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
+        self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
+        self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
+
+        # fmt: off
+        # backbone
+        bn_norm    = cfg.MODEL.BACKBONE.NORM
+        with_se    = cfg.MODEL.BACKBONE.WITH_SE
+        # fmt :on
+
+        backbone = build_backbone(cfg)
+        self.backbone = nn.Sequential(
+            backbone.conv1,
+            backbone.bn1,
+            backbone.relu,
+            backbone.maxpool,
+            backbone.layer1,
+            backbone.layer2,
+            backbone.layer3[0]
+        )
+        res_conv4 = nn.Sequential(*backbone.layer3[1:])
+        res_g_conv5 = backbone.layer4
+
+        res_p_conv5 = nn.Sequential(
+            Bottleneck(1024, 512, bn_norm, False, with_se, downsample=nn.Sequential(
+                nn.Conv2d(1024, 2048, 1, bias=False), get_norm(bn_norm, 2048))),
+            Bottleneck(2048, 512, bn_norm, False, with_se),
+            Bottleneck(2048, 512, bn_norm, False, with_se))
+        res_p_conv5.load_state_dict(backbone.layer4.state_dict())
+
+        # branch1
+        self.b1 = nn.Sequential(
+            copy.deepcopy(res_conv4),
+            copy.deepcopy(res_g_conv5)
+        )
+        self.b1_head = build_heads(cfg)
+
+        # branch2
+        self.b2 = nn.Sequential(
+            copy.deepcopy(res_conv4),
+            copy.deepcopy(res_p_conv5)
+        )
+        self.b2_head = build_heads(cfg)
+        self.b21_head = build_heads(cfg)
+        self.b22_head = build_heads(cfg)
+
+        # branch3
+        self.b3 = nn.Sequential(
+            copy.deepcopy(res_conv4),
+            copy.deepcopy(res_p_conv5)
+        )
+        self.b3_head = build_heads(cfg)
+        self.b31_head = build_heads(cfg)
+        self.b32_head = build_heads(cfg)
+        self.b33_head = build_heads(cfg)
+
+    @property
+    def device(self):
+        return self.pixel_mean.device
+
+    def forward(self, batched_inputs):
+        images = self.preprocess_image(batched_inputs)
+        features = self.backbone(images)  # (bs, 2048, 16, 8)
+
+        # branch1
+        b1_feat = self.b1(features)
+
+        # branch2
+        b2_feat = self.b2(features)
+        b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
+
+        # branch3
+        b3_feat = self.b3(features)
+        b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
+
+        if self.training:
+            assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
+            targets = batched_inputs["targets"].long().to(self.device)
+
+            if targets.sum() < 0: targets.zero_()
+
+            b1_outputs = self.b1_head(b1_feat, targets)
+            b2_outputs = self.b2_head(b2_feat, targets)
+            b21_outputs = self.b21_head(b21_feat, targets)
+            b22_outputs = self.b22_head(b22_feat, targets)
+            b3_outputs = self.b3_head(b3_feat, targets)
+            b31_outputs = self.b31_head(b31_feat, targets)
+            b32_outputs = self.b32_head(b32_feat, targets)
+            b33_outputs = self.b33_head(b33_feat, targets)
+
+            return {
+                "b1_outputs": b1_outputs,
+                "b2_outputs": b2_outputs,
+                "b21_outputs": b21_outputs,
+                "b22_outputs": b22_outputs,
+                "b3_outputs": b3_outputs,
+                "b31_outputs": b31_outputs,
+                "b32_outputs": b32_outputs,
+                "b33_outputs": b33_outputs,
+                "targets": targets,
+            }
+        else:
+            b1_pool_feat = self.b1_head(b1_feat)
+            b2_pool_feat = self.b2_head(b2_feat)
+            b21_pool_feat = self.b21_head(b21_feat)
+            b22_pool_feat = self.b22_head(b22_feat)
+            b3_pool_feat = self.b3_head(b3_feat)
+            b31_pool_feat = self.b31_head(b31_feat)
+            b32_pool_feat = self.b32_head(b32_feat)
+            b33_pool_feat = self.b33_head(b33_feat)
+
+            pred_feat = torch.cat([b1_pool_feat, b2_pool_feat, b3_pool_feat, b21_pool_feat,
+                                   b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
+            return pred_feat
+
+    def preprocess_image(self, batched_inputs):
+        r"""
+        Normalize and batch the input images.
+        """
+        if isinstance(batched_inputs, dict):
+            images = batched_inputs["images"].to(self.device)
+        elif isinstance(batched_inputs, torch.Tensor):
+            images = batched_inputs.to(self.device)
+        else:
+            raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs)))
+
+        images.sub_(self.pixel_mean).div_(self.pixel_std)
+        return images
+
+    def losses(self, outs):
+        # fmt: off
+        b1_outputs        = outs["b1_outputs"]
+        b2_outputs        = outs["b2_outputs"]
+        b21_outputs       = outs["b21_outputs"]
+        b22_outputs       = outs["b22_outputs"]
+        b3_outputs        = outs["b3_outputs"]
+        b31_outputs       = outs["b31_outputs"]
+        b32_outputs       = outs["b32_outputs"]
+        b33_outputs       = outs["b33_outputs"]
+        gt_labels         = outs["targets"]
+        # model predictions
+        pred_class_logits = b1_outputs['pred_class_logits'].detach()
+        b1_logits         = b1_outputs['cls_outputs']
+        b2_logits         = b2_outputs['cls_outputs']
+        b21_logits        = b21_outputs['cls_outputs']
+        b22_logits        = b22_outputs['cls_outputs']
+        b3_logits         = b3_outputs['cls_outputs']
+        b31_logits        = b31_outputs['cls_outputs']
+        b32_logits        = b32_outputs['cls_outputs']
+        b33_logits        = b33_outputs['cls_outputs']
+        b1_pool_feat      = b1_outputs['features']
+        b2_pool_feat      = b2_outputs['features']
+        b3_pool_feat      = b3_outputs['features']
+        b21_pool_feat     = b21_outputs['features']
+        b22_pool_feat     = b22_outputs['features']
+        b31_pool_feat     = b31_outputs['features']
+        b32_pool_feat     = b32_outputs['features']
+        b33_pool_feat     = b33_outputs['features']
+        # fmt: on
+
+        # Log prediction accuracy
+        log_accuracy(pred_class_logits, gt_labels)
+
+        b22_pool_feat = torch.cat((b21_pool_feat, b22_pool_feat), dim=1)
+        b33_pool_feat = torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)
+
+        loss_dict = {}
+        loss_names = self._cfg.MODEL.LOSSES.NAME
+
+        if "CrossEntropyLoss" in loss_names:
+            loss_dict['loss_cls_b1'] = cross_entropy_loss(
+                b1_logits,
+                gt_labels,
+                self._cfg.MODEL.LOSSES.CE.EPSILON,
+                self._cfg.MODEL.LOSSES.CE.ALPHA,
+            ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
+            loss_dict['loss_cls_b2'] = cross_entropy_loss(
+                b2_logits,
+                gt_labels,
+                self._cfg.MODEL.LOSSES.CE.EPSILON,
+                self._cfg.MODEL.LOSSES.CE.ALPHA,
+            ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
+            loss_dict['loss_cls_b21'] = cross_entropy_loss(
+                b21_logits,
+                gt_labels,
+                self._cfg.MODEL.LOSSES.CE.EPSILON,
+                self._cfg.MODEL.LOSSES.CE.ALPHA,
+            ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
+            loss_dict['loss_cls_b22'] = cross_entropy_loss(
+                b22_logits,
+                gt_labels,
+                self._cfg.MODEL.LOSSES.CE.EPSILON,
+                self._cfg.MODEL.LOSSES.CE.ALPHA,
+            ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
+            loss_dict['loss_cls_b3'] = cross_entropy_loss(
+                b3_logits,
+                gt_labels,
+                self._cfg.MODEL.LOSSES.CE.EPSILON,
+                self._cfg.MODEL.LOSSES.CE.ALPHA,
+            ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
+            loss_dict['loss_cls_b31'] = cross_entropy_loss(
+                b31_logits,
+                gt_labels,
+                self._cfg.MODEL.LOSSES.CE.EPSILON,
+                self._cfg.MODEL.LOSSES.CE.ALPHA,
+            ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
+            loss_dict['loss_cls_b32'] = cross_entropy_loss(
+                b32_logits,
+                gt_labels,
+                self._cfg.MODEL.LOSSES.CE.EPSILON,
+                self._cfg.MODEL.LOSSES.CE.ALPHA,
+            ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
+            loss_dict['loss_cls_b33'] = cross_entropy_loss(
+                b33_logits,
+                gt_labels,
+                self._cfg.MODEL.LOSSES.CE.EPSILON,
+                self._cfg.MODEL.LOSSES.CE.ALPHA,
+            ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
+
+        if "TripletLoss" in loss_names:
+            loss_dict['loss_triplet_b1'] = triplet_loss(
+                b1_pool_feat,
+                gt_labels,
+                self._cfg.MODEL.LOSSES.TRI.MARGIN,
+                self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
+                self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
+            ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
+            loss_dict['loss_triplet_b2'] = triplet_loss(
+                b2_pool_feat,
+                gt_labels,
+                self._cfg.MODEL.LOSSES.TRI.MARGIN,
+                self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
+                self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
+            ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
+            loss_dict['loss_triplet_b3'] = triplet_loss(
+                b3_pool_feat,
+                gt_labels,
+                self._cfg.MODEL.LOSSES.TRI.MARGIN,
+                self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
+                self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
+            ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
+            loss_dict['loss_triplet_b22'] = triplet_loss(
+                b22_pool_feat,
+                gt_labels,
+                self._cfg.MODEL.LOSSES.TRI.MARGIN,
+                self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
+                self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
+            ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
+            loss_dict['loss_triplet_b33'] = triplet_loss(
+                b33_pool_feat,
+                gt_labels,
+                self._cfg.MODEL.LOSSES.TRI.MARGIN,
+                self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
+                self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
+            ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
+
+        return loss_dict
diff --git a/modul/model.pt b/modul/model.pt
deleted file mode 100644
index 31ded02..0000000
--- a/modul/model.pt
+++ /dev/null
Binary files differ
diff --git a/reid_feature.cpp b/reid_feature.cpp
deleted file mode 100644
index fdf1e88..0000000
--- a/reid_feature.cpp
+++ /dev/null
@@ -1,122 +0,0 @@
-//
-// Created by Scheaven on 2020/1/3.
-//
-
-#include "reid_feature.h"
-#include <cuda_runtime_api.h>
-#include <torch/torch.h>
-
-bool ReID_Feature::ReID_init(int gpu_id)
-{
-    if(gpu_id == -1){
-        this->module = torch::jit::load(MODEL_PATH);
-        this->module.to(torch::kCPU);
-        this->module.eval();
-        this->is_gpu = false;
-    }else if(torch::cuda::is_available() && torch::cuda::device_count() >= gpu_id)
-    {
-        cudaSetDevice(gpu_id);
-        cout << "model loading::" << HUMAN_FEATS << endl;
-        this->module = torch::jit::load(MODEL_PATH, torch::Device(torch::DeviceType::CUDA,gpu_id));
-        this->module.to(torch::kCUDA);
-        this->module.eval();
-        this->is_gpu = true;
-    }else{
-        return false;
-    }
-    return true;
-}
-
-int ReID_Feature::ReID_size()
-{
-    int size = 2048;
-    return size;
-
-}
-
-bool ReID_Feature::ReID_extractor(float *pBuf, float *pFeature)
-{
-    auto input_tensor = torch::from_blob(pBuf, {1, 256, 128, 3});
-    input_tensor = input_tensor.permute({0, 3, 1, 2});
-    input_tensor[0][0] = input_tensor[0][0].sub_(0.485).div_(0.229);
-    input_tensor[0][1] = input_tensor[0][1].sub_(0.456).div_(0.224);
-    input_tensor[0][2] = input_tensor[0][2].sub_(0.406).div_(0.225);
-    if(this->is_gpu)
-        input_tensor = input_tensor.to(at::kCUDA);
-    torch::Tensor human_feat =this->module.forward({input_tensor}).toTensor();
-
-//    for (int k = 0; k < 20; ++k) {
-//        cout << "--extractor---human_feats------" <<human_feat[0][k+2000]<< endl;
-//    }
-    torch::Tensor query_feat;
-    if(this->is_gpu)
-        query_feat = human_feat.cpu();
-    else
-        query_feat = human_feat;
-
-    auto foo_one = query_feat.accessor<float,2>();
-
-    ReID_Utils RET;
-
-    float f_size = -0.727412;
-    for (int64_t i = 0; i < foo_one.size(0); i++) {
-        auto a1 = foo_one[i];
-        for (int64_t j = 0; j < a1.size(0); j++) {
-            pFeature[j] = a1[j];
-        }
-    }
-
-//    cout << "---- end 11-------" << pFeature[0] << endl;
-    return true;
-}
-
-float ReID_Feature::ReID_Compare(float *pFeature1, float *pFeature2)
-{
-    torch::Tensor query_feat = torch::zeros({1,2048});
-    torch::Tensor gallery_feat = torch::zeros({1,2048});
-
-    for (int i = 0; i < 2048; i++)
-    {
-        query_feat[0][i] = pFeature1[i];
-        gallery_feat[0][i] = pFeature2[i];
-    }
-
-    if(this->is_gpu)
-    {
-        query_feat = query_feat.cuda();
-        gallery_feat = gallery_feat.cuda();
-    }
-
-//    cout << "-----------------after-----------" << endl;
-//    cout << query_feat<< endl;
-
-//    for (int k = 0; k < 20; ++k) {
-//        cout << "-query_feat----1111111111------" <<query_feat[0][k+2000]<< endl;
-//    }
-//
-//    cout << "-----------------asdf-----------" << endl;
-////    cout << gallery_feat[0][0]<< endl;
-//    for (int k = 0; k < 20; ++k) {
-//        cout << "-gallery_feat----22222222------" <<gallery_feat[0][k+2000]<< endl;
-//    }
-
-    torch::Tensor a_similarity = torch::cosine_similarity(query_feat, gallery_feat);
-
-    if(this->is_gpu)
-        a_similarity = a_similarity.cpu();
-
-    auto foo_one = a_similarity.accessor<float,1>();
-//    cout << ":::::::::-" << endl;
-
-    float f_distance = foo_one[0];
-
-    return f_distance;
-
-}
-
-void ReID_Feature::ReID_Release()
-{
-        prinf("release");
-//    this->module = nullptr;//鍔犺浇妯″瀷
-//    return true;
-}
\ No newline at end of file
diff --git a/reid_feature.h b/reid_feature.h
deleted file mode 100644
index 8a98bf2..0000000
--- a/reid_feature.h
+++ /dev/null
@@ -1,32 +0,0 @@
-//
-// Created by Scheaven on 2020/1/3.
-//
-
-#ifndef INC_03_REID_STRONG_BASELINE_REID_FEATURE_H
-#define INC_03_REID_STRONG_BASELINE_REID_FEATURE_H
-#include <torch/script.h>
-#include <opencv2/opencv.hpp>
-#include <fstream>
-#include <string>
-#include <iomanip>
-#include <stdlib.h>
-#include <vector>
-#include <iostream>
-
-
-struct ReID_Feature {
-private:
-    bool is_gpu;
-    //    auto feat;
-    torch::jit::script::Module module;
-public:
-    bool ReID_init(int gpu_id);
-    int ReID_size();
-//    static unsigned char * extractor(unsigned char *pBuf, unsigned char *pFeature);
-    bool ReID_extractor(float *pBuf, float * pFeature);
-    float ReID_Compare(float *pFeature1, float *pFeature2);
-    void ReID_Release();
-};
-
-
-#endif //INC_03_REID_STRONG_BASELINE_REID_FEATURE_H
diff --git a/reid_utils.cpp b/reid_utils.cpp
deleted file mode 100644
index 17927ba..0000000
--- a/reid_utils.cpp
+++ /dev/null
@@ -1,24 +0,0 @@
-//
-// Created by Scheaven on 2020/1/3.
-//
-
-#include "reid_utils.h"
-
-
-void *ReID_Utils::normalize(unsigned char *vsrc, int w, int h, int chan){
-	float *data = (float*)malloc(h*w*chan*sizeof(float));
-	int size = w*h;
-	int size2 = size*2;
-
-	unsigned char *srcData = (unsigned char*)vsrc;
-
-	for(int i = 0;i<size;i++){
-        *(data) = *(srcData + 2) /255.0f;
-        *(data+size) = *(srcData + 1) /255.0f;
-        *(data+size2) = *(srcData) /255.0f;
-        data++;
-        srcData+=3;
-	}
-
-	return data;
-}
diff --git a/reid_utils.h b/reid_utils.h
deleted file mode 100644
index 47c249b..0000000
--- a/reid_utils.h
+++ /dev/null
@@ -1,22 +0,0 @@
-//
-// Created by Scheaven on 2020/1/3.
-//
-
-#ifndef INC_03_REID_STRONG_BASELINE_REID_UTILS_H
-#define INC_03_REID_STRONG_BASELINE_REID_UTILS_H
-#include <torch/script.h>
-#include <fstream>
-#include <string>
-#include <iomanip>
-#include <stdlib.h>
-#include <vector>
-#include <iostream>
-
-using namespace std;
-struct ReID_Utils{
-    public:
-        float *normalize(unsigned char *vsrc, int w, int h, int chan);
-};
-
-
-#endif //INC_03_REID_STRONG_BASELINE_REID_UTILS_H
diff --git a/test.cpp b/test.cpp
deleted file mode 100644
index 72640b9..0000000
--- a/test.cpp
+++ /dev/null
@@ -1,86 +0,0 @@
-//
-// Created by Scheaven on 2019/12/27.
-//
-#include <torch/script.h> // One-stop header.
-#include <iostream>
-#include <memory>
-#include <vector>
-#include <string>
-#include <opencv2/core/core.hpp>
-#include <opencv2/highgui/highgui.hpp>
-#include "opencv2/imgproc/imgproc.hpp"
-#include "opencv2/opencv.hpp"
-#include "opencv2/videoio.hpp"
-#include "reid_feature.h"
-
-
-using namespace std;
-//using namespace cv;
-
-int main(int argc, const char* argv[]) {
-//    if (argc != 2) {
-//        std::cerr << "usage: reid-app <image path>\n";;
-//        return -1;
-//    }
-
-//    torch::jit::script::Module module;
-//    char cam_id = 'A';
-//    ReID_Tracker Tracker;
-
-
-
-    /*鍒濆鍖�*/
-    int gpu_id = 0;
-    ReID_Feature R_Feater;
-    bool n_flog = R_Feater.ReID_init(0);
-    ReID_Utils r_util;
-//    ReID_Feature R_Feater(gpu_id);
-
-    /*opencv鍔犺浇鍥剧墖淇℃伅*/
-    cv::Mat human_img = cv::imread("./03.jpg");
-    cv::Mat human_img2 = cv::imread("./01.jpg");
-    if (human_img.data == nullptr)
-    {
-        cerr<<"===鍥剧墖鏂囦欢涓嶅瓨鍦�"<<endl;
-        return 0;
-    }else
-    {
-        //cv::namedWindow("Display", CV_WINDOW_AUTOSIZE);
-        try {
-            float pFeature1[2048];
-            /*杞箟鍥剧墖淇℃伅鏍煎紡*/
-            // cv::cvtColor(human_img, human_img, cv::COLOR_RGB2BGR);
-            // human_img.convertTo(human_img, CV_32FC3, 1.0f / 255.0f);
-            float *my_img_data = r_util.normalize(human_img.data, human_img.cols, human_img.rows, 3);
-            bool ex_flag1 = R_Feater.ReID_extractor(my_img_data, pFeature1);
-//            for (int k = 0; k < 20; ++k) {
-//                cout << "-----11111111111------" <<pFeature1[k+2000]<< endl;
-//            }
-
-            float pFeature2[2048];
-            // cv::cvtColor(human_img2, human_img2, cv::COLOR_RGB2BGR);
-            // human_img2.convertTo(human_img2, CV_32FC3, 1.0f / 255.0f);
-            float *my_img_data2 = r_util.normalize(human_img.data, human_img.cols, human_img.rows, 3);
-            bool ex_flag2 = R_Feater.ReID_extractor(my_img_data2, pFeature2);
-//            for (int k = 0; k < 20; ++k) {
-//                cout << "-----2222222222------" <<pFeature2[k+2000]<< endl;
-//            }
-
-            /*璁$畻鐩镐技搴�*/
-            cout << "--attention_distance human-" << endl;
-            float result = R_Feater.ReID_Compare(pFeature1, pFeature2);
-
-//            Tracker.storager(1,human_img);
-        }
-        catch (const c10::Error& e) {
-            std::cerr << "error loading the model\n";
-            return -1;
-        }
-
-        std::cout << "ok\n";
-        //cout<<  human_img  <<endl;
-        //cv::imshow("0909",human_img);
-        //cv::waitKey(0);
-    }
-
-}
diff --git a/tools/03_check_onnx.py b/tools/03_check_onnx.py
new file mode 100644
index 0000000..4467222
--- /dev/null
+++ b/tools/03_check_onnx.py
@@ -0,0 +1,168 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import torch
+import sys
+sys.path.append('.')
+
+import time
+
+import pycuda.autoinit
+import pycuda.driver as cuda
+import tensorrt as trt
+import torch
+import time
+from PIL import Image
+import cv2,os
+import torchvision
+import numpy as np
+
+max_batch_size = 1
+onnx_model_path = "/data/disk1/workspace/06_reid/01_fast_reid/02_fast_reid_inference/fastreid.onnx"
+TRT_LOGGER = trt.Logger()
+
+# class HostDeviceMem(object):
+#     def init(self, host_mem, device_mem):
+#         # """
+#         # host_mem: cpu memory
+#         # device_mem: gpu memory
+#         # """
+#         print("-----------11-----------")
+#         self.host = host_mem
+#         self.device = device_mem
+
+#     def init():
+#         # """
+#         # host_mem: cpu memory
+#         # device_mem: gpu memory
+#         # """
+#         print("---------22-------------")
+
+#     def __str__(self):
+#         return "Host:\n" + str(self.host)+"\nDevice:\n"+str(self.device)
+#     def __repr__(self):
+#         return self.__str__()
+
+def get_img_np_nchw(filename):
+    image = cv2.imread(filename)
+    image_cv = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+    image_cv = cv2.resize(image_cv, (256, 128))
+    miu = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
+    std = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)
+    img_np = np.array(image_cv, dtype=np.float)/255.
+    img_np = img_np.transpose((2, 0, 1))
+    img_np -= miu
+    img_np /= std
+    img_np_nchw = img_np[np.newaxis]
+    img_np_nchw = np.tile(img_np_nchw,(max_batch_size, 1, 1, 1))
+    return img_np_nchw
+
+
+class HostDeviceMem(object):
+    def __init__(self, host_mem, device_mem):
+        # """
+        # host_mem: cpu memory
+        # device_mem: gpu memory
+        # """
+        self.host = host_mem
+        self.device = device_mem
+        print()
+
+    def __str__(self):
+        return "Host:\n" + str(self.host)+"\nDevice:\n"+str(self.device)
+
+    def __repr__(self):
+        return self.__str__()
+
+def allocate_buffers(engine):
+    inputs, outputs, bindings = [], [], []
+    stream = cuda.Stream()
+    for binding in engine:
+        size = trt.volume(engine.get_binding_shape(binding))
+        dtype = trt.nptype(engine.get_binding_dtype(binding))
+        host_mem = cuda.pagelocked_empty(size, dtype)
+        device_mem = cuda.mem_alloc(host_mem.nbytes)
+        bindings.append(int(device_mem))
+        #append to the appropriate list
+        if engine.binding_is_input(binding):
+            inputs.append(HostDeviceMem(host_mem, device_mem))
+        else:
+            outputs.append(HostDeviceMem(host_mem, device_mem))
+    return inputs, outputs, bindings, stream
+
+def get_engine(max_batch_size=1, onnx_file_path="", engine_file_path="", fp16_mode=False, save_engine=True):
+    if os.path.exists(engine_file_path):
+        print("Reading engine from file: {}".format(engine_file_path))
+        with open(engine_file_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:
+            return runtime.deserialize_cuda_engine(f.read()) # 鍙嶅簭鍒楀寲
+    else:
+
+        explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+        # In TensorRT 7.0, the ONNX parser only supports full-dimensions mode, meaning that your network definition must be created with the explicitBatch flag set. For more information, see Working With Dynamic Shapes.
+
+        with trt.Builder(TRT_LOGGER) as builder, \
+            builder.create_network(explicit_batch) as network,  \
+            trt.OnnxParser(network, TRT_LOGGER) as parser:
+
+            config = builder.create_builder_config()
+            config.max_workspace_size = 1<<30
+            builder.max_batch_size = max_batch_size # 鎵ц鏃舵渶澶у彲浠ヤ娇鐢ㄧ殑batchsize
+            builder.fp16_mode = fp16_mode
+
+            if not os.path.exists(onnx_file_path):
+                quit("ONNX file {} not found!".format(onnx_file_path))
+            print('loading onnx file from path {} ...'.format(onnx_file_path))
+            with open(onnx_file_path, 'rb') as model: # 浜屽�煎寲鐨勭綉缁滅粨鏋滃拰鍙傛暟
+                print("Begining onnx file parsing")
+                parser.parse(model.read())
+
+            print("Completed parsing of onnx file")
+            print("Building an engine from file{}' this may take a while...".format(onnx_file_path))
+
+            #################
+            print(network.get_layer(network.num_layers-1).get_output(0).shape)
+            engine=builder.build_engin(network, config)
+            print("Completed creating Engine")
+            if save_engine:
+                with open(engine_file_path, 'wb') as f:
+                    f.write(engine.serialize())  # 搴忓垪鍖�
+            return engine
+
+def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
+    [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
+    context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
+    [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
+    # gpu to cpu
+    # Synchronize the stream
+    stream.synchronize()
+    return [out.host for out in outputs]
+
+
+def postprocess_the_outputs(h_outputs, shape_of_output):
+    h_outputs = h_outputs.reshape(*shape_of_output)
+    return h_outputs
+
+if __name__ == '__main__':
+    img_np_nchw = get_img_np_nchw("/data/disk1/project/data/01_reid/0_1.png").astype(np.float32)
+    fp16_mode = True
+    trt_engine_path ="./human_feature{}.trt".format(fp16_mode)
+
+    engine = get_engine(max_batch_size, onnx_model_path, trt_engine_path, fp16_mode)
+
+    context = engine.create_execution_context()
+    inputs, outputs, bindings, stream = allocate_buffers(engine)
+
+    shape_of_output = (max_batch_size, 2048)
+
+    inputs[0].host = img_np_nchw
+
+    t1 = time.time()
+    trt_outputs = do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream, batch_size = max_batch_size)
+    t2 = time.time()
+    print(trt_outputs, trt_outputs[0].shape)
+
+    feat = postprocess_the_outputs(trt_outputs[0], shape_of_output)
+    print('TensorRT ok')
+    print("Inference time with the TensorRT engine: {}".format(t2-t1))
+
+
+
diff --git a/tools/03_py2onnx.py b/tools/03_py2onnx.py
new file mode 100644
index 0000000..1090f23
--- /dev/null
+++ b/tools/03_py2onnx.py
@@ -0,0 +1,40 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Date    : 2021-04-29 18:20:17
+# @Author  : Scheaven (snow_mail@foxmail.com)
+# @Link    : www.github.com
+# @Version : $Id$
+
+import torch
+import sys
+sys.path.append('.')
+
+from data.data_utils import read_image
+from predictor import ReID_Model
+from config import get_cfg
+from data.transforms.build import build_transforms
+from engine.defaults import default_argument_parser, default_setup
+import time
+
+def setup(args):
+    """
+    Create configs and perform basic setups.
+    """
+    cfg = get_cfg()
+    cfg.merge_from_file(args.config_file)
+    cfg.merge_from_list(args.opts)
+    cfg.freeze()
+    default_setup(cfg, args)
+    return cfg
+
+if __name__ == '__main__':
+    args = default_argument_parser().parse_args()
+    cfg = setup(args)
+    cfg.defrost()
+    cfg.MODEL.BACKBONE.PRETRAIN = False
+
+    model = ReID_Model(cfg)
+
+
+    model.torch2onnx()
+
diff --git a/tools/04_trt_inference.py b/tools/04_trt_inference.py
new file mode 100644
index 0000000..15cb9d3
--- /dev/null
+++ b/tools/04_trt_inference.py
@@ -0,0 +1,96 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import argparse
+import glob
+import os
+import sys
+
+import cv2
+import numpy as np
+import tqdm
+
+sys.path.append("/export/home/lxy/runtimelib-tensorrt-tiny/build")
+
+import pytrt
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(description="trt model inference")
+
+    parser.add_argument(
+        "--model_path",
+        default="outputs/trt_model/baseline.engine",
+        help="trt model path"
+    )
+    parser.add_argument(
+        "--input",
+        nargs="+",
+        help="A list of space separated input images; "
+             "or a single glob pattern such as 'directory/*.jpg'",
+    )
+    parser.add_argument(
+        "--output",
+        default="trt_output",
+        help="path to save trt model inference results"
+    )
+    parser.add_argument(
+        "--output-name",
+        help="tensorRT model output name"
+    )
+    parser.add_argument(
+        "--height",
+        type=int,
+        default=256,
+        help="height of image"
+    )
+    parser.add_argument(
+        "--width",
+        type=int,
+        default=128,
+        help="width of image"
+    )
+    return parser
+
+
+def preprocess(image_path, image_height, image_width):
+    original_image = cv2.imread(image_path)
+    # the model expects RGB inputs
+    original_image = original_image[:, :, ::-1]
+
+    # Apply pre-processing to image.
+    img = cv2.resize(original_image, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
+    img = img.astype("float32").transpose(2, 0, 1)[np.newaxis]  # (1, 3, h, w)
+    return img
+
+
+def normalize(nparray, order=2, axis=-1):
+    """Normalize a N-D numpy array along the specified axis."""
+    norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
+    return nparray / (norm + np.finfo(np.float32).eps)
+
+
+if __name__ == "__main__":
+    args = get_parser().parse_args()
+
+    trt = pytrt.Trt()
+
+    onnxModel = ""
+    engineFile = args.model_path
+    customOutput = []
+    maxBatchSize = 1
+    calibratorData = []
+    mode = 2
+    trt.CreateEngine(onnxModel, engineFile, customOutput, maxBatchSize, mode, calibratorData)
+
+    if not os.path.exists(args.output): os.makedirs(args.output)
+
+    if args.input:
+        if os.path.isdir(args.input[0]):
+            args.input = glob.glob(os.path.expanduser(args.input[0]))
+            assert args.input, "The input path(s) was not found"
+        for path in tqdm.tqdm(args.input):
+            input_numpy_array = preprocess(path, args.height, args.width)
+            trt.DoInference(input_numpy_array)
+            feat = trt.GetOutput(args.output_name)
+            feat = normalize(feat, axis=1)
+            np.save(os.path.join(args.output, path.replace('.jpg', '.npy').split('/')[-1]), feat)
diff --git a/tools/__init__.py b/tools/__init__.py
new file mode 100644
index 0000000..f6f724d
--- /dev/null
+++ b/tools/__init__.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time    : 2020/10/26 11:47
+# @Author  : Scheaven
+# @File    :  __init__.py.py
+# @description: 
\ No newline at end of file
diff --git a/tools/__pycache__/predictor.cpython-37.pyc b/tools/__pycache__/predictor.cpython-37.pyc
new file mode 100644
index 0000000..2fbb3f2
--- /dev/null
+++ b/tools/__pycache__/predictor.cpython-37.pyc
Binary files differ
diff --git a/tools/__pycache__/predictor.cpython-38.pyc b/tools/__pycache__/predictor.cpython-38.pyc
new file mode 100644
index 0000000..50ee609
--- /dev/null
+++ b/tools/__pycache__/predictor.cpython-38.pyc
Binary files differ
diff --git a/tools/inference_net.py b/tools/inference_net.py
new file mode 100644
index 0000000..bc32670
--- /dev/null
+++ b/tools/inference_net.py
@@ -0,0 +1,60 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time    : 2020/10/26 11:48
+# @Author  : Scheaven
+# @File    :  inference_net.py
+# @description:
+
+import torch
+import sys
+sys.path.append('.')
+
+from data.data_utils import read_image
+from predictor import ReID_Model
+from config import get_cfg
+from data.transforms.build import build_transforms
+from engine.defaults import default_argument_parser, default_setup
+import time
+
+def setup(args):
+    """
+    Create configs and perform basic setups.
+    """
+    cfg = get_cfg()
+    cfg.merge_from_file(args.config_file)
+    cfg.merge_from_list(args.opts)
+    cfg.freeze()
+    default_setup(cfg, args)
+    return cfg
+
+if __name__ == '__main__':
+    args = default_argument_parser().parse_args()
+    cfg = setup(args)
+    cfg.defrost()
+    cfg.MODEL.BACKBONE.PRETRAIN = False
+
+    model = ReID_Model(cfg)
+
+    test_transforms = build_transforms(cfg, is_train=False)
+    # print (args.img_a1)
+    img_a1 = read_image(args.img_a1)
+    img_a2 = read_image(args.img_a2)
+    img_b1 = read_image(args.img_b1)
+    img_b2 = read_image(args.img_b2)
+
+    img_a1 = test_transforms(img_a1)
+    img_a2 = test_transforms(img_a2)
+    img_b1 = test_transforms(img_b1)
+    img_b2 = test_transforms(img_b2)
+
+    out = torch.zeros((2, *img_a1.size()), dtype=img_a1.dtype)
+    out[0] += img_a1
+    out[1] += img_a2
+    t1 = time.time()
+    qurey_feat = model.run_on_image(out)
+    t2 = time.time()
+    print("t2-t1:", t2-t1)
+
+    similarity1 = torch.cosine_similarity(qurey_feat[0], qurey_feat[1], dim=0)
+    t3 = time.time()
+    print("t2-t1::", t3-t2, similarity1)
diff --git a/tools/predictor.py b/tools/predictor.py
new file mode 100644
index 0000000..952f132
--- /dev/null
+++ b/tools/predictor.py
@@ -0,0 +1,114 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time    : 2020/10/26 15:50
+# @Author  : Scheaven
+# @File    :  predictor.py
+# @description:
+import cv2, io
+import torch
+import torch.nn.functional as F
+
+from modeling.meta_arch.build import build_model
+from utils.checkpoint import Checkpointer
+
+import onnx
+import torch
+from onnxsim import simplify
+from torch.onnx import OperatorExportTypes
+
+class ReID_Model(object):
+    def __init__(self, cfg):
+        self.cfg = cfg
+
+        self.predictor = DefaultPredictor(cfg)
+
+    def run_on_image(self, original_image):
+
+        predictions = self.predictor(original_image)
+        return predictions
+
+    def torch2onnx(self):
+        predictions = self.predictor.to_onnx()
+
+class DefaultPredictor:
+
+    def __init__(self, cfg):
+        self.cfg = cfg.clone()
+        self.cfg.defrost()
+        self.cfg.MODEL.BACKBONE.PRETRAIN = False
+        self.model = build_model(self.cfg)
+        for param in self.model.parameters():
+            param.requires_grad = False
+        self.model.cuda()
+        self.model.eval()
+
+        Checkpointer(self.model).load(cfg.MODEL.WEIGHTS)
+
+    def __call__(self, image):
+        with torch.no_grad():  # https://github.com/sphinx-doc/sphinx/issues/4258
+            images = image.cuda()
+            self.model.eval()
+            predictions = self.model(images)
+            torch.set_printoptions(edgeitems=2048)
+            print("------------\n", predictions)
+            pred_feat = F.normalize(predictions)
+            pred_feat = pred_feat.cpu().data
+            return pred_feat
+
+    def to_onnx(self):
+        inputs = torch.randn(1, 3, self.cfg.INPUT.SIZE_TEST[0], self.cfg.INPUT.SIZE_TEST[1]).cuda()
+        onnx_model = self.export_onnx_model(self.model, inputs)
+
+        model_simp, check = simplify(onnx_model)
+
+        model_simp = self.remove_initializer_from_input(model_simp)
+
+        assert check, "Simplified ONNX model could not be validated"
+
+        onnx.save_model(model_simp, f"fastreid.onnx")
+
+
+    def export_onnx_model(self, model, inputs):
+        assert isinstance(model, torch.nn.Module)
+        def _check_eval(module):
+            assert not module.training
+
+        model.apply(_check_eval)
+
+        with torch.no_grad():
+            with io.BytesIO() as f:
+                torch.onnx.export(
+                    model,
+                    inputs,
+                    f,
+                    operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
+                )
+                onnx_model = onnx.load_from_string(f.getvalue())
+
+        # Apply ONNX's Optimization
+        all_passes = onnx.optimizer.get_available_passes()
+        passes = ["extract_constant_to_initializer", "eliminate_unused_initializer", "fuse_bn_into_conv"]
+        assert all(p in all_passes for p in passes)
+        onnx_model = onnx.optimizer.optimize(onnx_model, passes)
+        return onnx_model
+
+
+    def remove_initializer_from_input(self, model):
+        if model.ir_version < 4:
+            print(
+                'Model with ir_version below 4 requires to include initilizer in graph input'
+            )
+            return
+
+        inputs = model.graph.input
+        name_to_input = {}
+        for input in inputs:
+            name_to_input[input.name] = input
+
+        for initializer in model.graph.initializer:
+            if initializer.name in name_to_input:
+                inputs.remove(name_to_input[initializer.name])
+
+        return model
+
+
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..9db4c43
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time    : 2020/10/26 14:53
+# @Author  : Scheaven
+# @File    :  __init__.py.py
+# @description: 
\ No newline at end of file
diff --git a/utils/__pycache__/__init__.cpython-37.pyc b/utils/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..c328618
--- /dev/null
+++ b/utils/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/utils/__pycache__/__init__.cpython-38.pyc b/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..5a4265d
--- /dev/null
+++ b/utils/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/checkpoint.cpython-37.pyc b/utils/__pycache__/checkpoint.cpython-37.pyc
new file mode 100644
index 0000000..19b7364
--- /dev/null
+++ b/utils/__pycache__/checkpoint.cpython-37.pyc
Binary files differ
diff --git a/utils/__pycache__/checkpoint.cpython-38.pyc b/utils/__pycache__/checkpoint.cpython-38.pyc
new file mode 100644
index 0000000..6f785ad
--- /dev/null
+++ b/utils/__pycache__/checkpoint.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/collect_env.cpython-38.pyc b/utils/__pycache__/collect_env.cpython-38.pyc
new file mode 100644
index 0000000..92bc83c
--- /dev/null
+++ b/utils/__pycache__/collect_env.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/comm.cpython-37.pyc b/utils/__pycache__/comm.cpython-37.pyc
new file mode 100644
index 0000000..00e5e78
--- /dev/null
+++ b/utils/__pycache__/comm.cpython-37.pyc
Binary files differ
diff --git a/utils/__pycache__/comm.cpython-38.pyc b/utils/__pycache__/comm.cpython-38.pyc
new file mode 100644
index 0000000..8979735
--- /dev/null
+++ b/utils/__pycache__/comm.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/env.cpython-38.pyc b/utils/__pycache__/env.cpython-38.pyc
new file mode 100644
index 0000000..8427005
--- /dev/null
+++ b/utils/__pycache__/env.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/events.cpython-37.pyc b/utils/__pycache__/events.cpython-37.pyc
new file mode 100644
index 0000000..64fe6a2
--- /dev/null
+++ b/utils/__pycache__/events.cpython-37.pyc
Binary files differ
diff --git a/utils/__pycache__/events.cpython-38.pyc b/utils/__pycache__/events.cpython-38.pyc
new file mode 100644
index 0000000..91f67bc
--- /dev/null
+++ b/utils/__pycache__/events.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/file_io.cpython-37.pyc b/utils/__pycache__/file_io.cpython-37.pyc
new file mode 100644
index 0000000..12cd498
--- /dev/null
+++ b/utils/__pycache__/file_io.cpython-37.pyc
Binary files differ
diff --git a/utils/__pycache__/file_io.cpython-38.pyc b/utils/__pycache__/file_io.cpython-38.pyc
new file mode 100644
index 0000000..1934ab0
--- /dev/null
+++ b/utils/__pycache__/file_io.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/history_buffer.cpython-37.pyc b/utils/__pycache__/history_buffer.cpython-37.pyc
new file mode 100644
index 0000000..dc54f13
--- /dev/null
+++ b/utils/__pycache__/history_buffer.cpython-37.pyc
Binary files differ
diff --git a/utils/__pycache__/history_buffer.cpython-38.pyc b/utils/__pycache__/history_buffer.cpython-38.pyc
new file mode 100644
index 0000000..27b103e
--- /dev/null
+++ b/utils/__pycache__/history_buffer.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/logger.cpython-38.pyc b/utils/__pycache__/logger.cpython-38.pyc
new file mode 100644
index 0000000..2bed12c
--- /dev/null
+++ b/utils/__pycache__/logger.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/registry.cpython-37.pyc b/utils/__pycache__/registry.cpython-37.pyc
new file mode 100644
index 0000000..bd377c1
--- /dev/null
+++ b/utils/__pycache__/registry.cpython-37.pyc
Binary files differ
diff --git a/utils/__pycache__/registry.cpython-38.pyc b/utils/__pycache__/registry.cpython-38.pyc
new file mode 100644
index 0000000..f01b9ce
--- /dev/null
+++ b/utils/__pycache__/registry.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/weight_init.cpython-37.pyc b/utils/__pycache__/weight_init.cpython-37.pyc
new file mode 100644
index 0000000..04292a6
--- /dev/null
+++ b/utils/__pycache__/weight_init.cpython-37.pyc
Binary files differ
diff --git a/utils/__pycache__/weight_init.cpython-38.pyc b/utils/__pycache__/weight_init.cpython-38.pyc
new file mode 100644
index 0000000..8833ded
--- /dev/null
+++ b/utils/__pycache__/weight_init.cpython-38.pyc
Binary files differ
diff --git a/utils/checkpoint.py b/utils/checkpoint.py
new file mode 100644
index 0000000..a7900b6
--- /dev/null
+++ b/utils/checkpoint.py
@@ -0,0 +1,403 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import collections
+import copy
+import logging
+import os
+from collections import defaultdict
+from typing import Any
+
+import numpy as np
+import torch
+import torch.nn as nn
+from termcolor import colored
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+
+from .file_io import PathManager
+
+
+class Checkpointer(object):
+    """
+    A checkpointer that can save/load model as well as extra checkpointable
+    objects.
+    """
+
+    def __init__(
+            self,
+            model: nn.Module,
+            save_dir: str = "",
+            *,
+            save_to_disk: bool = True,
+            **checkpointables: object,
+    ):
+        """
+        Args:
+            model (nn.Module): model.
+            save_dir (str): a directory to save and find checkpoints.
+            save_to_disk (bool): if True, save checkpoint to disk, otherwise
+                disable saving for this checkpointer.
+            checkpointables (object): any checkpointable objects, i.e., objects
+                that have the `state_dict()` and `load_state_dict()` method. For
+                example, it can be used like
+                `Checkpointer(model, "dir", optimizer=optimizer)`.
+        """
+        if isinstance(model, (DistributedDataParallel, DataParallel)):
+            model = model.module
+        self.model = model
+        self.checkpointables = copy.copy(checkpointables)
+        self.logger = logging.getLogger(__name__)
+        self.save_dir = save_dir
+        self.save_to_disk = save_to_disk
+
+    def save(self, name: str, **kwargs: dict):
+        """
+        Dump model and checkpointables to a file.
+        Args:
+            name (str): name of the file.
+            kwargs (dict): extra arbitrary data to save.
+        """
+        if not self.save_dir or not self.save_to_disk:
+            return
+
+        data = {}
+        data["model"] = self.model.state_dict()
+        for key, obj in self.checkpointables.items():
+            data[key] = obj.state_dict()
+        data.update(kwargs)
+
+        basename = "{}.pth".format(name)
+        save_file = os.path.join(self.save_dir, basename)
+        assert os.path.basename(save_file) == basename, basename
+        self.logger.info("Saving checkpoint to {}".format(save_file))
+        with PathManager.open(save_file, "wb") as f:
+            torch.save(data, f)
+        self.tag_last_checkpoint(basename)
+
+    def load(self, path: str):
+        """
+        Load from the given checkpoint. When path points to network file, this
+        function has to be called on all ranks.
+        Args:
+            path (str): path or url to the checkpoint. If empty, will not load
+                anything.
+        Returns:
+            dict:
+                extra data loaded from the checkpoint that has not been
+                processed. For example, those saved with
+                :meth:`.save(**extra_data)`.
+        """
+        if not path:
+            # no checkpoint provided
+            self.logger.info(
+                "No checkpoint found. Training model from scratch"
+            )
+            return {}
+        self.logger.info("Loading checkpoint from {}".format(path))
+        if not os.path.isfile(path):
+            path = PathManager.get_local_path(path)
+            assert os.path.isfile(path), "Checkpoint {} not found!".format(path)
+
+        checkpoint = self._load_file(path)
+        self._load_model(checkpoint)
+        for key, obj in self.checkpointables.items():
+            if key in checkpoint:
+                self.logger.info("Loading {} from {}".format(key, path))
+                obj.load_state_dict(checkpoint.pop(key))
+
+        # return any further checkpoint data
+        return checkpoint
+
+    def has_checkpoint(self):
+        """
+        Returns:
+            bool: whether a checkpoint exists in the target directory.
+        """
+        save_file = os.path.join(self.save_dir, "last_checkpoint")
+        return PathManager.exists(save_file)
+
+    def get_checkpoint_file(self):
+        """
+        Returns:
+            str: The latest checkpoint file in target directory.
+        """
+        save_file = os.path.join(self.save_dir, "last_checkpoint")
+        try:
+            with PathManager.open(save_file, "r") as f:
+                last_saved = f.read().strip()
+        except IOError:
+            # if file doesn't exist, maybe because it has just been
+            # deleted by a separate process
+            return ""
+        return os.path.join(self.save_dir, last_saved)
+
+    def get_all_checkpoint_files(self):
+        """
+        Returns:
+            list: All available checkpoint files (.pth files) in target
+                directory.
+        """
+        all_model_checkpoints = [
+            os.path.join(self.save_dir, file)
+            for file in PathManager.ls(self.save_dir)
+            if PathManager.isfile(os.path.join(self.save_dir, file))
+               and file.endswith(".pth")
+        ]
+        return all_model_checkpoints
+
+    def resume_or_load(self, path: str, *, resume: bool = True):
+        """
+        If `resume` is True, this method attempts to resume from the last
+        checkpoint, if exists. Otherwise, load checkpoint from the given path.
+        This is useful when restarting an interrupted training job.
+        Args:
+            path (str): path to the checkpoint.
+            resume (bool): if True, resume from the last checkpoint if it exists.
+        Returns:
+            same as :meth:`load`.
+        """
+        if resume and self.has_checkpoint():
+            path = self.get_checkpoint_file()
+        return self.load(path)
+
+    def tag_last_checkpoint(self, last_filename_basename: str):
+        """
+        Tag the last checkpoint.
+        Args:
+            last_filename_basename (str): the basename of the last filename.
+        """
+        save_file = os.path.join(self.save_dir, "last_checkpoint")
+        with PathManager.open(save_file, "w") as f:
+            f.write(last_filename_basename)
+
+    def _load_file(self, f: str):
+        """
+        Load a checkpoint file. Can be overwritten by subclasses to support
+        different formats.
+        Args:
+            f (str): a locally mounted file path.
+        Returns:
+            dict: with keys "model" and optionally others that are saved by
+                the checkpointer dict["model"] must be a dict which maps strings
+                to torch.Tensor or numpy arrays.
+        """
+        return torch.load(f, map_location=torch.device("cpu"))
+
+    def _load_model(self, checkpoint: Any):
+        """
+        Load weights from a checkpoint.
+        Args:
+            checkpoint (Any): checkpoint contains the weights.
+        """
+        checkpoint_state_dict = checkpoint.pop("model")
+        self._convert_ndarray_to_tensor(checkpoint_state_dict)
+
+        # if the state_dict comes from a model that was wrapped in a
+        # DataParallel or DistributedDataParallel during serialization,
+        # remove the "module" prefix before performing the matching.
+        _strip_prefix_if_present(checkpoint_state_dict, "module.")
+
+        # work around https://github.com/pytorch/pytorch/issues/24139
+        model_state_dict = self.model.state_dict()
+        for k in list(checkpoint_state_dict.keys()):
+            if k in model_state_dict:
+                shape_model = tuple(model_state_dict[k].shape)
+                shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                if shape_model != shape_checkpoint:
+                    self.logger.warning(
+                        "'{}' has shape {} in the checkpoint but {} in the "
+                        "model! Skipped.".format(
+                            k, shape_checkpoint, shape_model
+                        )
+                    )
+                    checkpoint_state_dict.pop(k)
+
+        incompatible = self.model.load_state_dict(
+            checkpoint_state_dict, strict=False
+        )
+        if incompatible.missing_keys:
+            self.logger.info(
+                get_missing_parameters_message(incompatible.missing_keys)
+            )
+        if incompatible.unexpected_keys:
+            self.logger.info(
+                get_unexpected_parameters_message(incompatible.unexpected_keys)
+            )
+
+    def _convert_ndarray_to_tensor(self, state_dict: dict):
+        """
+        In-place convert all numpy arrays in the state_dict to torch tensor.
+        Args:
+            state_dict (dict): a state-dict to be loaded to the model.
+        """
+        # model could be an OrderedDict with _metadata attribute
+        # (as returned by Pytorch's state_dict()). We should preserve these
+        # properties.
+        for k in list(state_dict.keys()):
+            v = state_dict[k]
+            if not isinstance(v, np.ndarray) and not isinstance(
+                    v, torch.Tensor
+            ):
+                raise ValueError(
+                    "Unsupported type found in checkpoint! {}: {}".format(
+                        k, type(v)
+                    )
+                )
+            if not isinstance(v, torch.Tensor):
+                state_dict[k] = torch.from_numpy(v)
+
+
+class PeriodicCheckpointer:
+    """
+    Save checkpoints periodically. When `.step(iteration)` is called, it will
+    execute `checkpointer.save` on the given checkpointer, if iteration is a
+    multiple of period or if `max_iter` is reached.
+    """
+
+    def __init__(self, checkpointer: Any, period: int, max_iter: int = None):
+        """
+        Args:
+            checkpointer (Any): the checkpointer object used to save
+            checkpoints.
+            period (int): the period to save checkpoint.
+            max_iter (int): maximum number of iterations. When it is reached,
+                a checkpoint named "model_final" will be saved.
+        """
+        self.checkpointer = checkpointer
+        self.period = int(period)
+        self.max_iter = max_iter
+
+    def step(self, iteration: int, **kwargs: Any):
+        """
+        Perform the appropriate action at the given iteration.
+        Args:
+            iteration (int): the current iteration, ranged in [0, max_iter-1].
+            kwargs (Any): extra data to save, same as in
+                :meth:`Checkpointer.save`.
+        """
+        iteration = int(iteration)
+        additional_state = {"iteration": iteration}
+        additional_state.update(kwargs)
+        if (iteration + 1) % self.period == 0:
+            self.checkpointer.save(
+                "model_{:07d}".format(iteration), **additional_state
+            )
+        if iteration >= self.max_iter - 1:
+            self.checkpointer.save("model_final", **additional_state)
+
+    def save(self, name: str, **kwargs: Any):
+        """
+        Same argument as :meth:`Checkpointer.save`.
+        Use this method to manually save checkpoints outside the schedule.
+        Args:
+            name (str): file name.
+            kwargs (Any): extra data to save, same as in
+                :meth:`Checkpointer.save`.
+        """
+        self.checkpointer.save(name, **kwargs)
+
+
+def get_missing_parameters_message(keys: list):
+    """
+    Get a logging-friendly message to report parameter names (keys) that are in
+    the model but not found in a checkpoint.
+    Args:
+        keys (list[str]): List of keys that were not found in the checkpoint.
+    Returns:
+        str: message.
+    """
+    groups = _group_checkpoint_keys(keys)
+    msg = "Some model parameters are not in the checkpoint:\n"
+    msg += "\n".join(
+        "  " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
+    )
+    return msg
+
+
+def get_unexpected_parameters_message(keys: list):
+    """
+    Get a logging-friendly message to report parameter names (keys) that are in
+    the checkpoint but not found in the model.
+    Args:
+        keys (list[str]): List of keys that were not found in the model.
+    Returns:
+        str: message.
+    """
+    groups = _group_checkpoint_keys(keys)
+    msg = "The checkpoint contains parameters not used by the model:\n"
+    msg += "\n".join(
+        "  " + colored(k + _group_to_str(v), "magenta")
+        for k, v in groups.items()
+    )
+    return msg
+
+
+def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
+    """
+    Strip the prefix in metadata, if any.
+    Args:
+        state_dict (OrderedDict): a state-dict to be loaded to the model.
+        prefix (str): prefix.
+    """
+    keys = sorted(state_dict.keys())
+    if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
+        return
+
+    for key in keys:
+        newkey = key[len(prefix):]
+        state_dict[newkey] = state_dict.pop(key)
+
+    # also strip the prefix in metadata, if any..
+    try:
+        metadata = state_dict._metadata
+    except AttributeError:
+        pass
+    else:
+        for key in list(metadata.keys()):
+            # for the metadata dict, the key can be:
+            # '': for the DDP module, which we want to remove.
+            # 'module': for the actual model.
+            # 'module.xx.xx': for the rest.
+
+            if len(key) == 0:
+                continue
+            newkey = key[len(prefix):]
+            metadata[newkey] = metadata.pop(key)
+
+
+def _group_checkpoint_keys(keys: list):
+    """
+    Group keys based on common prefixes. A prefix is the string up to the final
+    "." in each key.
+    Args:
+        keys (list[str]): list of parameter names, i.e. keys in the model
+            checkpoint dict.
+    Returns:
+        dict[list]: keys with common prefixes are grouped into lists.
+    """
+    groups = defaultdict(list)
+    for key in keys:
+        pos = key.rfind(".")
+        if pos >= 0:
+            head, tail = key[:pos], [key[pos + 1:]]
+        else:
+            head, tail = key, []
+        groups[head].extend(tail)
+    return groups
+
+
+def _group_to_str(group: list):
+    """
+    Format a group of parameter name suffixes into a loggable string.
+    Args:
+        group (list[str]): list of parameter name suffixes.
+    Returns:
+        str: formated string.
+    """
+    if len(group) == 0:
+        return ""
+
+    if len(group) == 1:
+        return "." + group[0]
+
+    return ".{" + ", ".join(group) + "}"
diff --git a/utils/collect_env.py b/utils/collect_env.py
new file mode 100644
index 0000000..ba5d2ce
--- /dev/null
+++ b/utils/collect_env.py
@@ -0,0 +1,158 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+# based on
+# https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/collect_env.py
+import importlib
+import os
+import re
+import subprocess
+import sys
+from collections import defaultdict
+
+import PIL
+import numpy as np
+import torch
+import torchvision
+from tabulate import tabulate
+
+__all__ = ["collect_env_info"]
+
+
+def collect_torch_env():
+    try:
+        import torch.__config__
+
+        return torch.__config__.show()
+    except ImportError:
+        # compatible with older versions of pytorch
+        from torch.utils.collect_env import get_pretty_env_info
+
+        return get_pretty_env_info()
+
+
+def get_env_module():
+    var_name = "FASTREID_ENV_MODULE"
+    return var_name, os.environ.get(var_name, "<not set>")
+
+
+def detect_compute_compatibility(CUDA_HOME, so_file):
+    try:
+        cuobjdump = os.path.join(CUDA_HOME, "bin", "cuobjdump")
+        if os.path.isfile(cuobjdump):
+            output = subprocess.check_output(
+                "'{}' --list-elf '{}'".format(cuobjdump, so_file), shell=True
+            )
+            output = output.decode("utf-8").strip().split("\n")
+            sm = []
+            for line in output:
+                line = re.findall(r"\.sm_[0-9]*\.", line)[0]
+                sm.append(line.strip("."))
+            sm = sorted(set(sm))
+            return ", ".join(sm)
+        else:
+            return so_file + "; cannot find cuobjdump"
+    except Exception:
+        # unhandled failure
+        return so_file
+
+
+def collect_env_info():
+    has_gpu = torch.cuda.is_available()  # true for both CUDA & ROCM
+    torch_version = torch.__version__
+
+    # NOTE: the use of CUDA_HOME and ROCM_HOME requires the CUDA/ROCM build deps, though in
+    # theory detectron2 should be made runnable with only the corresponding runtimes
+    from torch.utils.cpp_extension import CUDA_HOME
+
+    has_rocm = False
+    if tuple(map(int, torch_version.split(".")[:2])) >= (1, 5):
+        from torch.utils.cpp_extension import ROCM_HOME
+
+        if (getattr(torch.version, "hip", None) is not None) and (ROCM_HOME is not None):
+            has_rocm = True
+    has_cuda = has_gpu and (not has_rocm)
+
+    data = []
+    data.append(("sys.platform", sys.platform))
+    data.append(("Python", sys.version.replace("\n", "")))
+    data.append(("numpy", np.__version__))
+
+    try:
+        import fastreid  # noqa
+
+        data.append(
+            ("fastreid", fastreid.__version__ + " @" + os.path.dirname(fastreid.__file__))
+        )
+    except ImportError:
+        data.append(("fastreid", "failed to import"))
+
+    data.append(get_env_module())
+    data.append(("PyTorch", torch_version + " @" + os.path.dirname(torch.__file__)))
+    data.append(("PyTorch debug build", torch.version.debug))
+
+    data.append(("GPU available", has_gpu))
+    if has_gpu:
+        devices = defaultdict(list)
+        for k in range(torch.cuda.device_count()):
+            devices[torch.cuda.get_device_name(k)].append(str(k))
+        for name, devids in devices.items():
+            data.append(("GPU " + ",".join(devids), name))
+
+        if has_rocm:
+            data.append(("ROCM_HOME", str(ROCM_HOME)))
+        else:
+            data.append(("CUDA_HOME", str(CUDA_HOME)))
+
+            cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
+            if cuda_arch_list:
+                data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list))
+    data.append(("Pillow", PIL.__version__))
+
+    try:
+        data.append(
+            (
+                "torchvision",
+                str(torchvision.__version__) + " @" + os.path.dirname(torchvision.__file__),
+            )
+        )
+        if has_cuda:
+            try:
+                torchvision_C = importlib.util.find_spec("torchvision._C").origin
+                msg = detect_compute_compatibility(CUDA_HOME, torchvision_C)
+                data.append(("torchvision arch flags", msg))
+            except ImportError:
+                data.append(("torchvision._C", "failed to find"))
+    except AttributeError:
+        data.append(("torchvision", "unknown"))
+
+    try:
+        import fvcore
+
+        data.append(("fvcore", fvcore.__version__))
+    except ImportError:
+        pass
+
+    try:
+        import cv2
+
+        data.append(("cv2", cv2.__version__))
+    except ImportError:
+        pass
+    env_str = tabulate(data) + "\n"
+    env_str += collect_torch_env()
+    return env_str
+
+
+if __name__ == "__main__":
+    try:
+        import detectron2  # noqa
+    except ImportError:
+        print(collect_env_info())
+    else:
+        from fastreid.utils.collect_env import collect_env_info
+
+        print(collect_env_info())
diff --git a/utils/comm.py b/utils/comm.py
new file mode 100644
index 0000000..06cc7f0
--- /dev/null
+++ b/utils/comm.py
@@ -0,0 +1,255 @@
+"""
+This file contains primitives for multi-gpu communication.
+This is useful when doing distributed training.
+"""
+
+import functools
+import logging
+import numpy as np
+import pickle
+import torch
+import torch.distributed as dist
+
+_LOCAL_PROCESS_GROUP = None
+"""
+A torch process group which only includes processes that on the same machine as the current process.
+This variable is set when processes are spawned by `launch()` in "engine/launch.py".
+"""
+
+
+def get_world_size() -> int:
+    if not dist.is_available():
+        return 1
+    if not dist.is_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank() -> int:
+    if not dist.is_available():
+        return 0
+    if not dist.is_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def get_local_rank() -> int:
+    """
+    Returns:
+        The rank of the current process within the local (per-machine) process group.
+    """
+    if not dist.is_available():
+        return 0
+    if not dist.is_initialized():
+        return 0
+    assert _LOCAL_PROCESS_GROUP is not None
+    return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_size() -> int:
+    """
+    Returns:
+        The size of the per-machine process group,
+        i.e. the number of processes per machine.
+    """
+    if not dist.is_available():
+        return 1
+    if not dist.is_initialized():
+        return 1
+    return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def is_main_process() -> bool:
+    return get_rank() == 0
+
+
+def synchronize():
+    """
+    Helper function to synchronize (barrier) among all processes when
+    using distributed training
+    """
+    if not dist.is_available():
+        return
+    if not dist.is_initialized():
+        return
+    world_size = dist.get_world_size()
+    if world_size == 1:
+        return
+    dist.barrier()
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+    """
+    Return a process group based on gloo backend, containing all the ranks
+    The result is cached.
+    """
+    if dist.get_backend() == "nccl":
+        return dist.new_group(backend="gloo")
+    else:
+        return dist.group.WORLD
+
+
+def _serialize_to_tensor(data, group):
+    backend = dist.get_backend(group)
+    assert backend in ["gloo", "nccl"]
+    device = torch.device("cpu" if backend == "gloo" else "cuda")
+
+    buffer = pickle.dumps(data)
+    if len(buffer) > 1024 ** 3:
+        logger = logging.getLogger(__name__)
+        logger.warning(
+            "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
+                get_rank(), len(buffer) / (1024 ** 3), device
+            )
+        )
+    storage = torch.ByteStorage.from_buffer(buffer)
+    tensor = torch.ByteTensor(storage).to(device=device)
+    return tensor
+
+
+def _pad_to_largest_tensor(tensor, group):
+    """
+    Returns:
+        list[int]: size of the tensor, on each rank
+        Tensor: padded tensor that has the max size
+    """
+    world_size = dist.get_world_size(group=group)
+    assert (
+            world_size >= 1
+    ), "comm.gather/all_gather must be called from ranks within the given group!"
+    local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
+    size_list = [
+        torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
+    ]
+    dist.all_gather(size_list, local_size, group=group)
+    size_list = [int(size.item()) for size in size_list]
+
+    max_size = max(size_list)
+
+    # we pad the tensor because torch all_gather does not support
+    # gathering tensors of different shapes
+    if local_size != max_size:
+        padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
+        tensor = torch.cat((tensor, padding), dim=0)
+    return size_list, tensor
+
+
+def all_gather(data, group=None):
+    """
+    Run all_gather on arbitrary picklable data (not necessarily tensors).
+    Args:
+        data: any picklable object
+        group: a torch process group. By default, will use a group which
+            contains all ranks on gloo backend.
+    Returns:
+        list[data]: list of data gathered from each rank
+    """
+    if get_world_size() == 1:
+        return [data]
+    if group is None:
+        group = _get_global_gloo_group()
+    if dist.get_world_size(group) == 1:
+        return [data]
+
+    tensor = _serialize_to_tensor(data, group)
+
+    size_list, tensor = _pad_to_largest_tensor(tensor, group)
+    max_size = max(size_list)
+
+    # receiving Tensor from all ranks
+    tensor_list = [
+        torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+    ]
+    dist.all_gather(tensor_list, tensor, group=group)
+
+    data_list = []
+    for size, tensor in zip(size_list, tensor_list):
+        buffer = tensor.cpu().numpy().tobytes()[:size]
+        data_list.append(pickle.loads(buffer))
+
+    return data_list
+
+
+def gather(data, dst=0, group=None):
+    """
+    Run gather on arbitrary picklable data (not necessarily tensors).
+    Args:
+        data: any picklable object
+        dst (int): destination rank
+        group: a torch process group. By default, will use a group which
+            contains all ranks on gloo backend.
+    Returns:
+        list[data]: on dst, a list of data gathered from each rank. Otherwise,
+            an empty list.
+    """
+    if get_world_size() == 1:
+        return [data]
+    if group is None:
+        group = _get_global_gloo_group()
+    if dist.get_world_size(group=group) == 1:
+        return [data]
+    rank = dist.get_rank(group=group)
+
+    tensor = _serialize_to_tensor(data, group)
+    size_list, tensor = _pad_to_largest_tensor(tensor, group)
+
+    # receiving Tensor from all ranks
+    if rank == dst:
+        max_size = max(size_list)
+        tensor_list = [
+            torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+        ]
+        dist.gather(tensor, tensor_list, dst=dst, group=group)
+
+        data_list = []
+        for size, tensor in zip(size_list, tensor_list):
+            buffer = tensor.cpu().numpy().tobytes()[:size]
+            data_list.append(pickle.loads(buffer))
+        return data_list
+    else:
+        dist.gather(tensor, [], dst=dst, group=group)
+        return []
+
+
+def shared_random_seed():
+    """
+    Returns:
+        int: a random number that is the same across all workers.
+            If workers need a shared RNG, they can use this shared seed to
+            create one.
+    All workers must call this function, otherwise it will deadlock.
+    """
+    ints = np.random.randint(2 ** 31)
+    all_ints = all_gather(ints)
+    return all_ints[0]
+
+
+def reduce_dict(input_dict, average=True):
+    """
+    Reduce the values in the dictionary from all processes so that process with rank
+    0 has the reduced results.
+    Args:
+        input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
+        average (bool): whether to do average or sum
+    Returns:
+        a dict with the same keys as input_dict, after reduction.
+    """
+    world_size = get_world_size()
+    if world_size < 2:
+        return input_dict
+    with torch.no_grad():
+        names = []
+        values = []
+        # sort the keys so that they are consistent across processes
+        for k in sorted(input_dict.keys()):
+            names.append(k)
+            values.append(input_dict[k])
+        values = torch.stack(values, dim=0)
+        dist.reduce(values, dst=0)
+        if dist.get_rank() == 0 and average:
+            # only main process gets accumulated, so only divide by
+            # world_size in this case
+            values /= world_size
+        reduced_dict = {k: v for k, v in zip(names, values)}
+    return reduced_dict
diff --git a/utils/env.py b/utils/env.py
new file mode 100644
index 0000000..72f6afe
--- /dev/null
+++ b/utils/env.py
@@ -0,0 +1,119 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import importlib
+import importlib.util
+import logging
+import numpy as np
+import os
+import random
+import sys
+from datetime import datetime
+import torch
+
+__all__ = ["seed_all_rng"]
+
+
+TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
+"""
+PyTorch version as a tuple of 2 ints. Useful for comparison.
+"""
+
+
+def seed_all_rng(seed=None):
+    """
+    Set the random seed for the RNG in torch, numpy and python.
+    Args:
+        seed (int): if None, will use a strong random seed.
+    """
+    if seed is None:
+        seed = (
+            os.getpid()
+            + int(datetime.now().strftime("%S%f"))
+            + int.from_bytes(os.urandom(2), "big")
+        )
+        logger = logging.getLogger(__name__)
+        logger.info("Using a generated random seed {}".format(seed))
+    np.random.seed(seed)
+    torch.set_rng_state(torch.manual_seed(seed).get_state())
+    random.seed(seed)
+
+
+# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
+def _import_file(module_name, file_path, make_importable=False):
+    spec = importlib.util.spec_from_file_location(module_name, file_path)
+    module = importlib.util.module_from_spec(spec)
+    spec.loader.exec_module(module)
+    if make_importable:
+        sys.modules[module_name] = module
+    return module
+
+
+def _configure_libraries():
+    """
+    Configurations for some libraries.
+    """
+    # An environment option to disable `import cv2` globally,
+    # in case it leads to negative performance impact
+    disable_cv2 = int(os.environ.get("DETECTRON2_DISABLE_CV2", False))
+    if disable_cv2:
+        sys.modules["cv2"] = None
+    else:
+        # Disable opencl in opencv since its interaction with cuda often has negative effects
+        # This envvar is supported after OpenCV 3.4.0
+        os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled"
+        try:
+            import cv2
+
+            if int(cv2.__version__.split(".")[0]) >= 3:
+                cv2.ocl.setUseOpenCL(False)
+        except ImportError:
+            pass
+
+    def get_version(module, digit=2):
+        return tuple(map(int, module.__version__.split(".")[:digit]))
+
+    # fmt: off
+    assert get_version(torch) >= (1, 4), "Requires torch>=1.4"
+    import yaml
+    assert get_version(yaml) >= (5, 1), "Requires pyyaml>=5.1"
+    # fmt: on
+
+
+_ENV_SETUP_DONE = False
+
+
+def setup_environment():
+    """Perform environment setup work. The default setup is a no-op, but this
+    function allows the user to specify a Python source file or a module in
+    the $FASTREID_ENV_MODULE environment variable, that performs
+    custom setup work that may be necessary to their computing environment.
+    """
+    global _ENV_SETUP_DONE
+    if _ENV_SETUP_DONE:
+        return
+    _ENV_SETUP_DONE = True
+
+    _configure_libraries()
+
+    custom_module_path = os.environ.get("FASTREID_ENV_MODULE")
+
+    if custom_module_path:
+        setup_custom_environment(custom_module_path)
+    else:
+        # The default setup is a no-op
+        pass
+
+
+def setup_custom_environment(custom_module):
+    """
+    Load custom environment setup by importing a Python source file or a
+    module, and run the setup function.
+    """
+    if custom_module.endswith(".py"):
+        module = _import_file("fastreid.utils.env.custom_module", custom_module)
+    else:
+        module = importlib.import_module(custom_module)
+    assert hasattr(module, "setup_environment") and callable(module.setup_environment), (
+        "Custom environment module defined in {} does not have the "
+        "required callable attribute 'setup_environment'."
+    ).format(custom_module)
+    module.setup_environment()
\ No newline at end of file
diff --git a/utils/events.py b/utils/events.py
new file mode 100644
index 0000000..bc1120a
--- /dev/null
+++ b/utils/events.py
@@ -0,0 +1,445 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import datetime
+import json
+import logging
+import os
+import time
+from collections import defaultdict
+from contextlib import contextmanager
+import torch
+from .file_io import PathManager
+from .history_buffer import HistoryBuffer
+
+__all__ = [
+    "get_event_storage",
+    "JSONWriter",
+    "TensorboardXWriter",
+    "CommonMetricPrinter",
+    "EventStorage",
+]
+
+_CURRENT_STORAGE_STACK = []
+
+
+def get_event_storage():
+    """
+    Returns:
+        The :class:`EventStorage` object that's currently being used.
+        Throws an error if no :class:`EventStorage` is currently enabled.
+    """
+    assert len(
+        _CURRENT_STORAGE_STACK
+    ), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
+    return _CURRENT_STORAGE_STACK[-1]
+
+
+class EventWriter:
+    """
+    Base class for writers that obtain events from :class:`EventStorage` and process them.
+    """
+
+    def write(self):
+        raise NotImplementedError
+
+    def close(self):
+        pass
+
+
+class JSONWriter(EventWriter):
+    """
+    Write scalars to a json file.
+    It saves scalars as one json per line (instead of a big json) for easy parsing.
+    Examples parsing such a json file:
+    ::
+        $ cat metrics.json | jq -s '.[0:2]'
+        [
+          {
+            "data_time": 0.008433341979980469,
+            "iteration": 20,
+            "loss": 1.9228371381759644,
+            "loss_box_reg": 0.050025828182697296,
+            "loss_classifier": 0.5316952466964722,
+            "loss_mask": 0.7236229181289673,
+            "loss_rpn_box": 0.0856662318110466,
+            "loss_rpn_cls": 0.48198649287223816,
+            "lr": 0.007173333333333333,
+            "time": 0.25401854515075684
+          },
+          {
+            "data_time": 0.007216215133666992,
+            "iteration": 40,
+            "loss": 1.282649278640747,
+            "loss_box_reg": 0.06222952902317047,
+            "loss_classifier": 0.30682939291000366,
+            "loss_mask": 0.6970193982124329,
+            "loss_rpn_box": 0.038663312792778015,
+            "loss_rpn_cls": 0.1471673548221588,
+            "lr": 0.007706666666666667,
+            "time": 0.2490077018737793
+          }
+        ]
+        $ cat metrics.json | jq '.loss_mask'
+        0.7126231789588928
+        0.689423680305481
+        0.6776131987571716
+        ...
+    """
+
+    def __init__(self, json_file, window_size=20):
+        """
+        Args:
+            json_file (str): path to the json file. New data will be appended if the file exists.
+            window_size (int): the window size of median smoothing for the scalars whose
+                `smoothing_hint` are True.
+        """
+        self._file_handle = PathManager.open(json_file, "a")
+        self._window_size = window_size
+        self._last_write = -1
+
+    def write(self):
+        storage = get_event_storage()
+        to_save = defaultdict(dict)
+
+        for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
+            # keep scalars that have not been written
+            if iter <= self._last_write:
+                continue
+            to_save[iter][k] = v
+        all_iters = sorted(to_save.keys())
+        self._last_write = max(all_iters)
+
+        for itr, scalars_per_iter in to_save.items():
+            scalars_per_iter["iteration"] = itr
+            self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n")
+        self._file_handle.flush()
+        try:
+            os.fsync(self._file_handle.fileno())
+        except AttributeError:
+            pass
+
+    def close(self):
+        self._file_handle.close()
+
+
+class TensorboardXWriter(EventWriter):
+    """
+    Write all scalars to a tensorboard file.
+    """
+
+    def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
+        """
+        Args:
+            log_dir (str): the directory to save the output events
+            window_size (int): the scalars will be median-smoothed by this window size
+            kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
+        """
+        self._window_size = window_size
+        from torch.utils.tensorboard import SummaryWriter
+
+        self._writer = SummaryWriter(log_dir, **kwargs)
+        self._last_write = -1
+
+    def write(self):
+        storage = get_event_storage()
+        new_last_write = self._last_write
+        for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
+            if iter > self._last_write:
+                self._writer.add_scalar(k, v, iter)
+                new_last_write = max(new_last_write, iter)
+        self._last_write = new_last_write
+
+        # storage.put_{image,histogram} is only meant to be used by
+        # tensorboard writer. So we access its internal fields directly from here.
+        if len(storage._vis_data) >= 1:
+            for img_name, img, step_num in storage._vis_data:
+                self._writer.add_image(img_name, img, step_num)
+            # Storage stores all image data and rely on this writer to clear them.
+            # As a result it assumes only one writer will use its image data.
+            # An alternative design is to let storage store limited recent
+            # data (e.g. only the most recent image) that all writers can access.
+            # In that case a writer may not see all image data if its period is long.
+            storage.clear_images()
+
+        if len(storage._histograms) >= 1:
+            for params in storage._histograms:
+                self._writer.add_histogram_raw(**params)
+            storage.clear_histograms()
+
+    def close(self):
+        if hasattr(self, "_writer"):  # doesn't exist when the code fails at import
+            self._writer.close()
+
+
+class CommonMetricPrinter(EventWriter):
+    """
+    Print **common** metrics to the terminal, including
+    iteration time, ETA, memory, all losses, and the learning rate.
+    It also applies smoothing using a window of 20 elements.
+    It's meant to print common metrics in common ways.
+    To print something in more customized ways, please implement a similar printer by yourself.
+    """
+
+    def __init__(self, max_iter):
+        """
+        Args:
+            max_iter (int): the maximum number of iterations to train.
+                Used to compute ETA.
+        """
+        self.logger = logging.getLogger(__name__)
+        self._max_iter = max_iter
+        self._last_write = None
+
+    def write(self):
+        storage = get_event_storage()
+        iteration = storage.iter
+
+        try:
+            data_time = storage.history("data_time").avg(20)
+        except KeyError:
+            # they may not exist in the first few iterations (due to warmup)
+            # or when SimpleTrainer is not used
+            data_time = None
+
+        eta_string = None
+        try:
+            iter_time = storage.history("time").global_avg()
+            eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration)
+            storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
+            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+        except KeyError:
+            iter_time = None
+            # estimate eta on our own - more noisy
+            if self._last_write is not None:
+                estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
+                    iteration - self._last_write[0]
+                )
+                eta_seconds = estimate_iter_time * (self._max_iter - iteration)
+                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+            self._last_write = (iteration, time.perf_counter())
+
+        try:
+            lr = "{:.2e}".format(storage.history("lr").latest())
+        except KeyError:
+            lr = "N/A"
+
+        if torch.cuda.is_available():
+            max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
+        else:
+            max_mem_mb = None
+
+        # NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
+        self.logger.info(
+            " {eta}iter: {iter}  {losses}  {time}{data_time}lr: {lr}  {memory}".format(
+                eta=f"eta: {eta_string}  " if eta_string else "",
+                iter=iteration,
+                losses="  ".join(
+                    [
+                        "{}: {:.4g}".format(k, v.median(20))
+                        for k, v in storage.histories().items()
+                        if "loss" in k
+                    ]
+                ),
+                time="time: {:.4f}  ".format(iter_time) if iter_time is not None else "",
+                data_time="data_time: {:.4f}  ".format(data_time) if data_time is not None else "",
+                lr=lr,
+                memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "",
+            )
+        )
+
+
+class EventStorage:
+    """
+    The user-facing class that provides metric storage functionalities.
+    In the future we may add support for storing / logging other types of data if needed.
+    """
+
+    def __init__(self, start_iter=0):
+        """
+        Args:
+            start_iter (int): the iteration number to start with
+        """
+        self._history = defaultdict(HistoryBuffer)
+        self._smoothing_hints = {}
+        self._latest_scalars = {}
+        self._iter = start_iter
+        self._current_prefix = ""
+        self._vis_data = []
+        self._histograms = []
+
+    def put_image(self, img_name, img_tensor):
+        """
+        Add an `img_tensor` associated with `img_name`, to be shown on
+        tensorboard.
+        Args:
+            img_name (str): The name of the image to put into tensorboard.
+            img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
+                Tensor of shape `[channel, height, width]` where `channel` is
+                3. The image format should be RGB. The elements in img_tensor
+                can either have values in [0, 1] (float32) or [0, 255] (uint8).
+                The `img_tensor` will be visualized in tensorboard.
+        """
+        self._vis_data.append((img_name, img_tensor, self._iter))
+
+    def put_scalar(self, name, value, smoothing_hint=True):
+        """
+        Add a scalar `value` to the `HistoryBuffer` associated with `name`.
+        Args:
+            smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
+                smoothed when logged. The hint will be accessible through
+                :meth:`EventStorage.smoothing_hints`.  A writer may ignore the hint
+                and apply custom smoothing rule.
+                It defaults to True because most scalars we save need to be smoothed to
+                provide any useful signal.
+        """
+        name = self._current_prefix + name
+        history = self._history[name]
+        value = float(value)
+        history.update(value, self._iter)
+        self._latest_scalars[name] = (value, self._iter)
+
+        existing_hint = self._smoothing_hints.get(name)
+        if existing_hint is not None:
+            assert (
+                existing_hint == smoothing_hint
+            ), "Scalar {} was put with a different smoothing_hint!".format(name)
+        else:
+            self._smoothing_hints[name] = smoothing_hint
+
+    def put_scalars(self, *, smoothing_hint=True, **kwargs):
+        """
+        Put multiple scalars from keyword arguments.
+        Examples:
+            storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
+        """
+        for k, v in kwargs.items():
+            self.put_scalar(k, v, smoothing_hint=smoothing_hint)
+
+    def put_histogram(self, hist_name, hist_tensor, bins=1000):
+        """
+        Create a histogram from a tensor.
+        Args:
+            hist_name (str): The name of the histogram to put into tensorboard.
+            hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted
+                into a histogram.
+            bins (int): Number of histogram bins.
+        """
+        ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item()
+
+        # Create a histogram with PyTorch
+        hist_counts = torch.histc(hist_tensor, bins=bins)
+        hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32)
+
+        # Parameter for the add_histogram_raw function of SummaryWriter
+        hist_params = dict(
+            tag=hist_name,
+            min=ht_min,
+            max=ht_max,
+            num=len(hist_tensor),
+            sum=float(hist_tensor.sum()),
+            sum_squares=float(torch.sum(hist_tensor ** 2)),
+            bucket_limits=hist_edges[1:].tolist(),
+            bucket_counts=hist_counts.tolist(),
+            global_step=self._iter,
+        )
+        self._histograms.append(hist_params)
+
+    def history(self, name):
+        """
+        Returns:
+            HistoryBuffer: the scalar history for name
+        """
+        ret = self._history.get(name, None)
+        if ret is None:
+            raise KeyError("No history metric available for {}!".format(name))
+        return ret
+
+    def histories(self):
+        """
+        Returns:
+            dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
+        """
+        return self._history
+
+    def latest(self):
+        """
+        Returns:
+            dict[str -> (float, int)]: mapping from the name of each scalar to the most
+                recent value and the iteration number its added.
+        """
+        return self._latest_scalars
+
+    def latest_with_smoothing_hint(self, window_size=20):
+        """
+        Similar to :meth:`latest`, but the returned values
+        are either the un-smoothed original latest value,
+        or a median of the given window_size,
+        depend on whether the smoothing_hint is True.
+        This provides a default behavior that other writers can use.
+        """
+        result = {}
+        for k, (v, itr) in self._latest_scalars.items():
+            result[k] = (
+                self._history[k].median(window_size) if self._smoothing_hints[k] else v,
+                itr,
+            )
+        return result
+
+    def smoothing_hints(self):
+        """
+        Returns:
+            dict[name -> bool]: the user-provided hint on whether the scalar
+                is noisy and needs smoothing.
+        """
+        return self._smoothing_hints
+
+    def step(self):
+        """
+        User should call this function at the beginning of each iteration, to
+        notify the storage of the start of a new iteration.
+        The storage will then be able to associate the new data with the
+        correct iteration number.
+        """
+        self._iter += 1
+
+    @property
+    def iter(self):
+        return self._iter
+
+    @property
+    def iteration(self):
+        # for backward compatibility
+        return self._iter
+
+    def __enter__(self):
+        _CURRENT_STORAGE_STACK.append(self)
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        assert _CURRENT_STORAGE_STACK[-1] == self
+        _CURRENT_STORAGE_STACK.pop()
+
+    @contextmanager
+    def name_scope(self, name):
+        """
+        Yields:
+            A context within which all the events added to this storage
+            will be prefixed by the name scope.
+        """
+        old_prefix = self._current_prefix
+        self._current_prefix = name.rstrip("/") + "/"
+        yield
+        self._current_prefix = old_prefix
+
+    def clear_images(self):
+        """
+        Delete all the stored images for visualization. This should be called
+        after images are written to tensorboard.
+        """
+        self._vis_data = []
+
+    def clear_histograms(self):
+        """
+        Delete all the stored histograms for visualization.
+        This should be called after histograms are written to tensorboard.
+        """
+        self._histograms = []
\ No newline at end of file
diff --git a/utils/file_io.py b/utils/file_io.py
new file mode 100644
index 0000000..8533fe8
--- /dev/null
+++ b/utils/file_io.py
@@ -0,0 +1,520 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import errno
+import logging
+import os
+import shutil
+from collections import OrderedDict
+from typing import (
+    IO,
+    Any,
+    Callable,
+    Dict,
+    List,
+    MutableMapping,
+    Optional,
+    Union,
+)
+
+__all__ = ["PathManager", "get_cache_dir"]
+
+
+def get_cache_dir(cache_dir: Optional[str] = None) -> str:
+    """
+    Returns a default directory to cache static files
+    (usually downloaded from Internet), if None is provided.
+    Args:
+        cache_dir (None or str): if not None, will be returned as is.
+            If None, returns the default cache directory as:
+        1) $FVCORE_CACHE, if set
+        2) otherwise ~/.torch/fvcore_cache
+    """
+    if cache_dir is None:
+        cache_dir = os.path.expanduser(
+            os.getenv("FVCORE_CACHE", "~/.torch/fvcore_cache")
+        )
+    return cache_dir
+
+
+class PathHandler:
+    """
+    PathHandler is a base class that defines common I/O functionality for a URI
+    protocol. It routes I/O for a generic URI which may look like "protocol://*"
+    or a canonical filepath "/foo/bar/baz".
+    """
+
+    _strict_kwargs_check = True
+
+    def _check_kwargs(self, kwargs: Dict[str, Any]) -> None:
+        """
+        Checks if the given arguments are empty. Throws a ValueError if strict
+        kwargs checking is enabled and args are non-empty. If strict kwargs
+        checking is disabled, only a warning is logged.
+        Args:
+            kwargs (Dict[str, Any])
+        """
+        if self._strict_kwargs_check:
+            if len(kwargs) > 0:
+                raise ValueError("Unused arguments: {}".format(kwargs))
+        else:
+            logger = logging.getLogger(__name__)
+            for k, v in kwargs.items():
+                logger.warning(
+                    "[PathManager] {}={} argument ignored".format(k, v)
+                )
+
+    def _get_supported_prefixes(self) -> List[str]:
+        """
+        Returns:
+            List[str]: the list of URI prefixes this PathHandler can support
+        """
+        raise NotImplementedError()
+
+    def _get_local_path(self, path: str, **kwargs: Any) -> str:
+        """
+        Get a filepath which is compatible with native Python I/O such as `open`
+        and `os.path`.
+        If URI points to a remote resource, this function may download and cache
+        the resource to local disk. In this case, this function is meant to be
+        used with read-only resources.
+        Args:
+            path (str): A URI supported by this PathHandler
+        Returns:
+            local_path (str): a file path which exists on the local file system
+        """
+        raise NotImplementedError()
+
+    def _open(
+            self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
+    ) -> Union[IO[str], IO[bytes]]:
+        """
+        Open a stream to a URI, similar to the built-in `open`.
+        Args:
+            path (str): A URI supported by this PathHandler
+            mode (str): Specifies the mode in which the file is opened. It defaults
+                to 'r'.
+            buffering (int): An optional integer used to set the buffering policy.
+                Pass 0 to switch buffering off and an integer >= 1 to indicate the
+                size in bytes of a fixed-size chunk buffer. When no buffering
+                argument is given, the default buffering policy depends on the
+                underlying I/O implementation.
+        Returns:
+            file: a file-like object.
+        """
+        raise NotImplementedError()
+
+    def _copy(
+            self,
+            src_path: str,
+            dst_path: str,
+            overwrite: bool = False,
+            **kwargs: Any,
+    ) -> bool:
+        """
+        Copies a source path to a destination path.
+        Args:
+            src_path (str): A URI supported by this PathHandler
+            dst_path (str): A URI supported by this PathHandler
+            overwrite (bool): Bool flag for forcing overwrite of existing file
+        Returns:
+            status (bool): True on success
+        """
+        raise NotImplementedError()
+
+    def _exists(self, path: str, **kwargs: Any) -> bool:
+        """
+        Checks if there is a resource at the given URI.
+        Args:
+            path (str): A URI supported by this PathHandler
+        Returns:
+            bool: true if the path exists
+        """
+        raise NotImplementedError()
+
+    def _isfile(self, path: str, **kwargs: Any) -> bool:
+        """
+        Checks if the resource at the given URI is a file.
+        Args:
+            path (str): A URI supported by this PathHandler
+        Returns:
+            bool: true if the path is a file
+        """
+        raise NotImplementedError()
+
+    def _isdir(self, path: str, **kwargs: Any) -> bool:
+        """
+        Checks if the resource at the given URI is a directory.
+        Args:
+            path (str): A URI supported by this PathHandler
+        Returns:
+            bool: true if the path is a directory
+        """
+        raise NotImplementedError()
+
+    def _ls(self, path: str, **kwargs: Any) -> List[str]:
+        """
+        List the contents of the directory at the provided URI.
+        Args:
+            path (str): A URI supported by this PathHandler
+        Returns:
+            List[str]: list of contents in given path
+        """
+        raise NotImplementedError()
+
+    def _mkdirs(self, path: str, **kwargs: Any) -> None:
+        """
+        Recursive directory creation function. Like mkdir(), but makes all
+        intermediate-level directories needed to contain the leaf directory.
+        Similar to the native `os.makedirs`.
+        Args:
+            path (str): A URI supported by this PathHandler
+        """
+        raise NotImplementedError()
+
+    def _rm(self, path: str, **kwargs: Any) -> None:
+        """
+        Remove the file (not directory) at the provided URI.
+        Args:
+            path (str): A URI supported by this PathHandler
+        """
+        raise NotImplementedError()
+
+
+class NativePathHandler(PathHandler):
+    """
+    Handles paths that can be accessed using Python native system calls. This
+    handler uses `open()` and `os.*` calls on the given path.
+    """
+
+    def _get_local_path(self, path: str, **kwargs: Any) -> str:
+        self._check_kwargs(kwargs)
+        return path
+
+    def _open(
+            self,
+            path: str,
+            mode: str = "r",
+            buffering: int = -1,
+            encoding: Optional[str] = None,
+            errors: Optional[str] = None,
+            newline: Optional[str] = None,
+            closefd: bool = True,
+            opener: Optional[Callable] = None,
+            **kwargs: Any,
+    ) -> Union[IO[str], IO[bytes]]:
+        """
+        Open a path.
+        Args:
+            path (str): A URI supported by this PathHandler
+            mode (str): Specifies the mode in which the file is opened. It defaults
+                to 'r'.
+            buffering (int): An optional integer used to set the buffering policy.
+                Pass 0 to switch buffering off and an integer >= 1 to indicate the
+                size in bytes of a fixed-size chunk buffer. When no buffering
+                argument is given, the default buffering policy works as follows:
+                    * Binary files are buffered in fixed-size chunks; the size of
+                    the buffer is chosen using a heuristic trying to determine the
+                    underlying device鈥檚 鈥渂lock size鈥� and falling back on
+                    io.DEFAULT_BUFFER_SIZE. On many systems, the buffer will
+                    typically be 4096 or 8192 bytes long.
+            encoding (Optional[str]): the name of the encoding used to decode or
+                encode the file. This should only be used in text mode.
+            errors (Optional[str]): an optional string that specifies how encoding
+                and decoding errors are to be handled. This cannot be used in binary
+                mode.
+            newline (Optional[str]): controls how universal newlines mode works
+                (it only applies to text mode). It can be None, '', '\n', '\r',
+                and '\r\n'.
+            closefd (bool): If closefd is False and a file descriptor rather than
+                a filename was given, the underlying file descriptor will be kept
+                open when the file is closed. If a filename is given closefd must
+                be True (the default) otherwise an error will be raised.
+            opener (Optional[Callable]): A custom opener can be used by passing
+                a callable as opener. The underlying file descriptor for the file
+                object is then obtained by calling opener with (file, flags).
+                opener must return an open file descriptor (passing os.open as opener
+                results in functionality similar to passing None).
+            See https://docs.python.org/3/library/functions.html#open for details.
+        Returns:
+            file: a file-like object.
+        """
+        self._check_kwargs(kwargs)
+        return open(  # type: ignore
+            path,
+            mode,
+            buffering=buffering,
+            encoding=encoding,
+            errors=errors,
+            newline=newline,
+            closefd=closefd,
+            opener=opener,
+        )
+
+    def _copy(
+            self,
+            src_path: str,
+            dst_path: str,
+            overwrite: bool = False,
+            **kwargs: Any,
+    ) -> bool:
+        """
+        Copies a source path to a destination path.
+        Args:
+            src_path (str): A URI supported by this PathHandler
+            dst_path (str): A URI supported by this PathHandler
+            overwrite (bool): Bool flag for forcing overwrite of existing file
+        Returns:
+            status (bool): True on success
+        """
+        self._check_kwargs(kwargs)
+
+        if os.path.exists(dst_path) and not overwrite:
+            logger = logging.getLogger(__name__)
+            logger.error("Destination file {} already exists.".format(dst_path))
+            return False
+
+        try:
+            shutil.copyfile(src_path, dst_path)
+            return True
+        except Exception as e:
+            logger = logging.getLogger(__name__)
+            logger.error("Error in file copy - {}".format(str(e)))
+            return False
+
+    def _exists(self, path: str, **kwargs: Any) -> bool:
+        self._check_kwargs(kwargs)
+        return os.path.exists(path)
+
+    def _isfile(self, path: str, **kwargs: Any) -> bool:
+        self._check_kwargs(kwargs)
+        return os.path.isfile(path)
+
+    def _isdir(self, path: str, **kwargs: Any) -> bool:
+        self._check_kwargs(kwargs)
+        return os.path.isdir(path)
+
+    def _ls(self, path: str, **kwargs: Any) -> List[str]:
+        self._check_kwargs(kwargs)
+        return os.listdir(path)
+
+    def _mkdirs(self, path: str, **kwargs: Any) -> None:
+        self._check_kwargs(kwargs)
+        try:
+            os.makedirs(path, exist_ok=True)
+        except OSError as e:
+            # EEXIST it can still happen if multiple processes are creating the dir
+            if e.errno != errno.EEXIST:
+                raise
+
+    def _rm(self, path: str, **kwargs: Any) -> None:
+        self._check_kwargs(kwargs)
+        os.remove(path)
+
+
+class PathManager:
+    """
+    A class for users to open generic paths or translate generic paths to file names.
+    """
+
+    _PATH_HANDLERS: MutableMapping[str, PathHandler] = OrderedDict()
+    _NATIVE_PATH_HANDLER = NativePathHandler()
+
+    @staticmethod
+    def __get_path_handler(path: str) -> PathHandler:
+        """
+        Finds a PathHandler that supports the given path. Falls back to the native
+        PathHandler if no other handler is found.
+        Args:
+            path (str): URI path to resource
+        Returns:
+            handler (PathHandler)
+        """
+        for p in PathManager._PATH_HANDLERS.keys():
+            if path.startswith(p):
+                return PathManager._PATH_HANDLERS[p]
+        return PathManager._NATIVE_PATH_HANDLER
+
+    @staticmethod
+    def open(
+            path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
+    ) -> Union[IO[str], IO[bytes]]:
+        """
+        Open a stream to a URI, similar to the built-in `open`.
+        Args:
+            path (str): A URI supported by this PathHandler
+            mode (str): Specifies the mode in which the file is opened. It defaults
+                to 'r'.
+            buffering (int): An optional integer used to set the buffering policy.
+                Pass 0 to switch buffering off and an integer >= 1 to indicate the
+                size in bytes of a fixed-size chunk buffer. When no buffering
+                argument is given, the default buffering policy depends on the
+                underlying I/O implementation.
+        Returns:
+            file: a file-like object.
+        """
+        return PathManager.__get_path_handler(path)._open(  # type: ignore
+            path, mode, buffering=buffering, **kwargs
+        )
+
+    @staticmethod
+    def copy(
+            src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any
+    ) -> bool:
+        """
+        Copies a source path to a destination path.
+        Args:
+            src_path (str): A URI supported by this PathHandler
+            dst_path (str): A URI supported by this PathHandler
+            overwrite (bool): Bool flag for forcing overwrite of existing file
+        Returns:
+            status (bool): True on success
+        """
+
+        # Copying across handlers is not supported.
+        assert PathManager.__get_path_handler(  # type: ignore
+            src_path
+        ) == PathManager.__get_path_handler(dst_path)
+        return PathManager.__get_path_handler(src_path)._copy(
+            src_path, dst_path, overwrite, **kwargs
+        )
+
+    @staticmethod
+    def get_local_path(path: str, **kwargs: Any) -> str:
+        """
+        Get a filepath which is compatible with native Python I/O such as `open`
+        and `os.path`.
+        If URI points to a remote resource, this function may download and cache
+        the resource to local disk.
+        Args:
+            path (str): A URI supported by this PathHandler
+        Returns:
+            local_path (str): a file path which exists on the local file system
+        """
+        return PathManager.__get_path_handler(  # type: ignore
+            path
+        )._get_local_path(path, **kwargs)
+
+    @staticmethod
+    def exists(path: str, **kwargs: Any) -> bool:
+        """
+        Checks if there is a resource at the given URI.
+        Args:
+            path (str): A URI supported by this PathHandler
+        Returns:
+            bool: true if the path exists
+        """
+        return PathManager.__get_path_handler(path)._exists(  # type: ignore
+            path, **kwargs
+        )
+
+    @staticmethod
+    def isfile(path: str, **kwargs: Any) -> bool:
+        """
+        Checks if there the resource at the given URI is a file.
+        Args:
+            path (str): A URI supported by this PathHandler
+        Returns:
+            bool: true if the path is a file
+        """
+        return PathManager.__get_path_handler(path)._isfile(  # type: ignore
+            path, **kwargs
+        )
+
+    @staticmethod
+    def isdir(path: str, **kwargs: Any) -> bool:
+        """
+        Checks if the resource at the given URI is a directory.
+        Args:
+            path (str): A URI supported by this PathHandler
+        Returns:
+            bool: true if the path is a directory
+        """
+        return PathManager.__get_path_handler(path)._isdir(  # type: ignore
+            path, **kwargs
+        )
+
+    @staticmethod
+    def ls(path: str, **kwargs: Any) -> List[str]:
+        """
+        List the contents of the directory at the provided URI.
+        Args:
+            path (str): A URI supported by this PathHandler
+        Returns:
+            List[str]: list of contents in given path
+        """
+        return PathManager.__get_path_handler(path)._ls(  # type: ignore
+            path, **kwargs
+        )
+
+    @staticmethod
+    def mkdirs(path: str, **kwargs: Any) -> None:
+        """
+        Recursive directory creation function. Like mkdir(), but makes all
+        intermediate-level directories needed to contain the leaf directory.
+        Similar to the native `os.makedirs`.
+        Args:
+            path (str): A URI supported by this PathHandler
+        """
+        return PathManager.__get_path_handler(path)._mkdirs(  # type: ignore
+            path, **kwargs
+        )
+
+    @staticmethod
+    def rm(path: str, **kwargs: Any) -> None:
+        """
+        Remove the file (not directory) at the provided URI.
+        Args:
+            path (str): A URI supported by this PathHandler
+        """
+        return PathManager.__get_path_handler(path)._rm(  # type: ignore
+            path, **kwargs
+        )
+
+    @staticmethod
+    def register_handler(handler: PathHandler) -> None:
+        """
+        Register a path handler associated with `handler._get_supported_prefixes`
+        URI prefixes.
+        Args:
+            handler (PathHandler)
+        """
+        assert isinstance(handler, PathHandler), handler
+        for prefix in handler._get_supported_prefixes():
+            assert prefix not in PathManager._PATH_HANDLERS
+            PathManager._PATH_HANDLERS[prefix] = handler
+
+        # Sort path handlers in reverse order so longer prefixes take priority,
+        # eg: http://foo/bar before http://foo
+        PathManager._PATH_HANDLERS = OrderedDict(
+            sorted(
+                PathManager._PATH_HANDLERS.items(),
+                key=lambda t: t[0],
+                reverse=True,
+            )
+        )
+
+    @staticmethod
+    def set_strict_kwargs_checking(enable: bool) -> None:
+        """
+        Toggles strict kwargs checking. If enabled, a ValueError is thrown if any
+        unused parameters are passed to a PathHandler function. If disabled, only
+        a warning is given.
+        With a centralized file API, there's a tradeoff of convenience and
+        correctness delegating arguments to the proper I/O layers. An underlying
+        `PathHandler` may support custom arguments which should not be statically
+        exposed on the `PathManager` function. For example, a custom `HTTPURLHandler`
+        may want to expose a `cache_timeout` argument for `open()` which specifies
+        how old a locally cached resource can be before it's refetched from the
+        remote server. This argument would not make sense for a `NativePathHandler`.
+        If strict kwargs checking is disabled, `cache_timeout` can be passed to
+        `PathManager.open` which will forward the arguments to the underlying
+        handler. By default, checking is enabled since it is innately unsafe:
+        multiple `PathHandler`s could reuse arguments with different semantic
+        meanings or types.
+        Args:
+            enable (bool)
+        """
+        PathManager._NATIVE_PATH_HANDLER._strict_kwargs_check = enable
+        for handler in PathManager._PATH_HANDLERS.values():
+            handler._strict_kwargs_check = enable
diff --git a/utils/history_buffer.py b/utils/history_buffer.py
new file mode 100644
index 0000000..033c2b0
--- /dev/null
+++ b/utils/history_buffer.py
@@ -0,0 +1,71 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import numpy as np
+from typing import List, Tuple
+
+
+class HistoryBuffer:
+    """
+    Track a series of scalar values and provide access to smoothed values over a
+    window or the global average of the series.
+    """
+
+    def __init__(self, max_length: int = 1000000):
+        """
+        Args:
+            max_length: maximal number of values that can be stored in the
+                buffer. When the capacity of the buffer is exhausted, old
+                values will be removed.
+        """
+        self._max_length: int = max_length
+        self._data: List[Tuple[float, float]] = []  # (value, iteration) pairs
+        self._count: int = 0
+        self._global_avg: float = 0
+
+    def update(self, value: float, iteration: float = None):
+        """
+        Add a new scalar value produced at certain iteration. If the length
+        of the buffer exceeds self._max_length, the oldest element will be
+        removed from the buffer.
+        """
+        if iteration is None:
+            iteration = self._count
+        if len(self._data) == self._max_length:
+            self._data.pop(0)
+        self._data.append((value, iteration))
+
+        self._count += 1
+        self._global_avg += (value - self._global_avg) / self._count
+
+    def latest(self):
+        """
+        Return the latest scalar value added to the buffer.
+        """
+        return self._data[-1][0]
+
+    def median(self, window_size: int):
+        """
+        Return the median of the latest `window_size` values in the buffer.
+        """
+        return np.median([x[0] for x in self._data[-window_size:]])
+
+    def avg(self, window_size: int):
+        """
+        Return the mean of the latest `window_size` values in the buffer.
+        """
+        return np.mean([x[0] for x in self._data[-window_size:]])
+
+    def global_avg(self):
+        """
+        Return the mean of all the elements in the buffer. Note that this
+        includes those getting removed due to limited buffer storage.
+        """
+        return self._global_avg
+
+    def values(self):
+        """
+        Returns:
+            list[(number, iteration)]: content of the current buffer.
+        """
+        return self._data
diff --git a/utils/logger.py b/utils/logger.py
new file mode 100644
index 0000000..01e3da7
--- /dev/null
+++ b/utils/logger.py
@@ -0,0 +1,209 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import functools
+import logging
+import os
+import sys
+import time
+from collections import Counter
+from .file_io import PathManager
+from termcolor import colored
+
+
+class _ColorfulFormatter(logging.Formatter):
+    def __init__(self, *args, **kwargs):
+        self._root_name = kwargs.pop("root_name") + "."
+        self._abbrev_name = kwargs.pop("abbrev_name", "")
+        if len(self._abbrev_name):
+            self._abbrev_name = self._abbrev_name + "."
+        super(_ColorfulFormatter, self).__init__(*args, **kwargs)
+
+    def formatMessage(self, record):
+        record.name = record.name.replace(self._root_name, self._abbrev_name)
+        log = super(_ColorfulFormatter, self).formatMessage(record)
+        if record.levelno == logging.WARNING:
+            prefix = colored("WARNING", "red", attrs=["blink"])
+        elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
+            prefix = colored("ERROR", "red", attrs=["blink", "underline"])
+        else:
+            return log
+        return prefix + " " + log
+
+
+@functools.lru_cache()  # so that calling setup_logger multiple times won't add many handlers
+def setup_logger(
+        output=None, distributed_rank=0, *, color=True, name="fastreid", abbrev_name=None
+):
+    """
+    Args:
+        output (str): a file name or a directory to save log. If None, will not save log file.
+            If ends with ".txt" or ".log", assumed to be a file name.
+            Otherwise, logs will be saved to `output/log.txt`.
+        name (str): the root module name of this logger
+        abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
+            Set to "" to not log the root module in logs.
+            By default, will abbreviate "detectron2" to "d2" and leave other
+            modules unchanged.
+    """
+    logger = logging.getLogger(name)
+    logger.setLevel(logging.DEBUG)
+    logger.propagate = False
+
+    if abbrev_name is None:
+        abbrev_name = "d2" if name == "detectron2" else name
+
+    plain_formatter = logging.Formatter(
+        "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
+    )
+    # stdout logging: master only
+    if distributed_rank == 0:
+        ch = logging.StreamHandler(stream=sys.stdout)
+        ch.setLevel(logging.DEBUG)
+        if color:
+            formatter = _ColorfulFormatter(
+                colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
+                datefmt="%m/%d %H:%M:%S",
+                root_name=name,
+                abbrev_name=str(abbrev_name),
+            )
+        else:
+            formatter = plain_formatter
+        ch.setFormatter(formatter)
+        logger.addHandler(ch)
+
+    # file logging: all workers
+    if output is not None:
+        if output.endswith(".txt") or output.endswith(".log"):
+            filename = output
+        else:
+            filename = os.path.join(output, "log.txt")
+        if distributed_rank > 0:
+            filename = filename + ".rank{}".format(distributed_rank)
+        PathManager.mkdirs(os.path.dirname(filename))
+
+        fh = logging.StreamHandler(_cached_log_stream(filename))
+        fh.setLevel(logging.DEBUG)
+        fh.setFormatter(plain_formatter)
+        logger.addHandler(fh)
+
+    return logger
+
+
+# cache the opened file object, so that different calls to `setup_logger`
+# with the same file name can safely write to the same file.
+@functools.lru_cache(maxsize=None)
+def _cached_log_stream(filename):
+    return PathManager.open(filename, "a")
+
+
+"""
+Below are some other convenient logging methods.
+They are mainly adopted from
+https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py
+"""
+
+
+def _find_caller():
+    """
+    Returns:
+        str: module name of the caller
+        tuple: a hashable key to be used to identify different callers
+    """
+    frame = sys._getframe(2)
+    while frame:
+        code = frame.f_code
+        if os.path.join("utils", "logger.") not in code.co_filename:
+            mod_name = frame.f_globals["__name__"]
+            if mod_name == "__main__":
+                mod_name = "detectron2"
+            return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
+        frame = frame.f_back
+
+
+_LOG_COUNTER = Counter()
+_LOG_TIMER = {}
+
+
+def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
+    """
+    Log only for the first n times.
+    Args:
+        lvl (int): the logging level
+        msg (str):
+        n (int):
+        name (str): name of the logger to use. Will use the caller's module by default.
+        key (str or tuple[str]): the string(s) can be one of "caller" or
+            "message", which defines how to identify duplicated logs.
+            For example, if called with `n=1, key="caller"`, this function
+            will only log the first call from the same caller, regardless of
+            the message content.
+            If called with `n=1, key="message"`, this function will log the
+            same content only once, even if they are called from different places.
+            If called with `n=1, key=("caller", "message")`, this function
+            will not log only if the same caller has logged the same message before.
+    """
+    if isinstance(key, str):
+        key = (key,)
+    assert len(key) > 0
+
+    caller_module, caller_key = _find_caller()
+    hash_key = ()
+    if "caller" in key:
+        hash_key = hash_key + caller_key
+    if "message" in key:
+        hash_key = hash_key + (msg,)
+
+    _LOG_COUNTER[hash_key] += 1
+    if _LOG_COUNTER[hash_key] <= n:
+        logging.getLogger(name or caller_module).log(lvl, msg)
+
+
+def log_every_n(lvl, msg, n=1, *, name=None):
+    """
+    Log once per n times.
+    Args:
+        lvl (int): the logging level
+        msg (str):
+        n (int):
+        name (str): name of the logger to use. Will use the caller's module by default.
+    """
+    caller_module, key = _find_caller()
+    _LOG_COUNTER[key] += 1
+    if n == 1 or _LOG_COUNTER[key] % n == 1:
+        logging.getLogger(name or caller_module).log(lvl, msg)
+
+
+def log_every_n_seconds(lvl, msg, n=1, *, name=None):
+    """
+    Log no more than once per n seconds.
+    Args:
+        lvl (int): the logging level
+        msg (str):
+        n (int):
+        name (str): name of the logger to use. Will use the caller's module by default.
+    """
+    caller_module, key = _find_caller()
+    last_logged = _LOG_TIMER.get(key, None)
+    current_time = time.time()
+    if last_logged is None or current_time - last_logged >= n:
+        logging.getLogger(name or caller_module).log(lvl, msg)
+        _LOG_TIMER[key] = current_time
+
+# def create_small_table(small_dict):
+#     """
+#     Create a small table using the keys of small_dict as headers. This is only
+#     suitable for small dictionaries.
+#     Args:
+#         small_dict (dict): a result dictionary of only a few items.
+#     Returns:
+#         str: the table as a string.
+#     """
+#     keys, values = tuple(zip(*small_dict.items()))
+#     table = tabulate(
+#         [values],
+#         headers=keys,
+#         tablefmt="pipe",
+#         floatfmt=".3f",
+#         stralign="center",
+#         numalign="center",
+#     )
+#     return table
diff --git a/utils/registry.py b/utils/registry.py
new file mode 100644
index 0000000..ad5376b
--- /dev/null
+++ b/utils/registry.py
@@ -0,0 +1,66 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+from typing import Dict, Optional
+
+
+class Registry(object):
+    """
+    The registry that provides name -> object mapping, to support third-party
+    users' custom modules.
+    To create a registry (e.g. a backbone registry):
+    .. code-block:: python
+        BACKBONE_REGISTRY = Registry('BACKBONE')
+    To register an object:
+    .. code-block:: python
+        @BACKBONE_REGISTRY.register()
+        class MyBackbone():
+            ...
+    Or:
+    .. code-block:: python
+        BACKBONE_REGISTRY.register(MyBackbone)
+    """
+
+    def __init__(self, name: str) -> None:
+        """
+        Args:
+            name (str): the name of this registry
+        """
+        self._name: str = name
+        self._obj_map: Dict[str, object] = {}
+
+    def _do_register(self, name: str, obj: object) -> None:
+        assert (
+                name not in self._obj_map
+        ), "An object named '{}' was already registered in '{}' registry!".format(
+            name, self._name
+        )
+        self._obj_map[name] = obj
+
+    def register(self, obj: object = None) -> Optional[object]:
+        """
+        Register the given object under the the name `obj.__name__`.
+        Can be used as either a decorator or not. See docstring of this class for usage.
+        """
+        if obj is None:
+            # used as a decorator
+            def deco(func_or_class: object) -> object:
+                name = func_or_class.__name__  # pyre-ignore
+                self._do_register(name, func_or_class)
+                return func_or_class
+
+            return deco
+
+        # used as a function call
+        name = obj.__name__  # pyre-ignore
+        self._do_register(name, obj)
+
+    def get(self, name: str) -> object:
+        ret = self._obj_map.get(name)
+        if ret is None:
+            raise KeyError(
+                "No object named '{}' found in '{}' registry!".format(
+                    name, self._name
+                )
+            )
+        return ret
diff --git a/utils/weight_init.py b/utils/weight_init.py
new file mode 100644
index 0000000..82fb262
--- /dev/null
+++ b/utils/weight_init.py
@@ -0,0 +1,37 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import math
+from torch import nn
+
+__all__ = [
+    'weights_init_classifier',
+    'weights_init_kaiming',
+]
+
+
+def weights_init_kaiming(m):
+    classname = m.__class__.__name__
+    if classname.find('Linear') != -1:
+        nn.init.normal_(m.weight, 0, 0.01)
+        if m.bias is not None:
+            nn.init.constant_(m.bias, 0.0)
+    elif classname.find('Conv') != -1:
+        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+        if m.bias is not None:
+            nn.init.constant_(m.bias, 0.0)
+    elif classname.find('BatchNorm') != -1:
+        if m.affine:
+            nn.init.normal_(m.weight, 1.0, 0.02)
+            nn.init.constant_(m.bias, 0.0)
+
+
+def weights_init_classifier(m):
+    classname = m.__class__.__name__
+    if classname.find('Linear') != -1:
+        nn.init.normal_(m.weight, std=0.001)
+        if m.bias is not None:
+            nn.init.constant_(m.bias, 0.0)

--
Gitblit v1.8.0