#ifndef CAFFE2_CORE_BLOB_H_
|
#define CAFFE2_CORE_BLOB_H_
|
|
#include <cstddef>
|
#include <sstream>
|
#include <typeinfo>
|
#include <type_traits>
|
#include <vector>
|
#include "caffe2/core/common.h"
|
|
#include <ATen/core/blob.h>
|
#include <c10/util/typeid.h>
|
#include "caffe2/core/logging.h"
|
#include "caffe2/core/tensor.h"
|
#include "caffe2/core/tensor_int8.h"
|
|
namespace caffe2 {
|
|
inline bool BlobIsInt8TensorCPUType(const Blob& blob) {
|
return blob.meta().Match<int8::Int8TensorCPU>();
|
}
|
|
inline bool BlobIsTensorType(const Blob& blob, DeviceType device_type) {
|
bool is_match = blob.meta().Match<Tensor>();
|
if (!is_match) {
|
return false;
|
}
|
const Tensor* tensor = &blob.Get<Tensor>();
|
return tensor && *tensor && tensor->GetDeviceType() == device_type;
|
}
|
|
inline Tensor* BlobSetTensor(Blob* blob, Tensor&& tensor) {
|
return blob->Reset<Tensor>(new Tensor(std::move(tensor)));
|
}
|
|
inline Tensor GetSizedTensorWithOptions(
|
Tensor&& previous_tensor,
|
at::IntArrayRef dims,
|
at::TensorOptions options) {
|
Tensor tensor = std::move(previous_tensor);
|
if (!tensor.defined()) {
|
return caffe2::empty(dims, options);
|
}
|
if (tensor.GetDevice() == options.device() ||
|
(!tensor.GetDevice().has_index() &&
|
tensor.GetDeviceType() == options.device().type())) {
|
if (tensor.sizes() != dims) {
|
// Resize when the dims doesn't match
|
tensor.Resize(dims);
|
}
|
if (tensor.dtype() == options.dtype()) {
|
tensor.raw_mutable_data();
|
} else {
|
// create a new Tensor when the data_type doesn't match
|
return caffe2::empty(dims, options);
|
}
|
return tensor;
|
}
|
return caffe2::empty(dims, options);
|
}
|
|
// need to keep both functions that returns Tensor* and the one
|
// returns Tensor for clangr codemod
|
inline Tensor*
|
BlobGetMutableTensor(Blob* blob, at::IntArrayRef dims, at::TensorOptions options) {
|
if (blob->IsType<Tensor>()) {
|
Tensor* tensor = blob->GetMutable<Tensor>();
|
if (*tensor) {
|
// We only compare device_type if the index is not set since there are Tensors
|
// TODO: remove the extra check when all the Tensors are properly initialized
|
if (tensor->GetDevice() == options.device() || (!tensor->GetDevice().has_index() && tensor->GetDeviceType() == options.device().type())) {
|
if (tensor->sizes() != dims) {
|
// Resize when the dims doesn't match
|
tensor->Resize(dims);
|
}
|
if (tensor->dtype() == options.dtype()) {
|
tensor->raw_mutable_data();
|
} else {
|
tensor->raw_mutable_data(options.dtype());
|
}
|
return tensor;
|
}
|
// create a new Tensor when device doesn't match
|
}
|
}
|
|
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
|
<< " dims: " << dims;
|
// << " options: " << options; (operator<< for Options is in at:: now)
|
return BlobSetTensor(blob, caffe2::empty(dims, options));
|
}
|
|
inline Tensor
|
XBlobGetMutableTensor(Blob* blob, at::IntArrayRef dims, at::TensorOptions options) {
|
return BlobGetMutableTensor(blob, dims, options)->UnsafeSharedInstance();
|
}
|
|
inline Tensor* BlobGetMutableTensor(Blob* blob, DeviceType device_type) {
|
if (blob->IsType<Tensor>()) {
|
Tensor* tensor = blob->GetMutable<Tensor>();
|
if (*tensor && tensor->GetDeviceType() == device_type) {
|
return tensor;
|
}
|
}
|
|
// if we're here, then either Blob didn't hold a Tensor
|
// or that Tensor had the wrong DeviceType.
|
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
|
<< " DeviceType:" << device_type;
|
|
return BlobSetTensor(blob, Tensor(device_type));
|
}
|
|
inline const Tensor& BlobGetTensor(const Blob& blob, DeviceType device_type) {
|
if (blob.IsType<Tensor>()) {
|
const auto& tensor = blob.Get<Tensor>();
|
if (tensor.GetDeviceType() == device_type) {
|
return tensor;
|
}
|
}
|
CAFFE_THROW("Blob didn't contain a Tensor or the device_type doesn't match");
|
}
|
|
inline Tensor BlobGetTensorOrUndefined(const Blob& blob) {
|
if (blob.IsType<Tensor>()) {
|
return blob.Get<Tensor>().UnsafeSharedInstance();
|
} else {
|
return Tensor();
|
}
|
}
|
|
} // namespace caffe2
|
#endif // CAFFE2_CORE_BLOB_H_
|