#pragma once #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/operators/filler_op.h" #include "caffe2/utils/cast.h" #include "caffe2/utils/math.h" namespace caffe2 { template class GivenTensorFillOp final : public FillerOp { public: USE_OPERATOR_CONTEXT_FUNCTIONS; explicit GivenTensorFillOp(const OperatorDef& operator_def, Workspace* ws) : FillerOp(operator_def, ws) { const ArgumentHelper helper(operator_def); // GivenTensorFillOp can be provided with a "dtype" arg if float is // is specified as T. Otherwise, "dtype" is ignored. // In the ideal world, we would get rid of templating of T at all, but we // need to provide backwards compatibility. if (!std::is_same::value || !helper.HasArgument("dtype")) { ExtractValues(); } else { auto dtype = cast::GetCastDataType(helper, "dtype"); switch (dtype) { case TensorProto_DataType_FLOAT: ExtractValues(); break; case TensorProto_DataType_DOUBLE: ExtractValues(); break; case TensorProto_DataType_BOOL: ExtractValues(); break; case TensorProto_DataType_INT16: ExtractValues(); break; case TensorProto_DataType_INT32: ExtractValues(); break; case TensorProto_DataType_INT64: ExtractValues(); break; case TensorProto_DataType_STRING: ExtractValues(); break; case TensorProto_DataType_UNDEFINED: CAFFE_THROW("Cannot have undefined 'dtype' argument"); default: CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype); } } } bool Fill(Tensor* output) override { return (this->*body_)(output); } private: template void ExtractValues() { auto source_values = this->template GetRepeatedArgument("values"); ReinitializeTensor(&values_, {static_cast(source_values.size())}, at::dtype().device(CPU)); Type* values_data = values_.template mutable_data(); for (int i = 0; i < source_values.size(); i++) { values_data[i] = static_cast(source_values[i]); } body_ = &GivenTensorFillOp::FillWithType; } template bool FillWithType(Tensor* output) { DCHECK_EQ(output->numel(), values_.numel()) << "output size: " << output->numel() << " given size: " << values_.numel(); auto* data = output->template mutable_data(); const Type* values_data = values_.template data(); if (output->numel()) { context_.CopyItemsFromCPU( TypeMeta::Make(), output->numel(), values_data, data); } return true; } bool (GivenTensorFillOp::*body_)(Tensor* output); Tensor values_; }; } // namespace caffe2