#pragma once #include #include #include #include #include #include #include #include #include namespace at { class Tensor; } // namespace at namespace torch { using at::Tensor; namespace jit { namespace script { struct Module; } // namespace script } // namespace jit } // namespace torch namespace torch { namespace serialize { /// A recursive representation of tensors that can be deserialized from a file /// or stream. In most cases, users should not have to interact with this class, /// and should instead use `torch::load`. class TORCH_API InputArchive final { public: /// Default-constructs the `InputArchive`. InputArchive(); // Move is allowed. InputArchive(InputArchive&&) = default; InputArchive& operator=(InputArchive&&) = default; // Copy is disallowed. InputArchive(InputArchive&) = delete; InputArchive& operator=(InputArchive&) = delete; ~InputArchive() = default; /// Reads an `IValue` associated with a given `key`. void read(const std::string& key, c10::IValue& ivalue); /// Reads a `tensor` associated with a given `key`. If there is no `tensor` /// associated with the `key`, this returns false, otherwise it returns true. /// If the tensor is expected to be a buffer (not differentiable), `is_buffer` /// must be `true`. bool try_read(const std::string& key, Tensor& tensor, bool is_buffer = false); /// Reads a `tensor` associated with a given `key`. /// If the tensor is expected to be a buffer (not differentiable), `is_buffer` /// must be `true`. void read(const std::string& key, Tensor& tensor, bool is_buffer = false); /// Reads a `InputArchive` associated with a given `key`. If there is no /// `InputArchive` associated with the `key`, this returns false, otherwise /// it returns true. bool try_read(const std::string& key, InputArchive& archive); /// Reads an `InputArchive` associated with a given `key`. /// The archive can thereafter be used for further deserialization of the /// nested data. void read(const std::string& key, InputArchive& archive); /// Loads the `InputArchive` from a serialized representation stored in the /// file at `filename`. Storage are remapped using device option. If device /// is not specified, the module is loaded to the original device. void load_from(const std::string& filename, c10::optional device = c10::nullopt); /// Loads the `InputArchive` from a serialized representation stored in the /// given `stream`. Storage are remapped using device option. If device /// is not specified, the module is loaded to the original device. void load_from(std::istream& stream, c10::optional device = c10::nullopt); /// Forwards all arguments to `read()`. /// Useful for generic code that can be re-used for both `InputArchive` and /// `OutputArchive` (where `operator()` forwards to `write()`). template void operator()(Ts&&... ts) { read(std::forward(ts)...); } private: jit::script::Module module_; }; } // namespace serialize } // namespace torch