reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-10 c3765bd24fe73747688a0ec2a550f219c9acb384
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#pragma once
 
#include "pickler.h"
 
namespace torch {
namespace jit {
 
using ClassResolver =
    std::function<c10::StrongTypePtr(const c10::QualifiedName&)>;
 
using ObjLoader =
    std::function<c10::intrusive_ptr<c10::ivalue::Object>(at::StrongTypePtr, IValue)>;
 
// [unpickler refactor] there is some cruft around PickleOpCode::BUILD,
// PickleOpCode::NEWOBJ, and the last_opcode_ member below that should be deleted at
// some point, the Pickler doesn't produce it and it's only around to support
// models saved before 1.1
class Unpickler {
  TH_DISALLOW_COPY_AND_ASSIGN(Unpickler);
 
 public:
  // tensors inside the pickle are references to the tensor_table
  Unpickler(
      std::function<bool(char*, size_t)> reader,
      ClassResolver class_resolver,
      const std::vector<at::Tensor>* tensor_table)
      : reader_(reader),
        tensor_table_(tensor_table),
        class_resolver_(std::move(class_resolver)) {}
 
  // tensors inside the pickle contain meta-data, the raw tensor
  // dead is retrieved by calling `read_record`.
  Unpickler(
      std::function<bool(char*, size_t)> reader,
      ClassResolver class_resolver,
      ObjLoader obj_loader,
      std::function<at::DataPtr(const std::string&)> read_record,
      c10::optional<at::Device> device)
      : reader_(reader),
        tensor_table_(nullptr),
        class_resolver_(std::move(class_resolver)),
        obj_loader_(std::move(obj_loader)),
        read_record_(std::move(read_record)),
        device_(std::move(device)) {}
 
  IValue parse_ivalue();
 
 private:
  // No arguments ensures that a template arugment must be specified
  // so that the number of bytes read / type read is explicit
  template <typename T>
  T read() {
    T item;
    if (!reader_(reinterpret_cast<char*>(&item), sizeof(item))) {
      AT_ERROR("Unexpected end of pickler archive.");
    }
    return item;
  }
 
  std::string readBytes(size_t num_bytes);
 
  double readFloat();
  PickleOpCode readInstruction();
  PickleOpCode readOpCode();
  std::string readString();
  void readList(IValue list_ivalue);
  void setInput(size_t memo_id);
  void run();
 
  // Returns a pointer to the number of bytes requested. This should state-fully
  // remember how many bytes have been read
  std::function<bool(char*, size_t)> reader_;
 
  std::vector<IValue> stack_;
 
  // globals are represented on the stack as IValue integer indices
  // into this list
  std::vector<std::function<void(void)>> globals_;
  std::vector<IValue> memo_table_;
  std::vector<size_t> marks_;
  const std::vector<at::Tensor>* tensor_table_;
 
  // optionally nullptr, needs to be present for creating classes
  ClassResolver class_resolver_;
  ObjLoader obj_loader_;
  IValue empty_tuple_;
 
  std::function<at::DataPtr(const std::string&)> read_record_;
  c10::optional<at::Device> device_;
};
 
} // namespace jit
} // namespace torch