#pragma once #include #include #include #include #include #include #include namespace torch { namespace jit { namespace script { std::string typeString(py::handle h); inline std::shared_ptr toSimple(Value* v) { return std::make_shared(v); } // NB: This should be the single entry-point for instantiating a SugaredValue // from a Python object. If you are adding support for converting a new Python // type, *add it in this function's implementation*. std::shared_ptr toSugaredValue( py::object obj, Function& m, SourceRange loc, bool is_constant = false); c10::optional as_function(const py::object& obj); struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { PythonValue(py::object the_self, c10::optional rcb = c10::nullopt) : self(std::move(the_self)), rcb(std::move(rcb)) {} FunctionSchema getSchema( const size_t n_args, const size_t n_binders, const SourceRange& loc); // call it like a function, e.g. `outputs = this(inputs)` std::shared_ptr call( const SourceRange& loc, Function& m, at::ArrayRef inputs_, at::ArrayRef attributes, size_t n_binders) override; std::string kind() const override; std::vector> asTuple( const SourceRange& loc, Function& m, const c10::optional& size_hint = {}) override; std::shared_ptr attr( const SourceRange& loc, Function& m, const std::string& field) override; protected: py::object getattr(const SourceRange& loc, const std::string& name); void checkForAddToConstantsError(std::stringstream& ss); py::object self; c10::optional rcb; }; struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue { explicit PythonModuleValue(py::object mod) : PythonValue(std::move(mod)) {} std::shared_ptr attr( const SourceRange& loc, Function& m, const std::string& field) override; }; struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue { explicit ConstantPythonTupleValue(py::object tup) : PythonValue(std::move(tup)) {} std::vector> asTuple( const SourceRange& loc, Function& m, const c10::optional& size_hint = {}) override; Value* asValue(const SourceRange& loc, Function& m) override; }; // Represents all the parameters of a module as a List[Tensor] struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { ConstantParameterList(Value* the_list) : the_list_(the_list) {} std::string kind() const override { return "constant parameter list"; } std::shared_ptr call( const SourceRange& loc, Function& caller, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override { return toSimple(the_list_); } private: Value* the_list_; }; struct VISIBILITY_HIDDEN ConstantTupleValue : public SugaredValue { explicit ConstantTupleValue( std::vector> tup, bool callable = false) : tup_(tup){}; std::vector> asTuple( const SourceRange& loc, Function& m, const c10::optional& size_hint = {}) override { return tup_; }; std::string kind() const override { return "constant tuple"; } std::vector> tup_; bool callable_; }; struct VISIBILITY_HIDDEN ConstantTupleMethod : public SugaredValue { explicit ConstantTupleMethod( std::vector> tup, const std::string& name) : tup_(tup), name_(name){}; std::string kind() const override { return name_; } std::shared_ptr call( const SourceRange& loc, Function& f, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override { if (inputs.size() || attributes.size()) { throw ErrorReport(loc) << name_ << " method does not accept any arguments"; } return std::make_shared(tup_); } std::vector> tup_; const std::string name_; }; struct VISIBILITY_HIDDEN OverloadedMethodValue : public SugaredValue { OverloadedMethodValue(Value* module, std::vector method_names) : module_(module), method_names_(std::move(method_names)) {} std::string kind() const override { return "overloaded function"; } std::shared_ptr call( const SourceRange& loc, Function& caller, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override; private: Value* module_; std::vector method_names_; }; struct VISIBILITY_HIDDEN OverloadedFunctionValue : public SugaredValue { OverloadedFunctionValue(std::vector compiled_overloads) : compiled_overloads_(std::move(compiled_overloads)) {} std::string kind() const override { return "overloaded function"; } std::shared_ptr call( const SourceRange& loc, Function& caller, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override; private: std::vector compiled_overloads_; }; // defines how modules/methods behave inside the script subset. // for now this does not have any interaction with python. // in the future, we will add the ability to resolve `self.foo` to python // {functions, modules, contants} so this SugaredValue is defined here // anticipating we will eventually need to replace Module with a py::object // holding the actual nn.Module class. struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue { ModuleValue(Value* self, Module module, py::object py_module) : self_(self), module_(std::move(module)), py_module_(std::move(py_module)) {} std::string kind() const override { return "module"; } Value* asValue(const SourceRange& loc, Function& m) override; // select an attribute on it, e.g. `this.field` std::shared_ptr attr( const SourceRange& loc, Function& m, const std::string& field) override; // call module.forward std::shared_ptr call( const SourceRange& loc, Function& caller, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override { return attr(loc, caller, "forward") ->call(loc, caller, inputs, attributes, n_binders); } std::vector> asTuple( const SourceRange& loc, Function& m, const c10::optional& size_hint = {}) override; void setAttr( const SourceRange& loc, Function& m, const std::string& field, Value* newValue) override; private: Value* self_; Module module_; py::object py_module_; std::vector> desugarModuleContainer( bool get_keys, bool get_values, const SourceRange& loc, Function& m); }; struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue { BooleanDispatchValue(py::dict dispatched_fn) : dispatched_fn_(std::move(dispatched_fn)) {} std::string kind() const override { return "boolean dispatch"; } std::shared_ptr call( const SourceRange& loc, Function& caller, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override; private: py::dict dispatched_fn_; }; } // namespace script } // namespace jit } // namespace torch