#ifndef CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_ #define CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_ #include #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/operators/reducer_functors.h" #include "caffe2/perfkernels/math.h" #include "caffe2/utils/math.h" #ifdef CAFFE2_USE_MKL #include #define FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL #endif namespace caffe2 { template class FloatToFusedRandRowwiseQuantizedOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit FloatToFusedRandRowwiseQuantizedOp(Args&&... args) : Operator(std::forward(args)...), bitwidth_(OperatorBase::GetSingleArgument("bitwidth", 8)), random_(OperatorBase::GetSingleArgument("random", true)) { CAFFE_ENFORCE( bitwidth_ == 1 || bitwidth_ == 2 || bitwidth_ == 4 || bitwidth_ == 8, "Unsupported bitwidth"); if (random_) { #ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL int status = vslNewStream( &vslStream_, VSL_BRNG_MT19937, std::chrono::system_clock::now().time_since_epoch().count()); if (status != VSL_STATUS_OK) { LOG(WARNING) << "vslNewStream returns " << status; } #else gen_.seed(std::chrono::system_clock::now().time_since_epoch().count()); dis_.reset(new std::uniform_real_distribution(0.0f, 1.0f)); #endif } } ~FloatToFusedRandRowwiseQuantizedOp() { if (random_) { #ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL int status = vslDeleteStream(&vslStream_); if (status != VSL_STATUS_OK) { LOG(WARNING) << "vslDeleteStream returns " << status; } #endif } } bool RunOnDevice() override; private: INPUT_TAGS(DATA_FLOAT); OUTPUT_TAGS(DATA_FUSED_QUANTIZED); protected: size_t bitwidth_{8}; bool random_{true}; std::vector random_buffer_; #ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL VSLStreamStatePtr vslStream_; #else std::unique_ptr> dis_; std::minstd_rand gen_; #endif }; template class FusedRandRowwiseQuantizedToFloatOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; USE_SIMPLE_CTOR_DTOR(FusedRandRowwiseQuantizedToFloatOp) bool RunOnDevice() override; private: INPUT_TAGS(DATA_FUSED_QUANTIZED); OUTPUT_TAGS(DATA_FLOAT); }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_