#ifndef CAFFE2_OPERATORS_MAP_OPS_H_ #define CAFFE2_OPERATORS_MAP_OPS_H_ #include #include #include #include #include #include #include #include "caffe2/core/blob_serialization.h" #include "caffe2/core/context.h" #include "caffe2/core/operator.h" namespace caffe2 { template struct TypeNameTraits { static constexpr const char* name = "unknown"; }; template <> struct TypeNameTraits { static constexpr const char* name = "int64_t"; }; template <> struct TypeNameTraits { static constexpr const char* name = "int32_t"; }; template struct MapTypeTraits { using MapType = std::unordered_map; static string MapTypeName() { return string("(std::unordered_map<") + TypeNameTraits::name + ", " + TypeNameTraits::name + ">)"; } }; using MapType64To64 = MapTypeTraits::MapType; using MapType64To32 = MapTypeTraits::MapType; using MapType32To32 = MapTypeTraits::MapType; using MapType32To64 = MapTypeTraits::MapType; template class CreateMapOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit CreateMapOp(Args&&... args) : Operator(std::forward(args)...) {} ~CreateMapOp() {} bool RunOnDevice() override { TensorProto::DataType key_dtype = static_cast(this->template GetSingleArgument( "key_dtype", TensorProto_DataType_INT32)); return DispatchHelper>::call( this, DataTypeToTypeMeta(key_dtype)); } template bool DoRunWithType() { TensorProto::DataType value_dtype = static_cast(this->template GetSingleArgument( "value_dtype", TensorProto_DataType_INT32)); return DispatchHelper< TensorTypes2, KEY_T>::call(this, DataTypeToTypeMeta(value_dtype)); } template bool DoRunWithType2() { // clear to make sure the map is empty this->template Output::MapType>(MAP) ->clear(); return true; } template bool DoRunWithOtherType2() { TensorProto::DataType value_dtype = static_cast(this->template GetSingleArgument( "value_dtype", TensorProto_DataType_INT32)); CAFFE_THROW( "CreateMap is not implemented on value tensor of type ", DataTypeToTypeMeta(value_dtype).name(), "consider adding it as a type in the DispatchHelper list"); } OUTPUT_TAGS(MAP); }; template class KeyValueToMapOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit KeyValueToMapOp(Args&&... args) : Operator(std::forward(args)...) {} ~KeyValueToMapOp() {} bool RunOnDevice() override { return DispatchHelper>::call( this, Input(KEYS)); } template bool DoRunWithType() { return DispatchHelper< TensorTypes2, KEY_T>::call(this, Input(VALUES)); } template bool DoRunWithType2() { using MapType = typename MapTypeTraits::MapType; const auto& key_input = Input(KEYS); const auto& value_input = Input(VALUES); CAFFE_ENFORCE_EQ(key_input.numel(), value_input.numel()); auto* key_data = key_input.template data(); auto* value_data = value_input.template data(); auto* map_data = this->template Output(MAP); for (int i = 0; i < key_input.numel(); ++i) { map_data->emplace(key_data[i], value_data[i]); } return true; } template bool DoRunWithOtherType2() { CAFFE_THROW( "KeyValueToMap is not implemented on value tensor of type ", Input(VALUES).dtype().name(), "consider adding it as a type in the DispatchHelper list"); } INPUT_TAGS(KEYS, VALUES); OUTPUT_TAGS(MAP); }; template class MapToKeyValueOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit MapToKeyValueOp(Args&&... args) : Operator(std::forward(args)...) {} ~MapToKeyValueOp() {} bool RunOnDevice() override { return DispatchHelper>::call(this, OperatorBase::InputBlob(MAP)); } template bool DoRunWithType() { using key_type = typename MAP_T::key_type; using mapped_type = typename MAP_T::mapped_type; auto& map_data = this->template Input(MAP); auto* key_output = Output(KEYS, {static_cast(map_data.size())}, at::dtype()); auto* value_output = Output(VALUES, {static_cast(map_data.size())}, at::dtype()); auto* key_data = key_output->template mutable_data(); auto* value_data = value_output->template mutable_data(); for (const auto& it : map_data) { *key_data = it.first; *value_data = it.second; key_data++; value_data++; } return true; } INPUT_TAGS(MAP); OUTPUT_TAGS(KEYS, VALUES); }; template class MapSerializer : public BlobSerializerBase { public: using MapType = typename MapTypeTraits::MapType; void Serialize( const void* pointer, TypeMeta typeMeta, const string& name, BlobSerializerBase::SerializationAcceptor acceptor) override { CAFFE_ENFORCE(typeMeta.Match()); const MapType& map_data = *static_cast(pointer); int64_t sz = map_data.size(); Tensor key_tensor(CPU); key_tensor.Resize(sz); Tensor value_tensor(CPU); value_tensor.Resize(sz); auto* key_data = key_tensor.mutable_data(); auto* value_data = value_tensor.mutable_data(); for (const auto& it : map_data) { *key_data = it.first; *value_data = it.second; key_data++; value_data++; } TensorProtos tensor_protos; TensorSerializer ser; ser.Serialize( key_tensor, name, tensor_protos.add_protos(), 0, key_tensor.numel()); ser.Serialize( value_tensor, name, tensor_protos.add_protos(), 0, value_tensor.numel()); BlobProto blob_proto; blob_proto.set_name(name); blob_proto.set_type(MapTypeTraits::MapTypeName()); blob_proto.set_content(SerializeAsString_EnforceCheck(tensor_protos)); acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); } }; template class MapDeserializer : public BlobDeserializerBase { public: using MapType = typename MapTypeTraits::MapType; void Deserialize(const BlobProto& proto, Blob* blob) override { TensorProtos tensor_protos; CAFFE_ENFORCE( tensor_protos.ParseFromString(proto.content()), "Fail to parse TensorProtos"); TensorDeserializer deser; Tensor key_tensor = deser.Deserialize(tensor_protos.protos(0)); Tensor value_tensor = deser.Deserialize(tensor_protos.protos(1)); auto* key_data = key_tensor.data(); auto* value_data = value_tensor.data(); auto* map_ptr = blob->template GetMutable(); for (int i = 0; i < key_tensor.numel(); ++i) { map_ptr->emplace(key_data[i], value_data[i]); } } }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_MAP_OPS_H_