#pragma once
|
|
#include "caffe2/core/context.h"
|
#include "caffe2/core/tensor.h"
|
#include "caffe2/core/types.h"
|
#include "caffe2/proto/caffe2_pb.h"
|
#include "caffe2/python/dlpack.h"
|
|
#include <pybind11/pybind11.h>
|
#include <pybind11/stl.h>
|
|
namespace caffe2 {
|
namespace python {
|
|
namespace py = pybind11;
|
|
const DLDeviceType* CaffeToDLDeviceType(int device_type);
|
|
const DLDataType* CaffeToDLType(const TypeMeta& meta);
|
|
const TypeMeta& DLTypeToCaffe(const DLDataType& dl_type);
|
|
// TODO: remove context
|
template <class Context>
|
class DLPackWrapper {
|
public:
|
DLPackWrapper(Tensor* tensor, DeviceOption device_option)
|
: tensor(tensor), device_option(device_option) {}
|
|
py::object data() {
|
DLContext tensor_context;
|
auto device_type_ptr = CaffeToDLDeviceType(device_option.device_type());
|
CAFFE_ENFORCE(
|
device_type_ptr,
|
"Unsupported device type: ",
|
device_option.device_type());
|
tensor_context.device_type = *device_type_ptr;
|
tensor_context.device_id = device_option.device_id();
|
|
if (tensor->numel() <= 0) {
|
tensor->Resize(0);
|
}
|
if (tensor->dtype().id() == TypeIdentifier::uninitialized()) {
|
// treat uninitialized tensor as float tensor
|
tensor->template mutable_data<float>();
|
}
|
CAFFE_ENFORCE_GT(tensor->dim(), 0);
|
|
auto type_ptr = CaffeToDLType(tensor->dtype());
|
CAFFE_ENFORCE(
|
type_ptr,
|
"Tensor type is not supported in DLPack: ",
|
tensor->dtype().name());
|
DLDataType tensor_type = *type_ptr;
|
|
DLTensor dlTensor;
|
dlTensor.data = const_cast<void*>(tensor->raw_data());
|
dlTensor.ctx = tensor_context;
|
dlTensor.ndim = tensor->dim();
|
dlTensor.dtype = tensor_type;
|
dlTensor.shape = const_cast<int64_t*>(&(tensor->sizes()[0]));
|
dlTensor.strides = nullptr;
|
dlTensor.byte_offset = 0;
|
|
managed_tensor.dl_tensor = dlTensor;
|
// C2 Tensor memory is managed by C2
|
managed_tensor.manager_ctx = nullptr;
|
managed_tensor.deleter= [](DLManagedTensor*) {};
|
|
return py::reinterpret_steal<py::object>(
|
PyCapsule_New(&managed_tensor, "dltensor", nullptr));
|
}
|
|
void feed(py::object obj) {
|
CAFFE_ENFORCE(PyCapsule_CheckExact(obj.ptr()), "Expected DLPack capsule");
|
DLManagedTensor* dlMTensor =
|
(DLManagedTensor*)PyCapsule_GetPointer(obj.ptr(), "dltensor");
|
CAFFE_ENFORCE(dlMTensor, "Invalid DLPack capsule");
|
DLTensor* dlTensor = &dlMTensor->dl_tensor;
|
auto device_type_ptr = CaffeToDLDeviceType(device_option.device_type());
|
CAFFE_ENFORCE(
|
device_type_ptr,
|
"Unsupported device type: ",
|
device_option.device_type());
|
CAFFE_ENFORCE(
|
dlTensor->ctx.device_type == *device_type_ptr,
|
"DLPack tensor device type mismatch");
|
int dlpack_device_id = dlTensor->ctx.device_id;
|
CAFFE_ENFORCE_EQ(
|
dlpack_device_id,
|
device_option.device_id(),
|
"Expected same device id for DLPack and C2 tensors");
|
|
std::vector<int64_t> dims;
|
dims.reserve(dlTensor->ndim);
|
for (int idx = 0; idx < dlTensor->ndim; ++idx) {
|
dims.push_back(dlTensor->shape[idx]);
|
}
|
|
if (dlTensor->strides) {
|
int64_t stride = 1;
|
for (int idx = dims.size() - 1; idx >= 0; --idx) {
|
CAFFE_ENFORCE_EQ(
|
stride,
|
dlTensor->strides[idx],
|
"Tensors with non-standard strides are not supported");
|
stride *= dims[idx];
|
}
|
}
|
|
tensor->Resize(dims);
|
caffe2::TypeMeta meta = DLTypeToCaffe(dlTensor->dtype);
|
at::Device device = at::Device(tensor->GetDeviceType());
|
tensor->ShareExternalPointer(
|
at::DataPtr(
|
(void*)(((int8_t*)dlTensor->data) + dlTensor->byte_offset),
|
static_cast<void*>(dlMTensor),
|
[](void* t_ptr) -> void {
|
DLManagedTensor* mt_ptr = static_cast<DLManagedTensor*>(t_ptr);
|
if (mt_ptr->deleter) {
|
mt_ptr->deleter(mt_ptr);
|
}
|
},
|
device),
|
meta,
|
0);
|
}
|
|
Tensor* tensor;
|
DeviceOption device_option;
|
DLManagedTensor managed_tensor;
|
};
|
|
} // namespace python
|
} // namespace caffe2
|