#pragma once #include #include #include #include namespace torch { /// Serializes the given `value`. /// There must be an overload of `operator<<` between `serialize::OutputArchive` /// and `Value` for this method to be well-formed. Currently, such an overload /// is provided for (subclasses of): /// /// - `torch::nn::Module`, /// - `torch::optim::Optimizer` /// - `torch::Tensor` /// /// To perform the serialization, a `serialize::OutputArchive` is constructed, /// and all arguments after the `value` are forwarded to its `save_to` method. /// For example, you can pass a filename, or an `ostream`. /// /// \rst /// .. code-block:: cpp /// /// torch::nn::Linear model(3, 4); /// torch::save(model, "model.pt"); /// /// torch::optim::SGD sgd(/*lr=*/0.9); /// std::ostringstream stream; /// // Note that the same stream cannot be used in multiple torch::save(...) /// // invocations, otherwise the header will be corrupted. /// torch::save(sgd, stream); /// /// auto tensor = torch::ones({3, 4}); /// torch::save(tensor, "my_tensor.pt"); /// \endrst template void save(const Value& value, SaveToArgs&&... args) { serialize::OutputArchive archive( std::make_shared()); archive << value; archive.save_to(std::forward(args)...); } /// Serializes the given `tensor_vec` of type `std::vector`. /// /// To perform the serialization, a `serialize::OutputArchive` is constructed, /// and all arguments after the `tensor_vec` are forwarded to its `save_to` /// method. For example, you can pass a filename, or an `ostream`. /// /// \rst /// .. code-block:: cpp /// /// std::vector tensor_vec = { torch::randn({1, 2}), torch::randn({3, 4}) }; /// torch::save(tensor_vec, "my_tensor_vec.pt"); /// /// std::vector tensor_vec = { torch::randn({5, 6}), torch::randn({7, 8}) }; /// std::ostringstream stream; /// // Note that the same stream cannot be used in multiple torch::save(...) /// // invocations, otherwise the header will be corrupted. /// torch::save(tensor_vec, stream); /// \endrst template void save(const std::vector& tensor_vec, SaveToArgs&&... args) { serialize::OutputArchive archive( std::make_shared()); for (size_t i = 0; i < tensor_vec.size(); i++) { auto& value = tensor_vec[i]; archive.write(std::to_string(i), value); } archive.save_to(std::forward(args)...); } TORCH_API std::vector pickle_save(const torch::IValue& ivalue); /// Deserializes the given `value`. /// There must be an overload of `operator>>` between `serialize::InputArchive` /// and `Value` for this method to be well-formed. Currently, such an overload /// is provided for (subclasses of): /// /// - `torch::nn::Module`, /// - `torch::optim::Optimizer` /// - `torch::Tensor` /// /// To perform the serialization, a `serialize::InputArchive` is constructed, /// and all arguments after the `value` are forwarded to its `load_from` method. /// For example, you can pass a filename, or an `istream`. /// /// \rst /// .. code-block:: cpp /// /// torch::nn::Linear model(3, 4); /// torch::load(model, "model.pt"); /// /// torch::optim::SGD sgd(/*lr=*/0.9); /// std::istringstream stream("..."); /// torch::load(sgd, stream); /// /// auto tensor = torch::ones({3, 4}); /// torch::load(tensor, "my_tensor.pt"); /// \endrst template void load(Value& value, LoadFromArgs&&... args) { serialize::InputArchive archive; archive.load_from(std::forward(args)...); archive >> value; } /// Deserializes the given `tensor_vec` of type `std::vector`. /// /// To perform the serialization, a `serialize::InputArchive` is constructed, /// and all arguments after the `value` are forwarded to its `load_from` method. /// For example, you can pass a filename, or an `istream`. /// /// \rst /// .. code-block:: cpp /// /// std::vector tensor_vec; /// torch::load(tensor_vec, "my_tensor_vec.pt"); /// /// std::vector tensor_vec; /// std::istringstream stream("..."); /// torch::load(tensor_vec, stream); /// \endrst template void load(std::vector& tensor_vec, LoadFromArgs&&... args) { serialize::InputArchive archive; archive.load_from(std::forward(args)...); // NOTE: The number of elements in the serialized `std::vector` // is not known ahead of time, so we need a while-loop to increment the index, // and use `archive.try_read(...)` to check whether we have reached the end of // the serialized `std::vector`. size_t index = 0; torch::Tensor value; while (archive.try_read(std::to_string(index), value)) { tensor_vec.push_back(std::move(value)); value = torch::Tensor(); index++; } } } // namespace torch