#ifndef CAFFE2_OPT_CONVERTER_H
|
#define CAFFE2_OPT_CONVERTER_H
|
|
#include "caffe2/core/common.h"
|
#include "caffe2/core/logging.h"
|
#include "caffe2/opt/annotations.h"
|
#include "caffe2/proto/caffe2_pb.h"
|
#include "nomnigraph/Graph/Graph.h"
|
#include "nomnigraph/Representations/ControlFlow.h"
|
#include "nomnigraph/Representations/NeuralNet.h"
|
|
#include <unordered_map>
|
|
namespace caffe2 {
|
|
CAFFE2_API void injectDataEdgeIndicators(caffe2::NetDef* net);
|
CAFFE2_API void removeDataEdgeIndicators(caffe2::NetDef* net);
|
|
// Default conversion to a NNModule
|
// Optionally strict -- which checks for various input and output conditions.
|
// Optionally this function will update a vector that maps operators in the
|
// netdef positionally to NodeRefs in the resultant NNModule.
|
CAFFE2_API nom::repr::NNModule convertToNNModule(
|
const caffe2::NetDef& net,
|
bool strict = false,
|
std::vector<nom::repr::NNGraph::NodeRef>* = nullptr);
|
CAFFE2_API caffe2::NetDef convertToCaffe2Proto(nom::repr::NNModule&);
|
|
// Pass in an oldNet to copy all the attributes of that network.
|
// Be warned that transformations that modify the graph's inputs or outputs
|
// are not reflected in changes to external_input or external_output.
|
CAFFE2_API caffe2::NetDef convertToCaffe2Proto(nom::repr::NNModule&, const caffe2::NetDef& oldNet);
|
|
// Use these functions instead of the registry directly.
|
CAFFE2_API std::unique_ptr<nom::repr::NeuralNetOperator> convertToNeuralNetOperator(
|
const caffe2::OperatorDef& op);
|
|
CAFFE2_API caffe2::OperatorDef convertToOperatorDef(
|
const nom::repr::NNGraph::NodeRef& instrNode);
|
|
// If the annotation doesn't exist, attempt to add it
|
CAFFE2_API Caffe2Annotation* getOrAddCaffe2Annotation(
|
nom::repr::NNGraph::NodeRef& instrNode);
|
|
class CAFFE2_API Converter {
|
public:
|
explicit Converter() = default;
|
virtual std::unique_ptr<nom::repr::NeuralNetOperator>
|
convertToNeuralNetOperator(const OperatorDef&) = 0;
|
virtual OperatorDef convertToOperatorDef(const nom::repr::NeuralNetOperator*);
|
static std::map<std::string, caffe2::Argument> getArgumentsFromOperator(
|
caffe2::OperatorDef op);
|
|
virtual ~Converter() {}
|
};
|
|
C10_DECLARE_REGISTRY(ConverterRegistry, Converter);
|
#define REGISTER_CONVERTER(name, cls) \
|
C10_REGISTER_CLASS(ConverterRegistry, name, cls)
|
|
#define TRIVIAL_CONVERTER(opName) \
|
class opName##Converter : public Converter { \
|
std::unique_ptr<nom::repr::NeuralNetOperator> convertToNeuralNetOperator( \
|
const OperatorDef& op) override { \
|
return nom::util::make_unique<nom::repr::opName>(); \
|
} \
|
virtual ~opName##Converter() {} \
|
};
|
|
} // namespace caffe2
|
|
|
#endif // CAFFE2_OPT_CONVERTER_H
|