reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#ifndef CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_
#define CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_
 
#include <chrono>
 
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/operators/reducer_functors.h"
#include "caffe2/perfkernels/math.h"
#include "caffe2/utils/math.h"
 
#ifdef CAFFE2_USE_MKL
#include <mkl.h>
#define FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
#endif
 
namespace caffe2 {
 
template <class Context>
class FloatToFusedRandRowwiseQuantizedOp : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit FloatToFusedRandRowwiseQuantizedOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        bitwidth_(OperatorBase::GetSingleArgument<int32_t>("bitwidth", 8)),
        random_(OperatorBase::GetSingleArgument<bool>("random", true)) {
    CAFFE_ENFORCE(
        bitwidth_ == 1 || bitwidth_ == 2 || bitwidth_ == 4 || bitwidth_ == 8,
        "Unsupported bitwidth");
    if (random_) {
#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
      int status = vslNewStream(
          &vslStream_,
          VSL_BRNG_MT19937,
          std::chrono::system_clock::now().time_since_epoch().count());
      if (status != VSL_STATUS_OK) {
        LOG(WARNING) << "vslNewStream returns " << status;
      }
#else
      gen_.seed(std::chrono::system_clock::now().time_since_epoch().count());
      dis_.reset(new std::uniform_real_distribution<float>(0.0f, 1.0f));
#endif
    }
  }
 
  ~FloatToFusedRandRowwiseQuantizedOp() {
    if (random_) {
#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
      int status = vslDeleteStream(&vslStream_);
      if (status != VSL_STATUS_OK) {
        LOG(WARNING) << "vslDeleteStream returns " << status;
      }
#endif
    }
  }
 
  bool RunOnDevice() override;
 
 private:
  INPUT_TAGS(DATA_FLOAT);
  OUTPUT_TAGS(DATA_FUSED_QUANTIZED);
 
 protected:
  size_t bitwidth_{8};
  bool random_{true};
  std::vector<float> random_buffer_;
 
#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
  VSLStreamStatePtr vslStream_;
#else
  std::unique_ptr<std::uniform_real_distribution<float>> dis_;
  std::minstd_rand gen_;
#endif
};
 
template <class Context>
class FusedRandRowwiseQuantizedToFloatOp : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  USE_SIMPLE_CTOR_DTOR(FusedRandRowwiseQuantizedToFloatOp)
 
  bool RunOnDevice() override;
 
 private:
  INPUT_TAGS(DATA_FUSED_QUANTIZED);
  OUTPUT_TAGS(DATA_FLOAT);
};
 
} // namespace caffe2
 
#endif // CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_