From 291deeb1fcf45dbf39a24aa72a213ff3fd6b3405 Mon Sep 17 00:00:00 2001
From: Scheaven <xuepengqiang>
Date: 星期六, 18 九月 2021 14:22:50 +0800
Subject: [PATCH] update
---
utils/checkpoint.py | 403 ++
modeling/meta_arch/__pycache__/build.cpython-37.pyc | 0
modeling/meta_arch/build.py | 21
tools/__pycache__/predictor.cpython-38.pyc | 0
layers/activation.py | 59
modeling/meta_arch/__pycache__/__init__.cpython-37.pyc | 0
modeling/losses/__pycache__/__init__.cpython-37.pyc | 0
config/__pycache__/__init__.cpython-37.pyc | 0
modeling/heads/__pycache__/build.cpython-37.pyc | 0
modeling/backbones/resnet.py | 359 ++
modeling/heads/embedding_head.py | 97
engine/defaults.py | 115
layers/__pycache__/batch_drop.cpython-38.pyc | 0
modeling/heads/__pycache__/__init__.cpython-38.pyc | 0
tools/03_check_onnx.py | 168 +
utils/env.py | 119
utils/logger.py | 209 +
data/__pycache__/data_utils.cpython-37.pyc | 0
engine/__init__.py | 6
modeling/backbones/__pycache__/__init__.cpython-37.pyc | 0
layers/arcface.py | 54
layers/__pycache__/batch_norm.cpython-37.pyc | 0
modeling/losses/__pycache__/focal_loss.cpython-38.pyc | 0
utils/__pycache__/file_io.cpython-38.pyc | 0
data/transforms/__pycache__/build.cpython-38.pyc | 0
modeling/heads/bnneck_head.py | 61
modeling/backbones/resnest.py | 411 ++
config/__pycache__/defaults.cpython-38.pyc | 0
layers/context_block.py | 113
data/transforms/autoaugment.py | 812 +++++
modeling/heads/build.py | 25
utils/comm.py | 255 +
modeling/meta_arch/__pycache__/baseline.cpython-37.pyc | 0
modeling/backbones/__pycache__/osnet.cpython-37.pyc | 0
modeling/losses/__pycache__/cross_entroy_loss.cpython-38.pyc | 0
data/transforms/transforms.py | 204 +
layers/__pycache__/non_local.cpython-37.pyc | 0
modeling/heads/__pycache__/embedding_head.cpython-38.pyc | 0
utils/__pycache__/__init__.cpython-37.pyc | 0
utils/file_io.py | 520 +++
modeling/heads/__pycache__/attr_head.cpython-37.pyc | 0
modeling/backbones/__pycache__/resnet.cpython-38.pyc | 0
layers/__pycache__/frn.cpython-37.pyc | 0
modeling/backbones/__pycache__/resnest.cpython-38.pyc | 0
data/transforms/__pycache__/autoaugment.cpython-38.pyc | 0
layers/splat.py | 97
modeling/heads/linear_head.py | 50
utils/__pycache__/registry.cpython-37.pyc | 0
modeling/meta_arch/__pycache__/mgn.cpython-38.pyc | 0
utils/__pycache__/__init__.cpython-38.pyc | 0
tools/predictor.py | 114
layers/__pycache__/context_block.cpython-38.pyc | 0
modeling/backbones/__pycache__/resnet.cpython-37.pyc | 0
utils/__pycache__/checkpoint.cpython-37.pyc | 0
tools/03_py2onnx.py | 40
layers/__pycache__/splat.cpython-37.pyc | 0
modeling/heads/reduction_head.py | 73
layers/__pycache__/arcface.cpython-37.pyc | 0
data/transforms/build.py | 73
modeling/backbones/__pycache__/resnext.cpython-38.pyc | 0
utils/__pycache__/history_buffer.cpython-37.pyc | 0
modeling/__pycache__/__init__.cpython-37.pyc | 0
utils/weight_init.py | 37
layers/__pycache__/se_layer.cpython-38.pyc | 0
modeling/meta_arch/__pycache__/baseline.cpython-38.pyc | 0
utils/__pycache__/events.cpython-38.pyc | 0
modeling/__init__.py | 6
utils/__pycache__/comm.cpython-38.pyc | 0
data/transforms/__pycache__/functional.cpython-38.pyc | 0
layers/__pycache__/pooling.cpython-37.pyc | 0
modeling/losses/__pycache__/metric_loss.cpython-37.pyc | 0
layers/__pycache__/activation.cpython-38.pyc | 0
layers/non_local.py | 54
tools/inference_net.py | 60
layers/__pycache__/batch_drop.cpython-37.pyc | 0
config/__pycache__/config.cpython-37.pyc | 0
modeling/meta_arch/__pycache__/__init__.cpython-38.pyc | 0
layers/__pycache__/__init__.cpython-37.pyc | 0
data/__pycache__/__init__.cpython-38.pyc | 0
modeling/meta_arch/__init__.py | 12
modeling/meta_arch/mgn.py | 280 +
utils/__pycache__/weight_init.cpython-37.pyc | 0
layers/__pycache__/circle.cpython-37.pyc | 0
utils/collect_env.py | 158 +
data/transforms/__pycache__/transforms.cpython-38.pyc | 0
modeling/heads/attr_head.py | 77
data/__pycache__/data_utils.cpython-38.pyc | 0
data/data_utils.py | 45
modeling/losses/__init__.py | 9
config/__pycache__/__init__.cpython-38.pyc | 0
layers/__pycache__/frn.cpython-38.pyc | 0
modeling/heads/__pycache__/attr_head.cpython-38.pyc | 0
modeling/losses/focal_loss.py | 110
tools/__pycache__/predictor.cpython-37.pyc | 0
modeling/meta_arch/baseline.py | 119
modeling/meta_arch/__pycache__/build.cpython-38.pyc | 0
modeling/losses/__pycache__/metric_loss.cpython-38.pyc | 0
modeling/backbones/__pycache__/__init__.cpython-38.pyc | 0
data/__init__.py | 7
modeling/heads/__pycache__/build.cpython-38.pyc | 0
layers/__pycache__/non_local.cpython-38.pyc | 0
config/defaults.py | 273 +
data/transforms/__pycache__/__init__.cpython-38.pyc | 0
modeling/heads/__pycache__/__init__.cpython-37.pyc | 0
utils/__pycache__/collect_env.cpython-38.pyc | 0
modeling/backbones/__pycache__/resnest.cpython-37.pyc | 0
modeling/backbones/resnext.py | 198 +
utils/registry.py | 66
config/__init__.py | 9
tools/__init__.py | 6
utils/__pycache__/registry.cpython-38.pyc | 0
README.md | 48
modeling/losses/metric_loss.py | 215 +
modeling/backbones/__pycache__/build.cpython-38.pyc | 0
layers/circle.py | 42
modeling/losses/__pycache__/__init__.cpython-38.pyc | 0
utils/events.py | 445 +++
modeling/heads/__pycache__/embedding_head.cpython-37.pyc | 0
modeling/losses/__pycache__/cross_entroy_loss.cpython-37.pyc | 0
modeling/backbones/__pycache__/osnet.cpython-38.pyc | 0
layers/__pycache__/se_layer.cpython-37.pyc | 0
modeling/heads/__init__.py | 12
utils/history_buffer.py | 71
utils/__pycache__/logger.cpython-38.pyc | 0
modeling/backbones/__pycache__/build.cpython-37.pyc | 0
utils/__pycache__/weight_init.cpython-38.pyc | 0
data/transforms/__init__.py | 10
modeling/backbones/build.py | 28
tools/04_trt_inference.py | 96
modeling/backbones/osnet.py | 487 +++
config/__pycache__/config.cpython-38.pyc | 0
layers/batch_norm.py | 208 +
engine/__pycache__/__init__.cpython-38.pyc | 0
layers/frn.py | 199 +
layers/__pycache__/arcface.cpython-38.pyc | 0
layers/__pycache__/context_block.cpython-37.pyc | 0
layers/__pycache__/splat.cpython-38.pyc | 0
modeling/losses/cross_entroy_loss.py | 62
config/config.py | 161 +
layers/pooling.py | 79
utils/__pycache__/history_buffer.cpython-38.pyc | 0
modeling/__pycache__/__init__.cpython-38.pyc | 0
utils/__pycache__/env.cpython-38.pyc | 0
utils/__pycache__/file_io.cpython-37.pyc | 0
layers/__pycache__/__init__.cpython-38.pyc | 0
layers/__pycache__/activation.cpython-37.pyc | 0
data/__pycache__/__init__.cpython-37.pyc | 0
layers/__pycache__/batch_norm.cpython-38.pyc | 0
modeling/losses/__pycache__/focal_loss.cpython-37.pyc | 0
utils/__pycache__/comm.cpython-37.pyc | 0
layers/__pycache__/pooling.cpython-38.pyc | 0
/dev/null | 86
layers/__pycache__/circle.cpython-38.pyc | 0
utils/__init__.py | 6
utils/__pycache__/checkpoint.cpython-38.pyc | 0
modeling/meta_arch/__pycache__/mgn.cpython-37.pyc | 0
utils/__pycache__/events.cpython-37.pyc | 0
modeling/backbones/__pycache__/resnext.cpython-37.pyc | 0
engine/__pycache__/defaults.cpython-38.pyc | 0
layers/batch_drop.py | 32
data/transforms/functional.py | 190 +
layers/__init__.py | 17
layers/se_layer.py | 25
modeling/backbones/__init__.py | 12
164 files changed, 8,427 insertions(+), 88 deletions(-)
diff --git a/Makefile b/Makefile
deleted file mode 100644
index 4eec39b..0000000
--- a/Makefile
+++ /dev/null
@@ -1,93 +0,0 @@
-GPU=1
-CUDNN=1
-DEBUG=1
-
-ARCH= -gencode arch=compute_30,code=sm_30 \
- -gencode arch=compute_35,code=sm_35 \
- -gencode arch=compute_50,code=[sm_50,compute_50] \
- -gencode arch=compute_52,code=[sm_52,compute_52]
-# -gencode arch=compute_20,code=[sm_20,sm_21] \ This one is deprecated?
-
-# This is what I use, uncomment if you know your arch and want to specify
-# ARCH= -gencode arch=compute_52,code=compute_52
-
-VPATH=.
-EXEC=reid
-OBJDIR=./obj/
-
-CC=gcc
-CPP=g++
-NVCC=nvcc
-AR=ar
-ARFLAGS=rcs
-OPTS=-Ofast
-LDFLAGS= -lm -pthread
-CFLAGS=-Wall -Wno-unused-result -Wno-unknown-pragmas -Wfatal-errors -fPIC
-
-CFLAGS+= -I/home/disk1/s_opt/opencv/include/opencv4
-LDFLAGS+= -L/home/disk1/s_opt/opencv/lib -lopencv_core -lopencv_imgcodecs -lopencv_highgui -lopencv_imgproc
-
-
-COMMON = -std=c++11
-#flag = -Wl,-rpath=/home/disk1/s_opt/libtorch/lib
-#COMMON += -D_GLIBCXX_USE_CXX11_ABI=0 -I/home/disk1/s_opt/libtorch/include -I/home/disk1/s_opt/libtorch/include/torch/csrc/api/include
-#LDFLAGS+= -L/home/disk1/s_opt/libtorch/lib -ltorch -lc10 -lc10_cuda -lcudart -lgomp -lnvToolsExt
-
-flag = -Wl,-rpath=/home/disk1/s_opt/libtorch_CPP11/libtorch/lib
-COMMON += -I/home/disk1/s_opt/libtorch_CPP11/libtorch/include -I/home/disk1/s_opt/libtorch_CPP11/libtorch/include/torch/csrc/api/include
-LDFLAGS+= -L/home/disk1/s_opt/libtorch_CPP11/libtorch/lib -ltorch -lc10 -lc10_cuda -lcudart -lgomp -lnvToolsExt
-
-#COMMON = -std=c++11
-#COMMON= -I/home/disk1/data/s_software/CPP11_torch/libtorch/include -std=c++11
-#COMMON= -D_GLIBCXX_USE_CXX11_ABI=0 -I/home/disk1/data/s_software/CPP11_torch/libtorch/include -std=c++11
-#LDFLAGS+= -L/home/disk1/data/s_software/CPP11_torch/libtorch/lib -ltorch -lc10 -lc10_cuda -lcudart -lgomp -lnvToolsExt
-
-
-ifeq ($(DEBUG), 1)
-OPTS=-O0 -g
-endif
-
-CFLAGS+=$(OPTS)
-
-
-COMMON+= -DGPU -I/usr/local/cuda/include/
-
-ifeq ($(CUDNN), 1)
-COMMON+= -DCUDNN
-CFLAGS+= -DCUDNN
-LDFLAGS+= -lcudnn
-endif
-
-OBJ=test.o reid_feature.o
-
-ifeq ($(GPU), 1)
-LDFLAGS+= -lstdc++
-OBJ+=
-endif
-
-OBJS = $(addprefix $(OBJDIR), $(OBJ))
-DEPS = $(wildcard */*.h) Makefile
-
-all: obj $(EXEC)
-#all: obj results $(SLIB) $(ALIB) $(EXEC)
-
-
-$(EXEC): $(OBJS)
- $(CC) $(COMMON) $(CFLAGS) $^ -o $@ $(LDFLAGS) $(ALIB) $(flag)
-
-$(OBJDIR)%.o: %.cpp $(DEPS)
- $(CPP) $(COMMON) $(CFLAGS) -c $< -o $@
-
-$(OBJDIR)%.o: %.c $(DEPS)
- $(CC) $(COMMON) $(CFLAGS) -c $< -o $@
-
-$(OBJDIR)%.o: %.cu $(DEPS)
- $(NVCC) $(ARCH) $(COMMON) --compiler-options "$(CFLAGS)" -c $< -o $@
-
-obj:
- mkdir -p obj
-
-.PHONY: clean
-
-clean:
- rm -rf $(OBJS) $(EXEC) $(EXECOBJ) $(OBJDIR)/*
diff --git a/README.md b/README.md
index f4c8be8..7ea9697 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,48 @@
-## reID
+鏁版嵁璇存槑璺緞锛�
+https://github.com/JDAI-CV/fast-reid/tree/master/datasets
+澶ф濡備笅锛�
-浜轰綋閲嶈瘑鍒�
+璁粌
+
+./tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml MODEL.DEVICE "cuda:0"
+
+璇勪及锛�
+python tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --eval-only \
+MODEL.WEIGHTS /path/to/checkpoint_file MODEL.DEVICE "cuda:0"
+
+娴嬭瘯锛�
+python tools/test_net.py --img_a1 a_1.jpg --img_a2 a_2.jpg --img_b1 b_1.jpg --img_b2 b_2.jpg --config-file ./configs/Market1501/bagtricks_R101-ibn.yml MODEL.WEIGHTS ../market_bot_R101-ibn.pth MODEL.DEVICE "cuda:6"
+
+鎺ㄧ悊锛�
+python tools/inference_net.py --img_a1 /data/disk1/workspace/05_DarknetSort/02_humanID/1.jpg --img_a2 /data/disk1/workspace/05_DarknetSort/02_humanID/21.jpg --img_b1 b_1.jpg --img_b2 b_2.jpg --config-file ../01_fast-reid/fast-reid/configs/Market1501/bagtricks_R101-ibn.yml MODEL.WEIGHTS ../market_bot_R101-ibn.pth MODEL.DEVICE "cuda:6"
+
+
+杞琽nnx
+python tools/03_py2onnx.py --config-file ../fast-reid-master/configs/Market1501/bagtricks_R101-ibn.yml MODEL.WEIGHTS /data/disk1/project/model_dump/01_reid/market_bot_R101-ibn.pth MODEL.DEVICE "cuda:0"
+
+python tools/03_py2onnx.py --config-file ./configs/Market1501/bagtricks_R101-ibn.yml MODEL.WEIGHTS ../market_bot_R101-ibn.pth MODEL.DEVICE "cuda:0"
+
+杞瑃rt
+python tools/03_check_onnx.py
+
+./trtexec --onnx=/data/disk1/project/data/01_reid/02_changchuang/human_direction2.onnx --shapes=the_input:1x224x224x3 --workspace=4096 --saveEngine=/data/disk1/project/data/01_reid/02_changchuang/person_count_engine.trt
+
+
+杩愯
+python tools/04_trt_inference.py --model_path 4batch_fp16_True.trt
+
+
+瀹夎瑕佹眰锛氫互鍙婃墍闇�姹傜殑鍖�
+
+Linux or macOS with python 鈮� 3.6
+PyTorch 鈮� 1.6
+torchvision that matches the Pytorch installation. You can install them together at pytorch.org to make sure of this.
+yacs
+Cython (optional to compile evaluation code)
+tensorboard (needed for visualization): pip install tensorboard
+gdown (for automatically downloading pre-train model)
+sklearn
+termcolor
+tabulate
+faiss pip install faiss-cpu
diff --git a/config/__init__.py b/config/__init__.py
new file mode 100644
index 0000000..37e5cb6
--- /dev/null
+++ b/config/__init__.py
@@ -0,0 +1,9 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2020/10/26 17:40
+# @Author : Scheaven
+# @File : __init__.py.py
+# @description:
+
+from .config import CfgNode, get_cfg
+from .defaults import _C as cfg
diff --git a/config/__pycache__/__init__.cpython-37.pyc b/config/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..c77849f
--- /dev/null
+++ b/config/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/config/__pycache__/__init__.cpython-38.pyc b/config/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..f8290d7
--- /dev/null
+++ b/config/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/config/__pycache__/config.cpython-37.pyc b/config/__pycache__/config.cpython-37.pyc
new file mode 100644
index 0000000..1d6dbd8
--- /dev/null
+++ b/config/__pycache__/config.cpython-37.pyc
Binary files differ
diff --git a/config/__pycache__/config.cpython-38.pyc b/config/__pycache__/config.cpython-38.pyc
new file mode 100644
index 0000000..4f03814
--- /dev/null
+++ b/config/__pycache__/config.cpython-38.pyc
Binary files differ
diff --git a/config/__pycache__/defaults.cpython-38.pyc b/config/__pycache__/defaults.cpython-38.pyc
new file mode 100644
index 0000000..8b85b95
--- /dev/null
+++ b/config/__pycache__/defaults.cpython-38.pyc
Binary files differ
diff --git a/config/config.py b/config/config.py
new file mode 100644
index 0000000..85831f5
--- /dev/null
+++ b/config/config.py
@@ -0,0 +1,161 @@
+# encoding: utf-8
+"""
+@author: l1aoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import logging
+import os
+from typing import Any
+
+import yaml
+from yacs.config import CfgNode as _CfgNode
+
+from utils.file_io import PathManager
+
+BASE_KEY = "_BASE_"
+
+
+class CfgNode(_CfgNode):
+ """
+ Our own extended version of :class:`yacs.config.CfgNode`.
+ It contains the following extra features:
+ 1. The :meth:`merge_from_file` method supports the "_BASE_" key,
+ which allows the new CfgNode to inherit all the attributes from the
+ base configuration file.
+ 2. Keys that start with "COMPUTED_" are treated as insertion-only
+ "computed" attributes. They can be inserted regardless of whether
+ the CfgNode is frozen or not.
+ 3. With "allow_unsafe=True", it supports pyyaml tags that evaluate
+ expressions in config. See examples in
+ https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types
+ Note that this may lead to arbitrary code execution: you must not
+ load a config file from untrusted sources before manually inspecting
+ the content of the file.
+ """
+
+ @staticmethod
+ def load_yaml_with_base(filename: str, allow_unsafe: bool = False):
+ """
+ Just like `yaml.load(open(filename))`, but inherit attributes from its
+ `_BASE_`.
+ Args:
+ filename (str): the file name of the current config. Will be used to
+ find the base config file.
+ allow_unsafe (bool): whether to allow loading the config file with
+ `yaml.unsafe_load`.
+ Returns:
+ (dict): the loaded yaml
+ """
+ with PathManager.open(filename, "r") as f:
+ try:
+ cfg = yaml.safe_load(f)
+ except yaml.constructor.ConstructorError:
+ if not allow_unsafe:
+ raise
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ "Loading config {} with yaml.unsafe_load. Your machine may "
+ "be at risk if the file contains malicious content.".format(
+ filename
+ )
+ )
+ f.close()
+ with open(filename, "r") as f:
+ cfg = yaml.unsafe_load(f)
+
+ def merge_a_into_b(a, b):
+ # merge dict a into dict b. values in a will overwrite b.
+ for k, v in a.items():
+ if isinstance(v, dict) and k in b:
+ assert isinstance(
+ b[k], dict
+ ), "Cannot inherit key '{}' from base!".format(k)
+ merge_a_into_b(v, b[k])
+ else:
+ b[k] = v
+
+ if BASE_KEY in cfg:
+ base_cfg_file = cfg[BASE_KEY]
+ if base_cfg_file.startswith("~"):
+ base_cfg_file = os.path.expanduser(base_cfg_file)
+ if not any(
+ map(base_cfg_file.startswith, ["/", "https://", "http://"])
+ ):
+ # the path to base cfg is relative to the config file itself.
+ base_cfg_file = os.path.join(
+ os.path.dirname(filename), base_cfg_file
+ )
+ base_cfg = CfgNode.load_yaml_with_base(
+ base_cfg_file, allow_unsafe=allow_unsafe
+ )
+ del cfg[BASE_KEY]
+
+ merge_a_into_b(cfg, base_cfg)
+ return base_cfg
+ return cfg
+
+ def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = False):
+ """
+ Merge configs from a given yaml file.
+ Args:
+ cfg_filename: the file name of the yaml config.
+ allow_unsafe: whether to allow loading the config file with
+ `yaml.unsafe_load`.
+ """
+ loaded_cfg = CfgNode.load_yaml_with_base(
+ cfg_filename, allow_unsafe=allow_unsafe
+ )
+ loaded_cfg = type(self)(loaded_cfg)
+ # self.merge_from_other_cfg(loaded_cfg)
+
+ # Forward the following calls to base, but with a check on the BASE_KEY.
+ def merge_from_other_cfg(self, cfg_other):
+ """
+ Args:
+ cfg_other (CfgNode): configs to merge from.
+ """
+ assert (
+ BASE_KEY not in cfg_other
+ ), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
+
+
+ return super().merge_from_other_cfg(cfg_other)
+
+ def merge_from_list(self, cfg_list: list):
+ """
+ Args:
+ cfg_list (list): list of configs to merge from.
+ """
+ keys = set(cfg_list[0::2])
+ assert (
+ BASE_KEY not in keys
+ ), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
+ return super().merge_from_list(cfg_list)
+
+ def __setattr__(self, name: str, val: Any):
+ if name.startswith("COMPUTED_"):
+ if name in self:
+ old_val = self[name]
+ if old_val == val:
+ return
+ raise KeyError(
+ "Computed attributed '{}' already exists "
+ "with a different value! old={}, new={}.".format(
+ name, old_val, val
+ )
+ )
+ self[name] = val
+ else:
+ super().__setattr__(name, val)
+
+
+def get_cfg() -> CfgNode:
+ """
+ Get a copy of the default config.
+ Returns:
+ a fastreid CfgNode instance.
+ """
+ from .defaults import _C
+
+ return _C.clone()
diff --git a/config/defaults.py b/config/defaults.py
new file mode 100644
index 0000000..61651a5
--- /dev/null
+++ b/config/defaults.py
@@ -0,0 +1,273 @@
+from .config import CfgNode as CN
+
+# -----------------------------------------------------------------------------
+# Convention about Training / Test specific parameters
+# -----------------------------------------------------------------------------
+# Whenever an argument can be either used for training or for testing, the
+# corresponding name will be post-fixed by a _TRAIN for a training parameter,
+# or _TEST for a test-specific parameter.
+# For example, the number of images during training will be
+# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
+# IMAGES_PER_BATCH_TEST
+
+# -----------------------------------------------------------------------------
+# Config definition
+# -----------------------------------------------------------------------------
+
+_C = CN()
+
+# -----------------------------------------------------------------------------
+# MODEL
+# -----------------------------------------------------------------------------
+_C.MODEL = CN()
+_C.MODEL.DEVICE = "cuda"
+_C.MODEL.META_ARCHITECTURE = 'Baseline'
+_C.MODEL.FREEZE_LAYERS = ['']
+
+# ---------------------------------------------------------------------------- #
+# Backbone options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.BACKBONE = CN()
+
+_C.MODEL.BACKBONE.NAME = "build_resnet_backbone"
+_C.MODEL.BACKBONE.DEPTH = "50x"
+_C.MODEL.BACKBONE.LAST_STRIDE = 1
+# Backbone feature dimension
+_C.MODEL.BACKBONE.FEAT_DIM = 2048
+# Normalization method for the convolution layers.
+_C.MODEL.BACKBONE.NORM = "BN"
+# If use IBN block in backbone
+_C.MODEL.BACKBONE.WITH_IBN = False
+# If use SE block in backbone
+_C.MODEL.BACKBONE.WITH_SE = False
+# If use Non-local block in backbone
+_C.MODEL.BACKBONE.WITH_NL = False
+# If use ImageNet pretrain model
+_C.MODEL.BACKBONE.PRETRAIN = True
+# Pretrain model path
+_C.MODEL.BACKBONE.PRETRAIN_PATH = ''
+
+# ---------------------------------------------------------------------------- #
+# REID HEADS options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.HEADS = CN()
+_C.MODEL.HEADS.NAME = "EmbeddingHead"
+# Normalization method for the convolution layers.
+_C.MODEL.HEADS.NORM = "BN"
+# Number of identity
+_C.MODEL.HEADS.NUM_CLASSES = 0
+# Embedding dimension in head
+_C.MODEL.HEADS.EMBEDDING_DIM = 0
+# If use BNneck in embedding
+_C.MODEL.HEADS.WITH_BNNECK = True
+# Triplet feature using feature before(after) bnneck
+_C.MODEL.HEADS.NECK_FEAT = "before" # options: before, after
+# Pooling layer type
+_C.MODEL.HEADS.POOL_LAYER = "avgpool"
+
+# Classification layer type
+_C.MODEL.HEADS.CLS_LAYER = "linear" # "arcSoftmax" or "circleSoftmax"
+
+# Margin and Scale for margin-based classification layer
+_C.MODEL.HEADS.MARGIN = 0.15
+_C.MODEL.HEADS.SCALE = 128
+
+# ---------------------------------------------------------------------------- #
+# REID LOSSES options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.LOSSES = CN()
+_C.MODEL.LOSSES.NAME = ("CrossEntropyLoss",)
+
+# Cross Entropy Loss options
+_C.MODEL.LOSSES.CE = CN()
+# if epsilon == 0, it means no label smooth regularization,
+# if epsilon == -1, it means adaptive label smooth regularization
+_C.MODEL.LOSSES.CE.EPSILON = 0.0
+_C.MODEL.LOSSES.CE.ALPHA = 0.2
+_C.MODEL.LOSSES.CE.SCALE = 1.0
+
+# Triplet Loss options
+_C.MODEL.LOSSES.TRI = CN()
+_C.MODEL.LOSSES.TRI.MARGIN = 0.3
+_C.MODEL.LOSSES.TRI.NORM_FEAT = False
+_C.MODEL.LOSSES.TRI.HARD_MINING = True
+_C.MODEL.LOSSES.TRI.SCALE = 1.0
+
+# Circle Loss options
+_C.MODEL.LOSSES.CIRCLE = CN()
+_C.MODEL.LOSSES.CIRCLE.MARGIN = 0.25
+_C.MODEL.LOSSES.CIRCLE.ALPHA = 128
+_C.MODEL.LOSSES.CIRCLE.SCALE = 1.0
+
+# Focal Loss options
+_C.MODEL.LOSSES.FL = CN()
+_C.MODEL.LOSSES.FL.ALPHA = 0.25
+_C.MODEL.LOSSES.FL.GAMMA = 2
+_C.MODEL.LOSSES.FL.SCALE = 1.0
+
+# Path to a checkpoint file to be loaded to the model. You can find available models in the model zoo.
+_C.MODEL.WEIGHTS = ""
+
+# Values to be used for image normalization
+_C.MODEL.PIXEL_MEAN = [0.485*255, 0.456*255, 0.406*255]
+# Values to be used for image normalization
+_C.MODEL.PIXEL_STD = [0.229*255, 0.224*255, 0.225*255]
+
+
+# -----------------------------------------------------------------------------
+# INPUT
+# -----------------------------------------------------------------------------
+_C.INPUT = CN()
+# Size of the image during training
+_C.INPUT.SIZE_TRAIN = [256, 128]
+# Size of the image during test
+_C.INPUT.SIZE_TEST = [256, 128] #鍙傛暟鏄紙h,w锛�
+
+# Random probability for image horizontal flip
+_C.INPUT.DO_FLIP = True
+_C.INPUT.FLIP_PROB = 0.5
+
+# Value of padding size
+_C.INPUT.DO_PAD = True
+_C.INPUT.PADDING_MODE = 'constant'
+_C.INPUT.PADDING = 10
+# Random color jitter
+_C.INPUT.CJ = CN()
+_C.INPUT.CJ.ENABLED = False
+_C.INPUT.CJ.PROB = 0.8
+_C.INPUT.CJ.BRIGHTNESS = 0.15
+_C.INPUT.CJ.CONTRAST = 0.15
+_C.INPUT.CJ.SATURATION = 0.1
+_C.INPUT.CJ.HUE = 0.1
+# Auto augmentation
+_C.INPUT.DO_AUTOAUG = False
+# Augmix augmentation
+_C.INPUT.DO_AUGMIX = False
+# Random Erasing
+_C.INPUT.REA = CN()
+_C.INPUT.REA.ENABLED = False
+_C.INPUT.REA.PROB = 0.5
+_C.INPUT.REA.MEAN = [0.596*255, 0.558*255, 0.497*255] # [0.485*255, 0.456*255, 0.406*255]
+# Random Patch
+_C.INPUT.RPT = CN()
+_C.INPUT.RPT.ENABLED = False
+_C.INPUT.RPT.PROB = 0.5
+
+# -----------------------------------------------------------------------------
+# Dataset
+# -----------------------------------------------------------------------------
+_C.DATASETS = CN()
+# List of the dataset names for training
+_C.DATASETS.NAMES = ("Market1501",)
+# List of the dataset names for testing
+_C.DATASETS.TESTS = ("Market1501",)
+# Combine trainset and testset joint training
+_C.DATASETS.COMBINEALL = False
+
+# -----------------------------------------------------------------------------
+# DataLoader
+# -----------------------------------------------------------------------------
+_C.DATALOADER = CN()
+# P/K Sampler for data loading
+_C.DATALOADER.PK_SAMPLER = True
+# Naive sampler which don't consider balanced identity sampling
+_C.DATALOADER.NAIVE_WAY = False
+# Number of instance for each person
+_C.DATALOADER.NUM_INSTANCE = 4
+_C.DATALOADER.NUM_WORKERS = 8
+
+# ---------------------------------------------------------------------------- #
+# Solver
+# ---------------------------------------------------------------------------- #
+_C.SOLVER = CN()
+
+# AUTOMATIC MIXED PRECISION
+_C.SOLVER.AMP_ENABLED = False
+
+# Optimizer
+_C.SOLVER.OPT = "Adam"
+
+_C.SOLVER.MAX_ITER = 120
+
+_C.SOLVER.BASE_LR = 3e-4
+_C.SOLVER.BIAS_LR_FACTOR = 1.
+_C.SOLVER.HEADS_LR_FACTOR = 1.
+
+_C.SOLVER.MOMENTUM = 0.9
+
+_C.SOLVER.WEIGHT_DECAY = 0.0005
+_C.SOLVER.WEIGHT_DECAY_BIAS = 0.
+
+# Multi-step learning rate options
+_C.SOLVER.SCHED = "WarmupMultiStepLR"
+_C.SOLVER.GAMMA = 0.1
+_C.SOLVER.STEPS = [30, 55]
+
+# Cosine annealing learning rate options
+_C.SOLVER.DELAY_ITERS = 0
+_C.SOLVER.ETA_MIN_LR = 3e-7
+
+# Warmup options
+_C.SOLVER.WARMUP_FACTOR = 0.1
+_C.SOLVER.WARMUP_ITERS = 10
+_C.SOLVER.WARMUP_METHOD = "linear"
+
+_C.SOLVER.FREEZE_ITERS = 0
+
+# SWA options
+_C.SOLVER.SWA = CN()
+_C.SOLVER.SWA.ENABLED = False
+_C.SOLVER.SWA.ITER = 10
+_C.SOLVER.SWA.PERIOD = 2
+_C.SOLVER.SWA.LR_FACTOR = 10.
+_C.SOLVER.SWA.ETA_MIN_LR = 3.5e-6
+_C.SOLVER.SWA.LR_SCHED = False
+
+_C.SOLVER.CHECKPOINT_PERIOD = 20
+
+# Number of images per batch across all machines.
+# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
+# see 2 images per batch
+_C.SOLVER.IMS_PER_BATCH = 64
+
+# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
+# see 2 images per batch
+_C.TEST = CN()
+
+_C.TEST.EVAL_PERIOD = 20
+
+# Number of images per batch in one process.
+_C.TEST.IMS_PER_BATCH = 64
+_C.TEST.METRIC = "cosine"
+_C.TEST.ROC_ENABLED = False
+
+# Average query expansion
+_C.TEST.AQE = CN()
+_C.TEST.AQE.ENABLED = False
+_C.TEST.AQE.ALPHA = 3.0
+_C.TEST.AQE.QE_TIME = 1
+_C.TEST.AQE.QE_K = 5
+
+# Re-rank
+_C.TEST.RERANK = CN()
+_C.TEST.RERANK.ENABLED = False
+_C.TEST.RERANK.K1 = 20
+_C.TEST.RERANK.K2 = 6
+_C.TEST.RERANK.LAMBDA = 0.3
+
+# Precise batchnorm
+_C.TEST.PRECISE_BN = CN()
+_C.TEST.PRECISE_BN.ENABLED = False
+_C.TEST.PRECISE_BN.DATASET = 'Market1501'
+_C.TEST.PRECISE_BN.NUM_ITER = 300
+
+# ---------------------------------------------------------------------------- #
+# Misc options
+# ---------------------------------------------------------------------------- #
+_C.OUTPUT_DIR = "logs/"
+
+# Benchmark different cudnn algorithms.
+# If input images have very different sizes, this option will have large overhead
+# for about 10k iterations. It usually hurts total time, but can benefit for certain models.
+# If input images have the same or similar sizes, benchmark is often helpful.
+_C.CUDNN_BENCHMARK = False
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000..dbf5adc
--- /dev/null
+++ b/data/__init__.py
@@ -0,0 +1,7 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2020/10/26 16:24
+# @Author : Scheaven
+# @File : __init__.py.py
+# @description:
+
diff --git a/data/__pycache__/__init__.cpython-37.pyc b/data/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..36ec137
--- /dev/null
+++ b/data/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/data/__pycache__/__init__.cpython-38.pyc b/data/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..dda59a4
--- /dev/null
+++ b/data/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/data/__pycache__/data_utils.cpython-37.pyc b/data/__pycache__/data_utils.cpython-37.pyc
new file mode 100644
index 0000000..332f74a
--- /dev/null
+++ b/data/__pycache__/data_utils.cpython-37.pyc
Binary files differ
diff --git a/data/__pycache__/data_utils.cpython-38.pyc b/data/__pycache__/data_utils.cpython-38.pyc
new file mode 100644
index 0000000..da72b54
--- /dev/null
+++ b/data/__pycache__/data_utils.cpython-38.pyc
Binary files differ
diff --git a/data/data_utils.py b/data/data_utils.py
new file mode 100644
index 0000000..3b65111
--- /dev/null
+++ b/data/data_utils.py
@@ -0,0 +1,45 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+import numpy as np
+from PIL import Image, ImageOps
+
+from utils.file_io import PathManager
+
+
+def read_image(file_name, format=None):
+ """
+ Read an image into the given format.
+ Will apply rotation and flipping if the image has such exif information.
+ Args:
+ file_name (str): image file path
+ format (str): one of the supported image modes in PIL, or "BGR"
+ Returns:
+ image (np.ndarray): an HWC image
+ """
+ with PathManager.open(file_name, "rb") as f:
+ image = Image.open(f)
+
+ # capture and ignore this bug: https://github.com/python-pillow/Pillow/issues/3973
+ try:
+ image = ImageOps.exif_transpose(image)
+ except Exception:
+ pass
+
+ if format is not None:
+ # PIL only supports RGB, so convert to RGB and flip channels over below
+ conversion_format = format
+ if format == "BGR":
+ conversion_format = "RGB"
+ image = image.convert(conversion_format)
+ image = np.asarray(image)
+ if format == "BGR":
+ # flip channels if needed
+ image = image[:, :, ::-1]
+ # PIL squeezes out the channel dimension for "L", so make it HWC
+ if format == "L":
+ image = np.expand_dims(image, -1)
+ image = Image.fromarray(image)
+ return image
diff --git a/data/transforms/__init__.py b/data/transforms/__init__.py
new file mode 100644
index 0000000..45c7d13
--- /dev/null
+++ b/data/transforms/__init__.py
@@ -0,0 +1,10 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2020/10/26 16:24
+# @Author : Scheaven
+# @File : __init__.py.py
+# @description:
+
+from .build import build_transforms
+from .transforms import *
+from .autoaugment import *
\ No newline at end of file
diff --git a/data/transforms/__pycache__/__init__.cpython-38.pyc b/data/transforms/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..ae2715d
--- /dev/null
+++ b/data/transforms/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/data/transforms/__pycache__/autoaugment.cpython-38.pyc b/data/transforms/__pycache__/autoaugment.cpython-38.pyc
new file mode 100644
index 0000000..78532f6
--- /dev/null
+++ b/data/transforms/__pycache__/autoaugment.cpython-38.pyc
Binary files differ
diff --git a/data/transforms/__pycache__/build.cpython-38.pyc b/data/transforms/__pycache__/build.cpython-38.pyc
new file mode 100644
index 0000000..642959c
--- /dev/null
+++ b/data/transforms/__pycache__/build.cpython-38.pyc
Binary files differ
diff --git a/data/transforms/__pycache__/functional.cpython-38.pyc b/data/transforms/__pycache__/functional.cpython-38.pyc
new file mode 100644
index 0000000..d7e6a43
--- /dev/null
+++ b/data/transforms/__pycache__/functional.cpython-38.pyc
Binary files differ
diff --git a/data/transforms/__pycache__/transforms.cpython-38.pyc b/data/transforms/__pycache__/transforms.cpython-38.pyc
new file mode 100644
index 0000000..9a7430f
--- /dev/null
+++ b/data/transforms/__pycache__/transforms.cpython-38.pyc
Binary files differ
diff --git a/data/transforms/autoaugment.py b/data/transforms/autoaugment.py
new file mode 100644
index 0000000..487e70d
--- /dev/null
+++ b/data/transforms/autoaugment.py
@@ -0,0 +1,812 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+""" AutoAugment, RandAugment, and AugMix for PyTorch
+This code implements the searched ImageNet policies with various tweaks and improvements and
+does not include any of the search code.
+AA and RA Implementation adapted from:
+ https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
+AugMix adapted from:
+ https://github.com/google-research/augmix
+Papers:
+ AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
+ Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
+ RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
+ AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
+Hacked together by Ross Wightman
+"""
+import math
+import random
+import re
+
+import PIL
+import numpy as np
+from PIL import Image, ImageOps, ImageEnhance
+
+_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
+
+_FILL = (128, 128, 128)
+
+# This signifies the max integer that the controller RNN could predict for the
+# augmentation scheme.
+_MAX_LEVEL = 10.
+
+_HPARAMS_DEFAULT = dict(
+ translate_const=57,
+ img_mean=_FILL,
+)
+
+_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
+
+
+def _interpolation(kwargs):
+ interpolation = kwargs.pop('resample', Image.BILINEAR)
+ if isinstance(interpolation, (list, tuple)):
+ return random.choice(interpolation)
+ else:
+ return interpolation
+
+
+def _check_args_tf(kwargs):
+ if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
+ kwargs.pop('fillcolor')
+ kwargs['resample'] = _interpolation(kwargs)
+
+
+def shear_x(img, factor, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
+
+
+def shear_y(img, factor, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
+
+
+def translate_x_rel(img, pct, **kwargs):
+ pixels = pct * img.size[0]
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
+
+
+def translate_y_rel(img, pct, **kwargs):
+ pixels = pct * img.size[1]
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
+
+
+def translate_x_abs(img, pixels, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
+
+
+def translate_y_abs(img, pixels, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
+
+
+def rotate(img, degrees, **kwargs):
+ _check_args_tf(kwargs)
+ if _PIL_VER >= (5, 2):
+ return img.rotate(degrees, **kwargs)
+ elif _PIL_VER >= (5, 0):
+ w, h = img.size
+ post_trans = (0, 0)
+ rotn_center = (w / 2.0, h / 2.0)
+ angle = -math.radians(degrees)
+ matrix = [
+ round(math.cos(angle), 15),
+ round(math.sin(angle), 15),
+ 0.0,
+ round(-math.sin(angle), 15),
+ round(math.cos(angle), 15),
+ 0.0,
+ ]
+
+ def transform(x, y, matrix):
+ (a, b, c, d, e, f) = matrix
+ return a * x + b * y + c, d * x + e * y + f
+
+ matrix[2], matrix[5] = transform(
+ -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
+ )
+ matrix[2] += rotn_center[0]
+ matrix[5] += rotn_center[1]
+ return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
+ else:
+ return img.rotate(degrees, resample=kwargs['resample'])
+
+
+def auto_contrast(img, **__):
+ return ImageOps.autocontrast(img)
+
+
+def invert(img, **__):
+ return ImageOps.invert(img)
+
+
+def equalize(img, **__):
+ return ImageOps.equalize(img)
+
+
+def solarize(img, thresh, **__):
+ return ImageOps.solarize(img, thresh)
+
+
+def solarize_add(img, add, thresh=128, **__):
+ lut = []
+ for i in range(256):
+ if i < thresh:
+ lut.append(min(255, i + add))
+ else:
+ lut.append(i)
+ if img.mode in ("L", "RGB"):
+ if img.mode == "RGB" and len(lut) == 256:
+ lut = lut + lut + lut
+ return img.point(lut)
+ else:
+ return img
+
+
+def posterize(img, bits_to_keep, **__):
+ if bits_to_keep >= 8:
+ return img
+ return ImageOps.posterize(img, bits_to_keep)
+
+
+def contrast(img, factor, **__):
+ return ImageEnhance.Contrast(img).enhance(factor)
+
+
+def color(img, factor, **__):
+ return ImageEnhance.Color(img).enhance(factor)
+
+
+def brightness(img, factor, **__):
+ return ImageEnhance.Brightness(img).enhance(factor)
+
+
+def sharpness(img, factor, **__):
+ return ImageEnhance.Sharpness(img).enhance(factor)
+
+
+def _randomly_negate(v):
+ """With 50% prob, negate the value"""
+ return -v if random.random() > 0.5 else v
+
+
+def _rotate_level_to_arg(level, _hparams):
+ # range [-30, 30]
+ level = (level / _MAX_LEVEL) * 30.
+ level = _randomly_negate(level)
+ return level,
+
+
+def _enhance_level_to_arg(level, _hparams):
+ # range [0.1, 1.9]
+ return (level / _MAX_LEVEL) * 1.8 + 0.1,
+
+
+def _enhance_increasing_level_to_arg(level, _hparams):
+ # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
+ # range [0.1, 1.9]
+ level = (level / _MAX_LEVEL) * .9
+ level = 1.0 + _randomly_negate(level)
+ return level,
+
+
+def _shear_level_to_arg(level, _hparams):
+ # range [-0.3, 0.3]
+ level = (level / _MAX_LEVEL) * 0.3
+ level = _randomly_negate(level)
+ return level,
+
+
+def _translate_abs_level_to_arg(level, hparams):
+ translate_const = hparams['translate_const']
+ level = (level / _MAX_LEVEL) * float(translate_const)
+ level = _randomly_negate(level)
+ return level,
+
+
+def _translate_rel_level_to_arg(level, hparams):
+ # default range [-0.45, 0.45]
+ translate_pct = hparams.get('translate_pct', 0.45)
+ level = (level / _MAX_LEVEL) * translate_pct
+ level = _randomly_negate(level)
+ return level,
+
+
+def _posterize_level_to_arg(level, _hparams):
+ # As per Tensorflow TPU EfficientNet impl
+ # range [0, 4], 'keep 0 up to 4 MSB of original image'
+ # intensity/severity of augmentation decreases with level
+ return int((level / _MAX_LEVEL) * 4),
+
+
+def _posterize_increasing_level_to_arg(level, hparams):
+ # As per Tensorflow models research and UDA impl
+ # range [4, 0], 'keep 4 down to 0 MSB of original image',
+ # intensity/severity of augmentation increases with level
+ return 4 - _posterize_level_to_arg(level, hparams)[0],
+
+
+def _posterize_original_level_to_arg(level, _hparams):
+ # As per original AutoAugment paper description
+ # range [4, 8], 'keep 4 up to 8 MSB of image'
+ # intensity/severity of augmentation decreases with level
+ return int((level / _MAX_LEVEL) * 4) + 4,
+
+
+def _solarize_level_to_arg(level, _hparams):
+ # range [0, 256]
+ # intensity/severity of augmentation decreases with level
+ return int((level / _MAX_LEVEL) * 256),
+
+
+def _solarize_increasing_level_to_arg(level, _hparams):
+ # range [0, 256]
+ # intensity/severity of augmentation increases with level
+ return 256 - _solarize_level_to_arg(level, _hparams)[0],
+
+
+def _solarize_add_level_to_arg(level, _hparams):
+ # range [0, 110]
+ return int((level / _MAX_LEVEL) * 110),
+
+
+LEVEL_TO_ARG = {
+ 'AutoContrast': None,
+ 'Equalize': None,
+ 'Invert': None,
+ 'Rotate': _rotate_level_to_arg,
+ # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
+ 'Posterize': _posterize_level_to_arg,
+ 'PosterizeIncreasing': _posterize_increasing_level_to_arg,
+ 'PosterizeOriginal': _posterize_original_level_to_arg,
+ 'Solarize': _solarize_level_to_arg,
+ 'SolarizeIncreasing': _solarize_increasing_level_to_arg,
+ 'SolarizeAdd': _solarize_add_level_to_arg,
+ 'Color': _enhance_level_to_arg,
+ 'ColorIncreasing': _enhance_increasing_level_to_arg,
+ 'Contrast': _enhance_level_to_arg,
+ 'ContrastIncreasing': _enhance_increasing_level_to_arg,
+ 'Brightness': _enhance_level_to_arg,
+ 'BrightnessIncreasing': _enhance_increasing_level_to_arg,
+ 'Sharpness': _enhance_level_to_arg,
+ 'SharpnessIncreasing': _enhance_increasing_level_to_arg,
+ 'ShearX': _shear_level_to_arg,
+ 'ShearY': _shear_level_to_arg,
+ 'TranslateX': _translate_abs_level_to_arg,
+ 'TranslateY': _translate_abs_level_to_arg,
+ 'TranslateXRel': _translate_rel_level_to_arg,
+ 'TranslateYRel': _translate_rel_level_to_arg,
+}
+
+NAME_TO_OP = {
+ 'AutoContrast': auto_contrast,
+ 'Equalize': equalize,
+ 'Invert': invert,
+ 'Rotate': rotate,
+ 'Posterize': posterize,
+ 'PosterizeIncreasing': posterize,
+ 'PosterizeOriginal': posterize,
+ 'Solarize': solarize,
+ 'SolarizeIncreasing': solarize,
+ 'SolarizeAdd': solarize_add,
+ 'Color': color,
+ 'ColorIncreasing': color,
+ 'Contrast': contrast,
+ 'ContrastIncreasing': contrast,
+ 'Brightness': brightness,
+ 'BrightnessIncreasing': brightness,
+ 'Sharpness': sharpness,
+ 'SharpnessIncreasing': sharpness,
+ 'ShearX': shear_x,
+ 'ShearY': shear_y,
+ 'TranslateX': translate_x_abs,
+ 'TranslateY': translate_y_abs,
+ 'TranslateXRel': translate_x_rel,
+ 'TranslateYRel': translate_y_rel,
+}
+
+
+class AugmentOp:
+
+ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
+ hparams = hparams or _HPARAMS_DEFAULT
+ self.aug_fn = NAME_TO_OP[name]
+ self.level_fn = LEVEL_TO_ARG[name]
+ self.prob = prob
+ self.magnitude = magnitude
+ self.hparams = hparams.copy()
+ self.kwargs = dict(
+ fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
+ resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
+ )
+
+ # If magnitude_std is > 0, we introduce some randomness
+ # in the usually fixed policy and sample magnitude from a normal distribution
+ # with mean `magnitude` and std-dev of `magnitude_std`.
+ # NOTE This is my own hack, being tested, not in papers or reference impls.
+ self.magnitude_std = self.hparams.get('magnitude_std', 0)
+
+ def __call__(self, img):
+ if self.prob < 1.0 and random.random() > self.prob:
+ return img
+ magnitude = self.magnitude
+ if self.magnitude_std and self.magnitude_std > 0:
+ magnitude = random.gauss(magnitude, self.magnitude_std)
+ magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
+ level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
+ return self.aug_fn(img, *level_args, **self.kwargs)
+
+
+def auto_augment_policy_v0(hparams):
+ # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
+ policy = [
+ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
+ [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
+ [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
+ [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
+ [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
+ [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
+ [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
+ [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
+ [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
+ [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
+ [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
+ [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
+ [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
+ [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
+ [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
+ [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
+ [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
+ [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
+ [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
+ [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
+ [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
+ [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
+ [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
+ [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
+ [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
+ ]
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+ return pc
+
+
+def auto_augment_policy_v0r(hparams):
+ # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
+ # in Google research implementation (number of bits discarded increases with magnitude)
+ policy = [
+ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
+ [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
+ [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
+ [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
+ [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
+ [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
+ [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
+ [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
+ [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
+ [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
+ [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
+ [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
+ [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
+ [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
+ [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
+ [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
+ [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
+ [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
+ [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
+ [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
+ [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
+ [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
+ [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
+ [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
+ [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
+ ]
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+ return pc
+
+
+def auto_augment_policy_original(hparams):
+ # ImageNet policy from https://arxiv.org/abs/1805.09501
+ policy = [
+ [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+ [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+ [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
+ [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
+ [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
+ [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
+ [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
+ [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
+ [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
+ [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+ [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
+ [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
+ [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
+ [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
+ [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+ ]
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+ return pc
+
+
+def auto_augment_policy_originalr(hparams):
+ # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
+ policy = [
+ [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+ [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+ [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
+ [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
+ [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
+ [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
+ [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
+ [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
+ [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
+ [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+ [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
+ [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
+ [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
+ [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
+ [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+ ]
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+ return pc
+
+
+def auto_augment_policy(name="original"):
+ hparams = _HPARAMS_DEFAULT
+ if name == 'original':
+ return auto_augment_policy_original(hparams)
+ elif name == 'originalr':
+ return auto_augment_policy_originalr(hparams)
+ elif name == 'v0':
+ return auto_augment_policy_v0(hparams)
+ elif name == 'v0r':
+ return auto_augment_policy_v0r(hparams)
+ else:
+ assert False, 'Unknown AA policy (%s)' % name
+
+
+class AutoAugment:
+
+ def __init__(self, total_iter):
+ self.total_iter = total_iter
+ self.gamma = 0
+ self.policy = auto_augment_policy()
+
+ def __call__(self, img):
+ if random.uniform(0, 1) > self.gamma:
+ sub_policy = random.choice(self.policy)
+ self.gamma = min(1.0, self.gamma + 1.0 / self.total_iter)
+ for op in sub_policy:
+ img = op(img)
+ return img
+ else:
+ return img
+
+
+def auto_augment_transform(config_str, hparams):
+ """
+ Create a AutoAugment transform
+ :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
+ dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
+ The remaining sections, not order sepecific determine
+ 'mstd' - float std deviation of magnitude noise applied
+ Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
+ :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
+ :return: A PyTorch compatible Transform
+ """
+ config = config_str.split('-')
+ policy_name = config[0]
+ config = config[1:]
+ for c in config:
+ cs = re.split(r'(\d.*)', c)
+ if len(cs) < 2:
+ continue
+ key, val = cs[:2]
+ if key == 'mstd':
+ # noise param injected via hparams for now
+ hparams.setdefault('magnitude_std', float(val))
+ else:
+ assert False, 'Unknown AutoAugment config section'
+ aa_policy = auto_augment_policy(policy_name)
+ return AutoAugment(aa_policy)
+
+
+_RAND_TRANSFORMS = [
+ 'AutoContrast',
+ 'Equalize',
+ 'Invert',
+ 'Rotate',
+ 'Posterize',
+ 'Solarize',
+ 'SolarizeAdd',
+ 'Color',
+ 'Contrast',
+ 'Brightness',
+ 'Sharpness',
+ 'ShearX',
+ 'ShearY',
+ 'TranslateXRel',
+ 'TranslateYRel',
+ # 'Cutout' # NOTE I've implement this as random erasing separately
+]
+
+_RAND_INCREASING_TRANSFORMS = [
+ 'AutoContrast',
+ 'Equalize',
+ 'Invert',
+ 'Rotate',
+ 'PosterizeIncreasing',
+ 'SolarizeIncreasing',
+ 'SolarizeAdd',
+ 'ColorIncreasing',
+ 'ContrastIncreasing',
+ 'BrightnessIncreasing',
+ 'SharpnessIncreasing',
+ 'ShearX',
+ 'ShearY',
+ 'TranslateXRel',
+ 'TranslateYRel',
+ # 'Cutout' # NOTE I've implement this as random erasing separately
+]
+
+# These experimental weights are based loosely on the relative improvements mentioned in paper.
+# They may not result in increased performance, but could likely be tuned to so.
+_RAND_CHOICE_WEIGHTS_0 = {
+ 'Rotate': 0.3,
+ 'ShearX': 0.2,
+ 'ShearY': 0.2,
+ 'TranslateXRel': 0.1,
+ 'TranslateYRel': 0.1,
+ 'Color': .025,
+ 'Sharpness': 0.025,
+ 'AutoContrast': 0.025,
+ 'Solarize': .005,
+ 'SolarizeAdd': .005,
+ 'Contrast': .005,
+ 'Brightness': .005,
+ 'Equalize': .005,
+ 'Posterize': 0,
+ 'Invert': 0,
+}
+
+
+def _select_rand_weights(weight_idx=0, transforms=None):
+ transforms = transforms or _RAND_TRANSFORMS
+ assert weight_idx == 0 # only one set of weights currently
+ rand_weights = _RAND_CHOICE_WEIGHTS_0
+ probs = [rand_weights[k] for k in transforms]
+ probs /= np.sum(probs)
+ return probs
+
+
+def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
+ hparams = hparams or _HPARAMS_DEFAULT
+ transforms = transforms or _RAND_TRANSFORMS
+ return [AugmentOp(
+ name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
+
+
+class RandAugment:
+ def __init__(self, ops, num_layers=2, choice_weights=None):
+ self.ops = ops
+ self.num_layers = num_layers
+ self.choice_weights = choice_weights
+
+ def __call__(self, img):
+ # no replacement when using weighted choice
+ ops = np.random.choice(
+ self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
+ for op in ops:
+ img = op(img)
+ return img
+
+
+def rand_augment_transform(config_str, hparams):
+ """
+ Create a RandAugment transform
+ :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
+ dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
+ sections, not order sepecific determine
+ 'm' - integer magnitude of rand augment
+ 'n' - integer num layers (number of transform ops selected per image)
+ 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
+ 'mstd' - float std deviation of magnitude noise applied
+ 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
+ Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
+ 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
+ :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
+ :return: A PyTorch compatible Transform
+ """
+ magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
+ num_layers = 2 # default to 2 ops per image
+ weight_idx = None # default to no probability weights for op choice
+ transforms = _RAND_TRANSFORMS
+ config = config_str.split('-')
+ assert config[0] == 'rand'
+ config = config[1:]
+ for c in config:
+ cs = re.split(r'(\d.*)', c)
+ if len(cs) < 2:
+ continue
+ key, val = cs[:2]
+ if key == 'mstd':
+ # noise param injected via hparams for now
+ hparams.setdefault('magnitude_std', float(val))
+ elif key == 'inc':
+ if bool(val):
+ transforms = _RAND_INCREASING_TRANSFORMS
+ elif key == 'm':
+ magnitude = int(val)
+ elif key == 'n':
+ num_layers = int(val)
+ elif key == 'w':
+ weight_idx = int(val)
+ else:
+ assert False, 'Unknown RandAugment config section'
+ ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
+ choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
+ return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
+
+
+_AUGMIX_TRANSFORMS = [
+ 'AutoContrast',
+ 'ColorIncreasing', # not in paper
+ 'ContrastIncreasing', # not in paper
+ 'BrightnessIncreasing', # not in paper
+ 'SharpnessIncreasing', # not in paper
+ 'Equalize',
+ 'Rotate',
+ 'PosterizeIncreasing',
+ 'SolarizeIncreasing',
+ 'ShearX',
+ 'ShearY',
+ 'TranslateXRel',
+ 'TranslateYRel',
+]
+
+
+def augmix_ops(magnitude=10, hparams=None, transforms=None):
+ hparams = hparams or _HPARAMS_DEFAULT
+ transforms = transforms or _AUGMIX_TRANSFORMS
+ return [AugmentOp(
+ name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
+
+
+class AugMixAugment:
+ """ AugMix Transform
+ Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
+ From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
+ https://arxiv.org/abs/1912.02781
+ """
+
+ def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
+ self.ops = ops
+ self.alpha = alpha
+ self.width = width
+ self.depth = depth
+ self.blended = blended # blended mode is faster but not well tested
+
+ def _calc_blended_weights(self, ws, m):
+ ws = ws * m
+ cump = 1.
+ rws = []
+ for w in ws[::-1]:
+ alpha = w / cump
+ cump *= (1 - alpha)
+ rws.append(alpha)
+ return np.array(rws[::-1], dtype=np.float32)
+
+ def _apply_blended(self, img, mixing_weights, m):
+ # This is my first crack and implementing a slightly faster mixed augmentation. Instead
+ # of accumulating the mix for each chain in a Numpy array and then blending with original,
+ # it recomputes the blending coefficients and applies one PIL image blend per chain.
+ # TODO the results appear in the right ballpark but they differ by more than rounding.
+ img_orig = img.copy()
+ ws = self._calc_blended_weights(mixing_weights, m)
+ for w in ws:
+ depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
+ ops = np.random.choice(self.ops, depth, replace=True)
+ img_aug = img_orig # no ops are in-place, deep copy not necessary
+ for op in ops:
+ img_aug = op(img_aug)
+ img = Image.blend(img, img_aug, w)
+ return img
+
+ def _apply_basic(self, img, mixing_weights, m):
+ # This is a literal adaptation of the paper/official implementation without normalizations and
+ # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
+ # typical augmentation transforms, could use a GPU / Kornia implementation.
+ img_shape = img.size[0], img.size[1], len(img.getbands())
+ mixed = np.zeros(img_shape, dtype=np.float32)
+ for mw in mixing_weights:
+ depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
+ ops = np.random.choice(self.ops, depth, replace=True)
+ img_aug = img # no ops are in-place, deep copy not necessary
+ for op in ops:
+ img_aug = op(img_aug)
+ mixed += mw * np.asarray(img_aug, dtype=np.float32)
+ np.clip(mixed, 0, 255., out=mixed)
+ mixed = Image.fromarray(mixed.astype(np.uint8))
+ return Image.blend(img, mixed, m)
+
+ def __call__(self, img):
+ mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
+ m = np.float32(np.random.beta(self.alpha, self.alpha))
+ if self.blended:
+ mixed = self._apply_blended(img, mixing_weights, m)
+ else:
+ mixed = self._apply_basic(img, mixing_weights, m)
+ return mixed
+
+
+def augment_and_mix_transform(config_str, hparams):
+ """ Create AugMix PyTorch transform
+ :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
+ dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
+ sections, not order sepecific determine
+ 'm' - integer magnitude (severity) of augmentation mix (default: 3)
+ 'w' - integer width of augmentation chain (default: 3)
+ 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
+ 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
+ 'mstd' - float std deviation of magnitude noise applied (default: 0)
+ Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
+ :param hparams: Other hparams (kwargs) for the Augmentation transforms
+ :return: A PyTorch compatible Transform
+ """
+ magnitude = 3
+ width = 3
+ depth = -1
+ alpha = 1.
+ blended = False
+ config = config_str.split('-')
+ assert config[0] == 'augmix'
+ config = config[1:]
+ for c in config:
+ cs = re.split(r'(\d.*)', c)
+ if len(cs) < 2:
+ continue
+ key, val = cs[:2]
+ if key == 'mstd':
+ # noise param injected via hparams for now
+ hparams.setdefault('magnitude_std', float(val))
+ elif key == 'm':
+ magnitude = int(val)
+ elif key == 'w':
+ width = int(val)
+ elif key == 'd':
+ depth = int(val)
+ elif key == 'a':
+ alpha = float(val)
+ elif key == 'b':
+ blended = bool(val)
+ else:
+ assert False, 'Unknown AugMix config section'
+ ops = augmix_ops(magnitude=magnitude, hparams=hparams)
+ return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)
diff --git a/data/transforms/build.py b/data/transforms/build.py
new file mode 100644
index 0000000..7351aed
--- /dev/null
+++ b/data/transforms/build.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2020/10/26 16:25
+# @Author : Scheaven
+# @File : build.py
+# @description:
+
+import torchvision.transforms as T
+
+from .transforms import *
+from .autoaugment import AutoAugment
+
+
+def build_transforms(cfg, is_train=True):
+ res = []
+
+ if is_train:
+ size_train = cfg.INPUT.SIZE_TRAIN
+
+ # augmix augmentation
+ do_augmix = cfg.INPUT.DO_AUGMIX
+
+ # auto augmentation
+ do_autoaug = cfg.INPUT.DO_AUTOAUG
+ total_iter = cfg.SOLVER.MAX_ITER
+
+ # horizontal filp
+ do_flip = cfg.INPUT.DO_FLIP
+ flip_prob = cfg.INPUT.FLIP_PROB
+
+ # padding
+ do_pad = cfg.INPUT.DO_PAD
+ padding = cfg.INPUT.PADDING
+ padding_mode = cfg.INPUT.PADDING_MODE
+
+ # color jitter
+ do_cj = cfg.INPUT.CJ.ENABLED
+ cj_prob = cfg.INPUT.CJ.PROB
+ cj_brightness = cfg.INPUT.CJ.BRIGHTNESS
+ cj_contrast = cfg.INPUT.CJ.CONTRAST
+ cj_saturation = cfg.INPUT.CJ.SATURATION
+ cj_hue = cfg.INPUT.CJ.HUE
+
+ # random erasing
+ do_rea = cfg.INPUT.REA.ENABLED
+ rea_prob = cfg.INPUT.REA.PROB
+ rea_mean = cfg.INPUT.REA.MEAN
+ # random patch
+ do_rpt = cfg.INPUT.RPT.ENABLED
+ rpt_prob = cfg.INPUT.RPT.PROB
+
+ if do_autoaug:
+ res.append(AutoAugment(total_iter))
+ res.append(T.Resize(size_train, interpolation=3))
+ if do_flip:
+ res.append(T.RandomHorizontalFlip(p=flip_prob))
+ if do_pad:
+ res.extend([T.Pad(padding, padding_mode=padding_mode),
+ T.RandomCrop(size_train)])
+ if do_cj:
+ res.append(T.RandomApply([T.ColorJitter(cj_brightness, cj_contrast, cj_saturation, cj_hue)], p=cj_prob))
+ if do_augmix:
+ res.append(AugMix())
+ if do_rea:
+ res.append(RandomErasing(probability=rea_prob, mean=rea_mean))
+ if do_rpt:
+ res.append(RandomPatch(prob_happen=rpt_prob))
+ else:
+ size_test = cfg.INPUT.SIZE_TEST
+ res.append(T.Resize(size_test, interpolation=3))
+ res.append(ToTensor())
+ res.append(T.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]))
+ return T.Compose(res)
diff --git a/data/transforms/functional.py b/data/transforms/functional.py
new file mode 100644
index 0000000..359bea8
--- /dev/null
+++ b/data/transforms/functional.py
@@ -0,0 +1,190 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import numpy as np
+import torch
+from PIL import Image, ImageOps, ImageEnhance
+
+
+def to_tensor(pic):
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
+
+ See ``ToTensor`` for more details.
+
+ Args:
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
+
+ Returns:
+ Tensor: Converted image.
+ """
+ if isinstance(pic, np.ndarray):
+ assert len(pic.shape) in (2, 3)
+ # handle numpy array
+ if pic.ndim == 2:
+ pic = pic[:, :, None]
+
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
+ # backward compatibility
+ if isinstance(img, torch.ByteTensor):
+ return img.float()
+ else:
+ return img
+
+ # handle PIL Image
+ if pic.mode == 'I':
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
+ elif pic.mode == 'I;16':
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
+ elif pic.mode == 'F':
+ img = torch.from_numpy(np.array(pic, np.float32, copy=False))
+ elif pic.mode == '1':
+ img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
+ else:
+ img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
+ # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK
+ if pic.mode == 'YCbCr':
+ nchannel = 3
+ elif pic.mode == 'I;16':
+ nchannel = 1
+ else:
+ nchannel = len(pic.mode)
+ img = img.view(pic.size[1], pic.size[0], nchannel)
+ # put it from HWC to CHW format
+ # yikes, this transpose takes 80% of the loading time/CPU
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
+ if isinstance(img, torch.ByteTensor):
+ return img.float()
+ else:
+ return img
+
+
+def int_parameter(level, maxval):
+ """Helper function to scale `val` between 0 and maxval .
+ Args:
+ level: Level of the operation that will be between [0, `PARAMETER_MAX`].
+ maxval: Maximum value that the operation can have. This will be scaled to
+ level/PARAMETER_MAX.
+ Returns:
+ An int that results from scaling `maxval` according to `level`.
+ """
+ return int(level * maxval / 10)
+
+
+def float_parameter(level, maxval):
+ """Helper function to scale `val` between 0 and maxval.
+ Args:
+ level: Level of the operation that will be between [0, `PARAMETER_MAX`].
+ maxval: Maximum value that the operation can have. This will be scaled to
+ level/PARAMETER_MAX.
+ Returns:
+ A float that results from scaling `maxval` according to `level`.
+ """
+ return float(level) * maxval / 10.
+
+
+def sample_level(n):
+ return np.random.uniform(low=0.1, high=n)
+
+
+def autocontrast(pil_img, *args):
+ return ImageOps.autocontrast(pil_img)
+
+
+def equalize(pil_img, *args):
+ return ImageOps.equalize(pil_img)
+
+
+def posterize(pil_img, level, *args):
+ level = int_parameter(sample_level(level), 4)
+ return ImageOps.posterize(pil_img, 4 - level)
+
+
+def rotate(pil_img, level, *args):
+ degrees = int_parameter(sample_level(level), 30)
+ if np.random.uniform() > 0.5:
+ degrees = -degrees
+ return pil_img.rotate(degrees, resample=Image.BILINEAR)
+
+
+def solarize(pil_img, level, *args):
+ level = int_parameter(sample_level(level), 256)
+ return ImageOps.solarize(pil_img, 256 - level)
+
+
+def shear_x(pil_img, level, image_size):
+ level = float_parameter(sample_level(level), 0.3)
+ if np.random.uniform() > 0.5:
+ level = -level
+ return pil_img.transform(image_size,
+ Image.AFFINE, (1, level, 0, 0, 1, 0),
+ resample=Image.BILINEAR)
+
+
+def shear_y(pil_img, level, image_size):
+ level = float_parameter(sample_level(level), 0.3)
+ if np.random.uniform() > 0.5:
+ level = -level
+ return pil_img.transform(image_size,
+ Image.AFFINE, (1, 0, 0, level, 1, 0),
+ resample=Image.BILINEAR)
+
+
+def translate_x(pil_img, level, image_size):
+ level = int_parameter(sample_level(level), image_size[0] / 3)
+ if np.random.random() > 0.5:
+ level = -level
+ return pil_img.transform(image_size,
+ Image.AFFINE, (1, 0, level, 0, 1, 0),
+ resample=Image.BILINEAR)
+
+
+def translate_y(pil_img, level, image_size):
+ level = int_parameter(sample_level(level), image_size[1] / 3)
+ if np.random.random() > 0.5:
+ level = -level
+ return pil_img.transform(image_size,
+ Image.AFFINE, (1, 0, 0, 0, 1, level),
+ resample=Image.BILINEAR)
+
+
+# operation that overlaps with ImageNet-C's test set
+def color(pil_img, level, *args):
+ level = float_parameter(sample_level(level), 1.8) + 0.1
+ return ImageEnhance.Color(pil_img).enhance(level)
+
+
+# operation that overlaps with ImageNet-C's test set
+def contrast(pil_img, level, *args):
+ level = float_parameter(sample_level(level), 1.8) + 0.1
+ return ImageEnhance.Contrast(pil_img).enhance(level)
+
+
+# operation that overlaps with ImageNet-C's test set
+def brightness(pil_img, level, *args):
+ level = float_parameter(sample_level(level), 1.8) + 0.1
+ return ImageEnhance.Brightness(pil_img).enhance(level)
+
+
+# operation that overlaps with ImageNet-C's test set
+def sharpness(pil_img, level, *args):
+ level = float_parameter(sample_level(level), 1.8) + 0.1
+ return ImageEnhance.Sharpness(pil_img).enhance(level)
+
+
+augmentations_reid = [
+ autocontrast, equalize, posterize, shear_x, shear_y,
+ color, contrast, brightness, sharpness
+]
+
+augmentations = [
+ autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
+ translate_x, translate_y
+]
+
+augmentations_all = [
+ autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
+ translate_x, translate_y, color, contrast, brightness, sharpness
+]
diff --git a/data/transforms/transforms.py b/data/transforms/transforms.py
new file mode 100644
index 0000000..2e2def2
--- /dev/null
+++ b/data/transforms/transforms.py
@@ -0,0 +1,204 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+__all__ = ['ToTensor', 'RandomErasing', 'RandomPatch', 'AugMix',]
+
+import math
+import random
+from collections import deque
+
+import numpy as np
+from PIL import Image
+
+from .functional import to_tensor, augmentations_reid
+
+
+class ToTensor(object):
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
+
+ Converts a PIL Image or numpy.ndarray (H x W x C) in the range
+ [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 255.0]
+ if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
+ or if the numpy.ndarray has dtype = np.uint8
+
+ In the other cases, tensors are returned without scaling.
+ """
+
+ def __call__(self, pic):
+ """
+ Args:
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
+
+ Returns:
+ Tensor: Converted image.
+ """
+ return to_tensor(pic)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '()'
+
+
+class RandomErasing(object):
+ """ Randomly selects a rectangle region in an image and erases its pixels.
+ 'Random Erasing Data Augmentation' by Zhong et al.
+ See https://arxiv.org/pdf/1708.04896.pdf
+ Args:
+ probability: The probability that the Random Erasing operation will be performed.
+ sl: Minimum proportion of erased area against input image.
+ sh: Maximum proportion of erased area against input image.
+ r1: Minimum aspect ratio of erased area.
+ mean: Erasing value.
+ """
+
+ def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=255 * (0.49735, 0.4822, 0.4465)):
+ self.probability = probability
+ self.mean = mean
+ self.sl = sl
+ self.sh = sh
+ self.r1 = r1
+
+ def __call__(self, img):
+ img = np.asarray(img, dtype=np.float32).copy()
+ if random.uniform(0, 1) > self.probability:
+ return img
+
+ for attempt in range(100):
+ area = img.shape[0] * img.shape[1]
+ target_area = random.uniform(self.sl, self.sh) * area
+ aspect_ratio = random.uniform(self.r1, 1 / self.r1)
+
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if w < img.shape[1] and h < img.shape[0]:
+ x1 = random.randint(0, img.shape[0] - h)
+ y1 = random.randint(0, img.shape[1] - w)
+ if img.shape[2] == 3:
+ img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
+ img[x1:x1 + h, y1:y1 + w, 1] = self.mean[1]
+ img[x1:x1 + h, y1:y1 + w, 2] = self.mean[2]
+ else:
+ img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
+ return img
+ return img
+
+
+class RandomPatch(object):
+ """Random patch data augmentation.
+ There is a patch pool that stores randomly extracted pathces from person images.
+ For each input image, RandomPatch
+ 1) extracts a random patch and stores the patch in the patch pool;
+ 2) randomly selects a patch from the patch pool and pastes it on the
+ input (at random position) to simulate occlusion.
+ Reference:
+ - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
+ - Zhou et al. Learning Generalisable Omni-Scale Representations
+ for Person Re-Identification. arXiv preprint, 2019.
+ """
+
+ def __init__(self, prob_happen=0.5, pool_capacity=50000, min_sample_size=100,
+ patch_min_area=0.01, patch_max_area=0.5, patch_min_ratio=0.1,
+ prob_rotate=0.5, prob_flip_leftright=0.5,
+ ):
+ self.prob_happen = prob_happen
+
+ self.patch_min_area = patch_min_area
+ self.patch_max_area = patch_max_area
+ self.patch_min_ratio = patch_min_ratio
+
+ self.prob_rotate = prob_rotate
+ self.prob_flip_leftright = prob_flip_leftright
+
+ self.patchpool = deque(maxlen=pool_capacity)
+ self.min_sample_size = min_sample_size
+
+ def generate_wh(self, W, H):
+ area = W * H
+ for attempt in range(100):
+ target_area = random.uniform(self.patch_min_area, self.patch_max_area) * area
+ aspect_ratio = random.uniform(self.patch_min_ratio, 1. / self.patch_min_ratio)
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
+ if w < W and h < H:
+ return w, h
+ return None, None
+
+ def transform_patch(self, patch):
+ if random.uniform(0, 1) > self.prob_flip_leftright:
+ patch = patch.transpose(Image.FLIP_LEFT_RIGHT)
+ if random.uniform(0, 1) > self.prob_rotate:
+ patch = patch.rotate(random.randint(-10, 10))
+ return patch
+
+ def __call__(self, img):
+ if isinstance(img, np.ndarray):
+ img = Image.fromarray(img.astype(np.uint8))
+
+ W, H = img.size # original image size
+
+ # collect new patch
+ w, h = self.generate_wh(W, H)
+ if w is not None and h is not None:
+ x1 = random.randint(0, W - w)
+ y1 = random.randint(0, H - h)
+ new_patch = img.crop((x1, y1, x1 + w, y1 + h))
+ self.patchpool.append(new_patch)
+
+ if len(self.patchpool) < self.min_sample_size:
+ return img
+
+ if random.uniform(0, 1) > self.prob_happen:
+ return img
+
+ # paste a randomly selected patch on a random position
+ patch = random.sample(self.patchpool, 1)[0]
+ patchW, patchH = patch.size
+ x1 = random.randint(0, W - patchW)
+ y1 = random.randint(0, H - patchH)
+ patch = self.transform_patch(patch)
+ img.paste(patch, (x1, y1))
+
+ return img
+
+
+class AugMix(object):
+ """ Perform AugMix augmentation and compute mixture.
+ Args:
+ aug_prob_coeff: Probability distribution coefficients.
+ mixture_width: Number of augmentation chains to mix per augmented example.
+ mixture_depth: Depth of augmentation chains. -1 denotes stochastic depth in [1, 3]'
+ severity: Severity of underlying augmentation operators (between 1 to 10).
+ """
+
+ def __init__(self, aug_prob_coeff=1, mixture_width=3, mixture_depth=-1, severity=1):
+ self.aug_prob_coeff = aug_prob_coeff
+ self.mixture_width = mixture_width
+ self.mixture_depth = mixture_depth
+ self.severity = severity
+ self.aug_list = augmentations_reid
+
+ def __call__(self, image):
+ """Perform AugMix augmentations and compute mixture.
+ Returns:
+ mixed: Augmented and mixed image.
+ """
+ ws = np.float32(
+ np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width))
+ m = np.float32(np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff))
+
+ image = np.asarray(image, dtype=np.float32).copy()
+ mix = np.zeros_like(image)
+ h, w = image.shape[0], image.shape[1]
+ for i in range(self.mixture_width):
+ image_aug = Image.fromarray(image.copy().astype(np.uint8))
+ depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4)
+ for _ in range(depth):
+ op = np.random.choice(self.aug_list)
+ image_aug = op(image_aug, self.severity, (w, h))
+ mix += ws[i] * np.asarray(image_aug, dtype=np.float32)
+
+ mixed = (1 - m) * image + m * mix
+ return mixed
diff --git a/engine/__init__.py b/engine/__init__.py
new file mode 100644
index 0000000..77c1ec2
--- /dev/null
+++ b/engine/__init__.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2020/10/26 13:32
+# @Author : Scheaven
+# @File : __init__.py.py
+# @description:
\ No newline at end of file
diff --git a/engine/__pycache__/__init__.cpython-38.pyc b/engine/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..58e7ffd
--- /dev/null
+++ b/engine/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/engine/__pycache__/defaults.cpython-38.pyc b/engine/__pycache__/defaults.cpython-38.pyc
new file mode 100644
index 0000000..f07dd58
--- /dev/null
+++ b/engine/__pycache__/defaults.cpython-38.pyc
Binary files differ
diff --git a/engine/defaults.py b/engine/defaults.py
new file mode 100644
index 0000000..9c1261e
--- /dev/null
+++ b/engine/defaults.py
@@ -0,0 +1,115 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2020/10/26 13:38
+# @Author : Scheaven
+# @File : defaults.py
+# @description:
+
+"""
+ This file contains components with some default boilerplate logic user may need
+ in training / testing. They will not work for everyone, but many users may find them useful.
+ The behavior of functions/classes in this file is subject to change,
+ since they are meant to represent the "common default behavior" people need in their projects.
+"""
+import argparse
+import logging
+import os
+import sys
+from utils import comm
+from utils.env import seed_all_rng
+from utils.file_io import PathManager
+from utils.logger import setup_logger
+from utils.collect_env import collect_env_info
+from collections import OrderedDict
+
+import torch
+# import torch.nn.functional as F
+# from torch.nn.parallel import DistributedDataParallel
+
+def default_argument_parser():
+ """
+ Create a parser with some common arguments used by fastreid users.
+ Returns:
+ argparse.ArgumentParser:
+ """
+ parser = argparse.ArgumentParser(description="fastreid Training")
+ parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
+ parser.add_argument(
+ "--finetune",
+ action="store_true",
+ help="whether to attempt to finetune from the trained model",
+ )
+ parser.add_argument(
+ "--resume",
+ action="store_true",
+ help="whether to attempt to resume from the checkpoint directory",
+ )
+ parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
+ parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
+ parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
+ parser.add_argument("--img_a1", default="1.jpg", help="input image")
+ parser.add_argument("--img_a2", default="2.jpg", help="input image2")
+ parser.add_argument("--img_b1", default="1.jpg", help="input image")
+ parser.add_argument("--img_b2", default="2.jpg", help="input image2")
+ parser.add_argument(
+ "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
+ )
+
+ # PyTorch still may leave orphan processes in multi-gpu training.
+ # Therefore we use a deterministic way to obtain port,
+ # so that users are aware of orphan processes by seeing the port occupied.
+ port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
+ parser.add_argument("--dist-url", default="tcp://127.0.0.1:{}".format(port))
+ parser.add_argument(
+ "opts",
+ help="Modify config options using the command-line",
+ default=None,
+ nargs=argparse.REMAINDER,
+ )
+ return parser
+
+def default_setup(cfg, args):
+ """
+ Perform some basic common setups at the beginning of a job, including:
+ 1. Set up the detectron2 logger
+ 2. Log basic information about environment, cmdline arguments, and config
+ 3. Backup the config to the output directory
+ Args:
+ cfg (CfgNode): the full config to be used
+ args (argparse.NameSpace): the command line arguments to be logged
+ """
+ output_dir = cfg.OUTPUT_DIR
+ if comm.is_main_process() and output_dir:
+ PathManager.mkdirs(output_dir)
+
+ rank = comm.get_rank()
+ setup_logger(output_dir, distributed_rank=rank, name="fvcore")
+ logger = setup_logger(output_dir, distributed_rank=rank)
+
+ logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
+ logger.info("Environment info:\n" + collect_env_info())
+
+ logger.info("Command line arguments: " + str(args))
+ if hasattr(args, "config_file") and args.config_file != "":
+ logger.info(
+ "Contents of args.config_file={}:\n{}".format(
+ args.config_file, PathManager.open(args.config_file, "r").read()
+ )
+ )
+
+ logger.info("Running with full config:\n{}".format(cfg))
+ if comm.is_main_process() and output_dir:
+ # Note: some of our scripts may expect the existence of
+ # config.yaml in output directory
+ path = os.path.join(output_dir, "config.yaml")
+ with PathManager.open(path, "w") as f:
+ f.write(cfg.dump())
+ logger.info("Full config saved to {}".format(os.path.abspath(path)))
+
+ # make sure each worker has a different, yet deterministic seed if specified
+ seed_all_rng()
+
+ # cudnn benchmark has large overhead. It shouldn't be used considering the small size of
+ # typical validation set.
+ if not (hasattr(args, "eval_only") and args.eval_only):
+ torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
diff --git a/layers/__init__.py b/layers/__init__.py
new file mode 100644
index 0000000..2e67b13
--- /dev/null
+++ b/layers/__init__.py
@@ -0,0 +1,17 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from .activation import *
+from .arcface import Arcface
+from .batch_drop import BatchDrop
+from .batch_norm import *
+from .circle import Circle
+from .context_block import ContextBlock
+from .frn import FRN, TLU
+from .non_local import Non_local
+from .pooling import *
+from .se_layer import SELayer
+from .splat import SplAtConv2d
diff --git a/layers/__pycache__/__init__.cpython-37.pyc b/layers/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..502959c
--- /dev/null
+++ b/layers/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/__init__.cpython-38.pyc b/layers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..5537f1f
--- /dev/null
+++ b/layers/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/activation.cpython-37.pyc b/layers/__pycache__/activation.cpython-37.pyc
new file mode 100644
index 0000000..8490b01
--- /dev/null
+++ b/layers/__pycache__/activation.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/activation.cpython-38.pyc b/layers/__pycache__/activation.cpython-38.pyc
new file mode 100644
index 0000000..9d59e01
--- /dev/null
+++ b/layers/__pycache__/activation.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/arcface.cpython-37.pyc b/layers/__pycache__/arcface.cpython-37.pyc
new file mode 100644
index 0000000..a01ae38
--- /dev/null
+++ b/layers/__pycache__/arcface.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/arcface.cpython-38.pyc b/layers/__pycache__/arcface.cpython-38.pyc
new file mode 100644
index 0000000..f62c979
--- /dev/null
+++ b/layers/__pycache__/arcface.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/batch_drop.cpython-37.pyc b/layers/__pycache__/batch_drop.cpython-37.pyc
new file mode 100644
index 0000000..c8c49cf
--- /dev/null
+++ b/layers/__pycache__/batch_drop.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/batch_drop.cpython-38.pyc b/layers/__pycache__/batch_drop.cpython-38.pyc
new file mode 100644
index 0000000..e0b0d5e
--- /dev/null
+++ b/layers/__pycache__/batch_drop.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/batch_norm.cpython-37.pyc b/layers/__pycache__/batch_norm.cpython-37.pyc
new file mode 100644
index 0000000..944c044
--- /dev/null
+++ b/layers/__pycache__/batch_norm.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/batch_norm.cpython-38.pyc b/layers/__pycache__/batch_norm.cpython-38.pyc
new file mode 100644
index 0000000..4b45f1d
--- /dev/null
+++ b/layers/__pycache__/batch_norm.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/circle.cpython-37.pyc b/layers/__pycache__/circle.cpython-37.pyc
new file mode 100644
index 0000000..a4aa86b
--- /dev/null
+++ b/layers/__pycache__/circle.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/circle.cpython-38.pyc b/layers/__pycache__/circle.cpython-38.pyc
new file mode 100644
index 0000000..1b31d9e
--- /dev/null
+++ b/layers/__pycache__/circle.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/context_block.cpython-37.pyc b/layers/__pycache__/context_block.cpython-37.pyc
new file mode 100644
index 0000000..9782731
--- /dev/null
+++ b/layers/__pycache__/context_block.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/context_block.cpython-38.pyc b/layers/__pycache__/context_block.cpython-38.pyc
new file mode 100644
index 0000000..23f3cfa
--- /dev/null
+++ b/layers/__pycache__/context_block.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/frn.cpython-37.pyc b/layers/__pycache__/frn.cpython-37.pyc
new file mode 100644
index 0000000..677945a
--- /dev/null
+++ b/layers/__pycache__/frn.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/frn.cpython-38.pyc b/layers/__pycache__/frn.cpython-38.pyc
new file mode 100644
index 0000000..26d8109
--- /dev/null
+++ b/layers/__pycache__/frn.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/non_local.cpython-37.pyc b/layers/__pycache__/non_local.cpython-37.pyc
new file mode 100644
index 0000000..84b216d
--- /dev/null
+++ b/layers/__pycache__/non_local.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/non_local.cpython-38.pyc b/layers/__pycache__/non_local.cpython-38.pyc
new file mode 100644
index 0000000..70f1863
--- /dev/null
+++ b/layers/__pycache__/non_local.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/pooling.cpython-37.pyc b/layers/__pycache__/pooling.cpython-37.pyc
new file mode 100644
index 0000000..72e8113
--- /dev/null
+++ b/layers/__pycache__/pooling.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/pooling.cpython-38.pyc b/layers/__pycache__/pooling.cpython-38.pyc
new file mode 100644
index 0000000..db3c40b
--- /dev/null
+++ b/layers/__pycache__/pooling.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/se_layer.cpython-37.pyc b/layers/__pycache__/se_layer.cpython-37.pyc
new file mode 100644
index 0000000..a33eee6
--- /dev/null
+++ b/layers/__pycache__/se_layer.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/se_layer.cpython-38.pyc b/layers/__pycache__/se_layer.cpython-38.pyc
new file mode 100644
index 0000000..15871c0
--- /dev/null
+++ b/layers/__pycache__/se_layer.cpython-38.pyc
Binary files differ
diff --git a/layers/__pycache__/splat.cpython-37.pyc b/layers/__pycache__/splat.cpython-37.pyc
new file mode 100644
index 0000000..bbcd0ec
--- /dev/null
+++ b/layers/__pycache__/splat.cpython-37.pyc
Binary files differ
diff --git a/layers/__pycache__/splat.cpython-38.pyc b/layers/__pycache__/splat.cpython-38.pyc
new file mode 100644
index 0000000..0b7289f
--- /dev/null
+++ b/layers/__pycache__/splat.cpython-38.pyc
Binary files differ
diff --git a/layers/activation.py b/layers/activation.py
new file mode 100644
index 0000000..3e0aea6
--- /dev/null
+++ b/layers/activation.py
@@ -0,0 +1,59 @@
+# encoding: utf-8
+"""
+@author: xingyu liao
+@contact: liaoxingyu5@jd.com
+"""
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = [
+ 'Mish',
+ 'Swish',
+ 'MemoryEfficientSwish',
+ 'GELU']
+
+
+class Mish(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ # inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)
+ return x * (torch.tanh(F.softplus(x)))
+
+
+class Swish(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class SwishImplementation(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, i):
+ result = i * torch.sigmoid(i)
+ ctx.save_for_backward(i)
+ return result
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ i = ctx.saved_variables[0]
+ sigmoid_i = torch.sigmoid(i)
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
+
+
+class MemoryEfficientSwish(nn.Module):
+ def forward(self, x):
+ return SwishImplementation.apply(x)
+
+
+class GELU(nn.Module):
+ """
+ Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
+ """
+
+ def forward(self, x):
+ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
diff --git a/layers/arcface.py b/layers/arcface.py
new file mode 100644
index 0000000..be7f9f6
--- /dev/null
+++ b/layers/arcface.py
@@ -0,0 +1,54 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Parameter
+
+
+class Arcface(nn.Module):
+ def __init__(self, cfg, in_feat, num_classes):
+ super().__init__()
+ self.in_feat = in_feat
+ self._num_classes = num_classes
+ self._s = cfg.MODEL.HEADS.SCALE
+ self._m = cfg.MODEL.HEADS.MARGIN
+
+ self.cos_m = math.cos(self._m)
+ self.sin_m = math.sin(self._m)
+ self.threshold = math.cos(math.pi - self._m)
+ self.mm = math.sin(math.pi - self._m) * self._m
+
+ self.weight = Parameter(torch.Tensor(num_classes, in_feat))
+ self.register_buffer('t', torch.zeros(1))
+
+ def forward(self, features, targets):
+ # get cos(theta)
+ cos_theta = F.linear(F.normalize(features), F.normalize(self.weight))
+ cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
+
+ target_logit = cos_theta[torch.arange(0, features.size(0)), targets].view(-1, 1)
+
+ sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
+ cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin)
+ mask = cos_theta > cos_theta_m
+ final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm)
+
+ hard_example = cos_theta[mask]
+ with torch.no_grad():
+ self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
+ cos_theta[mask] = hard_example * (self.t + hard_example)
+ cos_theta.scatter_(1, targets.view(-1, 1).long(), final_target_logit)
+ pred_class_logits = cos_theta * self._s
+ return pred_class_logits
+
+ def extra_repr(self):
+ return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
+ self.in_feat, self._num_classes, self._s, self._m
+ )
diff --git a/layers/batch_drop.py b/layers/batch_drop.py
new file mode 100644
index 0000000..5c25697
--- /dev/null
+++ b/layers/batch_drop.py
@@ -0,0 +1,32 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import random
+
+from torch import nn
+
+
+class BatchDrop(nn.Module):
+ """ref: https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py
+ batch drop mask
+ """
+
+ def __init__(self, h_ratio, w_ratio):
+ super(BatchDrop, self).__init__()
+ self.h_ratio = h_ratio
+ self.w_ratio = w_ratio
+
+ def forward(self, x):
+ if self.training:
+ h, w = x.size()[-2:]
+ rh = round(self.h_ratio * h)
+ rw = round(self.w_ratio * w)
+ sx = random.randint(0, h - rh)
+ sy = random.randint(0, w - rw)
+ mask = x.new_ones(x.size())
+ mask[:, :, sx:sx + rh, sy:sy + rw] = 0
+ x = x * mask
+ return x
diff --git a/layers/batch_norm.py b/layers/batch_norm.py
new file mode 100644
index 0000000..e0e88e3
--- /dev/null
+++ b/layers/batch_norm.py
@@ -0,0 +1,208 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import logging
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+__all__ = [
+ "BatchNorm",
+ "IBN",
+ "GhostBatchNorm",
+ "FrozenBatchNorm",
+ "SyncBatchNorm",
+ "get_norm",
+]
+
+
+class BatchNorm(nn.BatchNorm2d):
+ def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
+ bias_init=0.0, **kwargs):
+ super().__init__(num_features, eps=eps, momentum=momentum)
+ if weight_init is not None: nn.init.constant_(self.weight, weight_init)
+ if bias_init is not None: nn.init.constant_(self.bias, bias_init)
+ self.weight.requires_grad_(not weight_freeze)
+ self.bias.requires_grad_(not bias_freeze)
+
+
+class SyncBatchNorm(nn.SyncBatchNorm):
+ def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
+ bias_init=0.0):
+ super().__init__(num_features, eps=eps, momentum=momentum)
+ if weight_init is not None: nn.init.constant_(self.weight, weight_init)
+ if bias_init is not None: nn.init.constant_(self.bias, bias_init)
+ self.weight.requires_grad_(not weight_freeze)
+ self.bias.requires_grad_(not bias_freeze)
+
+
+class IBN(nn.Module):
+ def __init__(self, planes, bn_norm, **kwargs):
+ super(IBN, self).__init__()
+ half1 = int(planes / 2)
+ self.half = half1
+ half2 = planes - half1
+ self.IN = nn.InstanceNorm2d(half1, affine=True)
+ self.BN = get_norm(bn_norm, half2, **kwargs)
+
+ def forward(self, x):
+ split = torch.split(x, self.half, 1)
+ out1 = self.IN(split[0].contiguous())
+ out2 = self.BN(split[1].contiguous())
+ out = torch.cat((out1, out2), 1)
+ return out
+
+
+class GhostBatchNorm(BatchNorm):
+ def __init__(self, num_features, num_splits=1, **kwargs):
+ super().__init__(num_features, **kwargs)
+ self.num_splits = num_splits
+ self.register_buffer('running_mean', torch.zeros(num_features))
+ self.register_buffer('running_var', torch.ones(num_features))
+
+ def forward(self, input):
+ N, C, H, W = input.shape
+ if self.training or not self.track_running_stats:
+ self.running_mean = self.running_mean.repeat(self.num_splits)
+ self.running_var = self.running_var.repeat(self.num_splits)
+ outputs = F.batch_norm(
+ input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var,
+ self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
+ True, self.momentum, self.eps).view(N, C, H, W)
+ self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0)
+ self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0)
+ return outputs
+ else:
+ return F.batch_norm(
+ input, self.running_mean, self.running_var,
+ self.weight, self.bias, False, self.momentum, self.eps)
+
+
+class FrozenBatchNorm(BatchNorm):
+ """
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
+ It contains non-trainable buffers called
+ "weight" and "bias", "running_mean", "running_var",
+ initialized to perform identity transformation.
+ The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
+ which are computed from the original four parameters of BN.
+ The affine transform `x * weight + bias` will perform the equivalent
+ computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
+ When loading a backbone model from Caffe2, "running_mean" and "running_var"
+ will be left unchanged as identity transformation.
+ Other pre-trained backbone models may contain all 4 parameters.
+ The forward is implemented by `F.batch_norm(..., training=False)`.
+ """
+
+ _version = 3
+
+ def __init__(self, num_features, eps=1e-5, **kwargs):
+ super().__init__(num_features, weight_freeze=True, bias_freeze=True, **kwargs)
+ self.num_features = num_features
+ self.eps = eps
+
+ def forward(self, x):
+ if x.requires_grad:
+ # When gradients are needed, F.batch_norm will use extra memory
+ # because its backward op computes gradients for weight/bias as well.
+ scale = self.weight * (self.running_var + self.eps).rsqrt()
+ bias = self.bias - self.running_mean * scale
+ scale = scale.reshape(1, -1, 1, 1)
+ bias = bias.reshape(1, -1, 1, 1)
+ return x * scale + bias
+ else:
+ # When gradients are not needed, F.batch_norm is a single fused op
+ # and provide more optimization opportunities.
+ return F.batch_norm(
+ x,
+ self.running_mean,
+ self.running_var,
+ self.weight,
+ self.bias,
+ training=False,
+ eps=self.eps,
+ )
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ version = local_metadata.get("version", None)
+
+ if version is None or version < 2:
+ # No running_mean/var in early versions
+ # This will silent the warnings
+ if prefix + "running_mean" not in state_dict:
+ state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
+ if prefix + "running_var" not in state_dict:
+ state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
+
+ if version is not None and version < 3:
+ logger = logging.getLogger(__name__)
+ logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip(".")))
+ # In version < 3, running_var are used without +eps.
+ state_dict[prefix + "running_var"] -= self.eps
+
+ super()._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ )
+
+ def __repr__(self):
+ return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
+
+ @classmethod
+ def convert_frozen_batchnorm(cls, module):
+ """
+ Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
+ Args:
+ module (torch.nn.Module):
+ Returns:
+ If module is BatchNorm/SyncBatchNorm, returns a new module.
+ Otherwise, in-place convert module and return it.
+ Similar to convert_sync_batchnorm in
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
+ """
+ bn_module = nn.modules.batchnorm
+ bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
+ res = module
+ if isinstance(module, bn_module):
+ res = cls(module.num_features)
+ if module.affine:
+ res.weight.data = module.weight.data.clone().detach()
+ res.bias.data = module.bias.data.clone().detach()
+ res.running_mean.data = module.running_mean.data
+ res.running_var.data = module.running_var.data
+ res.eps = module.eps
+ else:
+ for name, child in module.named_children():
+ new_child = cls.convert_frozen_batchnorm(child)
+ if new_child is not child:
+ res.add_module(name, new_child)
+ return res
+
+
+def get_norm(norm, out_channels, **kwargs):
+ """
+ Args:
+ norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN;
+ or a callable that thakes a channel number and returns
+ the normalization layer as a nn.Module
+ out_channels: number of channels for normalization layer
+
+ Returns:
+ nn.Module or None: the normalization layer
+ """
+ if isinstance(norm, str):
+ if len(norm) == 0:
+ return None
+ norm = {
+ "BN": BatchNorm,
+ "GhostBN": GhostBatchNorm,
+ "FrozenBN": FrozenBatchNorm,
+ "GN": lambda channels, **args: nn.GroupNorm(32, channels),
+ "syncBN": SyncBatchNorm,
+ }[norm]
+ return norm(out_channels, **kwargs)
diff --git a/layers/circle.py b/layers/circle.py
new file mode 100644
index 0000000..6182b60
--- /dev/null
+++ b/layers/circle.py
@@ -0,0 +1,42 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Parameter
+
+
+class Circle(nn.Module):
+ def __init__(self, cfg, in_feat, num_classes):
+ super().__init__()
+ self.in_feat = in_feat
+ self._num_classes = num_classes
+ self._s = cfg.MODEL.HEADS.SCALE
+ self._m = cfg.MODEL.HEADS.MARGIN
+
+ self.weight = Parameter(torch.Tensor(num_classes, in_feat))
+
+ def forward(self, features, targets):
+ sim_mat = F.linear(F.normalize(features), F.normalize(self.weight))
+ alpha_p = F.relu(-sim_mat.detach() + 1 + self._m)
+ alpha_n = F.relu(sim_mat.detach() + self._m)
+ delta_p = 1 - self._m
+ delta_n = self._m
+
+ s_p = self._s * alpha_p * (sim_mat - delta_p)
+ s_n = self._s * alpha_n * (sim_mat - delta_n)
+
+ targets = F.one_hot(targets, num_classes=self._num_classes)
+
+ pred_class_logits = targets * s_p + (1.0 - targets) * s_n
+
+ return pred_class_logits
+
+ def extra_repr(self):
+ return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
+ self.in_feat, self._num_classes, self._s, self._m
+ )
diff --git a/layers/context_block.py b/layers/context_block.py
new file mode 100644
index 0000000..7b1098a
--- /dev/null
+++ b/layers/context_block.py
@@ -0,0 +1,113 @@
+# copy from https://github.com/xvjiarui/GCNet/blob/master/mmdet/ops/gcb/context_block.py
+
+import torch
+from torch import nn
+
+__all__ = ['ContextBlock']
+
+
+def last_zero_init(m):
+ if isinstance(m, nn.Sequential):
+ nn.init.constant_(m[-1].weight, val=0)
+ if hasattr(m[-1], 'bias') and m[-1].bias is not None:
+ nn.init.constant_(m[-1].bias, 0)
+ else:
+ nn.init.constant_(m.weight, val=0)
+ if hasattr(m, 'bias') and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+
+class ContextBlock(nn.Module):
+
+ def __init__(self,
+ inplanes,
+ ratio,
+ pooling_type='att',
+ fusion_types=('channel_add',)):
+ super(ContextBlock, self).__init__()
+ assert pooling_type in ['avg', 'att']
+ assert isinstance(fusion_types, (list, tuple))
+ valid_fusion_types = ['channel_add', 'channel_mul']
+ assert all([f in valid_fusion_types for f in fusion_types])
+ assert len(fusion_types) > 0, 'at least one fusion should be used'
+ self.inplanes = inplanes
+ self.ratio = ratio
+ self.planes = int(inplanes * ratio)
+ self.pooling_type = pooling_type
+ self.fusion_types = fusion_types
+ if pooling_type == 'att':
+ self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
+ self.softmax = nn.Softmax(dim=2)
+ else:
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ if 'channel_add' in fusion_types:
+ self.channel_add_conv = nn.Sequential(
+ nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
+ nn.LayerNorm([self.planes, 1, 1]),
+ nn.ReLU(inplace=True), # yapf: disable
+ nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
+ else:
+ self.channel_add_conv = None
+ if 'channel_mul' in fusion_types:
+ self.channel_mul_conv = nn.Sequential(
+ nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
+ nn.LayerNorm([self.planes, 1, 1]),
+ nn.ReLU(inplace=True), # yapf: disable
+ nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
+ else:
+ self.channel_mul_conv = None
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.pooling_type == 'att':
+ nn.init.kaiming_normal_(self.conv_mask.weight, a=0, mode='fan_in', nonlinearity='relu')
+ if hasattr(self.conv_mask, 'bias') and self.conv_mask.bias is not None:
+ nn.init.constant_(self.conv_mask.bias, 0)
+ self.conv_mask.inited = True
+
+ if self.channel_add_conv is not None:
+ last_zero_init(self.channel_add_conv)
+ if self.channel_mul_conv is not None:
+ last_zero_init(self.channel_mul_conv)
+
+ def spatial_pool(self, x):
+ batch, channel, height, width = x.size()
+ if self.pooling_type == 'att':
+ input_x = x
+ # [N, C, H * W]
+ input_x = input_x.view(batch, channel, height * width)
+ # [N, 1, C, H * W]
+ input_x = input_x.unsqueeze(1)
+ # [N, 1, H, W]
+ context_mask = self.conv_mask(x)
+ # [N, 1, H * W]
+ context_mask = context_mask.view(batch, 1, height * width)
+ # [N, 1, H * W]
+ context_mask = self.softmax(context_mask)
+ # [N, 1, H * W, 1]
+ context_mask = context_mask.unsqueeze(-1)
+ # [N, 1, C, 1]
+ context = torch.matmul(input_x, context_mask)
+ # [N, C, 1, 1]
+ context = context.view(batch, channel, 1, 1)
+ else:
+ # [N, C, 1, 1]
+ context = self.avg_pool(x)
+
+ return context
+
+ def forward(self, x):
+ # [N, C, 1, 1]
+ context = self.spatial_pool(x)
+
+ out = x
+ if self.channel_mul_conv is not None:
+ # [N, C, 1, 1]
+ channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
+ out = out * channel_mul_term
+ if self.channel_add_conv is not None:
+ # [N, C, 1, 1]
+ channel_add_term = self.channel_add_conv(context)
+ out = out + channel_add_term
+
+ return out
diff --git a/layers/frn.py b/layers/frn.py
new file mode 100644
index 0000000..f00a1e4
--- /dev/null
+++ b/layers/frn.py
@@ -0,0 +1,199 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+from torch import nn
+from torch.nn.modules.batchnorm import BatchNorm2d
+from torch.nn import ReLU, LeakyReLU
+from torch.nn.parameter import Parameter
+
+
+class TLU(nn.Module):
+ def __init__(self, num_features):
+ """max(y, tau) = max(y - tau, 0) + tau = ReLU(y - tau) + tau"""
+ super(TLU, self).__init__()
+ self.num_features = num_features
+ self.tau = Parameter(torch.Tensor(num_features))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.zeros_(self.tau)
+
+ def extra_repr(self):
+ return 'num_features={num_features}'.format(**self.__dict__)
+
+ def forward(self, x):
+ return torch.max(x, self.tau.view(1, self.num_features, 1, 1))
+
+
+class FRN(nn.Module):
+ def __init__(self, num_features, eps=1e-6, is_eps_leanable=False):
+ """
+ weight = gamma, bias = beta
+ beta, gamma:
+ Variables of shape [1, 1, 1, C]. if TensorFlow
+ Variables of shape [1, C, 1, 1]. if PyTorch
+ eps: A scalar constant or learnable variable.
+ """
+ super(FRN, self).__init__()
+
+ self.num_features = num_features
+ self.init_eps = eps
+ self.is_eps_leanable = is_eps_leanable
+
+ self.weight = Parameter(torch.Tensor(num_features))
+ self.bias = Parameter(torch.Tensor(num_features))
+ if is_eps_leanable:
+ self.eps = Parameter(torch.Tensor(1))
+ else:
+ self.register_buffer('eps', torch.Tensor([eps]))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.ones_(self.weight)
+ nn.init.zeros_(self.bias)
+ if self.is_eps_leanable:
+ nn.init.constant_(self.eps, self.init_eps)
+
+ def extra_repr(self):
+ return 'num_features={num_features}, eps={init_eps}'.format(**self.__dict__)
+
+ def forward(self, x):
+ """
+ 0, 1, 2, 3 -> (B, H, W, C) in TensorFlow
+ 0, 1, 2, 3 -> (B, C, H, W) in PyTorch
+ TensorFlow code
+ nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True)
+ x = x * tf.rsqrt(nu2 + tf.abs(eps))
+ # This Code include TLU function max(y, tau)
+ return tf.maximum(gamma * x + beta, tau)
+ """
+ # Compute the mean norm of activations per channel.
+ nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)
+
+ # Perform FRN.
+ x = x * torch.rsqrt(nu2 + self.eps.abs())
+
+ # Scale and Bias
+ x = self.weight.view(1, self.num_features, 1, 1) * x + self.bias.view(1, self.num_features, 1, 1)
+ # x = self.weight * x + self.bias
+ return x
+
+
+def bnrelu_to_frn(module):
+ """
+ Convert 'BatchNorm2d + ReLU' to 'FRN + TLU'
+ """
+ mod = module
+ before_name = None
+ before_child = None
+ is_before_bn = False
+
+ for name, child in module.named_children():
+ if is_before_bn and isinstance(child, (ReLU, LeakyReLU)):
+ # Convert BN to FRN
+ if isinstance(before_child, BatchNorm2d):
+ mod.add_module(
+ before_name, FRN(num_features=before_child.num_features))
+ else:
+ raise NotImplementedError()
+
+ # Convert ReLU to TLU
+ mod.add_module(name, TLU(num_features=before_child.num_features))
+ else:
+ mod.add_module(name, bnrelu_to_frn(child))
+
+ before_name = name
+ before_child = child
+ is_before_bn = isinstance(child, BatchNorm2d)
+ return mod
+
+
+def convert(module, flag_name):
+ mod = module
+ before_ch = None
+ for name, child in module.named_children():
+ if hasattr(child, flag_name) and getattr(child, flag_name):
+ if isinstance(child, BatchNorm2d):
+ before_ch = child.num_features
+ mod.add_module(name, FRN(num_features=child.num_features))
+ # TODO bn is no good...
+ if isinstance(child, (ReLU, LeakyReLU)):
+ mod.add_module(name, TLU(num_features=before_ch))
+ else:
+ mod.add_module(name, convert(child, flag_name))
+ return mod
+
+
+def remove_flags(module, flag_name):
+ mod = module
+ for name, child in module.named_children():
+ if hasattr(child, 'is_convert_frn'):
+ delattr(child, flag_name)
+ mod.add_module(name, remove_flags(child, flag_name))
+ else:
+ mod.add_module(name, remove_flags(child, flag_name))
+ return mod
+
+
+def bnrelu_to_frn2(model, input_size=(3, 128, 128), batch_size=2, flag_name='is_convert_frn'):
+ forard_hooks = list()
+ backward_hooks = list()
+
+ is_before_bn = [False]
+
+ def register_forward_hook(module):
+ def hook(self, input, output):
+ if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
+ is_before_bn.append(False)
+ return
+
+ # input and output is required in hook def
+ is_converted = is_before_bn[-1] and isinstance(self, (ReLU, LeakyReLU))
+ if is_converted:
+ setattr(self, flag_name, True)
+ is_before_bn.append(isinstance(self, BatchNorm2d))
+
+ forard_hooks.append(module.register_forward_hook(hook))
+
+ is_before_relu = [False]
+
+ def register_backward_hook(module):
+ def hook(self, input, output):
+ if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
+ is_before_relu.append(False)
+ return
+ is_converted = is_before_relu[-1] and isinstance(self, BatchNorm2d)
+ if is_converted:
+ setattr(self, flag_name, True)
+ is_before_relu.append(isinstance(self, (ReLU, LeakyReLU)))
+
+ backward_hooks.append(module.register_backward_hook(hook))
+
+ # multiple inputs to the network
+ if isinstance(input_size, tuple):
+ input_size = [input_size]
+
+ # batch_size of 2 for batchnorm
+ x = [torch.rand(batch_size, *in_size) for in_size in input_size]
+
+ # register hook
+ model.apply(register_forward_hook)
+ model.apply(register_backward_hook)
+
+ # make a forward pass
+ output = model(*x)
+ output.sum().backward() # Raw output is not enabled to use backward()
+
+ # remove these hooks
+ for h in forard_hooks:
+ h.remove()
+ for h in backward_hooks:
+ h.remove()
+
+ model = convert(model, flag_name=flag_name)
+ model = remove_flags(model, flag_name=flag_name)
+ return model
diff --git a/layers/non_local.py b/layers/non_local.py
new file mode 100644
index 0000000..876ec43
--- /dev/null
+++ b/layers/non_local.py
@@ -0,0 +1,54 @@
+# encoding: utf-8
+
+
+import torch
+from torch import nn
+from .batch_norm import get_norm
+
+
+class Non_local(nn.Module):
+ def __init__(self, in_channels, bn_norm, num_splits, reduc_ratio=2):
+ super(Non_local, self).__init__()
+
+ self.in_channels = in_channels
+ self.inter_channels = reduc_ratio // reduc_ratio
+
+ self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
+ kernel_size=1, stride=1, padding=0)
+
+ self.W = nn.Sequential(
+ nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
+ kernel_size=1, stride=1, padding=0),
+ get_norm(bn_norm, self.in_channels, num_splits),
+ )
+ nn.init.constant_(self.W[1].weight, 0.0)
+ nn.init.constant_(self.W[1].bias, 0.0)
+
+ self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
+ kernel_size=1, stride=1, padding=0)
+
+ self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
+ kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ '''
+ :param x: (b, t, h, w)
+ :return x: (b, t, h, w)
+ '''
+ batch_size = x.size(0)
+ g_x = self.g(x).view(batch_size, self.inter_channels, -1)
+ g_x = g_x.permute(0, 2, 1)
+
+ theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
+ f = torch.matmul(theta_x, phi_x)
+ N = f.size(-1)
+ f_div_C = f / N
+
+ y = torch.matmul(f_div_C, g_x)
+ y = y.permute(0, 2, 1).contiguous()
+ y = y.view(batch_size, self.inter_channels, *x.size()[2:])
+ W_y = self.W(y)
+ z = W_y + x
+ return z
diff --git a/layers/pooling.py b/layers/pooling.py
new file mode 100644
index 0000000..5aec39e
--- /dev/null
+++ b/layers/pooling.py
@@ -0,0 +1,79 @@
+# encoding: utf-8
+"""
+@author: l1aoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Flatten(nn.Module):
+ def forward(self, input):
+ return input.view(input.size(0), -1)
+
+
+class GeneralizedMeanPooling(nn.Module):
+ r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes.
+ The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`
+ - At p = infinity, one gets Max Pooling
+ - At p = 1, one gets Average Pooling
+ The output is of size H x W, for any input size.
+ The number of output features is equal to the number of input planes.
+ Args:
+ output_size: the target output size of the image of the form H x W.
+ Can be a tuple (H, W) or a single H for a square image H x H
+ H and W can be either a ``int``, or ``None`` which means the size will
+ be the same as that of the input.
+ """
+
+ def __init__(self, norm, output_size=1, eps=1e-6):
+ super(GeneralizedMeanPooling, self).__init__()
+ assert norm > 0
+ self.p = float(norm)
+ self.output_size = output_size
+ self.eps = eps
+
+ def forward(self, x):
+ x = x.clamp(min=self.eps).pow(self.p)
+ return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(' \
+ + str(self.p) + ', ' \
+ + 'output_size=' + str(self.output_size) + ')'
+
+
+class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
+ """ Same, but norm is trainable
+ """
+
+ def __init__(self, norm=3, output_size=1, eps=1e-6):
+ super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps)
+ self.p = nn.Parameter(torch.ones(1) * norm)
+
+
+class AdaptiveAvgMaxPool2d(nn.Module):
+ def __init__(self):
+ super(AdaptiveAvgMaxPool2d, self).__init__()
+ self.avgpool = FastGlobalAvgPool2d()
+
+ def forward(self, x):
+ x_avg = self.avgpool(x, self.output_size)
+ x_max = F.adaptive_max_pool2d(x, 1)
+ x = x_max + x_avg
+ return x
+
+
+class FastGlobalAvgPool2d(nn.Module):
+ def __init__(self, flatten=False):
+ super(FastGlobalAvgPool2d, self).__init__()
+ self.flatten = flatten
+
+ def forward(self, x):
+ if self.flatten:
+ in_size = x.size()
+ return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
+ else:
+ return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
diff --git a/layers/se_layer.py b/layers/se_layer.py
new file mode 100644
index 0000000..04e1dc6
--- /dev/null
+++ b/layers/se_layer.py
@@ -0,0 +1,25 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from torch import nn
+
+
+class SELayer(nn.Module):
+ def __init__(self, channel, reduction=16):
+ super(SELayer, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, int(channel / reduction), bias=False),
+ nn.ReLU(inplace=True),
+ nn.Linear(int(channel / reduction), channel, bias=False),
+ nn.Sigmoid()
+ )
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y.expand_as(x)
diff --git a/layers/splat.py b/layers/splat.py
new file mode 100644
index 0000000..8a8901d
--- /dev/null
+++ b/layers/splat.py
@@ -0,0 +1,97 @@
+# encoding: utf-8
+"""
+@author: xingyu liao
+@contact: liaoxingyu5@jd.com
+"""
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import Conv2d, ReLU
+from torch.nn.modules.utils import _pair
+from .batch_norm import get_norm
+
+
+class SplAtConv2d(nn.Module):
+ """Split-Attention Conv2d
+ """
+
+ def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0),
+ dilation=(1, 1), groups=1, bias=True,
+ radix=2, reduction_factor=4,
+ rectify=False, rectify_avg=False, norm_layer=None, num_splits=1,
+ dropblock_prob=0.0, **kwargs):
+ super(SplAtConv2d, self).__init__()
+ padding = _pair(padding)
+ self.rectify = rectify and (padding[0] > 0 or padding[1] > 0)
+ self.rectify_avg = rectify_avg
+ inter_channels = max(in_channels * radix // reduction_factor, 32)
+ self.radix = radix
+ self.cardinality = groups
+ self.channels = channels
+ self.dropblock_prob = dropblock_prob
+ if self.rectify:
+ from rfconv import RFConv2d
+ self.conv = RFConv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation,
+ groups=groups * radix, bias=bias, average_mode=rectify_avg, **kwargs)
+ else:
+ self.conv = Conv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation,
+ groups=groups * radix, bias=bias, **kwargs)
+ self.use_bn = norm_layer is not None
+ if self.use_bn:
+ self.bn0 = get_norm(norm_layer, channels * radix, num_splits)
+ self.relu = ReLU(inplace=True)
+ self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality)
+ if self.use_bn:
+ self.bn1 = get_norm(norm_layer, inter_channels, num_splits)
+ self.fc2 = Conv2d(inter_channels, channels * radix, 1, groups=self.cardinality)
+
+ self.rsoftmax = rSoftMax(radix, groups)
+
+ def forward(self, x):
+ x = self.conv(x)
+ if self.use_bn:
+ x = self.bn0(x)
+ if self.dropblock_prob > 0.0:
+ x = self.dropblock(x)
+ x = self.relu(x)
+
+ batch, rchannel = x.shape[:2]
+ if self.radix > 1:
+ splited = torch.split(x, rchannel // self.radix, dim=1)
+ gap = sum(splited)
+ else:
+ gap = x
+ gap = F.adaptive_avg_pool2d(gap, 1)
+ gap = self.fc1(gap)
+
+ if self.use_bn:
+ gap = self.bn1(gap)
+ gap = self.relu(gap)
+
+ atten = self.fc2(gap)
+ atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
+
+ if self.radix > 1:
+ attens = torch.split(atten, rchannel // self.radix, dim=1)
+ out = sum([att * split for (att, split) in zip(attens, splited)])
+ else:
+ out = atten * x
+ return out.contiguous()
+
+
+class rSoftMax(nn.Module):
+ def __init__(self, radix, cardinality):
+ super().__init__()
+ self.radix = radix
+ self.cardinality = cardinality
+
+ def forward(self, x):
+ batch = x.size(0)
+ if self.radix > 1:
+ x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
+ x = F.softmax(x, dim=1)
+ x = x.reshape(batch, -1)
+ else:
+ x = torch.sigmoid(x)
+ return x
diff --git a/libtorch.tar.gz b/libtorch.tar.gz
deleted file mode 100644
index e3d2ead..0000000
--- a/libtorch.tar.gz
+++ /dev/null
Binary files differ
diff --git a/modeling/__init__.py b/modeling/__init__.py
new file mode 100644
index 0000000..0873432
--- /dev/null
+++ b/modeling/__init__.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2020/10/26 14:48
+# @Author : Scheaven
+# @File : __init__.py.py
+# @description:
\ No newline at end of file
diff --git a/modeling/__pycache__/__init__.cpython-37.pyc b/modeling/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..b0d4e9b
--- /dev/null
+++ b/modeling/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/modeling/__pycache__/__init__.cpython-38.pyc b/modeling/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..7293942
--- /dev/null
+++ b/modeling/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/modeling/backbones/__init__.py b/modeling/backbones/__init__.py
new file mode 100644
index 0000000..ec6f22d
--- /dev/null
+++ b/modeling/backbones/__init__.py
@@ -0,0 +1,12 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from .build import build_backbone, BACKBONE_REGISTRY
+
+from .resnet import build_resnet_backbone
+from .osnet import build_osnet_backbone
+from .resnest import build_resnest_backbone
+from .resnext import build_resnext_backbone
\ No newline at end of file
diff --git a/modeling/backbones/__pycache__/__init__.cpython-37.pyc b/modeling/backbones/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..adfac7c
--- /dev/null
+++ b/modeling/backbones/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/__init__.cpython-38.pyc b/modeling/backbones/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..2bb2c2f
--- /dev/null
+++ b/modeling/backbones/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/build.cpython-37.pyc b/modeling/backbones/__pycache__/build.cpython-37.pyc
new file mode 100644
index 0000000..ff5bf99
--- /dev/null
+++ b/modeling/backbones/__pycache__/build.cpython-37.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/build.cpython-38.pyc b/modeling/backbones/__pycache__/build.cpython-38.pyc
new file mode 100644
index 0000000..ab9d690
--- /dev/null
+++ b/modeling/backbones/__pycache__/build.cpython-38.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/osnet.cpython-37.pyc b/modeling/backbones/__pycache__/osnet.cpython-37.pyc
new file mode 100644
index 0000000..e567de5
--- /dev/null
+++ b/modeling/backbones/__pycache__/osnet.cpython-37.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/osnet.cpython-38.pyc b/modeling/backbones/__pycache__/osnet.cpython-38.pyc
new file mode 100644
index 0000000..24f2adf
--- /dev/null
+++ b/modeling/backbones/__pycache__/osnet.cpython-38.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/resnest.cpython-37.pyc b/modeling/backbones/__pycache__/resnest.cpython-37.pyc
new file mode 100644
index 0000000..34c7f8a
--- /dev/null
+++ b/modeling/backbones/__pycache__/resnest.cpython-37.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/resnest.cpython-38.pyc b/modeling/backbones/__pycache__/resnest.cpython-38.pyc
new file mode 100644
index 0000000..415db6c
--- /dev/null
+++ b/modeling/backbones/__pycache__/resnest.cpython-38.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/resnet.cpython-37.pyc b/modeling/backbones/__pycache__/resnet.cpython-37.pyc
new file mode 100644
index 0000000..d60009a
--- /dev/null
+++ b/modeling/backbones/__pycache__/resnet.cpython-37.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/resnet.cpython-38.pyc b/modeling/backbones/__pycache__/resnet.cpython-38.pyc
new file mode 100644
index 0000000..ffdab08
--- /dev/null
+++ b/modeling/backbones/__pycache__/resnet.cpython-38.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/resnext.cpython-37.pyc b/modeling/backbones/__pycache__/resnext.cpython-37.pyc
new file mode 100644
index 0000000..b1f70eb
--- /dev/null
+++ b/modeling/backbones/__pycache__/resnext.cpython-37.pyc
Binary files differ
diff --git a/modeling/backbones/__pycache__/resnext.cpython-38.pyc b/modeling/backbones/__pycache__/resnext.cpython-38.pyc
new file mode 100644
index 0000000..72bc6b1
--- /dev/null
+++ b/modeling/backbones/__pycache__/resnext.cpython-38.pyc
Binary files differ
diff --git a/modeling/backbones/build.py b/modeling/backbones/build.py
new file mode 100644
index 0000000..3006786
--- /dev/null
+++ b/modeling/backbones/build.py
@@ -0,0 +1,28 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from utils.registry import Registry
+
+BACKBONE_REGISTRY = Registry("BACKBONE")
+BACKBONE_REGISTRY.__doc__ = """
+Registry for backbones, which extract feature maps from images
+The registered object must be a callable that accepts two arguments:
+1. A :class:`detectron2.config.CfgNode`
+2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification.
+It must returns an instance of :class:`Backbone`.
+"""
+
+
+def build_backbone(cfg):
+ """
+ Build a backbone from `cfg.MODEL.BACKBONE.NAME`.
+ Returns:
+ an instance of :class:`Backbone`
+ """
+
+ backbone_name = cfg.MODEL.BACKBONE.NAME
+ backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg)
+ return backbone
diff --git a/modeling/backbones/osnet.py b/modeling/backbones/osnet.py
new file mode 100644
index 0000000..d71615a
--- /dev/null
+++ b/modeling/backbones/osnet.py
@@ -0,0 +1,487 @@
+# encoding: utf-8
+"""
+@author: xingyu liao
+@contact: liaoxingyu5@jd.com
+"""
+
+# based on:
+# https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/models/osnet.py
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from .build import BACKBONE_REGISTRY
+
+model_urls = {
+ 'osnet_x1_0':
+ 'https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY',
+ 'osnet_x0_75':
+ 'https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq',
+ 'osnet_x0_5':
+ 'https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i',
+ 'osnet_x0_25':
+ 'https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs',
+ 'osnet_ibn_x1_0':
+ 'https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l'
+}
+
+
+##########
+# Basic layers
+##########
+class ConvLayer(nn.Module):
+ """Convolution layer (conv + bn + relu)."""
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ groups=1,
+ IN=False
+ ):
+ super(ConvLayer, self).__init__()
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ bias=False,
+ groups=groups
+ )
+ if IN:
+ self.bn = nn.InstanceNorm2d(out_channels, affine=True)
+ else:
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ x = self.relu(x)
+ return x
+
+
+class Conv1x1(nn.Module):
+ """1x1 convolution + bn + relu."""
+
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
+ super(Conv1x1, self).__init__()
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ 1,
+ stride=stride,
+ padding=0,
+ bias=False,
+ groups=groups
+ )
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ x = self.relu(x)
+ return x
+
+
+class Conv1x1Linear(nn.Module):
+ """1x1 convolution + bn (w/o non-linearity)."""
+
+ def __init__(self, in_channels, out_channels, stride=1):
+ super(Conv1x1Linear, self).__init__()
+ self.conv = nn.Conv2d(
+ in_channels, out_channels, 1, stride=stride, padding=0, bias=False
+ )
+ self.bn = nn.BatchNorm2d(out_channels)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+
+class Conv3x3(nn.Module):
+ """3x3 convolution + bn + relu."""
+
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
+ super(Conv3x3, self).__init__()
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ 3,
+ stride=stride,
+ padding=1,
+ bias=False,
+ groups=groups
+ )
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ x = self.relu(x)
+ return x
+
+
+class LightConv3x3(nn.Module):
+ """Lightweight 3x3 convolution.
+ 1x1 (linear) + dw 3x3 (nonlinear).
+ """
+
+ def __init__(self, in_channels, out_channels):
+ super(LightConv3x3, self).__init__()
+ self.conv1 = nn.Conv2d(
+ in_channels, out_channels, 1, stride=1, padding=0, bias=False
+ )
+ self.conv2 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ 3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=out_channels
+ )
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.bn(x)
+ x = self.relu(x)
+ return x
+
+
+##########
+# Building blocks for omni-scale feature learning
+##########
+class ChannelGate(nn.Module):
+ """A mini-network that generates channel-wise gates conditioned on input tensor."""
+
+ def __init__(
+ self,
+ in_channels,
+ num_gates=None,
+ return_gates=False,
+ gate_activation='sigmoid',
+ reduction=16,
+ layer_norm=False
+ ):
+ super(ChannelGate, self).__init__()
+ if num_gates is None: num_gates = in_channels
+ self.return_gates = return_gates
+
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
+
+ self.fc1 = nn.Conv2d(
+ in_channels,
+ in_channels // reduction,
+ kernel_size=1,
+ bias=True,
+ padding=0
+ )
+ self.norm1 = None
+ if layer_norm: self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
+ self.relu = nn.ReLU(inplace=True)
+ self.fc2 = nn.Conv2d(
+ in_channels // reduction,
+ num_gates,
+ kernel_size=1,
+ bias=True,
+ padding=0
+ )
+ if gate_activation == 'sigmoid': self.gate_activation = nn.Sigmoid()
+ elif gate_activation == 'relu': self.gate_activation = nn.ReLU(inplace=True)
+ elif gate_activation == 'linear': self.gate_activation = nn.Identity()
+ else:
+ raise RuntimeError(
+ "Unknown gate activation: {}".format(gate_activation)
+ )
+
+ def forward(self, x):
+ input = x
+ x = self.global_avgpool(x)
+ x = self.fc1(x)
+ if self.norm1 is not None: x = self.norm1(x)
+ x = self.relu(x)
+ x = self.fc2(x)
+ x = self.gate_activation(x)
+ if self.return_gates: return x
+ return input * x
+
+
+class OSBlock(nn.Module):
+ """Omni-scale feature learning block."""
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ IN=False,
+ bottleneck_reduction=4,
+ **kwargs
+ ):
+ super(OSBlock, self).__init__()
+ mid_channels = out_channels // bottleneck_reduction
+ self.conv1 = Conv1x1(in_channels, mid_channels)
+ self.conv2a = LightConv3x3(mid_channels, mid_channels)
+ self.conv2b = nn.Sequential(
+ LightConv3x3(mid_channels, mid_channels),
+ LightConv3x3(mid_channels, mid_channels),
+ )
+ self.conv2c = nn.Sequential(
+ LightConv3x3(mid_channels, mid_channels),
+ LightConv3x3(mid_channels, mid_channels),
+ LightConv3x3(mid_channels, mid_channels),
+ )
+ self.conv2d = nn.Sequential(
+ LightConv3x3(mid_channels, mid_channels),
+ LightConv3x3(mid_channels, mid_channels),
+ LightConv3x3(mid_channels, mid_channels),
+ LightConv3x3(mid_channels, mid_channels),
+ )
+ self.gate = ChannelGate(mid_channels)
+ self.conv3 = Conv1x1Linear(mid_channels, out_channels)
+ self.downsample = None
+ if in_channels != out_channels:
+ self.downsample = Conv1x1Linear(in_channels, out_channels)
+ self.IN = None
+ if IN: self.IN = nn.InstanceNorm2d(out_channels, affine=True)
+ self.relu = nn.ReLU(True)
+
+ def forward(self, x):
+ identity = x
+ x1 = self.conv1(x)
+ x2a = self.conv2a(x1)
+ x2b = self.conv2b(x1)
+ x2c = self.conv2c(x1)
+ x2d = self.conv2d(x1)
+ x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
+ x3 = self.conv3(x2)
+ if self.downsample is not None:
+ identity = self.downsample(identity)
+ out = x3 + identity
+ if self.IN is not None:
+ out = self.IN(out)
+ return self.relu(out)
+
+
+##########
+# Network architecture
+##########
+class OSNet(nn.Module):
+ """Omni-Scale Network.
+
+ Reference:
+ - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
+ - Zhou et al. Learning Generalisable Omni-Scale Representations
+ for Person Re-Identification. arXiv preprint, 2019.
+ """
+
+ def __init__(
+ self,
+ blocks,
+ layers,
+ channels,
+ IN=False,
+ **kwargs
+ ):
+ super(OSNet, self).__init__()
+ num_blocks = len(blocks)
+ assert num_blocks == len(layers)
+ assert num_blocks == len(channels) - 1
+
+ # convolutional backbone
+ self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
+ self.conv2 = self._make_layer(
+ blocks[0],
+ layers[0],
+ channels[0],
+ channels[1],
+ reduce_spatial_size=True,
+ IN=IN
+ )
+ self.conv3 = self._make_layer(
+ blocks[1],
+ layers[1],
+ channels[1],
+ channels[2],
+ reduce_spatial_size=True
+ )
+ self.conv4 = self._make_layer(
+ blocks[2],
+ layers[2],
+ channels[2],
+ channels[3],
+ reduce_spatial_size=False
+ )
+ self.conv5 = Conv1x1(channels[3], channels[3])
+
+ self._init_params()
+
+ def _make_layer(
+ self,
+ block,
+ layer,
+ in_channels,
+ out_channels,
+ reduce_spatial_size,
+ IN=False
+ ):
+ layers = []
+
+ layers.append(block(in_channels, out_channels, IN=IN))
+ for i in range(1, layer):
+ layers.append(block(out_channels, out_channels, IN=IN))
+
+ if reduce_spatial_size:
+ layers.append(
+ nn.Sequential(
+ Conv1x1(out_channels, out_channels),
+ nn.AvgPool2d(2, stride=2),
+ )
+ )
+
+ return nn.Sequential(*layers)
+
+ def _init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu'
+ )
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ elif isinstance(m, nn.BatchNorm1d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.maxpool(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ x = self.conv4(x)
+ x = self.conv5(x)
+ return x
+
+
+def init_pretrained_weights(model, key=''):
+ """Initializes model with pretrained weights.
+
+ Layers that don't match with pretrained layers in name or size are kept unchanged.
+ """
+ import os
+ import errno
+ import gdown
+ from collections import OrderedDict
+ import warnings
+ import logging
+
+ logger = logging.getLogger(__name__)
+
+ def _get_torch_home():
+ ENV_TORCH_HOME = 'TORCH_HOME'
+ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+ DEFAULT_CACHE_DIR = '~/.cache'
+ torch_home = os.path.expanduser(
+ os.getenv(
+ ENV_TORCH_HOME,
+ os.path.join(
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
+ )
+ )
+ )
+ return torch_home
+
+ torch_home = _get_torch_home()
+ model_dir = os.path.join(torch_home, 'checkpoints')
+ try:
+ os.makedirs(model_dir)
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ # Directory already exists, ignore.
+ pass
+ else:
+ # Unexpected OSError, re-raise.
+ raise
+ filename = key + '_imagenet.pth'
+ cached_file = os.path.join(model_dir, filename)
+
+ if not os.path.exists(cached_file):
+ gdown.download(model_urls[key], cached_file, quiet=False)
+
+ state_dict = torch.load(cached_file)
+ model_dict = model.state_dict()
+ new_state_dict = OrderedDict()
+ matched_layers, discarded_layers = [], []
+
+ for k, v in state_dict.items():
+ if k.startswith('module.'):
+ k = k[7:] # discard module.
+
+ if k in model_dict and model_dict[k].size() == v.size():
+ new_state_dict[k] = v
+ matched_layers.append(k)
+ else:
+ discarded_layers.append(k)
+
+ model_dict.update(new_state_dict)
+ model.load_state_dict(model_dict)
+
+ if len(matched_layers) == 0:
+ warnings.warn(
+ 'The pretrained weights from "{}" cannot be loaded, '
+ 'please check the key names manually '
+ '(** ignored and continue **)'.format(cached_file)
+ )
+ else:
+ logger.info(
+ 'Successfully loaded imagenet pretrained weights from "{}"'.
+ format(cached_file)
+ )
+ if len(discarded_layers) > 0:
+ logger.info(
+ '** The following layers are discarded '
+ 'due to unmatched keys or layer size: {}'.
+ format(discarded_layers)
+ )
+
+
+@BACKBONE_REGISTRY.register()
+def build_osnet_backbone(cfg):
+ """
+ Create a OSNet instance from config.
+ Returns:
+ OSNet: a :class:`OSNet` instance
+ """
+
+ # fmt: off
+ pretrain = cfg.MODEL.BACKBONE.PRETRAIN
+ with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
+
+ num_blocks_per_stage = [2, 2, 2]
+ num_channels_per_stage = [64, 256, 384, 512]
+ model = OSNet([OSBlock, OSBlock, OSBlock], num_blocks_per_stage, num_channels_per_stage, with_ibn)
+ pretrain_key = 'osnet_ibn_x1_0' if with_ibn else 'osnet_x1_0'
+ if pretrain:
+ init_pretrained_weights(model, pretrain_key)
+ return model
diff --git a/modeling/backbones/resnest.py b/modeling/backbones/resnest.py
new file mode 100644
index 0000000..f5eaa9e
--- /dev/null
+++ b/modeling/backbones/resnest.py
@@ -0,0 +1,411 @@
+# encoding: utf-8
+# based on:
+# https://github.com/zhanghang1989/ResNeSt/blob/master/resnest/torch/resnest.py
+"""ResNeSt models"""
+
+import logging
+import math
+
+import torch
+from torch import nn
+
+from layers import (
+ IBN,
+ Non_local,
+ SplAtConv2d,
+ get_norm,
+)
+
+from utils.checkpoint import get_unexpected_parameters_message, get_missing_parameters_message
+
+from .build import BACKBONE_REGISTRY
+
+_url_format = 'https://hangzh.s3.amazonaws.com/encoding/models/{}-{}.pth'
+
+_model_sha256 = {name: checksum for checksum, name in [
+ ('528c19ca', 'resnest50'),
+ ('22405ba7', 'resnest101'),
+ ('75117900', 'resnest200'),
+ ('0cc87c48', 'resnest269'),
+]}
+
+
+def short_hash(name):
+ if name not in _model_sha256:
+ raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
+ return _model_sha256[name][:8]
+
+
+model_urls = {name: _url_format.format(name, short_hash(name)) for
+ name in _model_sha256.keys()
+ }
+
+
+class Bottleneck(nn.Module):
+ """ResNet Bottleneck
+ """
+ # pylint: disable=unused-argument
+ expansion = 4
+
+ def __init__(self, inplanes, planes, bn_norm, num_splits, with_ibn=False, stride=1, downsample=None,
+ radix=1, cardinality=1, bottleneck_width=64,
+ avd=False, avd_first=False, dilation=1, is_first=False,
+ rectified_conv=False, rectify_avg=False,
+ dropblock_prob=0.0, last_gamma=False):
+ super(Bottleneck, self).__init__()
+ group_width = int(planes * (bottleneck_width / 64.)) * cardinality
+ self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
+ if with_ibn:
+ self.bn1 = IBN(group_width, bn_norm, num_splits)
+ else:
+ self.bn1 = get_norm(bn_norm, group_width, num_splits)
+ self.dropblock_prob = dropblock_prob
+ self.radix = radix
+ self.avd = avd and (stride > 1 or is_first)
+ self.avd_first = avd_first
+
+ if self.avd:
+ self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
+ stride = 1
+
+ if radix > 1:
+ self.conv2 = SplAtConv2d(
+ group_width, group_width, kernel_size=3,
+ stride=stride, padding=dilation,
+ dilation=dilation, groups=cardinality, bias=False,
+ radix=radix, rectify=rectified_conv,
+ rectify_avg=rectify_avg,
+ norm_layer=bn_norm, num_splits=num_splits,
+ dropblock_prob=dropblock_prob)
+ elif rectified_conv:
+ from rfconv import RFConv2d
+ self.conv2 = RFConv2d(
+ group_width, group_width, kernel_size=3, stride=stride,
+ padding=dilation, dilation=dilation,
+ groups=cardinality, bias=False,
+ average_mode=rectify_avg)
+ self.bn2 = get_norm(bn_norm, group_width, num_splits)
+ else:
+ self.conv2 = nn.Conv2d(
+ group_width, group_width, kernel_size=3, stride=stride,
+ padding=dilation, dilation=dilation,
+ groups=cardinality, bias=False)
+ self.bn2 = get_norm(bn_norm, group_width, num_splits)
+
+ self.conv3 = nn.Conv2d(
+ group_width, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = get_norm(bn_norm, planes * 4, num_splits)
+
+ if last_gamma:
+ from torch.nn.init import zeros_
+ zeros_(self.bn3.weight)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.dilation = dilation
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ if self.dropblock_prob > 0.0:
+ out = self.dropblock1(out)
+ out = self.relu(out)
+
+ if self.avd and self.avd_first:
+ out = self.avd_layer(out)
+
+ out = self.conv2(out)
+ if self.radix == 1:
+ out = self.bn2(out)
+ if self.dropblock_prob > 0.0:
+ out = self.dropblock2(out)
+ out = self.relu(out)
+
+ if self.avd and not self.avd_first:
+ out = self.avd_layer(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+ if self.dropblock_prob > 0.0:
+ out = self.dropblock3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNest(nn.Module):
+ """ResNet Variants ResNest
+ Parameters
+ ----------
+ block : Block
+ Class for the residual block. Options are BasicBlockV1, BottleneckV1.
+ layers : list of int
+ Numbers of layers in each block
+ classes : int, default 1000
+ Number of classification classes.
+ dilated : bool, default False
+ Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
+ typically used in Semantic Segmentation.
+ norm_layer : object
+ Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
+ for Synchronized Cross-GPU BachNormalization).
+ Reference:
+ - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
+ - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
+ """
+
+ # pylint: disable=unused-variable
+ def __init__(self, last_stride, bn_norm, num_splits, with_ibn, with_nl, block, layers, non_layers, radix=1, groups=1,
+ bottleneck_width=64,
+ dilated=False, dilation=1,
+ deep_stem=False, stem_width=64, avg_down=False,
+ rectified_conv=False, rectify_avg=False,
+ avd=False, avd_first=False,
+ final_drop=0.0, dropblock_prob=0,
+ last_gamma=False):
+ self.cardinality = groups
+ self.bottleneck_width = bottleneck_width
+ # ResNet-D params
+ self.inplanes = stem_width * 2 if deep_stem else 64
+ self.avg_down = avg_down
+ self.last_gamma = last_gamma
+ # ResNeSt params
+ self.radix = radix
+ self.avd = avd
+ self.avd_first = avd_first
+
+ super().__init__()
+ self.rectified_conv = rectified_conv
+ self.rectify_avg = rectify_avg
+ if rectified_conv:
+ from rfconv import RFConv2d
+ conv_layer = RFConv2d
+ else:
+ conv_layer = nn.Conv2d
+ conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {}
+ if deep_stem:
+ self.conv1 = nn.Sequential(
+ conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs),
+ get_norm(bn_norm, stem_width, num_splits),
+ nn.ReLU(inplace=True),
+ conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
+ get_norm(bn_norm, stem_width, num_splits),
+ nn.ReLU(inplace=True),
+ conv_layer(stem_width, stem_width * 2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
+ )
+ else:
+ self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
+ bias=False, **conv_kwargs)
+ self.bn1 = get_norm(bn_norm, self.inplanes, num_splits)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, num_splits, with_ibn=with_ibn, is_first=False)
+ self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, num_splits, with_ibn=with_ibn)
+ if dilated or dilation == 4:
+ self.layer3 = self._make_layer(block, 256, layers[2], 1, bn_norm, num_splits, with_ibn=with_ibn,
+ dilation=2, dropblock_prob=dropblock_prob)
+ self.layer4 = self._make_layer(block, 512, layers[3], 1, bn_norm, num_splits, with_ibn=with_ibn,
+ dilation=4, dropblock_prob=dropblock_prob)
+ elif dilation == 2:
+ self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn=with_ibn,
+ dilation=1, dropblock_prob=dropblock_prob)
+ self.layer4 = self._make_layer(block, 512, layers[3], 1, bn_norm, num_splits, with_ibn=with_ibn,
+ dilation=2, dropblock_prob=dropblock_prob)
+ else:
+ self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn=with_ibn,
+ dropblock_prob=dropblock_prob)
+ self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, num_splits, with_ibn=with_ibn,
+ dropblock_prob=dropblock_prob)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ if with_nl:
+ self._build_nonlocal(layers, non_layers, bn_norm, num_splits)
+ else:
+ self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
+
+ def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", num_splits=1, with_ibn=False,
+ dilation=1, dropblock_prob=0.0, is_first=True):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ down_layers = []
+ if self.avg_down:
+ if dilation == 1:
+ down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride,
+ ceil_mode=True, count_include_pad=False))
+ else:
+ down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1,
+ ceil_mode=True, count_include_pad=False))
+ down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=1, bias=False))
+ else:
+ down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False))
+ down_layers.append(get_norm(bn_norm, planes * block.expansion, num_splits))
+ downsample = nn.Sequential(*down_layers)
+
+ layers = []
+ if planes == 512:
+ with_ibn = False
+ if dilation == 1 or dilation == 2:
+ layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, stride, downsample=downsample,
+ radix=self.radix, cardinality=self.cardinality,
+ bottleneck_width=self.bottleneck_width,
+ avd=self.avd, avd_first=self.avd_first,
+ dilation=1, is_first=is_first, rectified_conv=self.rectified_conv,
+ rectify_avg=self.rectify_avg,
+ dropblock_prob=dropblock_prob,
+ last_gamma=self.last_gamma))
+ elif dilation == 4:
+ layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, stride, downsample=downsample,
+ radix=self.radix, cardinality=self.cardinality,
+ bottleneck_width=self.bottleneck_width,
+ avd=self.avd, avd_first=self.avd_first,
+ dilation=2, is_first=is_first, rectified_conv=self.rectified_conv,
+ rectify_avg=self.rectify_avg,
+ dropblock_prob=dropblock_prob,
+ last_gamma=self.last_gamma))
+ else:
+ raise RuntimeError("=> unknown dilation size: {}".format(dilation))
+
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn,
+ radix=self.radix, cardinality=self.cardinality,
+ bottleneck_width=self.bottleneck_width,
+ avd=self.avd, avd_first=self.avd_first,
+ dilation=dilation, rectified_conv=self.rectified_conv,
+ rectify_avg=self.rectify_avg,
+ dropblock_prob=dropblock_prob,
+ last_gamma=self.last_gamma))
+
+ return nn.Sequential(*layers)
+
+ def _build_nonlocal(self, layers, non_layers, bn_norm, num_splits):
+ self.NL_1 = nn.ModuleList(
+ [Non_local(256, bn_norm, num_splits) for _ in range(non_layers[0])])
+ self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
+ self.NL_2 = nn.ModuleList(
+ [Non_local(512, bn_norm, num_splits) for _ in range(non_layers[1])])
+ self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
+ self.NL_3 = nn.ModuleList(
+ [Non_local(1024, bn_norm, num_splits) for _ in range(non_layers[2])])
+ self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
+ self.NL_4 = nn.ModuleList(
+ [Non_local(2048, bn_norm, num_splits) for _ in range(non_layers[3])])
+ self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ NL1_counter = 0
+ if len(self.NL_1_idx) == 0:
+ self.NL_1_idx = [-1]
+ for i in range(len(self.layer1)):
+ x = self.layer1[i](x)
+ if i == self.NL_1_idx[NL1_counter]:
+ _, C, H, W = x.shape
+ x = self.NL_1[NL1_counter](x)
+ NL1_counter += 1
+ # Layer 2
+ NL2_counter = 0
+ if len(self.NL_2_idx) == 0:
+ self.NL_2_idx = [-1]
+ for i in range(len(self.layer2)):
+ x = self.layer2[i](x)
+ if i == self.NL_2_idx[NL2_counter]:
+ _, C, H, W = x.shape
+ x = self.NL_2[NL2_counter](x)
+ NL2_counter += 1
+ # Layer 3
+ NL3_counter = 0
+ if len(self.NL_3_idx) == 0:
+ self.NL_3_idx = [-1]
+ for i in range(len(self.layer3)):
+ x = self.layer3[i](x)
+ if i == self.NL_3_idx[NL3_counter]:
+ _, C, H, W = x.shape
+ x = self.NL_3[NL3_counter](x)
+ NL3_counter += 1
+ # Layer 4
+ NL4_counter = 0
+ if len(self.NL_4_idx) == 0:
+ self.NL_4_idx = [-1]
+ for i in range(len(self.layer4)):
+ x = self.layer4[i](x)
+ if i == self.NL_4_idx[NL4_counter]:
+ _, C, H, W = x.shape
+ x = self.NL_4[NL4_counter](x)
+ NL4_counter += 1
+
+ return x
+
+
+@BACKBONE_REGISTRY.register()
+def build_resnest_backbone(cfg):
+ """
+ Create a ResNest instance from config.
+ Returns:
+ ResNet: a :class:`ResNet` instance.
+ """
+
+ # fmt: off
+ pretrain = cfg.MODEL.BACKBONE.PRETRAIN
+ last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
+ bn_norm = cfg.MODEL.BACKBONE.NORM
+ num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
+ with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
+ with_se = cfg.MODEL.BACKBONE.WITH_SE
+ with_nl = cfg.MODEL.BACKBONE.WITH_NL
+ depth = cfg.MODEL.BACKBONE.DEPTH
+
+ num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 200: [3, 24, 36, 3], 269: [3, 30, 48, 8]}[depth]
+ nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 3, 0]}[depth]
+ stem_width = {50: 32, 101: 64, 200: 64, 269: 64}[depth]
+ model = ResNest(last_stride, bn_norm, num_splits, with_ibn, with_nl, Bottleneck, num_blocks_per_stage,
+ nl_layers_per_stage, radix=2, groups=1, bottleneck_width=64,
+ deep_stem=True, stem_width=stem_width, avg_down=True,
+ avd=True, avd_first=False)
+ if pretrain:
+ # if not with_ibn:
+ # original resnet
+ state_dict = torch.hub.load_state_dict_from_url(
+ model_urls['resnest' + str(depth)], progress=True, check_hash=True)
+ # else:
+ # raise KeyError('Not implementation ibn in resnest')
+ # # ibn resnet
+ # state_dict = torch.load(pretrain_path)['state_dict']
+ # # remove module in name
+ # new_state_dict = {}
+ # for k in state_dict:
+ # new_k = '.'.join(k.split('.')[1:])
+ # if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
+ # new_state_dict[new_k] = state_dict[k]
+ # state_dict = new_state_dict
+ incompatible = model.load_state_dict(state_dict, strict=False)
+ logger = logging.getLogger(__name__)
+ if incompatible.missing_keys:
+ logger.info(
+ get_missing_parameters_message(incompatible.missing_keys)
+ )
+ if incompatible.unexpected_keys:
+ logger.info(
+ get_unexpected_parameters_message(incompatible.unexpected_keys)
+ )
+ return model
diff --git a/modeling/backbones/resnet.py b/modeling/backbones/resnet.py
new file mode 100644
index 0000000..5dc3f63
--- /dev/null
+++ b/modeling/backbones/resnet.py
@@ -0,0 +1,359 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import logging
+import math
+
+import torch
+from torch import nn
+
+from layers import (
+ IBN,
+ SELayer,
+ Non_local,
+ get_norm,
+)
+from utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
+from .build import BACKBONE_REGISTRY
+from utils import comm
+
+
+logger = logging.getLogger(__name__)
+model_urls = {
+ '18x': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ '34x': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ '50x': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ '101x': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'ibn_18x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_a-2f571257.pth',
+ 'ibn_34x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_a-94bc1577.pth',
+ 'ibn_50x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth',
+ 'ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth',
+ 'se_ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/se_resnet101_ibn_a-fabed4e2.pth',
+}
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False,
+ stride=1, downsample=None, reduction=16):
+ super(BasicBlock, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ if with_ibn:
+ self.bn1 = IBN(planes, bn_norm)
+ else:
+ self.bn1 = get_norm(bn_norm, planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn2 = get_norm(bn_norm, planes)
+ self.relu = nn.ReLU(inplace=True)
+ if with_se:
+ self.se = SELayer(planes, reduction)
+ else:
+ self.se = nn.Identity()
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False,
+ stride=1, downsample=None, reduction=16):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ if with_ibn:
+ self.bn1 = IBN(planes, bn_norm)
+ else:
+ self.bn1 = get_norm(bn_norm, planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = get_norm(bn_norm, planes)
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = get_norm(bn_norm, planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ if with_se:
+ self.se = SELayer(planes * self.expansion, reduction)
+ else:
+ self.se = nn.Identity()
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+ out = self.se(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(self, last_stride, bn_norm, with_ibn, with_se, with_nl, block, layers, non_layers):
+ self.inplanes = 64
+ super().__init__()
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = get_norm(bn_norm, 64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
+ self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, with_ibn, with_se)
+ self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, with_ibn, with_se)
+ self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn, with_se)
+ self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, with_se=with_se)
+
+ self.random_init()
+
+ # fmt: off
+ if with_nl: self._build_nonlocal(layers, non_layers, bn_norm)
+ else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
+ # fmt: on
+
+ def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", with_ibn=False, with_se=False):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ get_norm(bn_norm, planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, bn_norm, with_ibn, with_se, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes, bn_norm, with_ibn, with_se))
+
+ return nn.Sequential(*layers)
+
+ def _build_nonlocal(self, layers, non_layers, bn_norm):
+ self.NL_1 = nn.ModuleList(
+ [Non_local(256, bn_norm) for _ in range(non_layers[0])])
+ self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
+ self.NL_2 = nn.ModuleList(
+ [Non_local(512, bn_norm) for _ in range(non_layers[1])])
+ self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
+ self.NL_3 = nn.ModuleList(
+ [Non_local(1024, bn_norm) for _ in range(non_layers[2])])
+ self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
+ self.NL_4 = nn.ModuleList(
+ [Non_local(2048, bn_norm) for _ in range(non_layers[3])])
+ self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ NL1_counter = 0
+ if len(self.NL_1_idx) == 0:
+ self.NL_1_idx = [-1]
+ for i in range(len(self.layer1)):
+ x = self.layer1[i](x)
+ if i == self.NL_1_idx[NL1_counter]:
+ _, C, H, W = x.shape
+ x = self.NL_1[NL1_counter](x)
+ NL1_counter += 1
+ # Layer 2
+ NL2_counter = 0
+ if len(self.NL_2_idx) == 0:
+ self.NL_2_idx = [-1]
+ for i in range(len(self.layer2)):
+ x = self.layer2[i](x)
+ if i == self.NL_2_idx[NL2_counter]:
+ _, C, H, W = x.shape
+ x = self.NL_2[NL2_counter](x)
+ NL2_counter += 1
+ # Layer 3
+ NL3_counter = 0
+ if len(self.NL_3_idx) == 0:
+ self.NL_3_idx = [-1]
+ for i in range(len(self.layer3)):
+ x = self.layer3[i](x)
+ if i == self.NL_3_idx[NL3_counter]:
+ _, C, H, W = x.shape
+ x = self.NL_3[NL3_counter](x)
+ NL3_counter += 1
+ # Layer 4
+ NL4_counter = 0
+ if len(self.NL_4_idx) == 0:
+ self.NL_4_idx = [-1]
+ for i in range(len(self.layer4)):
+ x = self.layer4[i](x)
+ if i == self.NL_4_idx[NL4_counter]:
+ _, C, H, W = x.shape
+ x = self.NL_4[NL4_counter](x)
+ NL4_counter += 1
+
+ return x
+
+ def random_init(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ nn.init.normal_(m.weight, 0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+
+def init_pretrained_weights(key):
+ """Initializes model with pretrained weights.
+
+ Layers that don't match with pretrained layers in name or size are kept unchanged.
+ """
+ import os
+ import errno
+ import gdown
+
+ def _get_torch_home():
+ ENV_TORCH_HOME = 'TORCH_HOME'
+ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+ DEFAULT_CACHE_DIR = '~/.cache'
+ torch_home = os.path.expanduser(
+ os.getenv(
+ ENV_TORCH_HOME,
+ os.path.join(
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
+ )
+ )
+ )
+ return torch_home
+
+ torch_home = _get_torch_home()
+ model_dir = os.path.join(torch_home, 'checkpoints')
+ try:
+ os.makedirs(model_dir)
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ # Directory already exists, ignore.
+ pass
+ else:
+ # Unexpected OSError, re-raise.
+ raise
+
+ filename = model_urls[key].split('/')[-1]
+
+ cached_file = os.path.join(model_dir, filename)
+
+ if not os.path.exists(cached_file):
+ if comm.is_main_process():
+ gdown.download(model_urls[key], cached_file, quiet=False)
+
+ comm.synchronize()
+
+ logger.info(f"Loading pretrained model from {cached_file}")
+ state_dict = torch.load(cached_file, map_location=torch.device('cpu'))
+
+ return state_dict
+
+
+@BACKBONE_REGISTRY.register()
+def build_resnet_backbone(cfg):
+ """
+ Create a ResNet instance from config.
+ Returns:
+ ResNet: a :class:`ResNet` instance.
+ """
+
+ # fmt: off
+ pretrain = cfg.MODEL.BACKBONE.PRETRAIN
+ pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
+ last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
+ bn_norm = cfg.MODEL.BACKBONE.NORM
+ with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
+ with_se = cfg.MODEL.BACKBONE.WITH_SE
+ with_nl = cfg.MODEL.BACKBONE.WITH_NL
+ depth = cfg.MODEL.BACKBONE.DEPTH
+ # fmt: on
+
+ num_blocks_per_stage = {
+ '18x': [2, 2, 2, 2],
+ '34x': [3, 4, 6, 3],
+ '50x': [3, 4, 6, 3],
+ '101x': [3, 4, 23, 3],
+ }[depth]
+
+ nl_layers_per_stage = {
+ '18x': [0, 0, 0, 0],
+ '34x': [0, 0, 0, 0],
+ '50x': [0, 2, 3, 0],
+ '101x': [0, 2, 9, 0]
+ }[depth]
+
+ block = {
+ '18x': BasicBlock,
+ '34x': BasicBlock,
+ '50x': Bottleneck,
+ '101x': Bottleneck
+ }[depth]
+
+ model = ResNet(last_stride, bn_norm, with_ibn, with_se, with_nl, block,
+ num_blocks_per_stage, nl_layers_per_stage)
+ if pretrain:
+ # Load pretrain path if specifically
+ if pretrain_path:
+ try:
+ state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
+ logger.info(f"Loading pretrained model from {pretrain_path}")
+ except FileNotFoundError as e:
+ logger.info(f'{pretrain_path} is not found! Please check this path.')
+ raise e
+ except KeyError as e:
+ logger.info("State dict keys error! Please check the state dict.")
+ raise e
+ else:
+ key = depth
+ if with_ibn: key = 'ibn_' + key
+ if with_se: key = 'se_' + key
+
+ state_dict = init_pretrained_weights(key)
+
+ incompatible = model.load_state_dict(state_dict, strict=False)
+ if incompatible.missing_keys:
+ logger.info(
+ get_missing_parameters_message(incompatible.missing_keys)
+ )
+ if incompatible.unexpected_keys:
+ logger.info(
+ get_unexpected_parameters_message(incompatible.unexpected_keys)
+ )
+
+ return model
diff --git a/modeling/backbones/resnext.py b/modeling/backbones/resnext.py
new file mode 100644
index 0000000..b68253d
--- /dev/null
+++ b/modeling/backbones/resnext.py
@@ -0,0 +1,198 @@
+# encoding: utf-8
+"""
+@author: xingyu liao
+@contact: liaoxingyu5@jd.com
+"""
+
+# based on:
+# https://github.com/XingangPan/IBN-Net/blob/master/models/imagenet/resnext_ibn_a.py
+
+import math
+import logging
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import init
+import torch
+from layers import IBN
+from .build import BACKBONE_REGISTRY
+
+
+class Bottleneck(nn.Module):
+ """
+ RexNeXt bottleneck type C
+ """
+ expansion = 4
+
+ def __init__(self, inplanes, planes, with_ibn, baseWidth, cardinality, stride=1, downsample=None):
+ """ Constructor
+ Args:
+ inplanes: input channel dimensionality
+ planes: output channel dimensionality
+ baseWidth: base width.
+ cardinality: num of convolution groups.
+ stride: conv stride. Replaces pooling layer.
+ """
+ super(Bottleneck, self).__init__()
+
+ D = int(math.floor(planes * (baseWidth / 64)))
+ C = cardinality
+ self.conv1 = nn.Conv2d(inplanes, D * C, kernel_size=1, stride=1, padding=0, bias=False)
+ if with_ibn:
+ self.bn1 = IBN(D * C)
+ else:
+ self.bn1 = nn.BatchNorm2d(D * C)
+ self.conv2 = nn.Conv2d(D * C, D * C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False)
+ self.bn2 = nn.BatchNorm2d(D * C)
+ self.conv3 = nn.Conv2d(D * C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.downsample = downsample
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNeXt(nn.Module):
+ """
+ ResNext optimized for the ImageNet dataset, as specified in
+ https://arxiv.org/pdf/1611.05431.pdf
+ """
+
+ def __init__(self, last_stride, with_ibn, block, layers, baseWidth=4, cardinality=32):
+ """ Constructor
+ Args:
+ baseWidth: baseWidth for ResNeXt.
+ cardinality: number of convolution groups.
+ layers: config of layers, e.g., [3, 4, 6, 3]
+ num_classes: number of classes
+ """
+ super(ResNeXt, self).__init__()
+
+ self.cardinality = cardinality
+ self.baseWidth = baseWidth
+ self.inplanes = 64
+ self.output_size = 64
+
+ self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0], with_ibn=with_ibn)
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, with_ibn=with_ibn)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, with_ibn=with_ibn)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride, with_ibn=with_ibn)
+
+ self.random_init()
+
+ def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False):
+ """ Stack n bottleneck modules where n is inferred from the depth of the network.
+ Args:
+ block: block type used to construct ResNext
+ planes: number of output channels (need to multiply by block.expansion)
+ blocks: number of blocks to be built
+ stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
+ Returns: a Module consisting of n sequential bottlenecks.
+ """
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ if planes == 512:
+ with_ibn = False
+ layers.append(block(self.inplanes, planes, with_ibn, self.baseWidth, self.cardinality, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes, with_ibn, self.baseWidth, self.cardinality, 1, None))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool1(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ return x
+
+ def random_init(self):
+ self.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.InstanceNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+
+@BACKBONE_REGISTRY.register()
+def build_resnext_backbone(cfg):
+ """
+ Create a ResNeXt instance from config.
+ Returns:
+ ResNeXt: a :class:`ResNeXt` instance.
+ """
+
+ # fmt: off
+ pretrain = cfg.MODEL.BACKBONE.PRETRAIN
+ pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
+ last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
+ with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
+ with_se = cfg.MODEL.BACKBONE.WITH_SE
+ with_nl = cfg.MODEL.BACKBONE.WITH_NL
+ depth = cfg.MODEL.BACKBONE.DEPTH
+
+ num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth]
+ nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 3, 0]}[depth]
+ model = ResNeXt(last_stride, with_ibn, Bottleneck, num_blocks_per_stage)
+ if pretrain:
+ # if not with_ibn:
+ # original resnet
+ # state_dict = model_zoo.load_url(model_urls[depth])
+ # else:
+ # ibn resnet
+ state_dict = torch.load(pretrain_path)['state_dict']
+ # remove module in name
+ new_state_dict = {}
+ for k in state_dict:
+ new_k = '.'.join(k.split('.')[1:])
+ if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
+ new_state_dict[new_k] = state_dict[k]
+ state_dict = new_state_dict
+ res = model.load_state_dict(state_dict, strict=False)
+ logger = logging.getLogger(__name__)
+ logger.info('missing keys is {}'.format(res.missing_keys))
+ logger.info('unexpected keys is {}'.format(res.unexpected_keys))
+ return model
diff --git a/modeling/heads/__init__.py b/modeling/heads/__init__.py
new file mode 100644
index 0000000..0413982
--- /dev/null
+++ b/modeling/heads/__init__.py
@@ -0,0 +1,12 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from .build import REID_HEADS_REGISTRY, build_heads
+
+# import all the meta_arch, so they will be registered
+from .embedding_head import EmbeddingHead
+from .attr_head import AttrHead
+
diff --git a/modeling/heads/__pycache__/__init__.cpython-37.pyc b/modeling/heads/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..16cd05e
--- /dev/null
+++ b/modeling/heads/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/modeling/heads/__pycache__/__init__.cpython-38.pyc b/modeling/heads/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..4d162e3
--- /dev/null
+++ b/modeling/heads/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/modeling/heads/__pycache__/attr_head.cpython-37.pyc b/modeling/heads/__pycache__/attr_head.cpython-37.pyc
new file mode 100644
index 0000000..8a0b0f1
--- /dev/null
+++ b/modeling/heads/__pycache__/attr_head.cpython-37.pyc
Binary files differ
diff --git a/modeling/heads/__pycache__/attr_head.cpython-38.pyc b/modeling/heads/__pycache__/attr_head.cpython-38.pyc
new file mode 100644
index 0000000..806b1d5
--- /dev/null
+++ b/modeling/heads/__pycache__/attr_head.cpython-38.pyc
Binary files differ
diff --git a/modeling/heads/__pycache__/build.cpython-37.pyc b/modeling/heads/__pycache__/build.cpython-37.pyc
new file mode 100644
index 0000000..2facead
--- /dev/null
+++ b/modeling/heads/__pycache__/build.cpython-37.pyc
Binary files differ
diff --git a/modeling/heads/__pycache__/build.cpython-38.pyc b/modeling/heads/__pycache__/build.cpython-38.pyc
new file mode 100644
index 0000000..f51780b
--- /dev/null
+++ b/modeling/heads/__pycache__/build.cpython-38.pyc
Binary files differ
diff --git a/modeling/heads/__pycache__/embedding_head.cpython-37.pyc b/modeling/heads/__pycache__/embedding_head.cpython-37.pyc
new file mode 100644
index 0000000..e1cd5a7
--- /dev/null
+++ b/modeling/heads/__pycache__/embedding_head.cpython-37.pyc
Binary files differ
diff --git a/modeling/heads/__pycache__/embedding_head.cpython-38.pyc b/modeling/heads/__pycache__/embedding_head.cpython-38.pyc
new file mode 100644
index 0000000..982513d
--- /dev/null
+++ b/modeling/heads/__pycache__/embedding_head.cpython-38.pyc
Binary files differ
diff --git a/modeling/heads/attr_head.py b/modeling/heads/attr_head.py
new file mode 100644
index 0000000..62bc941
--- /dev/null
+++ b/modeling/heads/attr_head.py
@@ -0,0 +1,77 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+from torch import nn
+
+from layers import *
+from utils.weight_init import weights_init_kaiming, weights_init_classifier
+from .build import REID_HEADS_REGISTRY
+
+
+@REID_HEADS_REGISTRY.register()
+class AttrHead(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ # fmt: off
+ feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM
+ num_classes = cfg.MODEL.HEADS.NUM_CLASSES
+ pool_type = cfg.MODEL.HEADS.POOL_LAYER
+ cls_type = cfg.MODEL.HEADS.CLS_LAYER
+ with_bnneck = cfg.MODEL.HEADS.WITH_BNNECK
+ norm_type = cfg.MODEL.HEADS.NORM
+
+ if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d()
+ elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1)
+ elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1)
+ elif pool_type == 'gempoolP': self.pool_layer = GeneralizedMeanPoolingP()
+ elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPooling()
+ elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d()
+ elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d()
+ elif pool_type == "identity": self.pool_layer = nn.Identity()
+ elif pool_type == "flatten": self.pool_layer = Flatten()
+ else: raise KeyError(f"{pool_type} is not supported!")
+
+ # Classification layer
+ if cls_type == 'linear': self.classifier = nn.Linear(feat_dim, num_classes, bias=False)
+ elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, feat_dim, num_classes)
+ elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, feat_dim, num_classes)
+ elif cls_type == 'amSoftmax': self.classifier = AMSoftmax(cfg, feat_dim, num_classes)
+ else: raise KeyError(f"{cls_type} is not supported!")
+ # fmt: on
+
+ # bottleneck = []
+ # if with_bnneck:
+ # bottleneck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
+ bottleneck = [nn.BatchNorm1d(num_classes)]
+
+ self.bottleneck = nn.Sequential(*bottleneck)
+
+ self.bottleneck.apply(weights_init_kaiming)
+ self.classifier.apply(weights_init_classifier)
+
+ def forward(self, features, targets=None):
+ """
+ See :class:`ReIDHeads.forward`.
+ """
+ global_feat = self.pool_layer(features)
+ global_feat = global_feat[..., 0, 0]
+
+ classifier_name = self.classifier.__class__.__name__
+ # fmt: off
+ if classifier_name == 'Linear': cls_outputs = self.classifier(global_feat)
+ else: cls_outputs = self.classifier(global_feat, targets)
+ # fmt: on
+
+ cls_outputs = self.bottleneck(cls_outputs)
+
+ if self.training:
+ return {
+ "cls_outputs": cls_outputs,
+ }
+ else:
+ cls_outputs = torch.sigmoid(cls_outputs)
+ return cls_outputs
diff --git a/modeling/heads/bnneck_head.py b/modeling/heads/bnneck_head.py
new file mode 100644
index 0000000..e798a81
--- /dev/null
+++ b/modeling/heads/bnneck_head.py
@@ -0,0 +1,61 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from layers import *
+from modeling.losses import *
+from utils.weight_init import weights_init_kaiming, weights_init_classifier
+from .build import REID_HEADS_REGISTRY
+
+
+@REID_HEADS_REGISTRY.register()
+class BNneckHead(nn.Module):
+ def __init__(self, cfg, in_feat, num_classes, pool_layer):
+ super().__init__()
+ self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
+ self.pool_layer = pool_layer
+
+ self.bnneck = get_norm(cfg.MODEL.HEADS.NORM, in_feat, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True)
+ self.bnneck.apply(weights_init_kaiming)
+
+ # identity classification layer
+ cls_type = cfg.MODEL.HEADS.CLS_LAYER
+ if cls_type == 'linear': self.classifier = nn.Linear(in_feat, num_classes, bias=False)
+ elif cls_type == 'arcface': self.classifier = Arcface(cfg, in_feat, num_classes)
+ elif cls_type == 'circle': self.classifier = Circle(cfg, in_feat, num_classes)
+ else:
+ raise KeyError(f"{cls_type} is invalid, please choose from "
+ f"'linear', 'arcface' and 'circle'.")
+
+ self.classifier.apply(weights_init_classifier)
+
+ def forward(self, features, targets=None):
+ """
+ See :class:`ReIDHeads.forward`.
+ """
+ global_feat = self.pool_layer(features)
+ bn_feat = self.bnneck(global_feat)
+ bn_feat = bn_feat[..., 0, 0]
+
+ # Evaluation
+ if not self.training: return bn_feat
+
+ # Training
+ try:
+ cls_outputs = self.classifier(bn_feat)
+ pred_class_logits = cls_outputs.detach()
+ except TypeError:
+ cls_outputs = self.classifier(bn_feat, targets)
+ pred_class_logits = F.linear(F.normalize(bn_feat.detach()), F.normalize(self.classifier.weight.detach()))
+ # Log prediction accuracy
+ CrossEntropyLoss.log_accuracy(pred_class_logits, targets)
+
+ if self.neck_feat == "before":
+ feat = global_feat[..., 0, 0]
+ elif self.neck_feat == "after":
+ feat = bn_feat
+ else:
+ raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
+ return cls_outputs, feat
diff --git a/modeling/heads/build.py b/modeling/heads/build.py
new file mode 100644
index 0000000..b88bef5
--- /dev/null
+++ b/modeling/heads/build.py
@@ -0,0 +1,25 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+
+from utils.registry import Registry
+
+REID_HEADS_REGISTRY = Registry("HEADS")
+REID_HEADS_REGISTRY.__doc__ = """
+Registry for ROI heads in a generalized R-CNN model.
+ROIHeads take feature maps and region proposals, and
+perform per-region computation.
+The registered object will be called with `obj(cfg, input_shape)`.
+The call is expected to return an :class:`ROIHeads`.
+"""
+
+
+def build_heads(cfg):
+ """
+ Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`.
+ """
+ head = cfg.MODEL.HEADS.NAME
+ return REID_HEADS_REGISTRY.get(head)(cfg)
diff --git a/modeling/heads/embedding_head.py b/modeling/heads/embedding_head.py
new file mode 100644
index 0000000..1ead80b
--- /dev/null
+++ b/modeling/heads/embedding_head.py
@@ -0,0 +1,97 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch.nn.functional as F
+from torch import nn
+
+from layers import *
+from utils.weight_init import weights_init_kaiming, weights_init_classifier
+from .build import REID_HEADS_REGISTRY
+
+
+@REID_HEADS_REGISTRY.register()
+class EmbeddingHead(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ # fmt: off
+ feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM
+ embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM
+ num_classes = cfg.MODEL.HEADS.NUM_CLASSES
+ neck_feat = cfg.MODEL.HEADS.NECK_FEAT
+ pool_type = cfg.MODEL.HEADS.POOL_LAYER
+ cls_type = cfg.MODEL.HEADS.CLS_LAYER
+ with_bnneck = cfg.MODEL.HEADS.WITH_BNNECK
+ norm_type = cfg.MODEL.HEADS.NORM
+
+ if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d()
+ elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1)
+ elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1)
+ elif pool_type == 'gempoolP': self.pool_layer = GeneralizedMeanPoolingP()
+ elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPooling()
+ elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d()
+ elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d()
+ elif pool_type == "identity": self.pool_layer = nn.Identity()
+ elif pool_type == "flatten": self.pool_layer = Flatten()
+ else: raise KeyError(f"{pool_type} is not supported!")
+ # fmt: on
+
+ self.neck_feat = neck_feat
+
+ bottleneck = []
+ if embedding_dim > 0:
+ bottleneck.append(nn.Conv2d(feat_dim, embedding_dim, 1, 1, bias=False))
+ feat_dim = embedding_dim
+
+ if with_bnneck:
+ bottleneck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
+
+ self.bottleneck = nn.Sequential(*bottleneck)
+
+ # identity classification layer
+ # fmt: off
+ if cls_type == 'linear': self.classifier = nn.Linear(feat_dim, num_classes, bias=False)
+ elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, feat_dim, num_classes)
+ elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, feat_dim, num_classes)
+ elif cls_type == 'amSoftmax': self.classifier = AMSoftmax(cfg, feat_dim, num_classes)
+ else: raise KeyError(f"{cls_type} is not supported!")
+ # fmt: on
+
+ self.bottleneck.apply(weights_init_kaiming)
+ self.classifier.apply(weights_init_classifier)
+
+ def forward(self, features, targets=None):
+ """
+ See :class:`ReIDHeads.forward`.
+ """
+ global_feat = self.pool_layer(features)
+ bn_feat = self.bottleneck(global_feat)
+ bn_feat = bn_feat[..., 0, 0]
+
+ # Evaluation
+ # fmt: off
+ if not self.training: return bn_feat
+ # fmt: on
+
+ # Training
+ if self.classifier.__class__.__name__ == 'Linear':
+ cls_outputs = self.classifier(bn_feat)
+ pred_class_logits = F.linear(bn_feat, self.classifier.weight)
+ else:
+ cls_outputs = self.classifier(bn_feat, targets)
+ pred_class_logits = self.classifier.s * F.linear(F.normalize(bn_feat),
+ F.normalize(self.classifier.weight))
+
+ # fmt: off
+ if self.neck_feat == "before": feat = global_feat[..., 0, 0]
+ elif self.neck_feat == "after": feat = bn_feat
+ else: raise KeyError(f"{self.neck_feat} is invalid for MODEL.HEADS.NECK_FEAT")
+ # fmt: on
+
+ return {
+ "cls_outputs": cls_outputs,
+ "pred_class_logits": pred_class_logits,
+ "features": feat,
+ }
diff --git a/modeling/heads/linear_head.py b/modeling/heads/linear_head.py
new file mode 100644
index 0000000..9c34835
--- /dev/null
+++ b/modeling/heads/linear_head.py
@@ -0,0 +1,50 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from layers import *
+from modeling.losses import *
+from .build import REID_HEADS_REGISTRY
+from utils.weight_init import weights_init_classifier
+
+
+@REID_HEADS_REGISTRY.register()
+class LinearHead(nn.Module):
+ def __init__(self, cfg, in_feat, num_classes, pool_layer):
+ super().__init__()
+ self.pool_layer = pool_layer
+
+ # identity classification layer
+ cls_type = cfg.MODEL.HEADS.CLS_LAYER
+ if cls_type == 'linear': self.classifier = nn.Linear(in_feat, num_classes, bias=False)
+ elif cls_type == 'arcface': self.classifier = Arcface(cfg, in_feat, num_classes)
+ elif cls_type == 'circle': self.classifier = Circle(cfg, in_feat, num_classes)
+ else:
+ raise KeyError(f"{cls_type} is invalid, please choose from "
+ f"'linear', 'arcface' and 'circle'.")
+
+ self.classifier.apply(weights_init_classifier)
+
+ def forward(self, features, targets=None):
+ """
+ See :class:`ReIDHeads.forward`.
+ """
+ global_feat = self.pool_layer(features)
+ global_feat = global_feat[..., 0, 0]
+
+ # Evaluation
+ if not self.training: return global_feat
+
+ # Training
+ try:
+ cls_outputs = self.classifier(global_feat)
+ pred_class_logits = cls_outputs.detach()
+ except TypeError:
+ cls_outputs = self.classifier(global_feat, targets)
+ pred_class_logits = F.linear(F.normalize(global_feat.detach()), F.normalize(self.classifier.weight.detach()))
+ # Log prediction accuracy
+ CrossEntropyLoss.log_accuracy(pred_class_logits, targets)
+
+ return cls_outputs, global_feat
diff --git a/modeling/heads/reduction_head.py b/modeling/heads/reduction_head.py
new file mode 100644
index 0000000..3e29c84
--- /dev/null
+++ b/modeling/heads/reduction_head.py
@@ -0,0 +1,73 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from layers import *
+from modeling.losses import *
+from utils.weight_init import weights_init_kaiming, weights_init_classifier
+from .build import REID_HEADS_REGISTRY
+
+
+@REID_HEADS_REGISTRY.register()
+class ReductionHead(nn.Module):
+ def __init__(self, cfg, in_feat, num_classes, pool_layer):
+ super().__init__()
+ self._cfg = cfg
+ reduction_dim = cfg.MODEL.HEADS.REDUCTION_DIM
+ self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
+
+ self.pool_layer = pool_layer
+
+ self.bottleneck = nn.Sequential(
+ nn.Conv2d(in_feat, reduction_dim, 1, 1, bias=False),
+ get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT),
+ nn.LeakyReLU(0.1, inplace=True),
+ )
+
+ self.bnneck = get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True)
+
+ self.bottleneck.apply(weights_init_kaiming)
+ self.bnneck.apply(weights_init_kaiming)
+
+ # identity classification layer
+ cls_type = cfg.MODEL.HEADS.CLS_LAYER
+ if cls_type == 'linear': self.classifier = nn.Linear(reduction_dim, num_classes, bias=False)
+ elif cls_type == 'arcface': self.classifier = Arcface(cfg, reduction_dim, num_classes)
+ elif cls_type == 'circle': self.classifier = Circle(cfg, reduction_dim, num_classes)
+ else:
+ raise KeyError(f"{cls_type} is invalid, please choose from "
+ f"'linear', 'arcface' and 'circle'.")
+
+ self.classifier.apply(weights_init_classifier)
+
+ def forward(self, features, targets=None):
+ """
+ See :class:`ReIDHeads.forward`.
+ """
+ features = self.pool_layer(features)
+ global_feat = self.bottleneck(features)
+ bn_feat = self.bnneck(global_feat)
+ bn_feat = bn_feat[..., 0, 0]
+
+ # Evaluation
+ if not self.training: return bn_feat
+
+ # Training
+ try:
+ cls_outputs = self.classifier(bn_feat)
+ pred_class_logits = cls_outputs.detach()
+ except TypeError:
+ cls_outputs = self.classifier(bn_feat, targets)
+ pred_class_logits = F.linear(F.normalize(bn_feat.detach()), F.normalize(self.classifier.weight.detach()))
+ # Log prediction accuracy
+ CrossEntropyLoss.log_accuracy(pred_class_logits, targets)
+
+ if self.neck_feat == "before": feat = global_feat[..., 0, 0]
+ elif self.neck_feat == "after": feat = bn_feat
+ else:
+ raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
+
+ return cls_outputs, feat
+
diff --git a/modeling/losses/__init__.py b/modeling/losses/__init__.py
new file mode 100644
index 0000000..977e5c3
--- /dev/null
+++ b/modeling/losses/__init__.py
@@ -0,0 +1,9 @@
+# encoding: utf-8
+"""
+@author: l1aoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from .cross_entroy_loss import CrossEntropyLoss
+from .focal_loss import FocalLoss
+from .metric_loss import TripletLoss, CircleLoss
diff --git a/modeling/losses/__pycache__/__init__.cpython-37.pyc b/modeling/losses/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..9a28f08
--- /dev/null
+++ b/modeling/losses/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/modeling/losses/__pycache__/__init__.cpython-38.pyc b/modeling/losses/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..a6c71c7
--- /dev/null
+++ b/modeling/losses/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/modeling/losses/__pycache__/cross_entroy_loss.cpython-37.pyc b/modeling/losses/__pycache__/cross_entroy_loss.cpython-37.pyc
new file mode 100644
index 0000000..543c429
--- /dev/null
+++ b/modeling/losses/__pycache__/cross_entroy_loss.cpython-37.pyc
Binary files differ
diff --git a/modeling/losses/__pycache__/cross_entroy_loss.cpython-38.pyc b/modeling/losses/__pycache__/cross_entroy_loss.cpython-38.pyc
new file mode 100644
index 0000000..a43b6f3
--- /dev/null
+++ b/modeling/losses/__pycache__/cross_entroy_loss.cpython-38.pyc
Binary files differ
diff --git a/modeling/losses/__pycache__/focal_loss.cpython-37.pyc b/modeling/losses/__pycache__/focal_loss.cpython-37.pyc
new file mode 100644
index 0000000..3440861
--- /dev/null
+++ b/modeling/losses/__pycache__/focal_loss.cpython-37.pyc
Binary files differ
diff --git a/modeling/losses/__pycache__/focal_loss.cpython-38.pyc b/modeling/losses/__pycache__/focal_loss.cpython-38.pyc
new file mode 100644
index 0000000..bf84802
--- /dev/null
+++ b/modeling/losses/__pycache__/focal_loss.cpython-38.pyc
Binary files differ
diff --git a/modeling/losses/__pycache__/metric_loss.cpython-37.pyc b/modeling/losses/__pycache__/metric_loss.cpython-37.pyc
new file mode 100644
index 0000000..cad5fb8
--- /dev/null
+++ b/modeling/losses/__pycache__/metric_loss.cpython-37.pyc
Binary files differ
diff --git a/modeling/losses/__pycache__/metric_loss.cpython-38.pyc b/modeling/losses/__pycache__/metric_loss.cpython-38.pyc
new file mode 100644
index 0000000..a7b70f6
--- /dev/null
+++ b/modeling/losses/__pycache__/metric_loss.cpython-38.pyc
Binary files differ
diff --git a/modeling/losses/cross_entroy_loss.py b/modeling/losses/cross_entroy_loss.py
new file mode 100644
index 0000000..2108acd
--- /dev/null
+++ b/modeling/losses/cross_entroy_loss.py
@@ -0,0 +1,62 @@
+# encoding: utf-8
+"""
+@author: l1aoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+import torch
+import torch.nn.functional as F
+
+from utils.events import get_event_storage
+
+
+class CrossEntropyLoss(object):
+ """
+ A class that stores information and compute losses about outputs of a Baseline head.
+ """
+
+ def __init__(self, cfg):
+ self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
+ self._eps = cfg.MODEL.LOSSES.CE.EPSILON
+ self._alpha = cfg.MODEL.LOSSES.CE.ALPHA
+ self._scale = cfg.MODEL.LOSSES.CE.SCALE
+
+ @staticmethod
+ def log_accuracy(pred_class_logits, gt_classes, topk=(1,)):
+ """
+ Log the accuracy metrics to EventStorage.
+ """
+ bsz = pred_class_logits.size(0)
+ maxk = max(topk)
+ _, pred_class = pred_class_logits.topk(maxk, 1, True, True)
+ pred_class = pred_class.t()
+ correct = pred_class.eq(gt_classes.view(1, -1).expand_as(pred_class))
+
+ ret = []
+ for k in topk:
+ correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
+ ret.append(correct_k.mul_(1. / bsz))
+
+ storage = get_event_storage()
+ storage.put_scalar("cls_accuracy", ret[0])
+
+ def __call__(self, pred_class_logits, gt_classes):
+ """
+ Compute the softmax cross entropy loss for box classification.
+ Returns:
+ scalar Tensor
+ """
+ if self._eps >= 0:
+ smooth_param = self._eps
+ else:
+ # adaptive lsr
+ soft_label = F.softmax(pred_class_logits, dim=1)
+ smooth_param = self._alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1)
+
+ log_probs = F.log_softmax(pred_class_logits, dim=1)
+ with torch.no_grad():
+ targets = torch.ones_like(log_probs)
+ targets *= smooth_param / (self._num_classes - 1)
+ targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param))
+
+ loss = (-targets * log_probs).mean(0).sum()
+ return loss * self._scale
diff --git a/modeling/losses/focal_loss.py b/modeling/losses/focal_loss.py
new file mode 100644
index 0000000..c520594
--- /dev/null
+++ b/modeling/losses/focal_loss.py
@@ -0,0 +1,110 @@
+# encoding: utf-8
+"""
+@author: xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+import torch.nn.functional as F
+
+
+# based on:
+# https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py
+
+def focal_loss(
+ input: torch.Tensor,
+ target: torch.Tensor,
+ alpha: float,
+ gamma: float = 2.0,
+ reduction: str = 'mean', ) -> torch.Tensor:
+ r"""Function that computes Focal loss.
+ See :class:`fastreid.modeling.losses.FocalLoss` for details.
+ """
+ if not torch.is_tensor(input):
+ raise TypeError("Input type is not a torch.Tensor. Got {}"
+ .format(type(input)))
+
+ if not len(input.shape) >= 2:
+ raise ValueError("Invalid input shape, we expect BxCx*. Got: {}"
+ .format(input.shape))
+
+ if input.size(0) != target.size(0):
+ raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
+ .format(input.size(0), target.size(0)))
+
+ n = input.size(0)
+ out_size = (n,) + input.size()[2:]
+ if target.size()[1:] != input.size()[2:]:
+ raise ValueError('Expected target size {}, got {}'.format(
+ out_size, target.size()))
+
+ if not input.device == target.device:
+ raise ValueError(
+ "input and target must be in the same device. Got: {}".format(
+ input.device, target.device))
+
+ # compute softmax over the classes axis
+ input_soft = F.softmax(input, dim=1)
+
+ # create the labels one hot tensor
+ target_one_hot = F.one_hot(target, num_classes=input.shape[1])
+
+ # compute the actual focal loss
+ weight = torch.pow(-input_soft + 1., gamma)
+
+ focal = -alpha * weight * torch.log(input_soft)
+ loss_tmp = torch.sum(target_one_hot * focal, dim=1)
+
+ if reduction == 'none':
+ loss = loss_tmp
+ elif reduction == 'mean':
+ loss = torch.mean(loss_tmp)
+ elif reduction == 'sum':
+ loss = torch.sum(loss_tmp)
+ else:
+ raise NotImplementedError("Invalid reduction mode: {}"
+ .format(reduction))
+ return loss
+
+
+class FocalLoss(object):
+ r"""Criterion that computes Focal loss.
+ According to [1], the Focal loss is computed as follows:
+ .. math::
+ \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
+ where:
+ - :math:`p_t` is the model's estimated probability for each class.
+ Arguments:
+ alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
+ gamma (float): Focusing parameter :math:`\gamma >= 0`.
+ reduction (str, optional): Specifies the reduction to apply to the
+ output: 鈥榥one鈥� | 鈥榤ean鈥� | 鈥榮um鈥�. 鈥榥one鈥�: no reduction will be applied,
+ 鈥榤ean鈥�: the sum of the output will be divided by the number of elements
+ in the output, 鈥榮um鈥�: the output will be summed. Default: 鈥榥one鈥�.
+ Shape:
+ - Input: :math:`(N, C, *)` where C = number of classes.
+ - Target: :math:`(N, *)` where each value is
+ :math:`0 鈮� targets[i] 鈮� C鈭�1`.
+ Examples:
+ >>> N = 5 # num_classes
+ >>> loss = FocalLoss(cfg)
+ >>> input = torch.randn(1, N, 3, 5, requires_grad=True)
+ >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
+ >>> output = loss(input, target)
+ >>> output.backward()
+ References:
+ [1] https://arxiv.org/abs/1708.02002
+ """
+
+ # def __init__(self, alpha: float, gamma: float = 2.0,
+ # reduction: str = 'none') -> None:
+ def __init__(self, cfg):
+ self._alpha: float = cfg.MODEL.LOSSES.FL.ALPHA
+ self._gamma: float = cfg.MODEL.LOSSES.FL.GAMMA
+ self._scale: float = cfg.MODEL.LOSSES.FL.SCALE
+
+ def __call__(self, pred_class_logits: torch.Tensor, _, gt_classes: torch.Tensor) -> dict:
+ loss = focal_loss(pred_class_logits, gt_classes, self._alpha, self._gamma)
+ return {
+ 'loss_focal': loss * self._scale,
+ }
diff --git a/modeling/losses/metric_loss.py b/modeling/losses/metric_loss.py
new file mode 100644
index 0000000..edcf636
--- /dev/null
+++ b/modeling/losses/metric_loss.py
@@ -0,0 +1,215 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+import torch.nn.functional as F
+
+from utils import comm
+
+
+# utils
+@torch.no_grad()
+def concat_all_gather(tensor):
+ """
+ Performs all_gather operation on the provided tensors.
+ *** Warning ***: torch.distributed.all_gather has no gradient.
+ """
+ tensors_gather = [torch.ones_like(tensor)
+ for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+ output = torch.cat(tensors_gather, dim=0)
+ return output
+
+
+def normalize(x, axis=-1):
+ """Normalizing to unit length along the specified dimension.
+ Args:
+ x: pytorch Variable
+ Returns:
+ x: pytorch Variable, same shape as input
+ """
+ x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
+ return x
+
+
+def euclidean_dist(x, y):
+ m, n = x.size(0), y.size(0)
+ xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
+ yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
+ dist = xx + yy
+ dist.addmm_(1, -2, x, y.t())
+ dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
+ return dist
+
+
+def cosine_dist(x, y):
+ bs1, bs2 = x.size(0), y.size(0)
+ frac_up = torch.matmul(x, y.transpose(0, 1))
+ frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \
+ (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1)
+ cosine = frac_up / frac_down
+ return 1 - cosine
+
+
+def softmax_weights(dist, mask):
+ max_v = torch.max(dist * mask, dim=1, keepdim=True)[0]
+ diff = dist - max_v
+ Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero
+ W = torch.exp(diff) * mask / Z
+ return W
+
+
+def hard_example_mining(dist_mat, is_pos, is_neg):
+ """For each anchor, find the hardest positive and negative sample.
+ Args:
+ dist_mat: pair wise distance between samples, shape [N, M]
+ is_pos: positive index with shape [N, M]
+ is_neg: negative index with shape [N, M]
+ Returns:
+ dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
+ dist_an: pytorch Variable, distance(anchor, negative); shape [N]
+ p_inds: pytorch LongTensor, with shape [N];
+ indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
+ n_inds: pytorch LongTensor, with shape [N];
+ indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
+ NOTE: Only consider the case in which all labels have same num of samples,
+ thus we can cope with all anchors in parallel.
+ """
+
+ assert len(dist_mat.size()) == 2
+ N = dist_mat.size(0)
+
+ # `dist_ap` means distance(anchor, positive)
+ # both `dist_ap` and `relative_p_inds` with shape [N, 1]
+ # pos_dist = dist_mat[is_pos].contiguous().view(N, -1)
+ # ap_weight = F.softmax(pos_dist, dim=1)
+ # dist_ap = torch.sum(ap_weight * pos_dist, dim=1)
+ dist_ap, relative_p_inds = torch.max(
+ dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
+ # `dist_an` means distance(anchor, negative)
+ # both `dist_an` and `relative_n_inds` with shape [N, 1]
+ dist_an, relative_n_inds = torch.min(
+ dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
+ # neg_dist = dist_mat[is_neg].contiguous().view(N, -1)
+ # an_weight = F.softmax(-neg_dist, dim=1)
+ # dist_an = torch.sum(an_weight * neg_dist, dim=1)
+
+ # shape [N]
+ dist_ap = dist_ap.squeeze(1)
+ dist_an = dist_an.squeeze(1)
+
+ return dist_ap, dist_an
+
+
+def weighted_example_mining(dist_mat, is_pos, is_neg):
+ """For each anchor, find the weighted positive and negative sample.
+ Args:
+ dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
+ is_pos:
+ is_neg:
+ Returns:
+ dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
+ dist_an: pytorch Variable, distance(anchor, negative); shape [N]
+ """
+ assert len(dist_mat.size()) == 2
+
+ is_pos = is_pos.float()
+ is_neg = is_neg.float()
+ dist_ap = dist_mat * is_pos
+ dist_an = dist_mat * is_neg
+
+ weights_ap = softmax_weights(dist_ap, is_pos)
+ weights_an = softmax_weights(-dist_an, is_neg)
+
+ dist_ap = torch.sum(dist_ap * weights_ap, dim=1)
+ dist_an = torch.sum(dist_an * weights_an, dim=1)
+
+ return dist_ap, dist_an
+
+
+class TripletLoss(object):
+ """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
+ Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
+ Loss for Person Re-Identification'."""
+
+ def __init__(self, cfg):
+ self._margin = cfg.MODEL.LOSSES.TRI.MARGIN
+ self._normalize_feature = cfg.MODEL.LOSSES.TRI.NORM_FEAT
+ self._scale = cfg.MODEL.LOSSES.TRI.SCALE
+ self._hard_mining = cfg.MODEL.LOSSES.TRI.HARD_MINING
+
+ def __call__(self, embedding, targets):
+ if self._normalize_feature:
+ embedding = normalize(embedding, axis=-1)
+
+ # For distributed training, gather all features from different process.
+ if comm.get_world_size() > 1:
+ all_embedding = concat_all_gather(embedding)
+ all_targets = concat_all_gather(targets)
+ else:
+ all_embedding = embedding
+ all_targets = targets
+
+ dist_mat = euclidean_dist(embedding, all_embedding)
+
+ N, M = dist_mat.size()
+ is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
+ is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
+
+ if self._hard_mining:
+ dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
+ else:
+ dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)
+
+ y = dist_an.new().resize_as_(dist_an).fill_(1)
+
+ if self._margin > 0:
+ loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=self._margin)
+ else:
+ loss = F.soft_margin_loss(dist_an - dist_ap, y)
+ if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
+
+ return loss * self._scale
+
+
+class CircleLoss(object):
+ def __init__(self, cfg):
+ self._scale = cfg.MODEL.LOSSES.CIRCLE.SCALE
+
+ self.m = cfg.MODEL.LOSSES.CIRCLE.MARGIN
+ self.s = cfg.MODEL.LOSSES.CIRCLE.ALPHA
+
+ def __call__(self, embedding, targets):
+ embedding = F.normalize(embedding, dim=1)
+
+ if comm.get_world_size() > 1:
+ all_embedding = concat_all_gather(embedding)
+ all_targets = concat_all_gather(targets)
+ else:
+ all_embedding = embedding
+ all_targets = targets
+
+ dist_mat = torch.matmul(embedding, all_embedding.t())
+
+ N, M = dist_mat.size()
+ is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
+ is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
+
+ s_p = dist_mat[is_pos].contiguous().view(N, -1)
+ s_n = dist_mat[is_neg].contiguous().view(N, -1)
+
+ alpha_p = F.relu(-s_p.detach() + 1 + self.m)
+ alpha_n = F.relu(s_n.detach() + self.m)
+ delta_p = 1 - self.m
+ delta_n = self.m
+
+ logit_p = - self.s * alpha_p * (s_p - delta_p)
+ logit_n = self.s * alpha_n * (s_n - delta_n)
+
+ loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
+
+ return loss * self._scale
diff --git a/modeling/meta_arch/__init__.py b/modeling/meta_arch/__init__.py
new file mode 100644
index 0000000..84ae987
--- /dev/null
+++ b/modeling/meta_arch/__init__.py
@@ -0,0 +1,12 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from .build import META_ARCH_REGISTRY, build_model
+
+
+# import all the meta_arch, so they will be registered
+from .baseline import Baseline
+from .mgn import MGN
diff --git a/modeling/meta_arch/__pycache__/__init__.cpython-37.pyc b/modeling/meta_arch/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..9edd4bc
--- /dev/null
+++ b/modeling/meta_arch/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/modeling/meta_arch/__pycache__/__init__.cpython-38.pyc b/modeling/meta_arch/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..272307a
--- /dev/null
+++ b/modeling/meta_arch/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/modeling/meta_arch/__pycache__/baseline.cpython-37.pyc b/modeling/meta_arch/__pycache__/baseline.cpython-37.pyc
new file mode 100644
index 0000000..bc9d4b2
--- /dev/null
+++ b/modeling/meta_arch/__pycache__/baseline.cpython-37.pyc
Binary files differ
diff --git a/modeling/meta_arch/__pycache__/baseline.cpython-38.pyc b/modeling/meta_arch/__pycache__/baseline.cpython-38.pyc
new file mode 100644
index 0000000..382735e
--- /dev/null
+++ b/modeling/meta_arch/__pycache__/baseline.cpython-38.pyc
Binary files differ
diff --git a/modeling/meta_arch/__pycache__/build.cpython-37.pyc b/modeling/meta_arch/__pycache__/build.cpython-37.pyc
new file mode 100644
index 0000000..3b47628
--- /dev/null
+++ b/modeling/meta_arch/__pycache__/build.cpython-37.pyc
Binary files differ
diff --git a/modeling/meta_arch/__pycache__/build.cpython-38.pyc b/modeling/meta_arch/__pycache__/build.cpython-38.pyc
new file mode 100644
index 0000000..a27b826
--- /dev/null
+++ b/modeling/meta_arch/__pycache__/build.cpython-38.pyc
Binary files differ
diff --git a/modeling/meta_arch/__pycache__/mgn.cpython-37.pyc b/modeling/meta_arch/__pycache__/mgn.cpython-37.pyc
new file mode 100644
index 0000000..9e4a0a6
--- /dev/null
+++ b/modeling/meta_arch/__pycache__/mgn.cpython-37.pyc
Binary files differ
diff --git a/modeling/meta_arch/__pycache__/mgn.cpython-38.pyc b/modeling/meta_arch/__pycache__/mgn.cpython-38.pyc
new file mode 100644
index 0000000..ad31752
--- /dev/null
+++ b/modeling/meta_arch/__pycache__/mgn.cpython-38.pyc
Binary files differ
diff --git a/modeling/meta_arch/baseline.py b/modeling/meta_arch/baseline.py
new file mode 100644
index 0000000..4e4e518
--- /dev/null
+++ b/modeling/meta_arch/baseline.py
@@ -0,0 +1,119 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+from torch import nn
+
+from modeling.backbones import build_backbone
+from modeling.heads import build_heads
+from modeling.losses import *
+from .build import META_ARCH_REGISTRY
+
+
+@META_ARCH_REGISTRY.register()
+class Baseline(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self._cfg = cfg
+ assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
+ self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
+ self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
+
+ # backbone
+ self.backbone = build_backbone(cfg)
+
+ # head
+ self.heads = build_heads(cfg)
+
+ @property
+ def device(self):
+ return self.pixel_mean.device
+
+ def forward(self, batched_inputs):
+ images = self.preprocess_image(batched_inputs)
+ features = self.backbone(images)
+
+ if self.training:
+ assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
+ targets = batched_inputs["targets"].to(self.device)
+
+ # PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
+ # may be larger than that in the original dataset, so the circle/arcface will
+ # throw an error. We just set all the targets to 0 to avoid this problem.
+ if targets.sum() < 0: targets.zero_()
+
+ outputs = self.heads(features, targets)
+ return {
+ "outputs": outputs,
+ "targets": targets,
+ }
+ else:
+ outputs = self.heads(features)
+ return outputs
+
+ def preprocess_image(self, images):
+ r"""
+ Normalize and batch the input images.
+ """
+ # images = images.to(self.device)
+ # if isinstance(batched_inputs, dict):
+ # images = batched_inputs["images"].to(self.device)
+ # elif isinstance(batched_inputs, torch.Tensor):
+ # images = batched_inputs.to(self.device)
+ # else:
+ # raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs)))
+
+ # print(images)
+ print("-----------------------------------------------------------------------------")
+ images.sub_(self.pixel_mean).div_(self.pixel_std)
+ return images
+
+ def losses(self, outs):
+ r"""
+ Compute loss from modeling's outputs, the loss function input arguments
+ must be the same as the outputs of the model forwarding.
+ """
+ # fmt: off
+ outputs = outs["outputs"]
+ gt_labels = outs["targets"]
+ # model predictions
+ pred_class_logits = outputs['pred_class_logits'].detach()
+ cls_outputs = outputs['cls_outputs']
+ pred_features = outputs['features']
+ # fmt: on
+
+ # Log prediction accuracy
+ log_accuracy(pred_class_logits, gt_labels)
+
+ loss_dict = {}
+ loss_names = self._cfg.MODEL.LOSSES.NAME
+
+ if "CrossEntropyLoss" in loss_names:
+ loss_dict['loss_cls'] = cross_entropy_loss(
+ cls_outputs,
+ gt_labels,
+ self._cfg.MODEL.LOSSES.CE.EPSILON,
+ self._cfg.MODEL.LOSSES.CE.ALPHA,
+ ) * self._cfg.MODEL.LOSSES.CE.SCALE
+
+ if "TripletLoss" in loss_names:
+ loss_dict['loss_triplet'] = triplet_loss(
+ pred_features,
+ gt_labels,
+ self._cfg.MODEL.LOSSES.TRI.MARGIN,
+ self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
+ self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
+ ) * self._cfg.MODEL.LOSSES.TRI.SCALE
+
+ if "CircleLoss" in loss_names:
+ loss_dict['loss_circle'] = circle_loss(
+ pred_features,
+ gt_labels,
+ self._cfg.MODEL.LOSSES.CIRCLE.MARGIN,
+ self._cfg.MODEL.LOSSES.CIRCLE.ALPHA,
+ ) * self._cfg.MODEL.LOSSES.CIRCLE.SCALE
+
+ return loss_dict
diff --git a/modeling/meta_arch/build.py b/modeling/meta_arch/build.py
new file mode 100644
index 0000000..395b5c7
--- /dev/null
+++ b/modeling/meta_arch/build.py
@@ -0,0 +1,21 @@
+# encoding: utf-8
+from utils.registry import Registry
+
+META_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip
+META_ARCH_REGISTRY.__doc__ = """
+Registry for meta-architectures, i.e. the whole model.
+The registered object will be called with `obj(cfg)`
+and expected to return a `nn.Module` object.
+"""
+
+
+def build_model(cfg):
+ """
+ Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
+ Note that it does not load any weights from ``cfg``.
+ """
+ meta_arch = cfg.MODEL.META_ARCHITECTURE
+ model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
+ return model
+
+
diff --git a/modeling/meta_arch/mgn.py b/modeling/meta_arch/mgn.py
new file mode 100644
index 0000000..6050b55
--- /dev/null
+++ b/modeling/meta_arch/mgn.py
@@ -0,0 +1,280 @@
+# encoding: utf-8
+"""
+@author: liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+import copy
+
+import torch
+from torch import nn
+
+from layers import get_norm
+from modeling.backbones import build_backbone
+from modeling.backbones.resnet import Bottleneck
+from modeling.heads import build_heads
+from modeling.losses import *
+from .build import META_ARCH_REGISTRY
+
+
+@META_ARCH_REGISTRY.register()
+class MGN(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self._cfg = cfg
+ assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
+ self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
+ self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
+
+ # fmt: off
+ # backbone
+ bn_norm = cfg.MODEL.BACKBONE.NORM
+ with_se = cfg.MODEL.BACKBONE.WITH_SE
+ # fmt :on
+
+ backbone = build_backbone(cfg)
+ self.backbone = nn.Sequential(
+ backbone.conv1,
+ backbone.bn1,
+ backbone.relu,
+ backbone.maxpool,
+ backbone.layer1,
+ backbone.layer2,
+ backbone.layer3[0]
+ )
+ res_conv4 = nn.Sequential(*backbone.layer3[1:])
+ res_g_conv5 = backbone.layer4
+
+ res_p_conv5 = nn.Sequential(
+ Bottleneck(1024, 512, bn_norm, False, with_se, downsample=nn.Sequential(
+ nn.Conv2d(1024, 2048, 1, bias=False), get_norm(bn_norm, 2048))),
+ Bottleneck(2048, 512, bn_norm, False, with_se),
+ Bottleneck(2048, 512, bn_norm, False, with_se))
+ res_p_conv5.load_state_dict(backbone.layer4.state_dict())
+
+ # branch1
+ self.b1 = nn.Sequential(
+ copy.deepcopy(res_conv4),
+ copy.deepcopy(res_g_conv5)
+ )
+ self.b1_head = build_heads(cfg)
+
+ # branch2
+ self.b2 = nn.Sequential(
+ copy.deepcopy(res_conv4),
+ copy.deepcopy(res_p_conv5)
+ )
+ self.b2_head = build_heads(cfg)
+ self.b21_head = build_heads(cfg)
+ self.b22_head = build_heads(cfg)
+
+ # branch3
+ self.b3 = nn.Sequential(
+ copy.deepcopy(res_conv4),
+ copy.deepcopy(res_p_conv5)
+ )
+ self.b3_head = build_heads(cfg)
+ self.b31_head = build_heads(cfg)
+ self.b32_head = build_heads(cfg)
+ self.b33_head = build_heads(cfg)
+
+ @property
+ def device(self):
+ return self.pixel_mean.device
+
+ def forward(self, batched_inputs):
+ images = self.preprocess_image(batched_inputs)
+ features = self.backbone(images) # (bs, 2048, 16, 8)
+
+ # branch1
+ b1_feat = self.b1(features)
+
+ # branch2
+ b2_feat = self.b2(features)
+ b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
+
+ # branch3
+ b3_feat = self.b3(features)
+ b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
+
+ if self.training:
+ assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
+ targets = batched_inputs["targets"].long().to(self.device)
+
+ if targets.sum() < 0: targets.zero_()
+
+ b1_outputs = self.b1_head(b1_feat, targets)
+ b2_outputs = self.b2_head(b2_feat, targets)
+ b21_outputs = self.b21_head(b21_feat, targets)
+ b22_outputs = self.b22_head(b22_feat, targets)
+ b3_outputs = self.b3_head(b3_feat, targets)
+ b31_outputs = self.b31_head(b31_feat, targets)
+ b32_outputs = self.b32_head(b32_feat, targets)
+ b33_outputs = self.b33_head(b33_feat, targets)
+
+ return {
+ "b1_outputs": b1_outputs,
+ "b2_outputs": b2_outputs,
+ "b21_outputs": b21_outputs,
+ "b22_outputs": b22_outputs,
+ "b3_outputs": b3_outputs,
+ "b31_outputs": b31_outputs,
+ "b32_outputs": b32_outputs,
+ "b33_outputs": b33_outputs,
+ "targets": targets,
+ }
+ else:
+ b1_pool_feat = self.b1_head(b1_feat)
+ b2_pool_feat = self.b2_head(b2_feat)
+ b21_pool_feat = self.b21_head(b21_feat)
+ b22_pool_feat = self.b22_head(b22_feat)
+ b3_pool_feat = self.b3_head(b3_feat)
+ b31_pool_feat = self.b31_head(b31_feat)
+ b32_pool_feat = self.b32_head(b32_feat)
+ b33_pool_feat = self.b33_head(b33_feat)
+
+ pred_feat = torch.cat([b1_pool_feat, b2_pool_feat, b3_pool_feat, b21_pool_feat,
+ b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
+ return pred_feat
+
+ def preprocess_image(self, batched_inputs):
+ r"""
+ Normalize and batch the input images.
+ """
+ if isinstance(batched_inputs, dict):
+ images = batched_inputs["images"].to(self.device)
+ elif isinstance(batched_inputs, torch.Tensor):
+ images = batched_inputs.to(self.device)
+ else:
+ raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs)))
+
+ images.sub_(self.pixel_mean).div_(self.pixel_std)
+ return images
+
+ def losses(self, outs):
+ # fmt: off
+ b1_outputs = outs["b1_outputs"]
+ b2_outputs = outs["b2_outputs"]
+ b21_outputs = outs["b21_outputs"]
+ b22_outputs = outs["b22_outputs"]
+ b3_outputs = outs["b3_outputs"]
+ b31_outputs = outs["b31_outputs"]
+ b32_outputs = outs["b32_outputs"]
+ b33_outputs = outs["b33_outputs"]
+ gt_labels = outs["targets"]
+ # model predictions
+ pred_class_logits = b1_outputs['pred_class_logits'].detach()
+ b1_logits = b1_outputs['cls_outputs']
+ b2_logits = b2_outputs['cls_outputs']
+ b21_logits = b21_outputs['cls_outputs']
+ b22_logits = b22_outputs['cls_outputs']
+ b3_logits = b3_outputs['cls_outputs']
+ b31_logits = b31_outputs['cls_outputs']
+ b32_logits = b32_outputs['cls_outputs']
+ b33_logits = b33_outputs['cls_outputs']
+ b1_pool_feat = b1_outputs['features']
+ b2_pool_feat = b2_outputs['features']
+ b3_pool_feat = b3_outputs['features']
+ b21_pool_feat = b21_outputs['features']
+ b22_pool_feat = b22_outputs['features']
+ b31_pool_feat = b31_outputs['features']
+ b32_pool_feat = b32_outputs['features']
+ b33_pool_feat = b33_outputs['features']
+ # fmt: on
+
+ # Log prediction accuracy
+ log_accuracy(pred_class_logits, gt_labels)
+
+ b22_pool_feat = torch.cat((b21_pool_feat, b22_pool_feat), dim=1)
+ b33_pool_feat = torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)
+
+ loss_dict = {}
+ loss_names = self._cfg.MODEL.LOSSES.NAME
+
+ if "CrossEntropyLoss" in loss_names:
+ loss_dict['loss_cls_b1'] = cross_entropy_loss(
+ b1_logits,
+ gt_labels,
+ self._cfg.MODEL.LOSSES.CE.EPSILON,
+ self._cfg.MODEL.LOSSES.CE.ALPHA,
+ ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
+ loss_dict['loss_cls_b2'] = cross_entropy_loss(
+ b2_logits,
+ gt_labels,
+ self._cfg.MODEL.LOSSES.CE.EPSILON,
+ self._cfg.MODEL.LOSSES.CE.ALPHA,
+ ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
+ loss_dict['loss_cls_b21'] = cross_entropy_loss(
+ b21_logits,
+ gt_labels,
+ self._cfg.MODEL.LOSSES.CE.EPSILON,
+ self._cfg.MODEL.LOSSES.CE.ALPHA,
+ ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
+ loss_dict['loss_cls_b22'] = cross_entropy_loss(
+ b22_logits,
+ gt_labels,
+ self._cfg.MODEL.LOSSES.CE.EPSILON,
+ self._cfg.MODEL.LOSSES.CE.ALPHA,
+ ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
+ loss_dict['loss_cls_b3'] = cross_entropy_loss(
+ b3_logits,
+ gt_labels,
+ self._cfg.MODEL.LOSSES.CE.EPSILON,
+ self._cfg.MODEL.LOSSES.CE.ALPHA,
+ ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
+ loss_dict['loss_cls_b31'] = cross_entropy_loss(
+ b31_logits,
+ gt_labels,
+ self._cfg.MODEL.LOSSES.CE.EPSILON,
+ self._cfg.MODEL.LOSSES.CE.ALPHA,
+ ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
+ loss_dict['loss_cls_b32'] = cross_entropy_loss(
+ b32_logits,
+ gt_labels,
+ self._cfg.MODEL.LOSSES.CE.EPSILON,
+ self._cfg.MODEL.LOSSES.CE.ALPHA,
+ ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
+ loss_dict['loss_cls_b33'] = cross_entropy_loss(
+ b33_logits,
+ gt_labels,
+ self._cfg.MODEL.LOSSES.CE.EPSILON,
+ self._cfg.MODEL.LOSSES.CE.ALPHA,
+ ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
+
+ if "TripletLoss" in loss_names:
+ loss_dict['loss_triplet_b1'] = triplet_loss(
+ b1_pool_feat,
+ gt_labels,
+ self._cfg.MODEL.LOSSES.TRI.MARGIN,
+ self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
+ self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
+ ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
+ loss_dict['loss_triplet_b2'] = triplet_loss(
+ b2_pool_feat,
+ gt_labels,
+ self._cfg.MODEL.LOSSES.TRI.MARGIN,
+ self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
+ self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
+ ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
+ loss_dict['loss_triplet_b3'] = triplet_loss(
+ b3_pool_feat,
+ gt_labels,
+ self._cfg.MODEL.LOSSES.TRI.MARGIN,
+ self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
+ self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
+ ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
+ loss_dict['loss_triplet_b22'] = triplet_loss(
+ b22_pool_feat,
+ gt_labels,
+ self._cfg.MODEL.LOSSES.TRI.MARGIN,
+ self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
+ self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
+ ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
+ loss_dict['loss_triplet_b33'] = triplet_loss(
+ b33_pool_feat,
+ gt_labels,
+ self._cfg.MODEL.LOSSES.TRI.MARGIN,
+ self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
+ self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
+ ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
+
+ return loss_dict
diff --git a/modul/model.pt b/modul/model.pt
deleted file mode 100644
index 31ded02..0000000
--- a/modul/model.pt
+++ /dev/null
Binary files differ
diff --git a/reid_feature.cpp b/reid_feature.cpp
deleted file mode 100644
index fdf1e88..0000000
--- a/reid_feature.cpp
+++ /dev/null
@@ -1,122 +0,0 @@
-//
-// Created by Scheaven on 2020/1/3.
-//
-
-#include "reid_feature.h"
-#include <cuda_runtime_api.h>
-#include <torch/torch.h>
-
-bool ReID_Feature::ReID_init(int gpu_id)
-{
- if(gpu_id == -1){
- this->module = torch::jit::load(MODEL_PATH);
- this->module.to(torch::kCPU);
- this->module.eval();
- this->is_gpu = false;
- }else if(torch::cuda::is_available() && torch::cuda::device_count() >= gpu_id)
- {
- cudaSetDevice(gpu_id);
- cout << "model loading::" << HUMAN_FEATS << endl;
- this->module = torch::jit::load(MODEL_PATH, torch::Device(torch::DeviceType::CUDA,gpu_id));
- this->module.to(torch::kCUDA);
- this->module.eval();
- this->is_gpu = true;
- }else{
- return false;
- }
- return true;
-}
-
-int ReID_Feature::ReID_size()
-{
- int size = 2048;
- return size;
-
-}
-
-bool ReID_Feature::ReID_extractor(float *pBuf, float *pFeature)
-{
- auto input_tensor = torch::from_blob(pBuf, {1, 256, 128, 3});
- input_tensor = input_tensor.permute({0, 3, 1, 2});
- input_tensor[0][0] = input_tensor[0][0].sub_(0.485).div_(0.229);
- input_tensor[0][1] = input_tensor[0][1].sub_(0.456).div_(0.224);
- input_tensor[0][2] = input_tensor[0][2].sub_(0.406).div_(0.225);
- if(this->is_gpu)
- input_tensor = input_tensor.to(at::kCUDA);
- torch::Tensor human_feat =this->module.forward({input_tensor}).toTensor();
-
-// for (int k = 0; k < 20; ++k) {
-// cout << "--extractor---human_feats------" <<human_feat[0][k+2000]<< endl;
-// }
- torch::Tensor query_feat;
- if(this->is_gpu)
- query_feat = human_feat.cpu();
- else
- query_feat = human_feat;
-
- auto foo_one = query_feat.accessor<float,2>();
-
- ReID_Utils RET;
-
- float f_size = -0.727412;
- for (int64_t i = 0; i < foo_one.size(0); i++) {
- auto a1 = foo_one[i];
- for (int64_t j = 0; j < a1.size(0); j++) {
- pFeature[j] = a1[j];
- }
- }
-
-// cout << "---- end 11-------" << pFeature[0] << endl;
- return true;
-}
-
-float ReID_Feature::ReID_Compare(float *pFeature1, float *pFeature2)
-{
- torch::Tensor query_feat = torch::zeros({1,2048});
- torch::Tensor gallery_feat = torch::zeros({1,2048});
-
- for (int i = 0; i < 2048; i++)
- {
- query_feat[0][i] = pFeature1[i];
- gallery_feat[0][i] = pFeature2[i];
- }
-
- if(this->is_gpu)
- {
- query_feat = query_feat.cuda();
- gallery_feat = gallery_feat.cuda();
- }
-
-// cout << "-----------------after-----------" << endl;
-// cout << query_feat<< endl;
-
-// for (int k = 0; k < 20; ++k) {
-// cout << "-query_feat----1111111111------" <<query_feat[0][k+2000]<< endl;
-// }
-//
-// cout << "-----------------asdf-----------" << endl;
-//// cout << gallery_feat[0][0]<< endl;
-// for (int k = 0; k < 20; ++k) {
-// cout << "-gallery_feat----22222222------" <<gallery_feat[0][k+2000]<< endl;
-// }
-
- torch::Tensor a_similarity = torch::cosine_similarity(query_feat, gallery_feat);
-
- if(this->is_gpu)
- a_similarity = a_similarity.cpu();
-
- auto foo_one = a_similarity.accessor<float,1>();
-// cout << ":::::::::-" << endl;
-
- float f_distance = foo_one[0];
-
- return f_distance;
-
-}
-
-void ReID_Feature::ReID_Release()
-{
- prinf("release");
-// this->module = nullptr;//鍔犺浇妯″瀷
-// return true;
-}
\ No newline at end of file
diff --git a/reid_feature.h b/reid_feature.h
deleted file mode 100644
index 8a98bf2..0000000
--- a/reid_feature.h
+++ /dev/null
@@ -1,32 +0,0 @@
-//
-// Created by Scheaven on 2020/1/3.
-//
-
-#ifndef INC_03_REID_STRONG_BASELINE_REID_FEATURE_H
-#define INC_03_REID_STRONG_BASELINE_REID_FEATURE_H
-#include <torch/script.h>
-#include <opencv2/opencv.hpp>
-#include <fstream>
-#include <string>
-#include <iomanip>
-#include <stdlib.h>
-#include <vector>
-#include <iostream>
-
-
-struct ReID_Feature {
-private:
- bool is_gpu;
- // auto feat;
- torch::jit::script::Module module;
-public:
- bool ReID_init(int gpu_id);
- int ReID_size();
-// static unsigned char * extractor(unsigned char *pBuf, unsigned char *pFeature);
- bool ReID_extractor(float *pBuf, float * pFeature);
- float ReID_Compare(float *pFeature1, float *pFeature2);
- void ReID_Release();
-};
-
-
-#endif //INC_03_REID_STRONG_BASELINE_REID_FEATURE_H
diff --git a/reid_utils.cpp b/reid_utils.cpp
deleted file mode 100644
index 17927ba..0000000
--- a/reid_utils.cpp
+++ /dev/null
@@ -1,24 +0,0 @@
-//
-// Created by Scheaven on 2020/1/3.
-//
-
-#include "reid_utils.h"
-
-
-void *ReID_Utils::normalize(unsigned char *vsrc, int w, int h, int chan){
- float *data = (float*)malloc(h*w*chan*sizeof(float));
- int size = w*h;
- int size2 = size*2;
-
- unsigned char *srcData = (unsigned char*)vsrc;
-
- for(int i = 0;i<size;i++){
- *(data) = *(srcData + 2) /255.0f;
- *(data+size) = *(srcData + 1) /255.0f;
- *(data+size2) = *(srcData) /255.0f;
- data++;
- srcData+=3;
- }
-
- return data;
-}
diff --git a/reid_utils.h b/reid_utils.h
deleted file mode 100644
index 47c249b..0000000
--- a/reid_utils.h
+++ /dev/null
@@ -1,22 +0,0 @@
-//
-// Created by Scheaven on 2020/1/3.
-//
-
-#ifndef INC_03_REID_STRONG_BASELINE_REID_UTILS_H
-#define INC_03_REID_STRONG_BASELINE_REID_UTILS_H
-#include <torch/script.h>
-#include <fstream>
-#include <string>
-#include <iomanip>
-#include <stdlib.h>
-#include <vector>
-#include <iostream>
-
-using namespace std;
-struct ReID_Utils{
- public:
- float *normalize(unsigned char *vsrc, int w, int h, int chan);
-};
-
-
-#endif //INC_03_REID_STRONG_BASELINE_REID_UTILS_H
diff --git a/test.cpp b/test.cpp
deleted file mode 100644
index 72640b9..0000000
--- a/test.cpp
+++ /dev/null
@@ -1,86 +0,0 @@
-//
-// Created by Scheaven on 2019/12/27.
-//
-#include <torch/script.h> // One-stop header.
-#include <iostream>
-#include <memory>
-#include <vector>
-#include <string>
-#include <opencv2/core/core.hpp>
-#include <opencv2/highgui/highgui.hpp>
-#include "opencv2/imgproc/imgproc.hpp"
-#include "opencv2/opencv.hpp"
-#include "opencv2/videoio.hpp"
-#include "reid_feature.h"
-
-
-using namespace std;
-//using namespace cv;
-
-int main(int argc, const char* argv[]) {
-// if (argc != 2) {
-// std::cerr << "usage: reid-app <image path>\n";;
-// return -1;
-// }
-
-// torch::jit::script::Module module;
-// char cam_id = 'A';
-// ReID_Tracker Tracker;
-
-
-
- /*鍒濆鍖�*/
- int gpu_id = 0;
- ReID_Feature R_Feater;
- bool n_flog = R_Feater.ReID_init(0);
- ReID_Utils r_util;
-// ReID_Feature R_Feater(gpu_id);
-
- /*opencv鍔犺浇鍥剧墖淇℃伅*/
- cv::Mat human_img = cv::imread("./03.jpg");
- cv::Mat human_img2 = cv::imread("./01.jpg");
- if (human_img.data == nullptr)
- {
- cerr<<"===鍥剧墖鏂囦欢涓嶅瓨鍦�"<<endl;
- return 0;
- }else
- {
- //cv::namedWindow("Display", CV_WINDOW_AUTOSIZE);
- try {
- float pFeature1[2048];
- /*杞箟鍥剧墖淇℃伅鏍煎紡*/
- // cv::cvtColor(human_img, human_img, cv::COLOR_RGB2BGR);
- // human_img.convertTo(human_img, CV_32FC3, 1.0f / 255.0f);
- float *my_img_data = r_util.normalize(human_img.data, human_img.cols, human_img.rows, 3);
- bool ex_flag1 = R_Feater.ReID_extractor(my_img_data, pFeature1);
-// for (int k = 0; k < 20; ++k) {
-// cout << "-----11111111111------" <<pFeature1[k+2000]<< endl;
-// }
-
- float pFeature2[2048];
- // cv::cvtColor(human_img2, human_img2, cv::COLOR_RGB2BGR);
- // human_img2.convertTo(human_img2, CV_32FC3, 1.0f / 255.0f);
- float *my_img_data2 = r_util.normalize(human_img.data, human_img.cols, human_img.rows, 3);
- bool ex_flag2 = R_Feater.ReID_extractor(my_img_data2, pFeature2);
-// for (int k = 0; k < 20; ++k) {
-// cout << "-----2222222222------" <<pFeature2[k+2000]<< endl;
-// }
-
- /*璁$畻鐩镐技搴�*/
- cout << "--attention_distance human-" << endl;
- float result = R_Feater.ReID_Compare(pFeature1, pFeature2);
-
-// Tracker.storager(1,human_img);
- }
- catch (const c10::Error& e) {
- std::cerr << "error loading the model\n";
- return -1;
- }
-
- std::cout << "ok\n";
- //cout<< human_img <<endl;
- //cv::imshow("0909",human_img);
- //cv::waitKey(0);
- }
-
-}
diff --git a/tools/03_check_onnx.py b/tools/03_check_onnx.py
new file mode 100644
index 0000000..4467222
--- /dev/null
+++ b/tools/03_check_onnx.py
@@ -0,0 +1,168 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import torch
+import sys
+sys.path.append('.')
+
+import time
+
+import pycuda.autoinit
+import pycuda.driver as cuda
+import tensorrt as trt
+import torch
+import time
+from PIL import Image
+import cv2,os
+import torchvision
+import numpy as np
+
+max_batch_size = 1
+onnx_model_path = "/data/disk1/workspace/06_reid/01_fast_reid/02_fast_reid_inference/fastreid.onnx"
+TRT_LOGGER = trt.Logger()
+
+# class HostDeviceMem(object):
+# def init(self, host_mem, device_mem):
+# # """
+# # host_mem: cpu memory
+# # device_mem: gpu memory
+# # """
+# print("-----------11-----------")
+# self.host = host_mem
+# self.device = device_mem
+
+# def init():
+# # """
+# # host_mem: cpu memory
+# # device_mem: gpu memory
+# # """
+# print("---------22-------------")
+
+# def __str__(self):
+# return "Host:\n" + str(self.host)+"\nDevice:\n"+str(self.device)
+# def __repr__(self):
+# return self.__str__()
+
+def get_img_np_nchw(filename):
+ image = cv2.imread(filename)
+ image_cv = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ image_cv = cv2.resize(image_cv, (256, 128))
+ miu = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
+ std = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)
+ img_np = np.array(image_cv, dtype=np.float)/255.
+ img_np = img_np.transpose((2, 0, 1))
+ img_np -= miu
+ img_np /= std
+ img_np_nchw = img_np[np.newaxis]
+ img_np_nchw = np.tile(img_np_nchw,(max_batch_size, 1, 1, 1))
+ return img_np_nchw
+
+
+class HostDeviceMem(object):
+ def __init__(self, host_mem, device_mem):
+ # """
+ # host_mem: cpu memory
+ # device_mem: gpu memory
+ # """
+ self.host = host_mem
+ self.device = device_mem
+ print()
+
+ def __str__(self):
+ return "Host:\n" + str(self.host)+"\nDevice:\n"+str(self.device)
+
+ def __repr__(self):
+ return self.__str__()
+
+def allocate_buffers(engine):
+ inputs, outputs, bindings = [], [], []
+ stream = cuda.Stream()
+ for binding in engine:
+ size = trt.volume(engine.get_binding_shape(binding))
+ dtype = trt.nptype(engine.get_binding_dtype(binding))
+ host_mem = cuda.pagelocked_empty(size, dtype)
+ device_mem = cuda.mem_alloc(host_mem.nbytes)
+ bindings.append(int(device_mem))
+ #append to the appropriate list
+ if engine.binding_is_input(binding):
+ inputs.append(HostDeviceMem(host_mem, device_mem))
+ else:
+ outputs.append(HostDeviceMem(host_mem, device_mem))
+ return inputs, outputs, bindings, stream
+
+def get_engine(max_batch_size=1, onnx_file_path="", engine_file_path="", fp16_mode=False, save_engine=True):
+ if os.path.exists(engine_file_path):
+ print("Reading engine from file: {}".format(engine_file_path))
+ with open(engine_file_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:
+ return runtime.deserialize_cuda_engine(f.read()) # 鍙嶅簭鍒楀寲
+ else:
+
+ explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+ # In TensorRT 7.0, the ONNX parser only supports full-dimensions mode, meaning that your network definition must be created with the explicitBatch flag set. For more information, see Working With Dynamic Shapes.
+
+ with trt.Builder(TRT_LOGGER) as builder, \
+ builder.create_network(explicit_batch) as network, \
+ trt.OnnxParser(network, TRT_LOGGER) as parser:
+
+ config = builder.create_builder_config()
+ config.max_workspace_size = 1<<30
+ builder.max_batch_size = max_batch_size # 鎵ц鏃舵渶澶у彲浠ヤ娇鐢ㄧ殑batchsize
+ builder.fp16_mode = fp16_mode
+
+ if not os.path.exists(onnx_file_path):
+ quit("ONNX file {} not found!".format(onnx_file_path))
+ print('loading onnx file from path {} ...'.format(onnx_file_path))
+ with open(onnx_file_path, 'rb') as model: # 浜屽�煎寲鐨勭綉缁滅粨鏋滃拰鍙傛暟
+ print("Begining onnx file parsing")
+ parser.parse(model.read())
+
+ print("Completed parsing of onnx file")
+ print("Building an engine from file{}' this may take a while...".format(onnx_file_path))
+
+ #################
+ print(network.get_layer(network.num_layers-1).get_output(0).shape)
+ engine=builder.build_engin(network, config)
+ print("Completed creating Engine")
+ if save_engine:
+ with open(engine_file_path, 'wb') as f:
+ f.write(engine.serialize()) # 搴忓垪鍖�
+ return engine
+
+def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
+ [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
+ context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
+ [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
+ # gpu to cpu
+ # Synchronize the stream
+ stream.synchronize()
+ return [out.host for out in outputs]
+
+
+def postprocess_the_outputs(h_outputs, shape_of_output):
+ h_outputs = h_outputs.reshape(*shape_of_output)
+ return h_outputs
+
+if __name__ == '__main__':
+ img_np_nchw = get_img_np_nchw("/data/disk1/project/data/01_reid/0_1.png").astype(np.float32)
+ fp16_mode = True
+ trt_engine_path ="./human_feature{}.trt".format(fp16_mode)
+
+ engine = get_engine(max_batch_size, onnx_model_path, trt_engine_path, fp16_mode)
+
+ context = engine.create_execution_context()
+ inputs, outputs, bindings, stream = allocate_buffers(engine)
+
+ shape_of_output = (max_batch_size, 2048)
+
+ inputs[0].host = img_np_nchw
+
+ t1 = time.time()
+ trt_outputs = do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream, batch_size = max_batch_size)
+ t2 = time.time()
+ print(trt_outputs, trt_outputs[0].shape)
+
+ feat = postprocess_the_outputs(trt_outputs[0], shape_of_output)
+ print('TensorRT ok')
+ print("Inference time with the TensorRT engine: {}".format(t2-t1))
+
+
+
diff --git a/tools/03_py2onnx.py b/tools/03_py2onnx.py
new file mode 100644
index 0000000..1090f23
--- /dev/null
+++ b/tools/03_py2onnx.py
@@ -0,0 +1,40 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Date : 2021-04-29 18:20:17
+# @Author : Scheaven (snow_mail@foxmail.com)
+# @Link : www.github.com
+# @Version : $Id$
+
+import torch
+import sys
+sys.path.append('.')
+
+from data.data_utils import read_image
+from predictor import ReID_Model
+from config import get_cfg
+from data.transforms.build import build_transforms
+from engine.defaults import default_argument_parser, default_setup
+import time
+
+def setup(args):
+ """
+ Create configs and perform basic setups.
+ """
+ cfg = get_cfg()
+ cfg.merge_from_file(args.config_file)
+ cfg.merge_from_list(args.opts)
+ cfg.freeze()
+ default_setup(cfg, args)
+ return cfg
+
+if __name__ == '__main__':
+ args = default_argument_parser().parse_args()
+ cfg = setup(args)
+ cfg.defrost()
+ cfg.MODEL.BACKBONE.PRETRAIN = False
+
+ model = ReID_Model(cfg)
+
+
+ model.torch2onnx()
+
diff --git a/tools/04_trt_inference.py b/tools/04_trt_inference.py
new file mode 100644
index 0000000..15cb9d3
--- /dev/null
+++ b/tools/04_trt_inference.py
@@ -0,0 +1,96 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import argparse
+import glob
+import os
+import sys
+
+import cv2
+import numpy as np
+import tqdm
+
+sys.path.append("/export/home/lxy/runtimelib-tensorrt-tiny/build")
+
+import pytrt
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description="trt model inference")
+
+ parser.add_argument(
+ "--model_path",
+ default="outputs/trt_model/baseline.engine",
+ help="trt model path"
+ )
+ parser.add_argument(
+ "--input",
+ nargs="+",
+ help="A list of space separated input images; "
+ "or a single glob pattern such as 'directory/*.jpg'",
+ )
+ parser.add_argument(
+ "--output",
+ default="trt_output",
+ help="path to save trt model inference results"
+ )
+ parser.add_argument(
+ "--output-name",
+ help="tensorRT model output name"
+ )
+ parser.add_argument(
+ "--height",
+ type=int,
+ default=256,
+ help="height of image"
+ )
+ parser.add_argument(
+ "--width",
+ type=int,
+ default=128,
+ help="width of image"
+ )
+ return parser
+
+
+def preprocess(image_path, image_height, image_width):
+ original_image = cv2.imread(image_path)
+ # the model expects RGB inputs
+ original_image = original_image[:, :, ::-1]
+
+ # Apply pre-processing to image.
+ img = cv2.resize(original_image, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
+ img = img.astype("float32").transpose(2, 0, 1)[np.newaxis] # (1, 3, h, w)
+ return img
+
+
+def normalize(nparray, order=2, axis=-1):
+ """Normalize a N-D numpy array along the specified axis."""
+ norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
+ return nparray / (norm + np.finfo(np.float32).eps)
+
+
+if __name__ == "__main__":
+ args = get_parser().parse_args()
+
+ trt = pytrt.Trt()
+
+ onnxModel = ""
+ engineFile = args.model_path
+ customOutput = []
+ maxBatchSize = 1
+ calibratorData = []
+ mode = 2
+ trt.CreateEngine(onnxModel, engineFile, customOutput, maxBatchSize, mode, calibratorData)
+
+ if not os.path.exists(args.output): os.makedirs(args.output)
+
+ if args.input:
+ if os.path.isdir(args.input[0]):
+ args.input = glob.glob(os.path.expanduser(args.input[0]))
+ assert args.input, "The input path(s) was not found"
+ for path in tqdm.tqdm(args.input):
+ input_numpy_array = preprocess(path, args.height, args.width)
+ trt.DoInference(input_numpy_array)
+ feat = trt.GetOutput(args.output_name)
+ feat = normalize(feat, axis=1)
+ np.save(os.path.join(args.output, path.replace('.jpg', '.npy').split('/')[-1]), feat)
diff --git a/tools/__init__.py b/tools/__init__.py
new file mode 100644
index 0000000..f6f724d
--- /dev/null
+++ b/tools/__init__.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2020/10/26 11:47
+# @Author : Scheaven
+# @File : __init__.py.py
+# @description:
\ No newline at end of file
diff --git a/tools/__pycache__/predictor.cpython-37.pyc b/tools/__pycache__/predictor.cpython-37.pyc
new file mode 100644
index 0000000..2fbb3f2
--- /dev/null
+++ b/tools/__pycache__/predictor.cpython-37.pyc
Binary files differ
diff --git a/tools/__pycache__/predictor.cpython-38.pyc b/tools/__pycache__/predictor.cpython-38.pyc
new file mode 100644
index 0000000..50ee609
--- /dev/null
+++ b/tools/__pycache__/predictor.cpython-38.pyc
Binary files differ
diff --git a/tools/inference_net.py b/tools/inference_net.py
new file mode 100644
index 0000000..bc32670
--- /dev/null
+++ b/tools/inference_net.py
@@ -0,0 +1,60 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2020/10/26 11:48
+# @Author : Scheaven
+# @File : inference_net.py
+# @description:
+
+import torch
+import sys
+sys.path.append('.')
+
+from data.data_utils import read_image
+from predictor import ReID_Model
+from config import get_cfg
+from data.transforms.build import build_transforms
+from engine.defaults import default_argument_parser, default_setup
+import time
+
+def setup(args):
+ """
+ Create configs and perform basic setups.
+ """
+ cfg = get_cfg()
+ cfg.merge_from_file(args.config_file)
+ cfg.merge_from_list(args.opts)
+ cfg.freeze()
+ default_setup(cfg, args)
+ return cfg
+
+if __name__ == '__main__':
+ args = default_argument_parser().parse_args()
+ cfg = setup(args)
+ cfg.defrost()
+ cfg.MODEL.BACKBONE.PRETRAIN = False
+
+ model = ReID_Model(cfg)
+
+ test_transforms = build_transforms(cfg, is_train=False)
+ # print (args.img_a1)
+ img_a1 = read_image(args.img_a1)
+ img_a2 = read_image(args.img_a2)
+ img_b1 = read_image(args.img_b1)
+ img_b2 = read_image(args.img_b2)
+
+ img_a1 = test_transforms(img_a1)
+ img_a2 = test_transforms(img_a2)
+ img_b1 = test_transforms(img_b1)
+ img_b2 = test_transforms(img_b2)
+
+ out = torch.zeros((2, *img_a1.size()), dtype=img_a1.dtype)
+ out[0] += img_a1
+ out[1] += img_a2
+ t1 = time.time()
+ qurey_feat = model.run_on_image(out)
+ t2 = time.time()
+ print("t2-t1:", t2-t1)
+
+ similarity1 = torch.cosine_similarity(qurey_feat[0], qurey_feat[1], dim=0)
+ t3 = time.time()
+ print("t2-t1::", t3-t2, similarity1)
diff --git a/tools/predictor.py b/tools/predictor.py
new file mode 100644
index 0000000..952f132
--- /dev/null
+++ b/tools/predictor.py
@@ -0,0 +1,114 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2020/10/26 15:50
+# @Author : Scheaven
+# @File : predictor.py
+# @description:
+import cv2, io
+import torch
+import torch.nn.functional as F
+
+from modeling.meta_arch.build import build_model
+from utils.checkpoint import Checkpointer
+
+import onnx
+import torch
+from onnxsim import simplify
+from torch.onnx import OperatorExportTypes
+
+class ReID_Model(object):
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ self.predictor = DefaultPredictor(cfg)
+
+ def run_on_image(self, original_image):
+
+ predictions = self.predictor(original_image)
+ return predictions
+
+ def torch2onnx(self):
+ predictions = self.predictor.to_onnx()
+
+class DefaultPredictor:
+
+ def __init__(self, cfg):
+ self.cfg = cfg.clone()
+ self.cfg.defrost()
+ self.cfg.MODEL.BACKBONE.PRETRAIN = False
+ self.model = build_model(self.cfg)
+ for param in self.model.parameters():
+ param.requires_grad = False
+ self.model.cuda()
+ self.model.eval()
+
+ Checkpointer(self.model).load(cfg.MODEL.WEIGHTS)
+
+ def __call__(self, image):
+ with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
+ images = image.cuda()
+ self.model.eval()
+ predictions = self.model(images)
+ torch.set_printoptions(edgeitems=2048)
+ print("------------\n", predictions)
+ pred_feat = F.normalize(predictions)
+ pred_feat = pred_feat.cpu().data
+ return pred_feat
+
+ def to_onnx(self):
+ inputs = torch.randn(1, 3, self.cfg.INPUT.SIZE_TEST[0], self.cfg.INPUT.SIZE_TEST[1]).cuda()
+ onnx_model = self.export_onnx_model(self.model, inputs)
+
+ model_simp, check = simplify(onnx_model)
+
+ model_simp = self.remove_initializer_from_input(model_simp)
+
+ assert check, "Simplified ONNX model could not be validated"
+
+ onnx.save_model(model_simp, f"fastreid.onnx")
+
+
+ def export_onnx_model(self, model, inputs):
+ assert isinstance(model, torch.nn.Module)
+ def _check_eval(module):
+ assert not module.training
+
+ model.apply(_check_eval)
+
+ with torch.no_grad():
+ with io.BytesIO() as f:
+ torch.onnx.export(
+ model,
+ inputs,
+ f,
+ operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
+ )
+ onnx_model = onnx.load_from_string(f.getvalue())
+
+ # Apply ONNX's Optimization
+ all_passes = onnx.optimizer.get_available_passes()
+ passes = ["extract_constant_to_initializer", "eliminate_unused_initializer", "fuse_bn_into_conv"]
+ assert all(p in all_passes for p in passes)
+ onnx_model = onnx.optimizer.optimize(onnx_model, passes)
+ return onnx_model
+
+
+ def remove_initializer_from_input(self, model):
+ if model.ir_version < 4:
+ print(
+ 'Model with ir_version below 4 requires to include initilizer in graph input'
+ )
+ return
+
+ inputs = model.graph.input
+ name_to_input = {}
+ for input in inputs:
+ name_to_input[input.name] = input
+
+ for initializer in model.graph.initializer:
+ if initializer.name in name_to_input:
+ inputs.remove(name_to_input[initializer.name])
+
+ return model
+
+
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..9db4c43
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Time : 2020/10/26 14:53
+# @Author : Scheaven
+# @File : __init__.py.py
+# @description:
\ No newline at end of file
diff --git a/utils/__pycache__/__init__.cpython-37.pyc b/utils/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..c328618
--- /dev/null
+++ b/utils/__pycache__/__init__.cpython-37.pyc
Binary files differ
diff --git a/utils/__pycache__/__init__.cpython-38.pyc b/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..5a4265d
--- /dev/null
+++ b/utils/__pycache__/__init__.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/checkpoint.cpython-37.pyc b/utils/__pycache__/checkpoint.cpython-37.pyc
new file mode 100644
index 0000000..19b7364
--- /dev/null
+++ b/utils/__pycache__/checkpoint.cpython-37.pyc
Binary files differ
diff --git a/utils/__pycache__/checkpoint.cpython-38.pyc b/utils/__pycache__/checkpoint.cpython-38.pyc
new file mode 100644
index 0000000..6f785ad
--- /dev/null
+++ b/utils/__pycache__/checkpoint.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/collect_env.cpython-38.pyc b/utils/__pycache__/collect_env.cpython-38.pyc
new file mode 100644
index 0000000..92bc83c
--- /dev/null
+++ b/utils/__pycache__/collect_env.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/comm.cpython-37.pyc b/utils/__pycache__/comm.cpython-37.pyc
new file mode 100644
index 0000000..00e5e78
--- /dev/null
+++ b/utils/__pycache__/comm.cpython-37.pyc
Binary files differ
diff --git a/utils/__pycache__/comm.cpython-38.pyc b/utils/__pycache__/comm.cpython-38.pyc
new file mode 100644
index 0000000..8979735
--- /dev/null
+++ b/utils/__pycache__/comm.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/env.cpython-38.pyc b/utils/__pycache__/env.cpython-38.pyc
new file mode 100644
index 0000000..8427005
--- /dev/null
+++ b/utils/__pycache__/env.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/events.cpython-37.pyc b/utils/__pycache__/events.cpython-37.pyc
new file mode 100644
index 0000000..64fe6a2
--- /dev/null
+++ b/utils/__pycache__/events.cpython-37.pyc
Binary files differ
diff --git a/utils/__pycache__/events.cpython-38.pyc b/utils/__pycache__/events.cpython-38.pyc
new file mode 100644
index 0000000..91f67bc
--- /dev/null
+++ b/utils/__pycache__/events.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/file_io.cpython-37.pyc b/utils/__pycache__/file_io.cpython-37.pyc
new file mode 100644
index 0000000..12cd498
--- /dev/null
+++ b/utils/__pycache__/file_io.cpython-37.pyc
Binary files differ
diff --git a/utils/__pycache__/file_io.cpython-38.pyc b/utils/__pycache__/file_io.cpython-38.pyc
new file mode 100644
index 0000000..1934ab0
--- /dev/null
+++ b/utils/__pycache__/file_io.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/history_buffer.cpython-37.pyc b/utils/__pycache__/history_buffer.cpython-37.pyc
new file mode 100644
index 0000000..dc54f13
--- /dev/null
+++ b/utils/__pycache__/history_buffer.cpython-37.pyc
Binary files differ
diff --git a/utils/__pycache__/history_buffer.cpython-38.pyc b/utils/__pycache__/history_buffer.cpython-38.pyc
new file mode 100644
index 0000000..27b103e
--- /dev/null
+++ b/utils/__pycache__/history_buffer.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/logger.cpython-38.pyc b/utils/__pycache__/logger.cpython-38.pyc
new file mode 100644
index 0000000..2bed12c
--- /dev/null
+++ b/utils/__pycache__/logger.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/registry.cpython-37.pyc b/utils/__pycache__/registry.cpython-37.pyc
new file mode 100644
index 0000000..bd377c1
--- /dev/null
+++ b/utils/__pycache__/registry.cpython-37.pyc
Binary files differ
diff --git a/utils/__pycache__/registry.cpython-38.pyc b/utils/__pycache__/registry.cpython-38.pyc
new file mode 100644
index 0000000..f01b9ce
--- /dev/null
+++ b/utils/__pycache__/registry.cpython-38.pyc
Binary files differ
diff --git a/utils/__pycache__/weight_init.cpython-37.pyc b/utils/__pycache__/weight_init.cpython-37.pyc
new file mode 100644
index 0000000..04292a6
--- /dev/null
+++ b/utils/__pycache__/weight_init.cpython-37.pyc
Binary files differ
diff --git a/utils/__pycache__/weight_init.cpython-38.pyc b/utils/__pycache__/weight_init.cpython-38.pyc
new file mode 100644
index 0000000..8833ded
--- /dev/null
+++ b/utils/__pycache__/weight_init.cpython-38.pyc
Binary files differ
diff --git a/utils/checkpoint.py b/utils/checkpoint.py
new file mode 100644
index 0000000..a7900b6
--- /dev/null
+++ b/utils/checkpoint.py
@@ -0,0 +1,403 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import collections
+import copy
+import logging
+import os
+from collections import defaultdict
+from typing import Any
+
+import numpy as np
+import torch
+import torch.nn as nn
+from termcolor import colored
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+
+from .file_io import PathManager
+
+
+class Checkpointer(object):
+ """
+ A checkpointer that can save/load model as well as extra checkpointable
+ objects.
+ """
+
+ def __init__(
+ self,
+ model: nn.Module,
+ save_dir: str = "",
+ *,
+ save_to_disk: bool = True,
+ **checkpointables: object,
+ ):
+ """
+ Args:
+ model (nn.Module): model.
+ save_dir (str): a directory to save and find checkpoints.
+ save_to_disk (bool): if True, save checkpoint to disk, otherwise
+ disable saving for this checkpointer.
+ checkpointables (object): any checkpointable objects, i.e., objects
+ that have the `state_dict()` and `load_state_dict()` method. For
+ example, it can be used like
+ `Checkpointer(model, "dir", optimizer=optimizer)`.
+ """
+ if isinstance(model, (DistributedDataParallel, DataParallel)):
+ model = model.module
+ self.model = model
+ self.checkpointables = copy.copy(checkpointables)
+ self.logger = logging.getLogger(__name__)
+ self.save_dir = save_dir
+ self.save_to_disk = save_to_disk
+
+ def save(self, name: str, **kwargs: dict):
+ """
+ Dump model and checkpointables to a file.
+ Args:
+ name (str): name of the file.
+ kwargs (dict): extra arbitrary data to save.
+ """
+ if not self.save_dir or not self.save_to_disk:
+ return
+
+ data = {}
+ data["model"] = self.model.state_dict()
+ for key, obj in self.checkpointables.items():
+ data[key] = obj.state_dict()
+ data.update(kwargs)
+
+ basename = "{}.pth".format(name)
+ save_file = os.path.join(self.save_dir, basename)
+ assert os.path.basename(save_file) == basename, basename
+ self.logger.info("Saving checkpoint to {}".format(save_file))
+ with PathManager.open(save_file, "wb") as f:
+ torch.save(data, f)
+ self.tag_last_checkpoint(basename)
+
+ def load(self, path: str):
+ """
+ Load from the given checkpoint. When path points to network file, this
+ function has to be called on all ranks.
+ Args:
+ path (str): path or url to the checkpoint. If empty, will not load
+ anything.
+ Returns:
+ dict:
+ extra data loaded from the checkpoint that has not been
+ processed. For example, those saved with
+ :meth:`.save(**extra_data)`.
+ """
+ if not path:
+ # no checkpoint provided
+ self.logger.info(
+ "No checkpoint found. Training model from scratch"
+ )
+ return {}
+ self.logger.info("Loading checkpoint from {}".format(path))
+ if not os.path.isfile(path):
+ path = PathManager.get_local_path(path)
+ assert os.path.isfile(path), "Checkpoint {} not found!".format(path)
+
+ checkpoint = self._load_file(path)
+ self._load_model(checkpoint)
+ for key, obj in self.checkpointables.items():
+ if key in checkpoint:
+ self.logger.info("Loading {} from {}".format(key, path))
+ obj.load_state_dict(checkpoint.pop(key))
+
+ # return any further checkpoint data
+ return checkpoint
+
+ def has_checkpoint(self):
+ """
+ Returns:
+ bool: whether a checkpoint exists in the target directory.
+ """
+ save_file = os.path.join(self.save_dir, "last_checkpoint")
+ return PathManager.exists(save_file)
+
+ def get_checkpoint_file(self):
+ """
+ Returns:
+ str: The latest checkpoint file in target directory.
+ """
+ save_file = os.path.join(self.save_dir, "last_checkpoint")
+ try:
+ with PathManager.open(save_file, "r") as f:
+ last_saved = f.read().strip()
+ except IOError:
+ # if file doesn't exist, maybe because it has just been
+ # deleted by a separate process
+ return ""
+ return os.path.join(self.save_dir, last_saved)
+
+ def get_all_checkpoint_files(self):
+ """
+ Returns:
+ list: All available checkpoint files (.pth files) in target
+ directory.
+ """
+ all_model_checkpoints = [
+ os.path.join(self.save_dir, file)
+ for file in PathManager.ls(self.save_dir)
+ if PathManager.isfile(os.path.join(self.save_dir, file))
+ and file.endswith(".pth")
+ ]
+ return all_model_checkpoints
+
+ def resume_or_load(self, path: str, *, resume: bool = True):
+ """
+ If `resume` is True, this method attempts to resume from the last
+ checkpoint, if exists. Otherwise, load checkpoint from the given path.
+ This is useful when restarting an interrupted training job.
+ Args:
+ path (str): path to the checkpoint.
+ resume (bool): if True, resume from the last checkpoint if it exists.
+ Returns:
+ same as :meth:`load`.
+ """
+ if resume and self.has_checkpoint():
+ path = self.get_checkpoint_file()
+ return self.load(path)
+
+ def tag_last_checkpoint(self, last_filename_basename: str):
+ """
+ Tag the last checkpoint.
+ Args:
+ last_filename_basename (str): the basename of the last filename.
+ """
+ save_file = os.path.join(self.save_dir, "last_checkpoint")
+ with PathManager.open(save_file, "w") as f:
+ f.write(last_filename_basename)
+
+ def _load_file(self, f: str):
+ """
+ Load a checkpoint file. Can be overwritten by subclasses to support
+ different formats.
+ Args:
+ f (str): a locally mounted file path.
+ Returns:
+ dict: with keys "model" and optionally others that are saved by
+ the checkpointer dict["model"] must be a dict which maps strings
+ to torch.Tensor or numpy arrays.
+ """
+ return torch.load(f, map_location=torch.device("cpu"))
+
+ def _load_model(self, checkpoint: Any):
+ """
+ Load weights from a checkpoint.
+ Args:
+ checkpoint (Any): checkpoint contains the weights.
+ """
+ checkpoint_state_dict = checkpoint.pop("model")
+ self._convert_ndarray_to_tensor(checkpoint_state_dict)
+
+ # if the state_dict comes from a model that was wrapped in a
+ # DataParallel or DistributedDataParallel during serialization,
+ # remove the "module" prefix before performing the matching.
+ _strip_prefix_if_present(checkpoint_state_dict, "module.")
+
+ # work around https://github.com/pytorch/pytorch/issues/24139
+ model_state_dict = self.model.state_dict()
+ for k in list(checkpoint_state_dict.keys()):
+ if k in model_state_dict:
+ shape_model = tuple(model_state_dict[k].shape)
+ shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+ if shape_model != shape_checkpoint:
+ self.logger.warning(
+ "'{}' has shape {} in the checkpoint but {} in the "
+ "model! Skipped.".format(
+ k, shape_checkpoint, shape_model
+ )
+ )
+ checkpoint_state_dict.pop(k)
+
+ incompatible = self.model.load_state_dict(
+ checkpoint_state_dict, strict=False
+ )
+ if incompatible.missing_keys:
+ self.logger.info(
+ get_missing_parameters_message(incompatible.missing_keys)
+ )
+ if incompatible.unexpected_keys:
+ self.logger.info(
+ get_unexpected_parameters_message(incompatible.unexpected_keys)
+ )
+
+ def _convert_ndarray_to_tensor(self, state_dict: dict):
+ """
+ In-place convert all numpy arrays in the state_dict to torch tensor.
+ Args:
+ state_dict (dict): a state-dict to be loaded to the model.
+ """
+ # model could be an OrderedDict with _metadata attribute
+ # (as returned by Pytorch's state_dict()). We should preserve these
+ # properties.
+ for k in list(state_dict.keys()):
+ v = state_dict[k]
+ if not isinstance(v, np.ndarray) and not isinstance(
+ v, torch.Tensor
+ ):
+ raise ValueError(
+ "Unsupported type found in checkpoint! {}: {}".format(
+ k, type(v)
+ )
+ )
+ if not isinstance(v, torch.Tensor):
+ state_dict[k] = torch.from_numpy(v)
+
+
+class PeriodicCheckpointer:
+ """
+ Save checkpoints periodically. When `.step(iteration)` is called, it will
+ execute `checkpointer.save` on the given checkpointer, if iteration is a
+ multiple of period or if `max_iter` is reached.
+ """
+
+ def __init__(self, checkpointer: Any, period: int, max_iter: int = None):
+ """
+ Args:
+ checkpointer (Any): the checkpointer object used to save
+ checkpoints.
+ period (int): the period to save checkpoint.
+ max_iter (int): maximum number of iterations. When it is reached,
+ a checkpoint named "model_final" will be saved.
+ """
+ self.checkpointer = checkpointer
+ self.period = int(period)
+ self.max_iter = max_iter
+
+ def step(self, iteration: int, **kwargs: Any):
+ """
+ Perform the appropriate action at the given iteration.
+ Args:
+ iteration (int): the current iteration, ranged in [0, max_iter-1].
+ kwargs (Any): extra data to save, same as in
+ :meth:`Checkpointer.save`.
+ """
+ iteration = int(iteration)
+ additional_state = {"iteration": iteration}
+ additional_state.update(kwargs)
+ if (iteration + 1) % self.period == 0:
+ self.checkpointer.save(
+ "model_{:07d}".format(iteration), **additional_state
+ )
+ if iteration >= self.max_iter - 1:
+ self.checkpointer.save("model_final", **additional_state)
+
+ def save(self, name: str, **kwargs: Any):
+ """
+ Same argument as :meth:`Checkpointer.save`.
+ Use this method to manually save checkpoints outside the schedule.
+ Args:
+ name (str): file name.
+ kwargs (Any): extra data to save, same as in
+ :meth:`Checkpointer.save`.
+ """
+ self.checkpointer.save(name, **kwargs)
+
+
+def get_missing_parameters_message(keys: list):
+ """
+ Get a logging-friendly message to report parameter names (keys) that are in
+ the model but not found in a checkpoint.
+ Args:
+ keys (list[str]): List of keys that were not found in the checkpoint.
+ Returns:
+ str: message.
+ """
+ groups = _group_checkpoint_keys(keys)
+ msg = "Some model parameters are not in the checkpoint:\n"
+ msg += "\n".join(
+ " " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
+ )
+ return msg
+
+
+def get_unexpected_parameters_message(keys: list):
+ """
+ Get a logging-friendly message to report parameter names (keys) that are in
+ the checkpoint but not found in the model.
+ Args:
+ keys (list[str]): List of keys that were not found in the model.
+ Returns:
+ str: message.
+ """
+ groups = _group_checkpoint_keys(keys)
+ msg = "The checkpoint contains parameters not used by the model:\n"
+ msg += "\n".join(
+ " " + colored(k + _group_to_str(v), "magenta")
+ for k, v in groups.items()
+ )
+ return msg
+
+
+def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
+ """
+ Strip the prefix in metadata, if any.
+ Args:
+ state_dict (OrderedDict): a state-dict to be loaded to the model.
+ prefix (str): prefix.
+ """
+ keys = sorted(state_dict.keys())
+ if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
+ return
+
+ for key in keys:
+ newkey = key[len(prefix):]
+ state_dict[newkey] = state_dict.pop(key)
+
+ # also strip the prefix in metadata, if any..
+ try:
+ metadata = state_dict._metadata
+ except AttributeError:
+ pass
+ else:
+ for key in list(metadata.keys()):
+ # for the metadata dict, the key can be:
+ # '': for the DDP module, which we want to remove.
+ # 'module': for the actual model.
+ # 'module.xx.xx': for the rest.
+
+ if len(key) == 0:
+ continue
+ newkey = key[len(prefix):]
+ metadata[newkey] = metadata.pop(key)
+
+
+def _group_checkpoint_keys(keys: list):
+ """
+ Group keys based on common prefixes. A prefix is the string up to the final
+ "." in each key.
+ Args:
+ keys (list[str]): list of parameter names, i.e. keys in the model
+ checkpoint dict.
+ Returns:
+ dict[list]: keys with common prefixes are grouped into lists.
+ """
+ groups = defaultdict(list)
+ for key in keys:
+ pos = key.rfind(".")
+ if pos >= 0:
+ head, tail = key[:pos], [key[pos + 1:]]
+ else:
+ head, tail = key, []
+ groups[head].extend(tail)
+ return groups
+
+
+def _group_to_str(group: list):
+ """
+ Format a group of parameter name suffixes into a loggable string.
+ Args:
+ group (list[str]): list of parameter name suffixes.
+ Returns:
+ str: formated string.
+ """
+ if len(group) == 0:
+ return ""
+
+ if len(group) == 1:
+ return "." + group[0]
+
+ return ".{" + ", ".join(group) + "}"
diff --git a/utils/collect_env.py b/utils/collect_env.py
new file mode 100644
index 0000000..ba5d2ce
--- /dev/null
+++ b/utils/collect_env.py
@@ -0,0 +1,158 @@
+# encoding: utf-8
+"""
+@author: xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+# based on
+# https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/collect_env.py
+import importlib
+import os
+import re
+import subprocess
+import sys
+from collections import defaultdict
+
+import PIL
+import numpy as np
+import torch
+import torchvision
+from tabulate import tabulate
+
+__all__ = ["collect_env_info"]
+
+
+def collect_torch_env():
+ try:
+ import torch.__config__
+
+ return torch.__config__.show()
+ except ImportError:
+ # compatible with older versions of pytorch
+ from torch.utils.collect_env import get_pretty_env_info
+
+ return get_pretty_env_info()
+
+
+def get_env_module():
+ var_name = "FASTREID_ENV_MODULE"
+ return var_name, os.environ.get(var_name, "<not set>")
+
+
+def detect_compute_compatibility(CUDA_HOME, so_file):
+ try:
+ cuobjdump = os.path.join(CUDA_HOME, "bin", "cuobjdump")
+ if os.path.isfile(cuobjdump):
+ output = subprocess.check_output(
+ "'{}' --list-elf '{}'".format(cuobjdump, so_file), shell=True
+ )
+ output = output.decode("utf-8").strip().split("\n")
+ sm = []
+ for line in output:
+ line = re.findall(r"\.sm_[0-9]*\.", line)[0]
+ sm.append(line.strip("."))
+ sm = sorted(set(sm))
+ return ", ".join(sm)
+ else:
+ return so_file + "; cannot find cuobjdump"
+ except Exception:
+ # unhandled failure
+ return so_file
+
+
+def collect_env_info():
+ has_gpu = torch.cuda.is_available() # true for both CUDA & ROCM
+ torch_version = torch.__version__
+
+ # NOTE: the use of CUDA_HOME and ROCM_HOME requires the CUDA/ROCM build deps, though in
+ # theory detectron2 should be made runnable with only the corresponding runtimes
+ from torch.utils.cpp_extension import CUDA_HOME
+
+ has_rocm = False
+ if tuple(map(int, torch_version.split(".")[:2])) >= (1, 5):
+ from torch.utils.cpp_extension import ROCM_HOME
+
+ if (getattr(torch.version, "hip", None) is not None) and (ROCM_HOME is not None):
+ has_rocm = True
+ has_cuda = has_gpu and (not has_rocm)
+
+ data = []
+ data.append(("sys.platform", sys.platform))
+ data.append(("Python", sys.version.replace("\n", "")))
+ data.append(("numpy", np.__version__))
+
+ try:
+ import fastreid # noqa
+
+ data.append(
+ ("fastreid", fastreid.__version__ + " @" + os.path.dirname(fastreid.__file__))
+ )
+ except ImportError:
+ data.append(("fastreid", "failed to import"))
+
+ data.append(get_env_module())
+ data.append(("PyTorch", torch_version + " @" + os.path.dirname(torch.__file__)))
+ data.append(("PyTorch debug build", torch.version.debug))
+
+ data.append(("GPU available", has_gpu))
+ if has_gpu:
+ devices = defaultdict(list)
+ for k in range(torch.cuda.device_count()):
+ devices[torch.cuda.get_device_name(k)].append(str(k))
+ for name, devids in devices.items():
+ data.append(("GPU " + ",".join(devids), name))
+
+ if has_rocm:
+ data.append(("ROCM_HOME", str(ROCM_HOME)))
+ else:
+ data.append(("CUDA_HOME", str(CUDA_HOME)))
+
+ cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
+ if cuda_arch_list:
+ data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list))
+ data.append(("Pillow", PIL.__version__))
+
+ try:
+ data.append(
+ (
+ "torchvision",
+ str(torchvision.__version__) + " @" + os.path.dirname(torchvision.__file__),
+ )
+ )
+ if has_cuda:
+ try:
+ torchvision_C = importlib.util.find_spec("torchvision._C").origin
+ msg = detect_compute_compatibility(CUDA_HOME, torchvision_C)
+ data.append(("torchvision arch flags", msg))
+ except ImportError:
+ data.append(("torchvision._C", "failed to find"))
+ except AttributeError:
+ data.append(("torchvision", "unknown"))
+
+ try:
+ import fvcore
+
+ data.append(("fvcore", fvcore.__version__))
+ except ImportError:
+ pass
+
+ try:
+ import cv2
+
+ data.append(("cv2", cv2.__version__))
+ except ImportError:
+ pass
+ env_str = tabulate(data) + "\n"
+ env_str += collect_torch_env()
+ return env_str
+
+
+if __name__ == "__main__":
+ try:
+ import detectron2 # noqa
+ except ImportError:
+ print(collect_env_info())
+ else:
+ from fastreid.utils.collect_env import collect_env_info
+
+ print(collect_env_info())
diff --git a/utils/comm.py b/utils/comm.py
new file mode 100644
index 0000000..06cc7f0
--- /dev/null
+++ b/utils/comm.py
@@ -0,0 +1,255 @@
+"""
+This file contains primitives for multi-gpu communication.
+This is useful when doing distributed training.
+"""
+
+import functools
+import logging
+import numpy as np
+import pickle
+import torch
+import torch.distributed as dist
+
+_LOCAL_PROCESS_GROUP = None
+"""
+A torch process group which only includes processes that on the same machine as the current process.
+This variable is set when processes are spawned by `launch()` in "engine/launch.py".
+"""
+
+
+def get_world_size() -> int:
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank() -> int:
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def get_local_rank() -> int:
+ """
+ Returns:
+ The rank of the current process within the local (per-machine) process group.
+ """
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ assert _LOCAL_PROCESS_GROUP is not None
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_size() -> int:
+ """
+ Returns:
+ The size of the per-machine process group,
+ i.e. the number of processes per machine.
+ """
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def is_main_process() -> bool:
+ return get_rank() == 0
+
+
+def synchronize():
+ """
+ Helper function to synchronize (barrier) among all processes when
+ using distributed training
+ """
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ world_size = dist.get_world_size()
+ if world_size == 1:
+ return
+ dist.barrier()
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+ """
+ Return a process group based on gloo backend, containing all the ranks
+ The result is cached.
+ """
+ if dist.get_backend() == "nccl":
+ return dist.new_group(backend="gloo")
+ else:
+ return dist.group.WORLD
+
+
+def _serialize_to_tensor(data, group):
+ backend = dist.get_backend(group)
+ assert backend in ["gloo", "nccl"]
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
+
+ buffer = pickle.dumps(data)
+ if len(buffer) > 1024 ** 3:
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
+ get_rank(), len(buffer) / (1024 ** 3), device
+ )
+ )
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to(device=device)
+ return tensor
+
+
+def _pad_to_largest_tensor(tensor, group):
+ """
+ Returns:
+ list[int]: size of the tensor, on each rank
+ Tensor: padded tensor that has the max size
+ """
+ world_size = dist.get_world_size(group=group)
+ assert (
+ world_size >= 1
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
+ local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
+ size_list = [
+ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
+ ]
+ dist.all_gather(size_list, local_size, group=group)
+ size_list = [int(size.item()) for size in size_list]
+
+ max_size = max(size_list)
+
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ if local_size != max_size:
+ padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
+ tensor = torch.cat((tensor, padding), dim=0)
+ return size_list, tensor
+
+
+def all_gather(data, group=None):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
+ Args:
+ data: any picklable object
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group) == 1:
+ return [data]
+
+ tensor = _serialize_to_tensor(data, group)
+
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ tensor_list = [
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+ ]
+ dist.all_gather(tensor_list, tensor, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def gather(data, dst=0, group=None):
+ """
+ Run gather on arbitrary picklable data (not necessarily tensors).
+ Args:
+ data: any picklable object
+ dst (int): destination rank
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+ Returns:
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
+ an empty list.
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group=group) == 1:
+ return [data]
+ rank = dist.get_rank(group=group)
+
+ tensor = _serialize_to_tensor(data, group)
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+
+ # receiving Tensor from all ranks
+ if rank == dst:
+ max_size = max(size_list)
+ tensor_list = [
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+ ]
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+ return data_list
+ else:
+ dist.gather(tensor, [], dst=dst, group=group)
+ return []
+
+
+def shared_random_seed():
+ """
+ Returns:
+ int: a random number that is the same across all workers.
+ If workers need a shared RNG, they can use this shared seed to
+ create one.
+ All workers must call this function, otherwise it will deadlock.
+ """
+ ints = np.random.randint(2 ** 31)
+ all_ints = all_gather(ints)
+ return all_ints[0]
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Reduce the values in the dictionary from all processes so that process with rank
+ 0 has the reduced results.
+ Args:
+ input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
+ average (bool): whether to do average or sum
+ Returns:
+ a dict with the same keys as input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.reduce(values, dst=0)
+ if dist.get_rank() == 0 and average:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
diff --git a/utils/env.py b/utils/env.py
new file mode 100644
index 0000000..72f6afe
--- /dev/null
+++ b/utils/env.py
@@ -0,0 +1,119 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import importlib
+import importlib.util
+import logging
+import numpy as np
+import os
+import random
+import sys
+from datetime import datetime
+import torch
+
+__all__ = ["seed_all_rng"]
+
+
+TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
+"""
+PyTorch version as a tuple of 2 ints. Useful for comparison.
+"""
+
+
+def seed_all_rng(seed=None):
+ """
+ Set the random seed for the RNG in torch, numpy and python.
+ Args:
+ seed (int): if None, will use a strong random seed.
+ """
+ if seed is None:
+ seed = (
+ os.getpid()
+ + int(datetime.now().strftime("%S%f"))
+ + int.from_bytes(os.urandom(2), "big")
+ )
+ logger = logging.getLogger(__name__)
+ logger.info("Using a generated random seed {}".format(seed))
+ np.random.seed(seed)
+ torch.set_rng_state(torch.manual_seed(seed).get_state())
+ random.seed(seed)
+
+
+# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
+def _import_file(module_name, file_path, make_importable=False):
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ if make_importable:
+ sys.modules[module_name] = module
+ return module
+
+
+def _configure_libraries():
+ """
+ Configurations for some libraries.
+ """
+ # An environment option to disable `import cv2` globally,
+ # in case it leads to negative performance impact
+ disable_cv2 = int(os.environ.get("DETECTRON2_DISABLE_CV2", False))
+ if disable_cv2:
+ sys.modules["cv2"] = None
+ else:
+ # Disable opencl in opencv since its interaction with cuda often has negative effects
+ # This envvar is supported after OpenCV 3.4.0
+ os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled"
+ try:
+ import cv2
+
+ if int(cv2.__version__.split(".")[0]) >= 3:
+ cv2.ocl.setUseOpenCL(False)
+ except ImportError:
+ pass
+
+ def get_version(module, digit=2):
+ return tuple(map(int, module.__version__.split(".")[:digit]))
+
+ # fmt: off
+ assert get_version(torch) >= (1, 4), "Requires torch>=1.4"
+ import yaml
+ assert get_version(yaml) >= (5, 1), "Requires pyyaml>=5.1"
+ # fmt: on
+
+
+_ENV_SETUP_DONE = False
+
+
+def setup_environment():
+ """Perform environment setup work. The default setup is a no-op, but this
+ function allows the user to specify a Python source file or a module in
+ the $FASTREID_ENV_MODULE environment variable, that performs
+ custom setup work that may be necessary to their computing environment.
+ """
+ global _ENV_SETUP_DONE
+ if _ENV_SETUP_DONE:
+ return
+ _ENV_SETUP_DONE = True
+
+ _configure_libraries()
+
+ custom_module_path = os.environ.get("FASTREID_ENV_MODULE")
+
+ if custom_module_path:
+ setup_custom_environment(custom_module_path)
+ else:
+ # The default setup is a no-op
+ pass
+
+
+def setup_custom_environment(custom_module):
+ """
+ Load custom environment setup by importing a Python source file or a
+ module, and run the setup function.
+ """
+ if custom_module.endswith(".py"):
+ module = _import_file("fastreid.utils.env.custom_module", custom_module)
+ else:
+ module = importlib.import_module(custom_module)
+ assert hasattr(module, "setup_environment") and callable(module.setup_environment), (
+ "Custom environment module defined in {} does not have the "
+ "required callable attribute 'setup_environment'."
+ ).format(custom_module)
+ module.setup_environment()
\ No newline at end of file
diff --git a/utils/events.py b/utils/events.py
new file mode 100644
index 0000000..bc1120a
--- /dev/null
+++ b/utils/events.py
@@ -0,0 +1,445 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import datetime
+import json
+import logging
+import os
+import time
+from collections import defaultdict
+from contextlib import contextmanager
+import torch
+from .file_io import PathManager
+from .history_buffer import HistoryBuffer
+
+__all__ = [
+ "get_event_storage",
+ "JSONWriter",
+ "TensorboardXWriter",
+ "CommonMetricPrinter",
+ "EventStorage",
+]
+
+_CURRENT_STORAGE_STACK = []
+
+
+def get_event_storage():
+ """
+ Returns:
+ The :class:`EventStorage` object that's currently being used.
+ Throws an error if no :class:`EventStorage` is currently enabled.
+ """
+ assert len(
+ _CURRENT_STORAGE_STACK
+ ), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
+ return _CURRENT_STORAGE_STACK[-1]
+
+
+class EventWriter:
+ """
+ Base class for writers that obtain events from :class:`EventStorage` and process them.
+ """
+
+ def write(self):
+ raise NotImplementedError
+
+ def close(self):
+ pass
+
+
+class JSONWriter(EventWriter):
+ """
+ Write scalars to a json file.
+ It saves scalars as one json per line (instead of a big json) for easy parsing.
+ Examples parsing such a json file:
+ ::
+ $ cat metrics.json | jq -s '.[0:2]'
+ [
+ {
+ "data_time": 0.008433341979980469,
+ "iteration": 20,
+ "loss": 1.9228371381759644,
+ "loss_box_reg": 0.050025828182697296,
+ "loss_classifier": 0.5316952466964722,
+ "loss_mask": 0.7236229181289673,
+ "loss_rpn_box": 0.0856662318110466,
+ "loss_rpn_cls": 0.48198649287223816,
+ "lr": 0.007173333333333333,
+ "time": 0.25401854515075684
+ },
+ {
+ "data_time": 0.007216215133666992,
+ "iteration": 40,
+ "loss": 1.282649278640747,
+ "loss_box_reg": 0.06222952902317047,
+ "loss_classifier": 0.30682939291000366,
+ "loss_mask": 0.6970193982124329,
+ "loss_rpn_box": 0.038663312792778015,
+ "loss_rpn_cls": 0.1471673548221588,
+ "lr": 0.007706666666666667,
+ "time": 0.2490077018737793
+ }
+ ]
+ $ cat metrics.json | jq '.loss_mask'
+ 0.7126231789588928
+ 0.689423680305481
+ 0.6776131987571716
+ ...
+ """
+
+ def __init__(self, json_file, window_size=20):
+ """
+ Args:
+ json_file (str): path to the json file. New data will be appended if the file exists.
+ window_size (int): the window size of median smoothing for the scalars whose
+ `smoothing_hint` are True.
+ """
+ self._file_handle = PathManager.open(json_file, "a")
+ self._window_size = window_size
+ self._last_write = -1
+
+ def write(self):
+ storage = get_event_storage()
+ to_save = defaultdict(dict)
+
+ for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
+ # keep scalars that have not been written
+ if iter <= self._last_write:
+ continue
+ to_save[iter][k] = v
+ all_iters = sorted(to_save.keys())
+ self._last_write = max(all_iters)
+
+ for itr, scalars_per_iter in to_save.items():
+ scalars_per_iter["iteration"] = itr
+ self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n")
+ self._file_handle.flush()
+ try:
+ os.fsync(self._file_handle.fileno())
+ except AttributeError:
+ pass
+
+ def close(self):
+ self._file_handle.close()
+
+
+class TensorboardXWriter(EventWriter):
+ """
+ Write all scalars to a tensorboard file.
+ """
+
+ def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
+ """
+ Args:
+ log_dir (str): the directory to save the output events
+ window_size (int): the scalars will be median-smoothed by this window size
+ kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
+ """
+ self._window_size = window_size
+ from torch.utils.tensorboard import SummaryWriter
+
+ self._writer = SummaryWriter(log_dir, **kwargs)
+ self._last_write = -1
+
+ def write(self):
+ storage = get_event_storage()
+ new_last_write = self._last_write
+ for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
+ if iter > self._last_write:
+ self._writer.add_scalar(k, v, iter)
+ new_last_write = max(new_last_write, iter)
+ self._last_write = new_last_write
+
+ # storage.put_{image,histogram} is only meant to be used by
+ # tensorboard writer. So we access its internal fields directly from here.
+ if len(storage._vis_data) >= 1:
+ for img_name, img, step_num in storage._vis_data:
+ self._writer.add_image(img_name, img, step_num)
+ # Storage stores all image data and rely on this writer to clear them.
+ # As a result it assumes only one writer will use its image data.
+ # An alternative design is to let storage store limited recent
+ # data (e.g. only the most recent image) that all writers can access.
+ # In that case a writer may not see all image data if its period is long.
+ storage.clear_images()
+
+ if len(storage._histograms) >= 1:
+ for params in storage._histograms:
+ self._writer.add_histogram_raw(**params)
+ storage.clear_histograms()
+
+ def close(self):
+ if hasattr(self, "_writer"): # doesn't exist when the code fails at import
+ self._writer.close()
+
+
+class CommonMetricPrinter(EventWriter):
+ """
+ Print **common** metrics to the terminal, including
+ iteration time, ETA, memory, all losses, and the learning rate.
+ It also applies smoothing using a window of 20 elements.
+ It's meant to print common metrics in common ways.
+ To print something in more customized ways, please implement a similar printer by yourself.
+ """
+
+ def __init__(self, max_iter):
+ """
+ Args:
+ max_iter (int): the maximum number of iterations to train.
+ Used to compute ETA.
+ """
+ self.logger = logging.getLogger(__name__)
+ self._max_iter = max_iter
+ self._last_write = None
+
+ def write(self):
+ storage = get_event_storage()
+ iteration = storage.iter
+
+ try:
+ data_time = storage.history("data_time").avg(20)
+ except KeyError:
+ # they may not exist in the first few iterations (due to warmup)
+ # or when SimpleTrainer is not used
+ data_time = None
+
+ eta_string = None
+ try:
+ iter_time = storage.history("time").global_avg()
+ eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration)
+ storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ except KeyError:
+ iter_time = None
+ # estimate eta on our own - more noisy
+ if self._last_write is not None:
+ estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
+ iteration - self._last_write[0]
+ )
+ eta_seconds = estimate_iter_time * (self._max_iter - iteration)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ self._last_write = (iteration, time.perf_counter())
+
+ try:
+ lr = "{:.2e}".format(storage.history("lr").latest())
+ except KeyError:
+ lr = "N/A"
+
+ if torch.cuda.is_available():
+ max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
+ else:
+ max_mem_mb = None
+
+ # NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
+ self.logger.info(
+ " {eta}iter: {iter} {losses} {time}{data_time}lr: {lr} {memory}".format(
+ eta=f"eta: {eta_string} " if eta_string else "",
+ iter=iteration,
+ losses=" ".join(
+ [
+ "{}: {:.4g}".format(k, v.median(20))
+ for k, v in storage.histories().items()
+ if "loss" in k
+ ]
+ ),
+ time="time: {:.4f} ".format(iter_time) if iter_time is not None else "",
+ data_time="data_time: {:.4f} ".format(data_time) if data_time is not None else "",
+ lr=lr,
+ memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "",
+ )
+ )
+
+
+class EventStorage:
+ """
+ The user-facing class that provides metric storage functionalities.
+ In the future we may add support for storing / logging other types of data if needed.
+ """
+
+ def __init__(self, start_iter=0):
+ """
+ Args:
+ start_iter (int): the iteration number to start with
+ """
+ self._history = defaultdict(HistoryBuffer)
+ self._smoothing_hints = {}
+ self._latest_scalars = {}
+ self._iter = start_iter
+ self._current_prefix = ""
+ self._vis_data = []
+ self._histograms = []
+
+ def put_image(self, img_name, img_tensor):
+ """
+ Add an `img_tensor` associated with `img_name`, to be shown on
+ tensorboard.
+ Args:
+ img_name (str): The name of the image to put into tensorboard.
+ img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
+ Tensor of shape `[channel, height, width]` where `channel` is
+ 3. The image format should be RGB. The elements in img_tensor
+ can either have values in [0, 1] (float32) or [0, 255] (uint8).
+ The `img_tensor` will be visualized in tensorboard.
+ """
+ self._vis_data.append((img_name, img_tensor, self._iter))
+
+ def put_scalar(self, name, value, smoothing_hint=True):
+ """
+ Add a scalar `value` to the `HistoryBuffer` associated with `name`.
+ Args:
+ smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
+ smoothed when logged. The hint will be accessible through
+ :meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
+ and apply custom smoothing rule.
+ It defaults to True because most scalars we save need to be smoothed to
+ provide any useful signal.
+ """
+ name = self._current_prefix + name
+ history = self._history[name]
+ value = float(value)
+ history.update(value, self._iter)
+ self._latest_scalars[name] = (value, self._iter)
+
+ existing_hint = self._smoothing_hints.get(name)
+ if existing_hint is not None:
+ assert (
+ existing_hint == smoothing_hint
+ ), "Scalar {} was put with a different smoothing_hint!".format(name)
+ else:
+ self._smoothing_hints[name] = smoothing_hint
+
+ def put_scalars(self, *, smoothing_hint=True, **kwargs):
+ """
+ Put multiple scalars from keyword arguments.
+ Examples:
+ storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
+ """
+ for k, v in kwargs.items():
+ self.put_scalar(k, v, smoothing_hint=smoothing_hint)
+
+ def put_histogram(self, hist_name, hist_tensor, bins=1000):
+ """
+ Create a histogram from a tensor.
+ Args:
+ hist_name (str): The name of the histogram to put into tensorboard.
+ hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted
+ into a histogram.
+ bins (int): Number of histogram bins.
+ """
+ ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item()
+
+ # Create a histogram with PyTorch
+ hist_counts = torch.histc(hist_tensor, bins=bins)
+ hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32)
+
+ # Parameter for the add_histogram_raw function of SummaryWriter
+ hist_params = dict(
+ tag=hist_name,
+ min=ht_min,
+ max=ht_max,
+ num=len(hist_tensor),
+ sum=float(hist_tensor.sum()),
+ sum_squares=float(torch.sum(hist_tensor ** 2)),
+ bucket_limits=hist_edges[1:].tolist(),
+ bucket_counts=hist_counts.tolist(),
+ global_step=self._iter,
+ )
+ self._histograms.append(hist_params)
+
+ def history(self, name):
+ """
+ Returns:
+ HistoryBuffer: the scalar history for name
+ """
+ ret = self._history.get(name, None)
+ if ret is None:
+ raise KeyError("No history metric available for {}!".format(name))
+ return ret
+
+ def histories(self):
+ """
+ Returns:
+ dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
+ """
+ return self._history
+
+ def latest(self):
+ """
+ Returns:
+ dict[str -> (float, int)]: mapping from the name of each scalar to the most
+ recent value and the iteration number its added.
+ """
+ return self._latest_scalars
+
+ def latest_with_smoothing_hint(self, window_size=20):
+ """
+ Similar to :meth:`latest`, but the returned values
+ are either the un-smoothed original latest value,
+ or a median of the given window_size,
+ depend on whether the smoothing_hint is True.
+ This provides a default behavior that other writers can use.
+ """
+ result = {}
+ for k, (v, itr) in self._latest_scalars.items():
+ result[k] = (
+ self._history[k].median(window_size) if self._smoothing_hints[k] else v,
+ itr,
+ )
+ return result
+
+ def smoothing_hints(self):
+ """
+ Returns:
+ dict[name -> bool]: the user-provided hint on whether the scalar
+ is noisy and needs smoothing.
+ """
+ return self._smoothing_hints
+
+ def step(self):
+ """
+ User should call this function at the beginning of each iteration, to
+ notify the storage of the start of a new iteration.
+ The storage will then be able to associate the new data with the
+ correct iteration number.
+ """
+ self._iter += 1
+
+ @property
+ def iter(self):
+ return self._iter
+
+ @property
+ def iteration(self):
+ # for backward compatibility
+ return self._iter
+
+ def __enter__(self):
+ _CURRENT_STORAGE_STACK.append(self)
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ assert _CURRENT_STORAGE_STACK[-1] == self
+ _CURRENT_STORAGE_STACK.pop()
+
+ @contextmanager
+ def name_scope(self, name):
+ """
+ Yields:
+ A context within which all the events added to this storage
+ will be prefixed by the name scope.
+ """
+ old_prefix = self._current_prefix
+ self._current_prefix = name.rstrip("/") + "/"
+ yield
+ self._current_prefix = old_prefix
+
+ def clear_images(self):
+ """
+ Delete all the stored images for visualization. This should be called
+ after images are written to tensorboard.
+ """
+ self._vis_data = []
+
+ def clear_histograms(self):
+ """
+ Delete all the stored histograms for visualization.
+ This should be called after histograms are written to tensorboard.
+ """
+ self._histograms = []
\ No newline at end of file
diff --git a/utils/file_io.py b/utils/file_io.py
new file mode 100644
index 0000000..8533fe8
--- /dev/null
+++ b/utils/file_io.py
@@ -0,0 +1,520 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import errno
+import logging
+import os
+import shutil
+from collections import OrderedDict
+from typing import (
+ IO,
+ Any,
+ Callable,
+ Dict,
+ List,
+ MutableMapping,
+ Optional,
+ Union,
+)
+
+__all__ = ["PathManager", "get_cache_dir"]
+
+
+def get_cache_dir(cache_dir: Optional[str] = None) -> str:
+ """
+ Returns a default directory to cache static files
+ (usually downloaded from Internet), if None is provided.
+ Args:
+ cache_dir (None or str): if not None, will be returned as is.
+ If None, returns the default cache directory as:
+ 1) $FVCORE_CACHE, if set
+ 2) otherwise ~/.torch/fvcore_cache
+ """
+ if cache_dir is None:
+ cache_dir = os.path.expanduser(
+ os.getenv("FVCORE_CACHE", "~/.torch/fvcore_cache")
+ )
+ return cache_dir
+
+
+class PathHandler:
+ """
+ PathHandler is a base class that defines common I/O functionality for a URI
+ protocol. It routes I/O for a generic URI which may look like "protocol://*"
+ or a canonical filepath "/foo/bar/baz".
+ """
+
+ _strict_kwargs_check = True
+
+ def _check_kwargs(self, kwargs: Dict[str, Any]) -> None:
+ """
+ Checks if the given arguments are empty. Throws a ValueError if strict
+ kwargs checking is enabled and args are non-empty. If strict kwargs
+ checking is disabled, only a warning is logged.
+ Args:
+ kwargs (Dict[str, Any])
+ """
+ if self._strict_kwargs_check:
+ if len(kwargs) > 0:
+ raise ValueError("Unused arguments: {}".format(kwargs))
+ else:
+ logger = logging.getLogger(__name__)
+ for k, v in kwargs.items():
+ logger.warning(
+ "[PathManager] {}={} argument ignored".format(k, v)
+ )
+
+ def _get_supported_prefixes(self) -> List[str]:
+ """
+ Returns:
+ List[str]: the list of URI prefixes this PathHandler can support
+ """
+ raise NotImplementedError()
+
+ def _get_local_path(self, path: str, **kwargs: Any) -> str:
+ """
+ Get a filepath which is compatible with native Python I/O such as `open`
+ and `os.path`.
+ If URI points to a remote resource, this function may download and cache
+ the resource to local disk. In this case, this function is meant to be
+ used with read-only resources.
+ Args:
+ path (str): A URI supported by this PathHandler
+ Returns:
+ local_path (str): a file path which exists on the local file system
+ """
+ raise NotImplementedError()
+
+ def _open(
+ self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
+ ) -> Union[IO[str], IO[bytes]]:
+ """
+ Open a stream to a URI, similar to the built-in `open`.
+ Args:
+ path (str): A URI supported by this PathHandler
+ mode (str): Specifies the mode in which the file is opened. It defaults
+ to 'r'.
+ buffering (int): An optional integer used to set the buffering policy.
+ Pass 0 to switch buffering off and an integer >= 1 to indicate the
+ size in bytes of a fixed-size chunk buffer. When no buffering
+ argument is given, the default buffering policy depends on the
+ underlying I/O implementation.
+ Returns:
+ file: a file-like object.
+ """
+ raise NotImplementedError()
+
+ def _copy(
+ self,
+ src_path: str,
+ dst_path: str,
+ overwrite: bool = False,
+ **kwargs: Any,
+ ) -> bool:
+ """
+ Copies a source path to a destination path.
+ Args:
+ src_path (str): A URI supported by this PathHandler
+ dst_path (str): A URI supported by this PathHandler
+ overwrite (bool): Bool flag for forcing overwrite of existing file
+ Returns:
+ status (bool): True on success
+ """
+ raise NotImplementedError()
+
+ def _exists(self, path: str, **kwargs: Any) -> bool:
+ """
+ Checks if there is a resource at the given URI.
+ Args:
+ path (str): A URI supported by this PathHandler
+ Returns:
+ bool: true if the path exists
+ """
+ raise NotImplementedError()
+
+ def _isfile(self, path: str, **kwargs: Any) -> bool:
+ """
+ Checks if the resource at the given URI is a file.
+ Args:
+ path (str): A URI supported by this PathHandler
+ Returns:
+ bool: true if the path is a file
+ """
+ raise NotImplementedError()
+
+ def _isdir(self, path: str, **kwargs: Any) -> bool:
+ """
+ Checks if the resource at the given URI is a directory.
+ Args:
+ path (str): A URI supported by this PathHandler
+ Returns:
+ bool: true if the path is a directory
+ """
+ raise NotImplementedError()
+
+ def _ls(self, path: str, **kwargs: Any) -> List[str]:
+ """
+ List the contents of the directory at the provided URI.
+ Args:
+ path (str): A URI supported by this PathHandler
+ Returns:
+ List[str]: list of contents in given path
+ """
+ raise NotImplementedError()
+
+ def _mkdirs(self, path: str, **kwargs: Any) -> None:
+ """
+ Recursive directory creation function. Like mkdir(), but makes all
+ intermediate-level directories needed to contain the leaf directory.
+ Similar to the native `os.makedirs`.
+ Args:
+ path (str): A URI supported by this PathHandler
+ """
+ raise NotImplementedError()
+
+ def _rm(self, path: str, **kwargs: Any) -> None:
+ """
+ Remove the file (not directory) at the provided URI.
+ Args:
+ path (str): A URI supported by this PathHandler
+ """
+ raise NotImplementedError()
+
+
+class NativePathHandler(PathHandler):
+ """
+ Handles paths that can be accessed using Python native system calls. This
+ handler uses `open()` and `os.*` calls on the given path.
+ """
+
+ def _get_local_path(self, path: str, **kwargs: Any) -> str:
+ self._check_kwargs(kwargs)
+ return path
+
+ def _open(
+ self,
+ path: str,
+ mode: str = "r",
+ buffering: int = -1,
+ encoding: Optional[str] = None,
+ errors: Optional[str] = None,
+ newline: Optional[str] = None,
+ closefd: bool = True,
+ opener: Optional[Callable] = None,
+ **kwargs: Any,
+ ) -> Union[IO[str], IO[bytes]]:
+ """
+ Open a path.
+ Args:
+ path (str): A URI supported by this PathHandler
+ mode (str): Specifies the mode in which the file is opened. It defaults
+ to 'r'.
+ buffering (int): An optional integer used to set the buffering policy.
+ Pass 0 to switch buffering off and an integer >= 1 to indicate the
+ size in bytes of a fixed-size chunk buffer. When no buffering
+ argument is given, the default buffering policy works as follows:
+ * Binary files are buffered in fixed-size chunks; the size of
+ the buffer is chosen using a heuristic trying to determine the
+ underlying device鈥檚 鈥渂lock size鈥� and falling back on
+ io.DEFAULT_BUFFER_SIZE. On many systems, the buffer will
+ typically be 4096 or 8192 bytes long.
+ encoding (Optional[str]): the name of the encoding used to decode or
+ encode the file. This should only be used in text mode.
+ errors (Optional[str]): an optional string that specifies how encoding
+ and decoding errors are to be handled. This cannot be used in binary
+ mode.
+ newline (Optional[str]): controls how universal newlines mode works
+ (it only applies to text mode). It can be None, '', '\n', '\r',
+ and '\r\n'.
+ closefd (bool): If closefd is False and a file descriptor rather than
+ a filename was given, the underlying file descriptor will be kept
+ open when the file is closed. If a filename is given closefd must
+ be True (the default) otherwise an error will be raised.
+ opener (Optional[Callable]): A custom opener can be used by passing
+ a callable as opener. The underlying file descriptor for the file
+ object is then obtained by calling opener with (file, flags).
+ opener must return an open file descriptor (passing os.open as opener
+ results in functionality similar to passing None).
+ See https://docs.python.org/3/library/functions.html#open for details.
+ Returns:
+ file: a file-like object.
+ """
+ self._check_kwargs(kwargs)
+ return open( # type: ignore
+ path,
+ mode,
+ buffering=buffering,
+ encoding=encoding,
+ errors=errors,
+ newline=newline,
+ closefd=closefd,
+ opener=opener,
+ )
+
+ def _copy(
+ self,
+ src_path: str,
+ dst_path: str,
+ overwrite: bool = False,
+ **kwargs: Any,
+ ) -> bool:
+ """
+ Copies a source path to a destination path.
+ Args:
+ src_path (str): A URI supported by this PathHandler
+ dst_path (str): A URI supported by this PathHandler
+ overwrite (bool): Bool flag for forcing overwrite of existing file
+ Returns:
+ status (bool): True on success
+ """
+ self._check_kwargs(kwargs)
+
+ if os.path.exists(dst_path) and not overwrite:
+ logger = logging.getLogger(__name__)
+ logger.error("Destination file {} already exists.".format(dst_path))
+ return False
+
+ try:
+ shutil.copyfile(src_path, dst_path)
+ return True
+ except Exception as e:
+ logger = logging.getLogger(__name__)
+ logger.error("Error in file copy - {}".format(str(e)))
+ return False
+
+ def _exists(self, path: str, **kwargs: Any) -> bool:
+ self._check_kwargs(kwargs)
+ return os.path.exists(path)
+
+ def _isfile(self, path: str, **kwargs: Any) -> bool:
+ self._check_kwargs(kwargs)
+ return os.path.isfile(path)
+
+ def _isdir(self, path: str, **kwargs: Any) -> bool:
+ self._check_kwargs(kwargs)
+ return os.path.isdir(path)
+
+ def _ls(self, path: str, **kwargs: Any) -> List[str]:
+ self._check_kwargs(kwargs)
+ return os.listdir(path)
+
+ def _mkdirs(self, path: str, **kwargs: Any) -> None:
+ self._check_kwargs(kwargs)
+ try:
+ os.makedirs(path, exist_ok=True)
+ except OSError as e:
+ # EEXIST it can still happen if multiple processes are creating the dir
+ if e.errno != errno.EEXIST:
+ raise
+
+ def _rm(self, path: str, **kwargs: Any) -> None:
+ self._check_kwargs(kwargs)
+ os.remove(path)
+
+
+class PathManager:
+ """
+ A class for users to open generic paths or translate generic paths to file names.
+ """
+
+ _PATH_HANDLERS: MutableMapping[str, PathHandler] = OrderedDict()
+ _NATIVE_PATH_HANDLER = NativePathHandler()
+
+ @staticmethod
+ def __get_path_handler(path: str) -> PathHandler:
+ """
+ Finds a PathHandler that supports the given path. Falls back to the native
+ PathHandler if no other handler is found.
+ Args:
+ path (str): URI path to resource
+ Returns:
+ handler (PathHandler)
+ """
+ for p in PathManager._PATH_HANDLERS.keys():
+ if path.startswith(p):
+ return PathManager._PATH_HANDLERS[p]
+ return PathManager._NATIVE_PATH_HANDLER
+
+ @staticmethod
+ def open(
+ path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
+ ) -> Union[IO[str], IO[bytes]]:
+ """
+ Open a stream to a URI, similar to the built-in `open`.
+ Args:
+ path (str): A URI supported by this PathHandler
+ mode (str): Specifies the mode in which the file is opened. It defaults
+ to 'r'.
+ buffering (int): An optional integer used to set the buffering policy.
+ Pass 0 to switch buffering off and an integer >= 1 to indicate the
+ size in bytes of a fixed-size chunk buffer. When no buffering
+ argument is given, the default buffering policy depends on the
+ underlying I/O implementation.
+ Returns:
+ file: a file-like object.
+ """
+ return PathManager.__get_path_handler(path)._open( # type: ignore
+ path, mode, buffering=buffering, **kwargs
+ )
+
+ @staticmethod
+ def copy(
+ src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any
+ ) -> bool:
+ """
+ Copies a source path to a destination path.
+ Args:
+ src_path (str): A URI supported by this PathHandler
+ dst_path (str): A URI supported by this PathHandler
+ overwrite (bool): Bool flag for forcing overwrite of existing file
+ Returns:
+ status (bool): True on success
+ """
+
+ # Copying across handlers is not supported.
+ assert PathManager.__get_path_handler( # type: ignore
+ src_path
+ ) == PathManager.__get_path_handler(dst_path)
+ return PathManager.__get_path_handler(src_path)._copy(
+ src_path, dst_path, overwrite, **kwargs
+ )
+
+ @staticmethod
+ def get_local_path(path: str, **kwargs: Any) -> str:
+ """
+ Get a filepath which is compatible with native Python I/O such as `open`
+ and `os.path`.
+ If URI points to a remote resource, this function may download and cache
+ the resource to local disk.
+ Args:
+ path (str): A URI supported by this PathHandler
+ Returns:
+ local_path (str): a file path which exists on the local file system
+ """
+ return PathManager.__get_path_handler( # type: ignore
+ path
+ )._get_local_path(path, **kwargs)
+
+ @staticmethod
+ def exists(path: str, **kwargs: Any) -> bool:
+ """
+ Checks if there is a resource at the given URI.
+ Args:
+ path (str): A URI supported by this PathHandler
+ Returns:
+ bool: true if the path exists
+ """
+ return PathManager.__get_path_handler(path)._exists( # type: ignore
+ path, **kwargs
+ )
+
+ @staticmethod
+ def isfile(path: str, **kwargs: Any) -> bool:
+ """
+ Checks if there the resource at the given URI is a file.
+ Args:
+ path (str): A URI supported by this PathHandler
+ Returns:
+ bool: true if the path is a file
+ """
+ return PathManager.__get_path_handler(path)._isfile( # type: ignore
+ path, **kwargs
+ )
+
+ @staticmethod
+ def isdir(path: str, **kwargs: Any) -> bool:
+ """
+ Checks if the resource at the given URI is a directory.
+ Args:
+ path (str): A URI supported by this PathHandler
+ Returns:
+ bool: true if the path is a directory
+ """
+ return PathManager.__get_path_handler(path)._isdir( # type: ignore
+ path, **kwargs
+ )
+
+ @staticmethod
+ def ls(path: str, **kwargs: Any) -> List[str]:
+ """
+ List the contents of the directory at the provided URI.
+ Args:
+ path (str): A URI supported by this PathHandler
+ Returns:
+ List[str]: list of contents in given path
+ """
+ return PathManager.__get_path_handler(path)._ls( # type: ignore
+ path, **kwargs
+ )
+
+ @staticmethod
+ def mkdirs(path: str, **kwargs: Any) -> None:
+ """
+ Recursive directory creation function. Like mkdir(), but makes all
+ intermediate-level directories needed to contain the leaf directory.
+ Similar to the native `os.makedirs`.
+ Args:
+ path (str): A URI supported by this PathHandler
+ """
+ return PathManager.__get_path_handler(path)._mkdirs( # type: ignore
+ path, **kwargs
+ )
+
+ @staticmethod
+ def rm(path: str, **kwargs: Any) -> None:
+ """
+ Remove the file (not directory) at the provided URI.
+ Args:
+ path (str): A URI supported by this PathHandler
+ """
+ return PathManager.__get_path_handler(path)._rm( # type: ignore
+ path, **kwargs
+ )
+
+ @staticmethod
+ def register_handler(handler: PathHandler) -> None:
+ """
+ Register a path handler associated with `handler._get_supported_prefixes`
+ URI prefixes.
+ Args:
+ handler (PathHandler)
+ """
+ assert isinstance(handler, PathHandler), handler
+ for prefix in handler._get_supported_prefixes():
+ assert prefix not in PathManager._PATH_HANDLERS
+ PathManager._PATH_HANDLERS[prefix] = handler
+
+ # Sort path handlers in reverse order so longer prefixes take priority,
+ # eg: http://foo/bar before http://foo
+ PathManager._PATH_HANDLERS = OrderedDict(
+ sorted(
+ PathManager._PATH_HANDLERS.items(),
+ key=lambda t: t[0],
+ reverse=True,
+ )
+ )
+
+ @staticmethod
+ def set_strict_kwargs_checking(enable: bool) -> None:
+ """
+ Toggles strict kwargs checking. If enabled, a ValueError is thrown if any
+ unused parameters are passed to a PathHandler function. If disabled, only
+ a warning is given.
+ With a centralized file API, there's a tradeoff of convenience and
+ correctness delegating arguments to the proper I/O layers. An underlying
+ `PathHandler` may support custom arguments which should not be statically
+ exposed on the `PathManager` function. For example, a custom `HTTPURLHandler`
+ may want to expose a `cache_timeout` argument for `open()` which specifies
+ how old a locally cached resource can be before it's refetched from the
+ remote server. This argument would not make sense for a `NativePathHandler`.
+ If strict kwargs checking is disabled, `cache_timeout` can be passed to
+ `PathManager.open` which will forward the arguments to the underlying
+ handler. By default, checking is enabled since it is innately unsafe:
+ multiple `PathHandler`s could reuse arguments with different semantic
+ meanings or types.
+ Args:
+ enable (bool)
+ """
+ PathManager._NATIVE_PATH_HANDLER._strict_kwargs_check = enable
+ for handler in PathManager._PATH_HANDLERS.values():
+ handler._strict_kwargs_check = enable
diff --git a/utils/history_buffer.py b/utils/history_buffer.py
new file mode 100644
index 0000000..033c2b0
--- /dev/null
+++ b/utils/history_buffer.py
@@ -0,0 +1,71 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import numpy as np
+from typing import List, Tuple
+
+
+class HistoryBuffer:
+ """
+ Track a series of scalar values and provide access to smoothed values over a
+ window or the global average of the series.
+ """
+
+ def __init__(self, max_length: int = 1000000):
+ """
+ Args:
+ max_length: maximal number of values that can be stored in the
+ buffer. When the capacity of the buffer is exhausted, old
+ values will be removed.
+ """
+ self._max_length: int = max_length
+ self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
+ self._count: int = 0
+ self._global_avg: float = 0
+
+ def update(self, value: float, iteration: float = None):
+ """
+ Add a new scalar value produced at certain iteration. If the length
+ of the buffer exceeds self._max_length, the oldest element will be
+ removed from the buffer.
+ """
+ if iteration is None:
+ iteration = self._count
+ if len(self._data) == self._max_length:
+ self._data.pop(0)
+ self._data.append((value, iteration))
+
+ self._count += 1
+ self._global_avg += (value - self._global_avg) / self._count
+
+ def latest(self):
+ """
+ Return the latest scalar value added to the buffer.
+ """
+ return self._data[-1][0]
+
+ def median(self, window_size: int):
+ """
+ Return the median of the latest `window_size` values in the buffer.
+ """
+ return np.median([x[0] for x in self._data[-window_size:]])
+
+ def avg(self, window_size: int):
+ """
+ Return the mean of the latest `window_size` values in the buffer.
+ """
+ return np.mean([x[0] for x in self._data[-window_size:]])
+
+ def global_avg(self):
+ """
+ Return the mean of all the elements in the buffer. Note that this
+ includes those getting removed due to limited buffer storage.
+ """
+ return self._global_avg
+
+ def values(self):
+ """
+ Returns:
+ list[(number, iteration)]: content of the current buffer.
+ """
+ return self._data
diff --git a/utils/logger.py b/utils/logger.py
new file mode 100644
index 0000000..01e3da7
--- /dev/null
+++ b/utils/logger.py
@@ -0,0 +1,209 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import functools
+import logging
+import os
+import sys
+import time
+from collections import Counter
+from .file_io import PathManager
+from termcolor import colored
+
+
+class _ColorfulFormatter(logging.Formatter):
+ def __init__(self, *args, **kwargs):
+ self._root_name = kwargs.pop("root_name") + "."
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
+ if len(self._abbrev_name):
+ self._abbrev_name = self._abbrev_name + "."
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
+
+ def formatMessage(self, record):
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
+ log = super(_ColorfulFormatter, self).formatMessage(record)
+ if record.levelno == logging.WARNING:
+ prefix = colored("WARNING", "red", attrs=["blink"])
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
+ else:
+ return log
+ return prefix + " " + log
+
+
+@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
+def setup_logger(
+ output=None, distributed_rank=0, *, color=True, name="fastreid", abbrev_name=None
+):
+ """
+ Args:
+ output (str): a file name or a directory to save log. If None, will not save log file.
+ If ends with ".txt" or ".log", assumed to be a file name.
+ Otherwise, logs will be saved to `output/log.txt`.
+ name (str): the root module name of this logger
+ abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
+ Set to "" to not log the root module in logs.
+ By default, will abbreviate "detectron2" to "d2" and leave other
+ modules unchanged.
+ """
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.DEBUG)
+ logger.propagate = False
+
+ if abbrev_name is None:
+ abbrev_name = "d2" if name == "detectron2" else name
+
+ plain_formatter = logging.Formatter(
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
+ )
+ # stdout logging: master only
+ if distributed_rank == 0:
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging.DEBUG)
+ if color:
+ formatter = _ColorfulFormatter(
+ colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
+ datefmt="%m/%d %H:%M:%S",
+ root_name=name,
+ abbrev_name=str(abbrev_name),
+ )
+ else:
+ formatter = plain_formatter
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+
+ # file logging: all workers
+ if output is not None:
+ if output.endswith(".txt") or output.endswith(".log"):
+ filename = output
+ else:
+ filename = os.path.join(output, "log.txt")
+ if distributed_rank > 0:
+ filename = filename + ".rank{}".format(distributed_rank)
+ PathManager.mkdirs(os.path.dirname(filename))
+
+ fh = logging.StreamHandler(_cached_log_stream(filename))
+ fh.setLevel(logging.DEBUG)
+ fh.setFormatter(plain_formatter)
+ logger.addHandler(fh)
+
+ return logger
+
+
+# cache the opened file object, so that different calls to `setup_logger`
+# with the same file name can safely write to the same file.
+@functools.lru_cache(maxsize=None)
+def _cached_log_stream(filename):
+ return PathManager.open(filename, "a")
+
+
+"""
+Below are some other convenient logging methods.
+They are mainly adopted from
+https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py
+"""
+
+
+def _find_caller():
+ """
+ Returns:
+ str: module name of the caller
+ tuple: a hashable key to be used to identify different callers
+ """
+ frame = sys._getframe(2)
+ while frame:
+ code = frame.f_code
+ if os.path.join("utils", "logger.") not in code.co_filename:
+ mod_name = frame.f_globals["__name__"]
+ if mod_name == "__main__":
+ mod_name = "detectron2"
+ return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
+ frame = frame.f_back
+
+
+_LOG_COUNTER = Counter()
+_LOG_TIMER = {}
+
+
+def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
+ """
+ Log only for the first n times.
+ Args:
+ lvl (int): the logging level
+ msg (str):
+ n (int):
+ name (str): name of the logger to use. Will use the caller's module by default.
+ key (str or tuple[str]): the string(s) can be one of "caller" or
+ "message", which defines how to identify duplicated logs.
+ For example, if called with `n=1, key="caller"`, this function
+ will only log the first call from the same caller, regardless of
+ the message content.
+ If called with `n=1, key="message"`, this function will log the
+ same content only once, even if they are called from different places.
+ If called with `n=1, key=("caller", "message")`, this function
+ will not log only if the same caller has logged the same message before.
+ """
+ if isinstance(key, str):
+ key = (key,)
+ assert len(key) > 0
+
+ caller_module, caller_key = _find_caller()
+ hash_key = ()
+ if "caller" in key:
+ hash_key = hash_key + caller_key
+ if "message" in key:
+ hash_key = hash_key + (msg,)
+
+ _LOG_COUNTER[hash_key] += 1
+ if _LOG_COUNTER[hash_key] <= n:
+ logging.getLogger(name or caller_module).log(lvl, msg)
+
+
+def log_every_n(lvl, msg, n=1, *, name=None):
+ """
+ Log once per n times.
+ Args:
+ lvl (int): the logging level
+ msg (str):
+ n (int):
+ name (str): name of the logger to use. Will use the caller's module by default.
+ """
+ caller_module, key = _find_caller()
+ _LOG_COUNTER[key] += 1
+ if n == 1 or _LOG_COUNTER[key] % n == 1:
+ logging.getLogger(name or caller_module).log(lvl, msg)
+
+
+def log_every_n_seconds(lvl, msg, n=1, *, name=None):
+ """
+ Log no more than once per n seconds.
+ Args:
+ lvl (int): the logging level
+ msg (str):
+ n (int):
+ name (str): name of the logger to use. Will use the caller's module by default.
+ """
+ caller_module, key = _find_caller()
+ last_logged = _LOG_TIMER.get(key, None)
+ current_time = time.time()
+ if last_logged is None or current_time - last_logged >= n:
+ logging.getLogger(name or caller_module).log(lvl, msg)
+ _LOG_TIMER[key] = current_time
+
+# def create_small_table(small_dict):
+# """
+# Create a small table using the keys of small_dict as headers. This is only
+# suitable for small dictionaries.
+# Args:
+# small_dict (dict): a result dictionary of only a few items.
+# Returns:
+# str: the table as a string.
+# """
+# keys, values = tuple(zip(*small_dict.items()))
+# table = tabulate(
+# [values],
+# headers=keys,
+# tablefmt="pipe",
+# floatfmt=".3f",
+# stralign="center",
+# numalign="center",
+# )
+# return table
diff --git a/utils/registry.py b/utils/registry.py
new file mode 100644
index 0000000..ad5376b
--- /dev/null
+++ b/utils/registry.py
@@ -0,0 +1,66 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+from typing import Dict, Optional
+
+
+class Registry(object):
+ """
+ The registry that provides name -> object mapping, to support third-party
+ users' custom modules.
+ To create a registry (e.g. a backbone registry):
+ .. code-block:: python
+ BACKBONE_REGISTRY = Registry('BACKBONE')
+ To register an object:
+ .. code-block:: python
+ @BACKBONE_REGISTRY.register()
+ class MyBackbone():
+ ...
+ Or:
+ .. code-block:: python
+ BACKBONE_REGISTRY.register(MyBackbone)
+ """
+
+ def __init__(self, name: str) -> None:
+ """
+ Args:
+ name (str): the name of this registry
+ """
+ self._name: str = name
+ self._obj_map: Dict[str, object] = {}
+
+ def _do_register(self, name: str, obj: object) -> None:
+ assert (
+ name not in self._obj_map
+ ), "An object named '{}' was already registered in '{}' registry!".format(
+ name, self._name
+ )
+ self._obj_map[name] = obj
+
+ def register(self, obj: object = None) -> Optional[object]:
+ """
+ Register the given object under the the name `obj.__name__`.
+ Can be used as either a decorator or not. See docstring of this class for usage.
+ """
+ if obj is None:
+ # used as a decorator
+ def deco(func_or_class: object) -> object:
+ name = func_or_class.__name__ # pyre-ignore
+ self._do_register(name, func_or_class)
+ return func_or_class
+
+ return deco
+
+ # used as a function call
+ name = obj.__name__ # pyre-ignore
+ self._do_register(name, obj)
+
+ def get(self, name: str) -> object:
+ ret = self._obj_map.get(name)
+ if ret is None:
+ raise KeyError(
+ "No object named '{}' found in '{}' registry!".format(
+ name, self._name
+ )
+ )
+ return ret
diff --git a/utils/weight_init.py b/utils/weight_init.py
new file mode 100644
index 0000000..82fb262
--- /dev/null
+++ b/utils/weight_init.py
@@ -0,0 +1,37 @@
+# encoding: utf-8
+"""
+@author: xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import math
+from torch import nn
+
+__all__ = [
+ 'weights_init_classifier',
+ 'weights_init_kaiming',
+]
+
+
+def weights_init_kaiming(m):
+ classname = m.__class__.__name__
+ if classname.find('Linear') != -1:
+ nn.init.normal_(m.weight, 0, 0.01)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0.0)
+ elif classname.find('Conv') != -1:
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0.0)
+ elif classname.find('BatchNorm') != -1:
+ if m.affine:
+ nn.init.normal_(m.weight, 1.0, 0.02)
+ nn.init.constant_(m.bias, 0.0)
+
+
+def weights_init_classifier(m):
+ classname = m.__class__.__name__
+ if classname.find('Linear') != -1:
+ nn.init.normal_(m.weight, std=0.001)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0.0)
--
Gitblit v1.8.0