reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-14 3c5565466db64950d797fdc3e40c599d73a3e239
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
#pragma once
 
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/jit/script/module.h>
#include <torch/csrc/jit/script/sugared_value.h>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
 
namespace torch {
namespace jit {
namespace script {
 
std::string typeString(py::handle h);
 
inline std::shared_ptr<SugaredValue> toSimple(Value* v) {
  return std::make_shared<SimpleValue>(v);
}
 
// NB: This should be the single entry-point for instantiating a SugaredValue
// from a Python object. If you are adding support for converting a new Python
// type, *add it in this function's implementation*.
std::shared_ptr<SugaredValue> toSugaredValue(
    py::object obj,
    Function& m,
    SourceRange loc,
    bool is_constant = false);
 
c10::optional<StrongFunctionPtr> as_function(const py::object& obj);
 
struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
  PythonValue(py::object the_self, c10::optional<py::object> rcb = c10::nullopt)
      : self(std::move(the_self)), rcb(std::move(rcb)) {}
 
  FunctionSchema getSchema(
      const size_t n_args,
      const size_t n_binders,
      const SourceRange& loc);
 
  // call it like a function, e.g. `outputs = this(inputs)`
  std::shared_ptr<SugaredValue> call(
      const SourceRange& loc,
      Function& m,
      at::ArrayRef<NamedValue> inputs_,
      at::ArrayRef<NamedValue> attributes,
      size_t n_binders) override;
 
  std::string kind() const override;
 
  std::vector<std::shared_ptr<SugaredValue>> asTuple(
      const SourceRange& loc,
      Function& m,
      const c10::optional<size_t>& size_hint = {}) override;
 
  std::shared_ptr<SugaredValue> attr(
      const SourceRange& loc,
      Function& m,
      const std::string& field) override;
 
 protected:
  py::object getattr(const SourceRange& loc, const std::string& name);
 
  void checkForAddToConstantsError(std::stringstream& ss);
 
  py::object self;
  c10::optional<py::object> rcb;
};
 
struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
  explicit PythonModuleValue(py::object mod) : PythonValue(std::move(mod)) {}
 
  std::shared_ptr<SugaredValue> attr(
      const SourceRange& loc,
      Function& m,
      const std::string& field) override;
};
 
struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
  explicit ConstantPythonTupleValue(py::object tup)
      : PythonValue(std::move(tup)) {}
  std::vector<std::shared_ptr<SugaredValue>> asTuple(
      const SourceRange& loc,
      Function& m,
      const c10::optional<size_t>& size_hint = {}) override;
 
  Value* asValue(const SourceRange& loc, Function& m) override;
};
 
// Represents all the parameters of a module as a List[Tensor]
struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
  ConstantParameterList(Value* the_list) : the_list_(the_list) {}
  std::string kind() const override {
    return "constant parameter list";
  }
  std::shared_ptr<SugaredValue> call(
      const SourceRange& loc,
      Function& caller,
      at::ArrayRef<NamedValue> inputs,
      at::ArrayRef<NamedValue> attributes,
      size_t n_binders) override {
    return toSimple(the_list_);
  }
 
 private:
  Value* the_list_;
};
 
struct VISIBILITY_HIDDEN ConstantTupleValue : public SugaredValue {
  explicit ConstantTupleValue(
      std::vector<std::shared_ptr<SugaredValue>> tup,
      bool callable = false)
      : tup_(tup){};
 
  std::vector<std::shared_ptr<SugaredValue>> asTuple(
      const SourceRange& loc,
      Function& m,
      const c10::optional<size_t>& size_hint = {}) override {
    return tup_;
  };
 
  std::string kind() const override {
    return "constant tuple";
  }
 
  std::vector<std::shared_ptr<SugaredValue>> tup_;
  bool callable_;
};
 
struct VISIBILITY_HIDDEN ConstantTupleMethod : public SugaredValue {
  explicit ConstantTupleMethod(
      std::vector<std::shared_ptr<SugaredValue>> tup,
      const std::string& name)
      : tup_(tup), name_(name){};
 
  std::string kind() const override {
    return name_;
  }
 
