reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-02-28 27bef7116852ea5e165bfe454b86345bd57a16ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#pragma once
 
#include <torch/nn/cloneable.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
 
#include <cstddef>
#include <vector>
 
namespace torch {
namespace nn {
 
/// Options for the `Embedding` module.
struct TORCH_API EmbeddingOptions {
  EmbeddingOptions(int64_t count, int64_t dimension);
  /// The number of embeddings (number of rows in the table).
  TORCH_ARG(int64_t, count);
  /// The size of each embedding vector (number of columns in the table).
  TORCH_ARG(int64_t, dimension);
};
 
/// Performs a lookup in a fixed size embedding table.
class TORCH_API EmbeddingImpl : public torch::nn::Cloneable<EmbeddingImpl> {
 public:
  EmbeddingImpl(int64_t count, int64_t dimension)
      : EmbeddingImpl(EmbeddingOptions(count, dimension)) {}
  explicit EmbeddingImpl(EmbeddingOptions options);
 
  void reset() override;
 
  /// Pretty prints the `Embedding` module into the given `stream`.
  void pretty_print(std::ostream& stream) const override;
 
  /// Performs a lookup on the embedding table stored in `weight` using the
  /// `indices` supplied and returns the result.
  Tensor forward(const Tensor& indices);
 
  /// The `Options` used to configure this `Embedding` module.
  /// Changes to `EmbeddingOptions` *after construction* have no effect.
  EmbeddingOptions options;
 
  /// The embedding table.
  Tensor weight;
};
 
/// A `ModuleHolder` subclass for `EmbeddingImpl`.
/// See the documentation for `EmbeddingImpl` class to learn what methods it
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
TORCH_MODULE(Embedding);
 
} // namespace nn
} // namespace torch