#pragma once #include "caffe2/core/common.h" #include "caffe2/core/logging.h" #include "caffe2/proto/caffe2_pb.h" #include "nomnigraph/Representations/NeuralNet.h" namespace caffe2 { class CAFFE2_API Caffe2Annotation : public nom::repr::Annotation { public: Caffe2Annotation() : Annotation(AnnotationKind::Caffe2) {} Caffe2Annotation(std::string device) : Annotation(AnnotationKind::Caffe2), Device(device) {} virtual ~Caffe2Annotation() {} void setOperatorDef(const caffe2::OperatorDef& opDef); bool hasOperatorDef() const; const caffe2::OperatorDef& getOperatorDef() const; caffe2::OperatorDef* getMutableOperatorDef(); void setDeviceOption(const caffe2::DeviceOption& opDef); bool hasDeviceOption() const; const caffe2::DeviceOption& getDeviceOption() const; caffe2::DeviceOption* getMutableDeviceOption(); // Distributed annotations void setDevice(std::string device); const std::string getDevice() const; void setDeviceType(int device); int getDeviceType() const; enum class ParallelizationScheme { none, split_by_batch, split_by_length, shard, shard_by_number }; void setParallelization(ParallelizationScheme, int num = -1); ParallelizationScheme getParallelizationScheme() const; int getParallelization() const; void setKeyNode(nom::repr::NNGraph::NodeRef); const nom::repr::NNGraph::NodeRef& getKeyNode() const; void setLengthNode(nom::repr::NNGraph::NodeRef); const nom::repr::NNGraph::NodeRef& getLengthNode() const; void setComponentLevels(std::vector components); std::vector getComponentLevels() const; static bool classof(const Annotation* A); private: std::string Device = ""; caffe2::OperatorDef OpDef; bool OpDefExists = false; // Distributed annotations int DeviceType = caffe2::DeviceTypeProto::PROTO_CPU; ParallelizationScheme parallelization_scheme_ = ParallelizationScheme::none; int parallelization_ = -1; nom::repr::NNGraph::NodeRef key_node_ = nullptr; nom::repr::NNGraph::NodeRef length_node_ = nullptr; std::vector component_levels_; }; } // namespace caffe2