#pragma once #include #include #include // This is in the c10 namespace because we use ADL to find the functions in it. namespace c10 { // FIXME: this should be (and was) Scalar::toTensor, but there is currently no way // to implement this without going through Derived Types (which are not part of core). inline at::Tensor scalar_to_tensor(Scalar s, const Device device = at::kCPU) { if (s.isFloatingPoint()) { return at::scalar_tensor(s, at::device(device).dtype(at::kDouble)); } else if (s.isBoolean()) { return at::scalar_tensor(s, at::device(device).dtype(at::kBool)); } else if (s.isComplex()) { return at::scalar_tensor(s, at::device(device).dtype(at::kComplexDouble)); } else { AT_ASSERT(s.isIntegral(false)); return at::scalar_tensor(s, at::device(device).dtype(at::kLong)); } } }