#pragma once #include #include #include #include #include #include #include namespace torch { using at::IterArgs; struct CountTensors : IterArgs { size_t out = 0; void operator()(const at::Tensor& x) { out += 1; } void operator()(at::ArrayRef xs) { out += xs.size(); } }; template size_t count_tensors(Args&&... args) { return CountTensors().apply(std::forward(args)...).out; } struct CountVariables : IterArgs { size_t out = 0; void operator()(const autograd::Variable& x) { out += 1; } void operator()(at::ArrayRef xs) { out += xs.size(); } }; template inline size_t count_variables(Args&&... args) { return CountVariables().apply(std::forward(args)...).out; } //===----------------------------------------------------------------------===// // std::index_sequence shim for C++11 //===----------------------------------------------------------------------===// // A container of type-template parameter indices. template struct Indices {}; // Decrements the index N, adds N-1 to the list of indices and forwards // whatever we arleady have. template struct MakeIndices : MakeIndices {}; // Partial specialization that forms our base case. When N is zero, we stop // and define a typedef that will be visible to earlier classes due to // inheritance. The typedef we define is an index list containing the numbers // 0 through N-1. template struct MakeIndices<0, Is...> { using indices = Indices; }; //===----------------------------------------------------------------------===// // Utilities //===----------------------------------------------------------------------===// template using enable_if_t = typename std::enable_if::type; template using disable_if_t = enable_if_t; template using decay_t = typename std::decay::type; namespace detail { template struct pack; } // namespace detail template struct all_of : std::is_same< detail::pack, detail::pack> {}; template struct any_of; template <> struct any_of<> : std::false_type {}; template struct any_of { static constexpr bool value = head || any_of::value; }; template struct none_of { static constexpr bool value = !any_of::value; }; template using enable_if_all_of_t = enable_if_t::value>; template using disable_if_contains_t = enable_if_all_of_t<(!std::is_same>::value)...>; template void apply(Function function, Ts&&... ts) { // https://stackoverflow.com/questions/13978916/inserting-a-variadic-argument-list-into-a-vector // Creates a dummy array, so that each function call is evaluated in order. // `(function(), 0)` is because `function` should (!) return `void`, so // according to the comma operator, it is evaluated and its result (`void`) // is discarded. Then the zero is evaluated and used as an element in the // array. The first zero ensures the array is not empty. int _[]{0, (function(std::forward(ts)), 0)...}; (void)_; } template ReturnType unpack(Function function, Accessor accessor) { return ReturnType(unpack( std::move(function), std::move(accessor), typename MakeIndices::indices())); } template ReturnType unpack(Function function, Accessor accessor, Indices) { return ReturnType(function(accessor.template operator()(Is)...)); } } // namespace torch