reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-11 bdf3ad71583fb4ef100d3819ecdae8fd9f70083e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#pragma once
 
/**
 * This file contains functionality to take a C++ function and infer its
 * c10::FunctionSchema.
 */
 
#include <ATen/core/function_schema.h>
#include <c10/util/C++17.h>
#include <c10/util/Metaprogramming.h>
 
namespace c10 {
namespace detail {
 
namespace infer_schema {
 
/// The templated inference code creates `ArgumentDef` instead of `Argument`,
/// because that can be constructed at compile time and has a much smaller
/// binary size than having calls to `Argument` constructors in the template.
/// Creating `Argument` objects from `ArgumentDef` can then be done at
/// runtime in a non-templated way.
struct ArgumentDef final {
  using GetTypeFn = TypePtr();
  GetTypeFn* getTypeFn;
};
 
template<bool V>
struct bool_t {};
template<> struct bool_t<true> : std::true_type {};
template<> struct bool_t<false> : std::false_type {};
 
/// Checks the static C++ types `Types` for correctness to catch common error cases.
template <class... Types>
constexpr int checkStaticTypes() {
 // Give nice error messages for some of the common error cases.
 // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
 static_assert(guts::conjunction<
     bool_t<!std::is_integral<Types>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>...
   >::value, "INVALID TYPE: Only int64_t and bool are supported as an integral argument type");
 static_assert(guts::conjunction<
     bool_t<!std::is_same<Types, float>::value>...
   >::value, "INVALID TYPE: float is not supported as an argument type, use double instead");
 return 0;
}
 
template <typename... Ts, size_t... Is>
constexpr std::array<ArgumentDef, sizeof...(Ts)> createArgumentVectorFromTypes(guts::index_sequence<Is...>) {
  return (
    // Check types for common errors
    checkStaticTypes<Ts...>(),
 
    // Create the return value
    std::array<ArgumentDef, sizeof...(Ts)>{{ArgumentDef{&getTypePtr_<guts::decay_t<Ts>>::call}...}}
  );
}
 
/// Creates a vector of `ArgumentDef` from a list of C++ types that are specified
/// as template arguments.
template<class ParameterTypes> struct createArguments final {};
template<class... ParameterTypes>
struct createArguments<guts::typelist::typelist<ParameterTypes...>> final {
  static constexpr std::array<ArgumentDef, sizeof...(ParameterTypes)> call() {
    return createArgumentVectorFromTypes<ParameterTypes...>(
        guts::make_index_sequence<sizeof...(ParameterTypes)>()
    );
  }
};
 
/// Creates a vector of `ArgumentDef` from a list of C++ types that are specified
/// as a tuple (i.e. in the way c10 kernels return values).
/// It can be a tuple<A, B, C> if there's three output arguments with types A, B, C.
/// It can be an empty tuple<>, or void for kernels that don't return anything.
/// It can be a single type A (i.e. no tuple) for the case where a kernel just
/// returns one value.
template<class ReturnTypeTuple, class Enable = void> struct createReturns final {};
 
template<class... ReturnTypes>
struct createReturns<std::tuple<ReturnTypes...>, void> final {
  static constexpr std::array<ArgumentDef, sizeof...(ReturnTypes)> call() {
    return createArgumentVectorFromTypes<ReturnTypes...>(
        guts::make_index_sequence<sizeof...(ReturnTypes)>()
    );
  }
};
 
template<class ReturnType>
struct createReturns<ReturnType, guts::enable_if_t<!std::is_same<void, ReturnType>::value && !guts::is_instantiation_of<std::tuple, ReturnType>::value>> final {
  static constexpr std::array<ArgumentDef, 1> call() {
    return createReturns<std::tuple<ReturnType>>::call();
  }
};
 
template<>
struct createReturns<void, void> final {
  static constexpr std::array<ArgumentDef, 0> call() {
    return createReturns<std::tuple<>>::call();
  }
};
 
template<size_t NumArgs>
std::vector<Argument> createArgumentVector(const std::array<ArgumentDef, NumArgs>& args) {
  std::vector<Argument> result;
  result.reserve(NumArgs);
  for (size_t i = 0; i < args.size(); ++i) {
    // Arguments are named "_<index>"
    result.push_back(Argument("_" + c10::guts::to_string(i), (*args[i].getTypeFn)()));
  }
  return result;
}
 
// This is intentionally a separate function
// because then the template is smaller and that benefits binary size
inline FunctionSchema make_function_schema(std::string&& name, std::string&& overload_name, std::vector<Argument>&& arguments, std::vector<Argument>&& returns) {
  return FunctionSchema(std::move(name), std::move(overload_name), std::move(arguments), std::move(returns));
}
 
template<size_t NumArgs, size_t NumReturns>
inline FunctionSchema make_function_schema(std::string&& name, std::string&& overload_name, const std::array<ArgumentDef, NumArgs>& arguments, const std::array<ArgumentDef, NumReturns>& returns) {
  return make_function_schema(std::move(name), std::move(overload_name), createArgumentVector(arguments), createArgumentVector(returns));
}
 
/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
/// function.
template <typename FunctionTraits>
FunctionSchema createFunctionSchemaFromTraits(std::string&& name, std::string&& overload_name) {
 using ReturnType = typename FunctionTraits::return_type;
 using ParameterTypes = typename FunctionTraits::parameter_types;
 
 constexpr auto arguments = createArguments<ParameterTypes>::call();
 constexpr auto returns = createReturns<ReturnType>::call();
 
 return make_function_schema(std::move(name), std::move(overload_name), arguments, returns);
}
}
}
 
template<class FuncType>
FunctionSchema inferFunctionSchema(std::string&& name, std::string&& overload_name) {
  return detail::infer_schema::createFunctionSchemaFromTraits<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
}
 
CAFFE2_API c10::optional<std::string> findSchemaDifferences(const FunctionSchema& inferred, const FunctionSchema& specified);
 
}