#pragma once #include #include #include #include #include #include #include #include #include #include namespace py = pybind11; namespace pybind11 { namespace detail { // torch.autograd.Variable <-> at::Tensor conversions (without unwrapping) template <> struct type_caster { public: PYBIND11_TYPE_CASTER(at::Tensor, _("at::Tensor")); bool load(handle src, bool) { PyObject* obj = src.ptr(); if (THPVariable_Check(obj)) { value = reinterpret_cast(obj)->cdata; return true; } return false; } static handle cast(const at::Tensor& src, return_value_policy /* policy */, handle /* parent */) { if (!src.is_variable()) { throw std::runtime_error( "Expected tensor's dynamic type to be Variable, not Tensor"); } return handle(THPVariable_Wrap(torch::autograd::Variable(src))); } }; template<> struct type_caster { public: PYBIND11_TYPE_CASTER(torch::autograd::Variable, _("torch::autograd::Variable")); bool load(handle src, bool) { PyObject *source = src.ptr(); if (THPVariable_Check(source)) { value = ((THPVariable*)source)->cdata; return true; } else { return false; } } static handle cast(torch::autograd::Variable src, return_value_policy /* policy */, handle /* parent */) { return handle(THPVariable_Wrap(std::move(src))); } }; template<> struct type_caster { public: PYBIND11_TYPE_CASTER(at::IntArrayRef, _("at::IntArrayRef")); bool load(handle src, bool) { PyObject *source = src.ptr(); auto tuple = PyTuple_Check(source); if (tuple || PyList_Check(source)) { auto size = tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source); v_value.resize(size); for (int idx = 0; idx < size; idx++) { PyObject* obj = tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx); if (THPVariable_Check(obj)) { v_value[idx] = THPVariable_Unpack(obj).item(); } else if (PyLong_Check(obj)) { // use THPUtils_unpackLong after it is safe to include python_numbers.h v_value[idx] = THPUtils_unpackLong(obj); } else { return false; } } value = v_value; return true; } return false; } static handle cast(at::IntArrayRef src, return_value_policy /* policy */, handle /* parent */) { return handle(THPUtils_packInt64Array(src.size(), src.data())); } private: std::vector v_value; }; // Pybind11 bindings for our optional type. // http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers template struct type_caster> : optional_caster> {}; }} // namespace pybind11::detail