#pragma once // Wrap tensor operation outputs as PyObject* #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace autograd { namespace utils { inline PyObject* wrap(bool value) { if (value) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; } } inline PyObject* wrap(int64_t value) { return THPUtils_packInt64(value); } inline PyObject* wrap(double value) { return PyFloat_FromDouble(value); } inline PyObject* wrap(std::complex value) { // I could probably also use FromComplex with a reinterpret cast, // but... eh. return PyComplex_FromDoubles(value.real(), value.imag()); } inline PyObject* wrap(void* value) { return THPUtils_packInt64(reinterpret_cast(value)); } inline PyObject* wrap(THPDtype *dtype) { Py_INCREF(dtype); return (PyObject*)dtype; } inline PyObject* wrap(at::ScalarType scalarType) { return wrap(getDtype(scalarType)); } inline PyObject* wrap(THPLayout *layout) { Py_INCREF(layout); return (PyObject*)layout; } inline PyObject* wrap(at::Tensor tensor) { return THPVariable_Wrap(Variable(std::move(tensor))); } inline PyObject* wrap(at::Scalar scalar) { return wrap(make_variable(scalar_to_tensor(scalar))); } inline PyObject* wrap(at::QScheme qscheme) { auto* thp_qscheme = torch::utils::getTHPQScheme(qscheme); Py_INCREF(thp_qscheme); return thp_qscheme; } inline PyObject* wrap(std::tuple tensors) { auto r = THPObjectPtr{PyTuple_New(2)}; if (!r) throw python_error(); PyTuple_SET_ITEM(r.get(), 0, wrap(std::get<0>(tensors))); PyTuple_SET_ITEM(r.get(), 1, wrap(std::get<1>(tensors))); return r.release(); } inline PyObject* wrap(PyTypeObject *type, std::tuple tensors) { auto r = THPObjectPtr{PyStructSequence_New(type)}; if (!r) throw python_error(); PyStructSequence_SET_ITEM(r.get(), 0, wrap(std::get<0>(tensors))); PyStructSequence_SET_ITEM(r.get(), 1, wrap(std::get<1>(tensors))); return r.release(); } inline PyObject* wrap(std::tuple tensors) { auto r = THPObjectPtr{PyTuple_New(3)}; if (!r) throw python_error(); PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors)))); PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors)))); PyTuple_SET_ITEM(r.get(), 2, wrap(std::move(std::get<2>(tensors)))); return r.release(); } inline PyObject* wrap(PyTypeObject *type, std::tuple tensors) { auto r = THPObjectPtr{PyStructSequence_New(type)}; if (!r) throw python_error(); PyStructSequence_SET_ITEM(r.get(), 0, wrap(std::get<0>(tensors))); PyStructSequence_SET_ITEM(r.get(), 1, wrap(std::get<1>(tensors))); PyStructSequence_SET_ITEM(r.get(), 2, wrap(std::get<2>(tensors))); return r.release(); } inline PyObject* wrap(std::tuple tensors) { auto r = THPObjectPtr{PyTuple_New(4)}; if (!r) throw python_error(); PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors)))); PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors)))); PyTuple_SET_ITEM(r.get(), 2, wrap(std::move(std::get<2>(tensors)))); PyTuple_SET_ITEM(r.get(), 3, wrap(std::get<3>(tensors))); return r.release(); } inline PyObject* wrap(std::tuple tensors) { auto r = THPObjectPtr{PyTuple_New(4)}; if (!r) throw python_error(); PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors)))); PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors)))); PyTuple_SET_ITEM(r.get(), 2, wrap(std::move(std::get<2>(tensors)))); PyTuple_SET_ITEM(r.get(), 3, wrap(std::move(std::get<3>(tensors)))); return r.release(); } inline PyObject* wrap(std::tuple tensors) { auto r = THPObjectPtr{PyTuple_New(4)}; if (!r) throw python_error(); PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors)))); PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors)))); PyTuple_SET_ITEM(r.get(), 2, wrap(std::move(std::get<2>(tensors)))); PyTuple_SET_ITEM(r.get(), 3, wrap(std::move(std::get<3>(tensors)))); return r.release(); } inline PyObject* wrap(std::tuple tensors) { auto r = THPObjectPtr{PyTuple_New(5)}; if (!r) throw python_error(); PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors)))); PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors)))); PyTuple_SET_ITEM(r.get(), 2, wrap(std::move(std::get<2>(tensors)))); PyTuple_SET_ITEM(r.get(), 3, wrap(std::move(std::get<3>(tensors)))); PyTuple_SET_ITEM(r.get(), 4, wrap(std::move(std::get<4>(tensors)))); return r.release(); } inline PyObject* wrap(at::TensorList tl) { auto r = THPObjectPtr{PyTuple_New(tl.size())}; if (!r) throw python_error(); for (size_t i = 0; i < tl.size(); ++i) { PyTuple_SET_ITEM(r.get(), i, wrap(tl[i])); } return r.release(); } inline PyObject* wrap(at::IntArrayRef list) { auto r = THPObjectPtr{PyTuple_New(list.size())}; if (!r) throw python_error(); for (size_t i = 0; i < list.size(); ++i) { PyTuple_SET_ITEM(r.get(), i, wrap(list[i])); } return r.release(); } }}} // namespace torch::autograd::utils