// in memory description of all ATen Ops similar to Caffe2 schema // once C10 exists this can be removed, or stubbed out, but we need // it now to implement correct semantic checking for script #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace jit { struct Node; using ::c10::Symbol; using ::c10::FunctionSchema; using OperationCreator = Operation (*)(const Node*); /* * Note: JIT relies on Operator instances having static lifetime, because * it for example stores a non-owning FunctionSchema* pointer in the Node class, * which points to the function shema stored in the Operator instance. * Also, jit::Operator is meant to store more operator related information like * symbolic derivatives, which also requires them to have static lifetime * so that changes to symbolic derivatives are remembered. * * Now, currently, the c10 operator library doesn't store jit::Operator instances, * but we use a listener pattern that notifies JIT about changes in the * c10 operator library and then registers jit::Operator instances to the JIT * operator registry, acting as wrappers to the c10 operators. * * However, that results in code duplication as JIT and c10 will likely get * their own mechanisms for storing derivatives and other operator related * information, and all of this would have to be wrapped from c10 into JIT. * * We should consider merging the JIT and c10 registries, moving jit::Operator * to c10 and storing these jit::Operator instances in the c10 operator library * instead, allowing us to have these mechanisms only implemented once. * However, the current jit::Operator implementation has additional features * like OperationCreator that aren't needed in c10 (they're only used for * prim ops like If/Else or While which wouldn't be in the c10 operator library), * and which depend on other JIT features which we don't want to move to c10 * (notably jit/ir.h). We might, however, be able, to split jit::Operator into * a c10::Operator with the core features and a jit::Operator that adds the * JIT-only features like OperationCreator, and then use c10::Operator in the * c10 operator library. */ struct TORCH_API Operator { Operator(c10::OperatorHandle opHandle, Operation operation) : schema_(std::make_shared(opHandle.schema())), op_(std::make_shared(std::move(operation))), c10Handle_(opHandle), options_(c10Handle_->options()) {} Operator( FunctionSchema schema, OperationCreator op_creator, c10::OperatorOptions options = c10::OperatorOptions()) : schema_(std::make_shared(std::move(schema))), op_creator_(std::move(op_creator)), options_(std::move(options)) {} Operator( const std::string& schema, OperationCreator op_creator, c10::OperatorOptions options = c10::OperatorOptions()) : schema_string_(schema), op_creator_(std::move(op_creator)), options_(std::move(options)) {} // Helper constructor to register `op` to run // run for _every_ IR Node where n.kind() == name, regardless of arguments. // This is accomplished by marking the schema varargs and having no required // arguments. This is used for things like prim::While or prim::If that can // take a number of different valid input types and lengths. Operator( Symbol name, OperationCreator op_creator, c10::OperatorOptions options = c10::OperatorOptions()) : Operator( FunctionSchema( name, "", {}, {}, /*is_vararg*/ true, /*is_varret*/ true), std::move(op_creator), std::move(options)) {} Operator( FunctionSchema schema, Operation op, c10::OperatorOptions options = c10::OperatorOptions()) : schema_(std::make_shared(std::move(schema))), op_(std::make_shared(std::move(op))), options_(std::move(options)) {} Operator( const std::string& schema, int(*op)(Stack&), c10::OperatorOptions options = c10::OperatorOptions()) : schema_string_(schema), op_(std::make_shared(std::move(op))), options_(std::move(options)) {} bool matches(const Node* node) const; Operation getOperation(const Node* node = nullptr) const { if (op_) { return *op_; } AT_ASSERT(node != nullptr); return op_creator_(node); } const FunctionSchema& schema() const { // we lazily parse schema initialized from strings so that // we do less work during static operator registration if (!schema_) { schema_ = std::make_shared(parseSchema(schema_string_.value())); schema_string_ = c10::nullopt; } return *schema_; } bool isC10Op() const { return c10Handle_.has_value(); } c10::AliasAnalysisKind aliasAnalysisKind() const { return options_.aliasAnalysis(); } private: mutable c10::optional schema_string_; // cannot use c10::optional because windows has issues that require an // assignment operator to be generated cannot use std::unique_ptr because // initializer lists of Operators end up copying the Operator mutable std::shared_ptr schema_; // Essentially a variant. // NB: std::function has a default state (where it == nullptr). std::shared_ptr op_; OperationCreator op_creator_; c10::optional c10Handle_; c10::OperatorOptions options_; }; TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema); TORCH_API const std::vector> getAllOperators(); TORCH_API const std::vector>& getAllOperatorsFor( Symbol name); std::shared_ptr findOperatorFor(const Node* node); const Operator& getOperatorFor(const Node* node); inline Operation getOperation(const Node* node) { // note: getOperatorFor ensures that getOperatorFor(node).matches(node) == // true so the call to selectVariant is always valid. return getOperatorFor(node).getOperation(node); } TORCH_API std::vector findSimilarOperators(Symbol input_op); TORCH_API void registerOperator(Operator&& op); // XXX: this function is meant to be used with string literals only! Operator& sig(const char* signature_literal); struct OperatorSet { OperatorSet(std::initializer_list sig_literals); // XXX: Returns a nullptr if no Operator in the set matches n Operator* find(const Node* n) const; private: std::unordered_map>> ops; }; } // namespace jit } // namespace torch