#pragma once
|
|
#include <torch/csrc/jit/pybind_utils.h>
|
#include <torch/csrc/jit/script/module.h>
|
#include <torch/csrc/jit/script/sugared_value.h>
|
#include <memory>
|
#include <sstream>
|
#include <string>
|
#include <vector>
|
|
namespace torch {
|
namespace jit {
|
namespace script {
|
|
std::string typeString(py::handle h);
|
|
inline std::shared_ptr<SugaredValue> toSimple(Value* v) {
|
return std::make_shared<SimpleValue>(v);
|
}
|
|
// NB: This should be the single entry-point for instantiating a SugaredValue
|
// from a Python object. If you are adding support for converting a new Python
|
// type, *add it in this function's implementation*.
|
std::shared_ptr<SugaredValue> toSugaredValue(
|
py::object obj,
|
Function& m,
|
SourceRange loc,
|
bool is_constant = false);
|
|
c10::optional<StrongFunctionPtr> as_function(const py::object& obj);
|
|
struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
|
PythonValue(py::object the_self, c10::optional<py::object> rcb = c10::nullopt)
|
: self(std::move(the_self)), rcb(std::move(rcb)) {}
|
|
FunctionSchema getSchema(
|
const size_t n_args,
|
const size_t n_binders,
|
const SourceRange& loc);
|
|
// call it like a function, e.g. `outputs = this(inputs)`
|
std::shared_ptr<SugaredValue> call(
|
const SourceRange& loc,
|
Function& m,
|
at::ArrayRef<NamedValue> inputs_,
|
at::ArrayRef<NamedValue> attributes,
|
size_t n_binders) override;
|
|
std::string kind() const override;
|
|
std::vector<std::shared_ptr<SugaredValue>> asTuple(
|
const SourceRange& loc,
|
Function& m,
|
const c10::optional<size_t>& size_hint = {}) override;
|
|
std::shared_ptr<SugaredValue> attr(
|
const SourceRange& loc,
|
Function& m,
|
const std::string& field) override;
|
|
protected:
|
py::object getattr(const SourceRange& loc, const std::string& name);
|
|
void checkForAddToConstantsError(std::stringstream& ss);
|
|
py::object self;
|
c10::optional<py::object> rcb;
|
};
|
|
struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
|
explicit PythonModuleValue(py::object mod) : PythonValue(std::move(mod)) {}
|
|
std::shared_ptr<SugaredValue> attr(
|
const SourceRange& loc,
|
Function& m,
|
const std::string& field) override;
|
};
|
|
struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
|
explicit ConstantPythonTupleValue(py::object tup)
|
: PythonValue(std::move(tup)) {}
|
std::vector<std::shared_ptr<SugaredValue>> asTuple(
|
const SourceRange& loc,
|
Function& m,
|
const c10::optional<size_t>& size_hint = {}) override;
|
|
Value* asValue(const SourceRange& loc, Function& m) override;
|
};
|
|
// Represents all the parameters of a module as a List[Tensor]
|
struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
|
ConstantParameterList(Value* the_list) : the_list_(the_list) {}
|
std::string kind() const override {
|
return "constant parameter list";
|
}
|
std::shared_ptr<SugaredValue> call(
|
const SourceRange& loc,
|
Function& caller,
|
at::ArrayRef<NamedValue> inputs,
|
at::ArrayRef<NamedValue> attributes,
|
size_t n_binders) override {
|
return toSimple(the_list_);
|
}
|
|
private:
|
Value* the_list_;
|
};
|
|
struct VISIBILITY_HIDDEN ConstantTupleValue : public SugaredValue {
|
explicit ConstantTupleValue(
|
std::vector<std::shared_ptr<SugaredValue>> tup,
|
bool callable = false)
|
: tup_(tup){};
|
|
std::vector<std::shared_ptr<SugaredValue>> asTuple(
|
const SourceRange& loc,
|
Function& m,
|
const c10::optional<size_t>& size_hint = {}) override {
|
return tup_;
|
};
|
|
std::string kind() const override {
|
return "constant tuple";
|
}
|
|
std::vector<std::shared_ptr<SugaredValue>> tup_;
|
bool callable_;
|
};
|
|
struct VISIBILITY_HIDDEN ConstantTupleMethod : public SugaredValue {
|
explicit ConstantTupleMethod(
|
std::vector<std::shared_ptr<SugaredValue>> tup,
|
const std::string& name)
|
: tup_(tup), name_(name){};
|
|
std::string kind() const override {
|
return name_;
|
}
|
|
std::shared_ptr<SugaredValue> call(
|
const SourceRange& loc,
|
Function& f,
|
at::ArrayRef<NamedValue> inputs,
|
at::ArrayRef<NamedValue> attributes,
|
size_t n_binders) override {
|
if (inputs.size() || attributes.size()) {
|
throw ErrorReport(loc)
|
<< name_ << " method does not accept any arguments";
|
}
|
return std::make_shared<ConstantTupleValue>(tup_);
|
}
|
|
std::vector<std::shared_ptr<SugaredValue>> tup_;
|
const std::string name_;
|
};
|
|
struct VISIBILITY_HIDDEN OverloadedMethodValue : public SugaredValue {
|
OverloadedMethodValue(Value* module, std::vector<std::string> method_names)
|
: module_(module), method_names_(std::move(method_names)) {}
|
|
std::string kind() const override {
|
return "overloaded function";
|
}
|
|
std::shared_ptr<SugaredValue> call(
|
const SourceRange& loc,
|
Function& caller,
|
at::ArrayRef<NamedValue> inputs,
|
at::ArrayRef<NamedValue> attributes,
|
size_t n_binders) override;
|
|
private:
|
Value* module_;
|
std::vector<std::string> method_names_;
|
};
|
|
struct VISIBILITY_HIDDEN OverloadedFunctionValue : public SugaredValue {
|
OverloadedFunctionValue(std::vector<StrongFunctionPtr> compiled_overloads)
|
: compiled_overloads_(std::move(compiled_overloads)) {}
|
|
std::string kind() const override {
|
return "overloaded function";
|
}
|
|
std::shared_ptr<SugaredValue> call(
|
const SourceRange& loc,
|
Function& caller,
|
at::ArrayRef<NamedValue> inputs,
|
at::ArrayRef<NamedValue> attributes,
|
size_t n_binders) override;
|
|
private:
|
std::vector<StrongFunctionPtr> compiled_overloads_;
|
};
|
|
// defines how modules/methods behave inside the script subset.
|
// for now this does not have any interaction with python.
|
// in the future, we will add the ability to resolve `self.foo` to python
|
// {functions, modules, contants} so this SugaredValue is defined here
|
// anticipating we will eventually need to replace Module with a py::object
|
// holding the actual nn.Module class.
|
|
struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
|
ModuleValue(Value* self, Module module, py::object py_module)
|
: self_(self),
|
module_(std::move(module)),
|
py_module_(std::move(py_module)) {}
|
|
std::string kind() const override {
|
return "module";
|
}
|
|
Value* asValue(const SourceRange& loc, Function& m) override;
|
|
// select an attribute on it, e.g. `this.field`
|
std::shared_ptr<SugaredValue> attr(
|
const SourceRange& loc,
|
Function& m,
|
const std::string& field) override;
|
|
// call module.forward
|
std::shared_ptr<SugaredValue> call(
|
const SourceRange& loc,
|
Function& caller,
|
at::ArrayRef<NamedValue> inputs,
|
at::ArrayRef<NamedValue> attributes,
|
size_t n_binders) override {
|
return attr(loc, caller, "forward")
|
->call(loc, caller, inputs, attributes, n_binders);
|
}
|
|
std::vector<std::shared_ptr<SugaredValue>> asTuple(
|
const SourceRange& loc,
|
Function& m,
|
const c10::optional<size_t>& size_hint = {}) override;
|
|
void setAttr(
|
const SourceRange& loc,
|
Function& m,
|
const std::string& field,
|
Value* newValue) override;
|
|
private:
|
Value* self_;
|
Module module_;
|
py::object py_module_;
|
|
std::vector<std::shared_ptr<SugaredValue>> desugarModuleContainer(
|
bool get_keys,
|
bool get_values,
|
const SourceRange& loc,
|
Function& m);
|
};
|
|
struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
|
BooleanDispatchValue(py::dict dispatched_fn)
|
: dispatched_fn_(std::move(dispatched_fn)) {}
|
|
std::string kind() const override {
|
return "boolean dispatch";
|
}
|
|
std::shared_ptr<SugaredValue> call(
|
const SourceRange& loc,
|
Function& caller,
|
at::ArrayRef<NamedValue> inputs,
|
at::ArrayRef<NamedValue> attributes,
|
size_t n_binders) override;
|
|
private:
|
py::dict dispatched_fn_;
|
};
|
|
} // namespace script
|
} // namespace jit
|
} // namespace torch
|