reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-10 c3765bd24fe73747688a0ec2a550f219c9acb384
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#pragma once
 
#include <torch/data/datasets/base.h>
#include <torch/data/example.h>
 
#include <cstddef>
#include <vector>
 
namespace torch {
namespace serialize {
class OutputArchive;
class InputArchive;
} // namespace serialize
} // namespace torch
 
namespace torch {
namespace data {
namespace datasets {
 
/// A stateful dataset is a dataset that maintains some internal state, which
/// will be `reset()` at the beginning of each epoch. Subclasses can override
/// the `reset()` method to configure this behavior. Further, the return type of
/// a stateful dataset's `get_batch()` method is always an `optional`. When the
/// stateful dataset wants to indicate to the dataloader that its epoch has
/// ended, it should return an empty optional. The dataloader knows to modify
/// its implementation based on whether the dataset is stateless or stateful.
///
/// Note that when subclassing a from `StatefulDataset<Self, T>`, the return
/// type of `get_batch()`, which the subclass must override, will be
/// `optional<T>` (i.e. the type specified in the `StatefulDataset` specialization is automatically boxed into an `optional` for the datast's `BatchType`).
template <
    typename Self,
    typename Batch = std::vector<Example<>>,
    typename BatchRequest = size_t>
class StatefulDataset
    : public BatchDataset<Self, optional<Batch>, BatchRequest> {
 public:
  /// Resets internal state of the dataset.
  virtual void reset() = 0;
 
  /// Saves the statefulDataset's state to OutputArchive.
  virtual void save(serialize::OutputArchive& archive) const = 0;
 
  /// Deserializes the statefulDataset's state from the `archive`.
  virtual void load(serialize::InputArchive& archive) = 0;
};
 
/// Serializes a statefulDataset to `OutputArchive`.
template <typename... Args>
serialize::OutputArchive& operator<<(
    serialize::OutputArchive& archive,
    const StatefulDataset<Args...>& statefulDataset) {
  statefulDataset.save(archive);
  return archive;
}
 
/// Deserializes a statefulDataset from an `InputArchive`.
template <typename... Args>
serialize::InputArchive& operator>>(
    serialize::InputArchive& archive,
    StatefulDataset<Args...>& statefulDataset) {
  statefulDataset.load(archive);
  return archive;
}
 
} // namespace datasets
} // namespace data
} // namespace torch