reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-10 c3765bd24fe73747688a0ec2a550f219c9acb384
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#pragma once
 
#include <torch/csrc/python_headers.h>
#include <ATen/ATen.h>
 
#include <torch/csrc/utils/python_arg_parser.h>
 
namespace torch { namespace autograd { namespace utils {
 
// The parameter allow_copy is to accept copy for Tensor.to (and by proxy
// PackedSequences.to) but not nn.Module.to.
inline std::tuple<c10::optional<at::Device>, c10::optional<at::ScalarType>, bool, bool>
  parse_to_conversion(PyObject *args, PyObject *kwargs, bool allow_copy) {
  static PythonArgParser parser({
    "to(Device device=None, ScalarType dtype=None, bool non_blocking=False, bool copy=False)",
    "to(ScalarType dtype, bool non_blocking=False, bool copy=False)",
    "to(Tensor tensor, bool non_blocking=False, bool copy=False)",
  });
  ParsedArgs<4> parsed_args;
  auto r = parser.parse(args, kwargs, parsed_args);
  if (r.idx == 0) {
    if (!allow_copy && !r.isNone(3))
      throw std::runtime_error(".to() does not accept copy argument");
    return std::make_tuple(r.deviceOptional(0), r.scalartypeOptional(1), r.toBool(2), r.toBool(3));
  } else if (r.idx == 1) {
    if (!allow_copy && !r.isNone(2))
      throw std::runtime_error(".to() does not accept copy argument");
    return std::make_tuple(c10::nullopt, r.scalartype(0), r.toBool(1), r.toBool(2));
  } else {
    auto tensor = r.tensor(0);
    if (!allow_copy && !r.isNone(2))
      throw std::runtime_error(".to() does not accept copy argument");
    return std::make_tuple(
      tensor.device(),
      tensor.scalar_type(),
      r.toBool(1),
      r.toBool(2)
    );
  }
}
}}} // namespace torch::autograd::utils