#ifndef CAFFE2_OPERATORS_BYTE_WEIGHT_DEQUANT_OP_H_ #define CAFFE2_OPERATORS_BYTE_WEIGHT_DEQUANT_OP_H_ #include "caffe2/core/operator.h" #include "caffe2/utils/eigen_utils.h" #include "caffe2/utils/math.h" namespace caffe2 { template class ByteWeightDequantOp : public Operator { public: ByteWeightDequantOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), min_(this->template GetSingleArgument("min", -3)), max_(this->template GetSingleArgument("max", 3)), shape_(this->template GetRepeatedArgument("shape")) {} USE_OPERATOR_FUNCTIONS(Context); using Operator::Operator; bool RunOnDevice() override { const auto& WI = Input(0); auto* Y = Output(0, shape_, at::dtype()); float bin_interval = (max_ - min_) / 255.0; int total = 1; for (int i = 0; i < shape_.size(); i++) { total *= Y->size(i); } const uint8_t* Xdata; if (WI.template IsType()) { CAFFE_ENFORCE(total, WI.nbytes()); Xdata = WI.template data(); } else { CAFFE_ENFORCE(total, WI.template data()[0].size()); Xdata = reinterpret_cast( WI.template data()[0].c_str()); } auto* Ydata = Y->template mutable_data(); ConstEigenVectorMap index(&Xdata[0], total); EigenVectorMap weights(&Ydata[0], total); weights = (index.cast().array() * bin_interval) + min_; return true; } private: float min_; float max_; std::vector shape_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_BYTE_WEIGHT_DEQUANT_OP_H_