#pragma once #include "caffe2/core/context.h" #include "caffe2/core/tensor.h" #include "caffe2/core/types.h" #include "caffe2/proto/caffe2_pb.h" #include "caffe2/python/dlpack.h" #include #include namespace caffe2 { namespace python { namespace py = pybind11; const DLDeviceType* CaffeToDLDeviceType(int device_type); const DLDataType* CaffeToDLType(const TypeMeta& meta); const TypeMeta& DLTypeToCaffe(const DLDataType& dl_type); // TODO: remove context template class DLPackWrapper { public: DLPackWrapper(Tensor* tensor, DeviceOption device_option) : tensor(tensor), device_option(device_option) {} py::object data() { DLContext tensor_context; auto device_type_ptr = CaffeToDLDeviceType(device_option.device_type()); CAFFE_ENFORCE( device_type_ptr, "Unsupported device type: ", device_option.device_type()); tensor_context.device_type = *device_type_ptr; tensor_context.device_id = device_option.device_id(); if (tensor->numel() <= 0) { tensor->Resize(0); } if (tensor->dtype().id() == TypeIdentifier::uninitialized()) { // treat uninitialized tensor as float tensor tensor->template mutable_data(); } CAFFE_ENFORCE_GT(tensor->dim(), 0); auto type_ptr = CaffeToDLType(tensor->dtype()); CAFFE_ENFORCE( type_ptr, "Tensor type is not supported in DLPack: ", tensor->dtype().name()); DLDataType tensor_type = *type_ptr; DLTensor dlTensor; dlTensor.data = const_cast(tensor->raw_data()); dlTensor.ctx = tensor_context; dlTensor.ndim = tensor->dim(); dlTensor.dtype = tensor_type; dlTensor.shape = const_cast(&(tensor->sizes()[0])); dlTensor.strides = nullptr; dlTensor.byte_offset = 0; managed_tensor.dl_tensor = dlTensor; // C2 Tensor memory is managed by C2 managed_tensor.manager_ctx = nullptr; managed_tensor.deleter= [](DLManagedTensor*) {}; return py::reinterpret_steal( PyCapsule_New(&managed_tensor, "dltensor", nullptr)); } void feed(py::object obj) { CAFFE_ENFORCE(PyCapsule_CheckExact(obj.ptr()), "Expected DLPack capsule"); DLManagedTensor* dlMTensor = (DLManagedTensor*)PyCapsule_GetPointer(obj.ptr(), "dltensor"); CAFFE_ENFORCE(dlMTensor, "Invalid DLPack capsule"); DLTensor* dlTensor = &dlMTensor->dl_tensor; auto device_type_ptr = CaffeToDLDeviceType(device_option.device_type()); CAFFE_ENFORCE( device_type_ptr, "Unsupported device type: ", device_option.device_type()); CAFFE_ENFORCE( dlTensor->ctx.device_type == *device_type_ptr, "DLPack tensor device type mismatch"); int dlpack_device_id = dlTensor->ctx.device_id; CAFFE_ENFORCE_EQ( dlpack_device_id, device_option.device_id(), "Expected same device id for DLPack and C2 tensors"); std::vector dims; dims.reserve(dlTensor->ndim); for (int idx = 0; idx < dlTensor->ndim; ++idx) { dims.push_back(dlTensor->shape[idx]); } if (dlTensor->strides) { int64_t stride = 1; for (int idx = dims.size() - 1; idx >= 0; --idx) { CAFFE_ENFORCE_EQ( stride, dlTensor->strides[idx], "Tensors with non-standard strides are not supported"); stride *= dims[idx]; } } tensor->Resize(dims); caffe2::TypeMeta meta = DLTypeToCaffe(dlTensor->dtype); at::Device device = at::Device(tensor->GetDeviceType()); tensor->ShareExternalPointer( at::DataPtr( (void*)(((int8_t*)dlTensor->data) + dlTensor->byte_offset), static_cast(dlMTensor), [](void* t_ptr) -> void { DLManagedTensor* mt_ptr = static_cast(t_ptr); if (mt_ptr->deleter) { mt_ptr->deleter(mt_ptr); } }, device), meta, 0); } Tensor* tensor; DeviceOption device_option; DLManagedTensor managed_tensor; }; } // namespace python } // namespace caffe2