reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-16 a47fccb11fa3470901aebcb27f861d242d0925e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#pragma once
 
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/data/samplers/base.h>
#include <torch/data/samplers/custom_batch_request.h>
#include <torch/types.h>
 
#include <cstddef>
 
namespace torch {
namespace serialize {
class InputArchive;
class OutputArchive;
} // namespace serialize
} // namespace torch
 
namespace torch {
namespace data {
namespace samplers {
 
/// A wrapper around a batch size value, which implements the
/// `CustomBatchRequest` interface.
struct TORCH_API BatchSize : public CustomBatchRequest {
  explicit BatchSize(size_t size);
  size_t size() const noexcept override;
  operator size_t() const noexcept;
  size_t size_;
};
 
/// A sampler for (potentially infinite) streams of data.
///
/// The major feature of the `StreamSampler` is that it does not return
/// particular indices, but instead only the number of elements to fetch from
/// the dataset. The dataset has to decide how to produce those elements.
class TORCH_API StreamSampler : public Sampler<BatchSize> {
 public:
  /// Constructs the `StreamSampler` with the number of individual examples that
  /// should be fetched until the sampler is exhausted.
  explicit StreamSampler(size_t epoch_size);
 
  /// Resets the internal state of the sampler.
  void reset(optional<size_t> new_size = nullopt) override;
 
  /// Returns a `BatchSize` object with the number of elements to fetch in the
  /// next batch. This number is the minimum of the supplied `batch_size` and
  /// the difference between the `epoch_size` and the current index. If the
  /// `epoch_size` has been reached, returns an empty optional.
  optional<BatchSize> next(size_t batch_size) override;
 
  /// Serializes the `StreamSampler` to the `archive`.
  void save(serialize::OutputArchive& archive) const override;
 
  /// Deserializes the `StreamSampler` from the `archive`.
  void load(serialize::InputArchive& archive) override;
 
 private:
  size_t examples_retrieved_so_far_ = 0;
  size_t epoch_size_;
};
 
} // namespace samplers
} // namespace data
} // namespace torch