#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace jit { namespace script { struct Def; struct ClassDef; struct SugaredValue; struct Resolver; using ResolverPtr = std::shared_ptr; struct Self { virtual ~Self() {} virtual std::shared_ptr makeSugared(Value* v) const = 0; virtual ClassTypePtr getClassType() const = 0; }; // A CompilationUnit is a list of named Functions // with helper methods to iterate the list, or invoke the function. // Classes have a CompilationUnit holding the class methods // and Modules also have a CompilationUnit holding the Functions that // are used to implement their Methods struct TORCH_API CompilationUnit { // constructor that takes a set of functions to compile using the native // resolver explicit CompilationUnit(const std::string& source); CompilationUnit() = default; CompilationUnit& operator=(CompilationUnit&&) = default; CompilationUnit(CompilationUnit&&) = default; CompilationUnit& operator=(const CompilationUnit&) = delete; CompilationUnit(const CompilationUnit&) = delete; Function* find_function(const c10::QualifiedName& name) const { auto it = dict_.find(name); if (it == dict_.end()) { return nullptr; } return functions_[it->second].get(); } Function& get_function(const c10::QualifiedName& name) const { if (auto r = find_function(name)) { return *r; } TORCH_CHECK(false, "attempted to get undefined function ", name.name()); } void set_optimized(bool o) { AT_WARN( "CompilationUnit::set_optimized() is deprecated and has no effect. " "Please use setGraphExecutorOptimize()"); } bool is_optimized() const { AT_WARN( "CompilationUnit::is_optimized() is deprecated and always returns true. " "Please use getGraphExecutorOptimize()"); return true; } // for historic reasons, these are defined in compiler.cpp // Returns the list of Function's just defined. std::vector define( const c10::optional& prefix, const std::vector& definitions, const std::vector& resolvers, /* determines how we handle free variables in each definition*/ // if non-null, the first argument to each def, is bound to this value const Self* self, // see [name mangling] bool shouldMangle = false); // same as above but parse the definitions from source // Returns the list of Function's just defined. std::vector define( // prefix namespace to put all the defined functions into const c10::optional& prefix, const std::string& source, const ResolverPtr& resolver, const Self* self); void define_interface( const c10::QualifiedName& qualifiedName, const ClassDef& classDef, ResolverPtr rcb); Function* create_function( c10::QualifiedName name, std::shared_ptr graph, bool shouldMangle = false) { if (shouldMangle) { name = mangle(name); } auto fn = torch::make_unique( std::move(name), std::move(graph), nullptr); auto ret = fn.get(); register_function(std::move(fn)); return ret; } std::vector get_functions() const { return fmap(functions_, [](const std::unique_ptr& fn) { return fn.get(); }); } /// Run a method from this compilation. /// /// For example: /// @code /// IValue output = module->run("relu_script", a, b); /// @endcode /// /// To get a compile a module from a source string, see torch::jit::compile /// /// @param method_name The name of the method to run /// @param args Arguments to be passed to the method /// @return An IValue containing the return value (or values if it is a tuple) /// from the method template IValue run_method(const c10::QualifiedName& method_name, Types&&... args) { return get_function(method_name)({IValue(std::forward(args))...}); } void drop_all_functions() { dict_.clear(); functions_.clear(); } /** * Register a class as being owned by this compilation unit. */ void register_type(c10::NamedTypePtr namedType) { // TODO: class types cannot be redefined because we have no way right now // of invalidating their methods. NamedTuples are fine though, since they // don't have methods. TORCH_CHECK( 0 == classDict_.count(*namedType->name()), "class '", namedType->name()->qualifiedName(), "' already defined."); classes_.push_back(std::move(namedType)); classDict_[*classes_.back()->name()] = classes_.size() - 1; }; c10::ClassTypePtr get_class(const c10::QualifiedName& name) const { auto type = get_type(name); if (!type) { return nullptr; } return type->cast(); } c10::TupleTypePtr get_named_tuple(const c10::QualifiedName& name) const { for (const auto& cls : classes_) { if (cls->name()->qualifiedName() == name.qualifiedName()) { return cls->expect(); } } return nullptr; } c10::NamedTypePtr get_type(const c10::QualifiedName& name) const { auto it = classDict_.find(name); if (it == classDict_.end()) { return nullptr; } return classes_[it->second]; } // For testing: clear all Python-defined classes to ensure that unit tests // have isolation. void _clear_python_cu() { // Delete all the associated class methods for (auto type : classes_) { if (auto cls = type->cast()) { for (auto method : cls->methods()) { // Tombstone the method in the compilation unit. // Don't erase because the dict_ auto it = dict_.find(method->qualname()); TORCH_INTERNAL_ASSERT(it != dict_.end()); functions_[it->second] = nullptr; // Erase in our big lookup table dict_.erase(it); } } } classes_.clear(); classDict_.clear(); } // [name mangling] All code objects must have a unique qualified name in a // CompilationUnit. In Python, sometimes functions won't have unique qualified // name (for example, nested functions). So we mangle Python functions to // ensure that they are uniquely named. // // We also use mangling to distinguish different Module instances. Since each // Module is a singleton class instance, different instances of the same // Python Module will have different types but the same qualified name. c10::QualifiedName mangle(const c10::QualifiedName& name) const; private: std::unique_ptr define( const c10::optional& prefix, const Def& def, const ResolverPtr& resolver, const Self* self, const std::unordered_map& function_table, bool shouldMangle = false) const; Function& register_function(std::unique_ptr fn) { TORCH_CHECK( 0 == dict_.count(fn->qualname().qualifiedName()), "method '", fn->qualname().qualifiedName(), "' already defined."); functions_.emplace_back(std::move(fn)); dict_[functions_.back()->qualname()] = functions_.size() - 1; return *functions_.back(); } std::vector> functions_; // for fast lookup std::unordered_map dict_; std::unordered_map classDict_; // [class ownership] Right now there aree two relationships between classes // and compilation units: // 1. Classes have compilation units internally that hold their methods. // 2. On load, the TypePtrs of any imported classes are owned by the main // module's compilation unit. std::vector classes_; mutable size_t mangleIndex_ = 0; }; } // namespace script // An owning pointer to a Function. Just a pair of a raw Function ptr and it's // owning CU. We need this because pybind requires a ref-counted way to refer to // Functions. struct StrongFunctionPtr { StrongFunctionPtr( std::shared_ptr cu, Function* function) : cu_(std::move(cu)), function_(function) { TORCH_INTERNAL_ASSERT(cu_); TORCH_INTERNAL_ASSERT(function_); } std::shared_ptr cu_; Function* function_; }; } // namespace jit } // namespace torch