#ifndef CAFFE2_OPERATORS_HALF_FLOAT_OPS_H_ #define CAFFE2_OPERATORS_HALF_FLOAT_OPS_H_ #include "caffe2/core/context.h" #include "caffe2/core/operator.h" namespace caffe2 { template class FloatToHalfOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; USE_SIMPLE_CTOR_DTOR(FloatToHalfOp); bool RunOnDevice() override; }; template class HalfToFloatOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; USE_SIMPLE_CTOR_DTOR(HalfToFloatOp); bool RunOnDevice() override; }; class Float16ConstantFillOp : public Operator { public: template explicit Float16ConstantFillOp(Args&&... args) : Operator(std::forward(args)...), shape_(this->template GetRepeatedArgument("shape")) {} USE_OPERATOR_FUNCTIONS(CPUContext); virtual ~Float16ConstantFillOp() {} bool RunOnDevice() override; private: vector shape_; }; class Float16UniformFillOp : public Operator { public: template explicit Float16UniformFillOp(Args&&... args) : Operator(std::forward(args)...), shape_(this->template GetRepeatedArgument("shape")), min_(this->template GetSingleArgument("min", 0)), max_(this->template GetSingleArgument("max", 1)) { if (InputSize() == 3) { CAFFE_ENFORCE( !this->template HasSingleArgumentOfType("min"), "Cannot set both min arg and min input blob"); CAFFE_ENFORCE( !this->template HasSingleArgumentOfType("max"), "Cannot set both max arg and max input blob"); } else { CAFFE_ENFORCE_LT( min_, max_, "Max value should be bigger than min value."); } } USE_OPERATOR_FUNCTIONS(CPUContext); virtual ~Float16UniformFillOp() {} bool RunOnDevice() override; private: vector shape_; float min_; float max_; }; inline std::vector Float16FillerTensorInference( const OperatorDef& def, const vector& in) { vector out(1); ArgumentHelper helper(def); out[0].set_data_type(static_cast( helper.GetSingleArgument("dtype", TensorProto_DataType_FLOAT16))); auto shape = helper.GetRepeatedArgument("shape"); for (int d : shape) { out[0].add_dims(d); } return out; } } // namespace caffe2 #endif // CAFFE2_OPERATORS_HALF_FLOAT_OPS_H_