#pragma once #include "caffe2/core/common.h" #include "caffe2/proto/caffe2_pb.h" #include "caffe2/utils/proto_utils.h" #include "caffe2/utils/string_utils.h" #include #include #include namespace caffe2 { namespace transform { /** * Graph representation of an operator. */ struct CAFFE2_API Node { public: // Empty constructor for resize Node() {} // Alternate constructor Node( const OperatorDef& op, bool active, std::map> parents, std::map> children) : op(op), active(active), parents(parents), children(children) {} // The OperatorDef which this node represents. OperatorDef op; // Keeps track of if an operator has been deleted through a transformation. bool active = true; // Stores a pair (idx, blob_list), // idx = index of the child // blob_list = a list of strings, containing the blobs that connect the nodes std::map> parents; std::map> children; }; /** * Graph representation of a Netdef. */ struct CAFFE2_API Graph { public: /** * Given a subgraph, gets all of the parents of the subgraph, as well as * their associated blob names. Sorted by blob names. * * := (name of blob writing into subgraph, * index of node that writes into subgraph using that blob) */ const std::vector> GetSubgraphInput( const std::vector& subgraph); /** * Given a subgraph, gets all of the children of the subgraph, as well as * their associated blob names. Sorted by blob names. * * := (name of blob reading from subgraph, * index of node that reads from subgraph using that blob) */ const std::vector> GetSubgraphOutput( const std::vector& subgraph); /** * Graph generation. * Given a netdef, returns a Graph. * * Each node represents an operator. * An edge exists between two nodes if the parent op writes to a blob, which * is the input of the child blob, with no other op writing to the blob in * between the execution order. * * Time Complexity: O(E), where E is the number of blobs */ explicit Graph(const NetDef& net_def); /** * Generates a NetDef Representation for the current graph. * Nodes are visited in topological order, which is proper Opdef ordering. * TODO(benz): * There exists conflicts with repeated blob names, where topological sorting * is not sufficient for correct netdef representation, unless blobs are * renamed. * For example, if after a transformation, We have operator ancestry: * A --> B --> C, and also A --> D --> E, where B -> C and D -> E uses the * same blob name, then A, B, D, E, C is a correct topological ordering, * but D will write to the blob that C reads from, instead of B. * Currently believe that there will always be ambiguity unless blobs are * renamed. * This is solved by performing SSA on all transformed blob names. */ NetDef GetNetDef(); /** * Deactivate a subgraph, and get rid of all edges into this subgraph. */ void DeactivateSubgraph(std::vector subgraph); size_t size() const { return nodes_.size(); } void push_node(const Node& new_node) { return nodes_.push_back(new_node); } void resize_nodes(size_t new_size) { nodes_.resize(new_size); } // Index safe, less verbose way to access nodes inline const Node& node(size_t idx) const { return nodes_.at(idx); } inline Node& node(size_t idx) { return nodes_.at(idx); } inline bool is_node_active(size_t idx) { return node(idx).active; } inline const std::set& external_input() const { return external_input_; } inline const std::set& external_output() const { return external_output_; } private: const std::vector> GetSubgraphPerimeterHelper( bool from_children, const std::vector& match); // Stores the netdef representation. Is updated upon calls to GetNetDef. NetDef netdef_; // Stores which blobs the graph reads from, and writes to. std::set external_input_; std::set external_output_; // Keeps track of all the Operators currently within graph, even if inactive. std::vector nodes_; }; } // namespace transform // Adds an operator def to a netdef. // Returns the ptr, if you want to add anything extra (such as device_option) CAFFE2_API OperatorDef* AddOp( NetDef* netdef_ptr, string op_type, std::vector inputs, std::vector outputs); /** * This allows for the use of * and | to match operator types, * engines, or any other property that is represented by strings. * * For example, if we wanted to match an operator to Conv or FC, we can give: * "Conv|FC" as the type() of that op. */ CAFFE2_API bool MatchStrings(string p, string s); /** * This ensures that each named arg that exists in the pattern exists in g_op, * is equal in value. */ CAFFE2_API bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op); } // namespace caffe2