#pragma once #include #include #include #include #include namespace torch { namespace nn { /// Options for the `Embedding` module. struct TORCH_API EmbeddingOptions { EmbeddingOptions(int64_t count, int64_t dimension); /// The number of embeddings (number of rows in the table). TORCH_ARG(int64_t, count); /// The size of each embedding vector (number of columns in the table). TORCH_ARG(int64_t, dimension); }; /// Performs a lookup in a fixed size embedding table. class TORCH_API EmbeddingImpl : public torch::nn::Cloneable { public: EmbeddingImpl(int64_t count, int64_t dimension) : EmbeddingImpl(EmbeddingOptions(count, dimension)) {} explicit EmbeddingImpl(EmbeddingOptions options); void reset() override; /// Pretty prints the `Embedding` module into the given `stream`. void pretty_print(std::ostream& stream) const override; /// Performs a lookup on the embedding table stored in `weight` using the /// `indices` supplied and returns the result. Tensor forward(const Tensor& indices); /// The `Options` used to configure this `Embedding` module. /// Changes to `EmbeddingOptions` *after construction* have no effect. EmbeddingOptions options; /// The embedding table. Tensor weight; }; /// A `ModuleHolder` subclass for `EmbeddingImpl`. /// See the documentation for `EmbeddingImpl` class to learn what methods it /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's /// module storage semantics. TORCH_MODULE(Embedding); } // namespace nn } // namespace torch