#pragma once
|
|
#include "caffe2/core/common.h"
|
#include "caffe2/proto/caffe2_pb.h"
|
#include "caffe2/utils/proto_utils.h"
|
#include "caffe2/utils/string_utils.h"
|
|
#include <algorithm>
|
#include <unordered_map>
|
#include <unordered_set>
|
|
namespace caffe2 {
|
|
namespace transform {
|
|
/**
|
* Graph representation of an operator.
|
*/
|
struct CAFFE2_API Node {
|
public:
|
// Empty constructor for resize
|
Node() {}
|
|
// Alternate constructor
|
Node(
|
const OperatorDef& op,
|
bool active,
|
std::map<int, std::vector<string>> parents,
|
std::map<int, std::vector<string>> children)
|
: op(op), active(active), parents(parents), children(children) {}
|
|
// The OperatorDef which this node represents.
|
OperatorDef op;
|
|
// Keeps track of if an operator has been deleted through a transformation.
|
bool active = true;
|
|
// Stores a pair (idx, blob_list),
|
// idx = index of the child
|
// blob_list = a list of strings, containing the blobs that connect the nodes
|
std::map<int, std::vector<string>> parents;
|
std::map<int, std::vector<string>> children;
|
};
|
|
/**
|
* Graph representation of a Netdef.
|
*/
|
struct CAFFE2_API Graph {
|
public:
|
/**
|
* Given a subgraph, gets all of the parents of the subgraph, as well as
|
* their associated blob names. Sorted by blob names.
|
*
|
* <string, int> := (name of blob writing into subgraph,
|
* index of node that writes into subgraph using that blob)
|
*/
|
const std::vector<std::pair<string, int>> GetSubgraphInput(
|
const std::vector<int>& subgraph);
|
|
/**
|
* Given a subgraph, gets all of the children of the subgraph, as well as
|
* their associated blob names. Sorted by blob names.
|
*
|
* <string, int> := (name of blob reading from subgraph,
|
* index of node that reads from subgraph using that blob)
|
*/
|
const std::vector<std::pair<string, int>> GetSubgraphOutput(
|
const std::vector<int>& subgraph);
|
|
/**
|
* Graph generation.
|
* Given a netdef, returns a Graph.
|
*
|
* Each node represents an operator.
|
* An edge exists between two nodes if the parent op writes to a blob, which
|
* is the input of the child blob, with no other op writing to the blob in
|
* between the execution order.
|
*
|
* Time Complexity: O(E), where E is the number of blobs
|
*/
|
explicit Graph(const NetDef& net_def);
|
|
/**
|
* Generates a NetDef Representation for the current graph.
|
* Nodes are visited in topological order, which is proper Opdef ordering.
|
* TODO(benz):
|
* There exists conflicts with repeated blob names, where topological sorting
|
* is not sufficient for correct netdef representation, unless blobs are
|
* renamed.
|
* For example, if after a transformation, We have operator ancestry:
|
* A --> B --> C, and also A --> D --> E, where B -> C and D -> E uses the
|
* same blob name, then A, B, D, E, C is a correct topological ordering,
|
* but D will write to the blob that C reads from, instead of B.
|
* Currently believe that there will always be ambiguity unless blobs are
|
* renamed.
|
* This is solved by performing SSA on all transformed blob names.
|
*/
|
NetDef GetNetDef();
|
|
/**
|
* Deactivate a subgraph, and get rid of all edges into this subgraph.
|
*/
|
void DeactivateSubgraph(std::vector<int> subgraph);
|
|
size_t size() const {
|
return nodes_.size();
|
}
|
|
void push_node(const Node& new_node) {
|
return nodes_.push_back(new_node);
|
}
|
|
void resize_nodes(size_t new_size) {
|
nodes_.resize(new_size);
|
}
|
|
// Index safe, less verbose way to access nodes
|
inline const Node& node(size_t idx) const {
|
return nodes_.at(idx);
|
}
|
|
inline Node& node(size_t idx) {
|
return nodes_.at(idx);
|
}
|
|
inline bool is_node_active(size_t idx) {
|
return node(idx).active;
|
}
|
|
inline const std::set<string>& external_input() const {
|
return external_input_;
|
}
|
|
inline const std::set<string>& external_output() const {
|
return external_output_;
|
}
|
|
private:
|
const std::vector<std::pair<string, int>> GetSubgraphPerimeterHelper(
|
bool from_children,
|
const std::vector<int>& match);
|
|
// Stores the netdef representation. Is updated upon calls to GetNetDef.
|
NetDef netdef_;
|
|
// Stores which blobs the graph reads from, and writes to.
|
std::set<string> external_input_;
|
std::set<string> external_output_;
|
|
// Keeps track of all the Operators currently within graph, even if inactive.
|
std::vector<Node> nodes_;
|
};
|
|
} // namespace transform
|
|
// Adds an operator def to a netdef.
|
// Returns the ptr, if you want to add anything extra (such as device_option)
|
CAFFE2_API OperatorDef* AddOp(
|
NetDef* netdef_ptr,
|
string op_type,
|
std::vector<string> inputs,
|
std::vector<string> outputs);
|
|
/**
|
* This allows for the use of * and | to match operator types,
|
* engines, or any other property that is represented by strings.
|
*
|
* For example, if we wanted to match an operator to Conv or FC, we can give:
|
* "Conv|FC" as the type() of that op.
|
*/
|
CAFFE2_API bool MatchStrings(string p, string s);
|
|
/**
|
* This ensures that each named arg that exists in the pattern exists in g_op,
|
* is equal in value.
|
*/
|
CAFFE2_API bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op);
|
|
} // namespace caffe2
|