  std::shared_ptr<SugaredValue> call(
      const SourceRange& loc,
      Function& f,
      at::ArrayRef<NamedValue> inputs,
      at::ArrayRef<NamedValue> attributes,
      size_t n_binders) override {
    if (inputs.size() || attributes.size()) {
      throw ErrorReport(loc)
          << name_ << " method does not accept any arguments";
    }
    return std::make_shared<ConstantTupleValue>(tup_);
  }
 
  std::vector<std::shared_ptr<SugaredValue>> tup_;
  const std::string name_;
};
 
struct VISIBILITY_HIDDEN OverloadedMethodValue : public SugaredValue {
  OverloadedMethodValue(Value* module, std::vector<std::string> method_names)
      : module_(module), method_names_(std::move(method_names)) {}
 
  std::string kind() const override {
    return "overloaded function";
  }
 
  std::shared_ptr<SugaredValue> call(
      const SourceRange& loc,
      Function& caller,
      at::ArrayRef<NamedValue> inputs,
      at::ArrayRef<NamedValue> attributes,
      size_t n_binders) override;
 
 private:
  Value* module_;
  std::vector<std::string> method_names_;
};
 
struct VISIBILITY_HIDDEN OverloadedFunctionValue : public SugaredValue {
  OverloadedFunctionValue(std::vector<StrongFunctionPtr> compiled_overloads)
      : compiled_overloads_(std::move(compiled_overloads)) {}
 
  std::string kind() const override {
    return "overloaded function";
  }
 
  std::shared_ptr<SugaredValue> call(
      const SourceRange& loc,
      Function& caller,
      at::ArrayRef<NamedValue> inputs,
      at::ArrayRef<NamedValue> attributes,
      size_t n_binders) override;
 
 private:
  std::vector<StrongFunctionPtr> compiled_overloads_;
};
 
// defines how modules/methods behave inside the script subset.
// for now this does not have any interaction with python.
// in the future, we will add the ability to resolve `self.foo` to python
// {functions, modules, contants} so this SugaredValue is defined here
// anticipating we will eventually need to replace Module with a py::object
// holding the actual nn.Module class.
 
struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
  ModuleValue(Value* self, Module module, py::object py_module)
      : self_(self),
        module_(std::move(module)),
        py_module_(std::move(py_module)) {}
 
  std::string kind() const override {
    return "module";
  }
 
  Value* asValue(const SourceRange& loc, Function& m) override;
 
  // select an attribute on it, e.g. `this.field`
  std::shared_ptr<SugaredValue> attr(
      const SourceRange& loc,
      Function& m,
      const std::string& field) override;
 
  // call module.forward
  std::shared_ptr<SugaredValue> call(
      const SourceRange& loc,
      Function& caller,
      at::ArrayRef<NamedValue> inputs,
      at::ArrayRef<NamedValue> attributes,
      size_t n_binders) override {
    return attr(loc, caller, "forward")
        ->call(loc, caller, inputs, attributes, n_binders);
  }
 
  std::vector<std::shared_ptr<SugaredValue>> asTuple(
      const SourceRange& loc,
      Function& m,
      const c10::optional<size_t>& size_hint = {}) override;
 
  void setAttr(
      const SourceRange& loc,
      Function& m,
      const std::string& field,
      Value* newValue) override;
 
 private:
  Value* self_;
  Module module_;
  py::object py_module_;
 
  std::vector<std::shared_ptr<SugaredValue>> desugarModuleContainer(
      bool get_keys,
      bool get_values,
      const SourceRange& loc,
      Function& m);
};
 
struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
  BooleanDispatchValue(py::dict dispatched_fn)
      : dispatched_fn_(std::move(dispatched_fn)) {}
 
  std::string kind() const override {
    return "boolean dispatch";
  }
 
  std::shared_ptr<SugaredValue> call(
      const SourceRange& loc,
      Function& caller,
      at::ArrayRef<NamedValue> inputs,
      at::ArrayRef<NamedValue> attributes,
      size_t n_binders) override;
 
 private:
  py::dict dispatched_fn_;
};
 
} // namespace script
} // namespace jit
} // namespace torch