#ifndef CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_
|
#define CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_
|
|
#include <iostream>
|
#include <mutex>
|
|
#include "caffe2/core/db.h"
|
#include "caffe2/operators/prefetch_op.h"
|
|
namespace caffe2 {
|
|
template <class Context>
|
class TensorProtosDBInput final : public PrefetchOperator<Context> {
|
public:
|
using OperatorBase::OutputSize;
|
using PrefetchOperator<Context>::prefetch_thread_;
|
explicit TensorProtosDBInput(const OperatorDef& operator_def, Workspace* ws);
|
~TensorProtosDBInput() {
|
PrefetchOperator<Context>::Finalize();
|
}
|
|
bool Prefetch() override;
|
bool CopyPrefetched() override;
|
|
private:
|
// Prefetch will always just happen on the CPU side.
|
vector<Blob> prefetched_blobs_;
|
int batch_size_;
|
bool shape_inferred_ = false;
|
string key_;
|
string value_;
|
};
|
|
template <class Context>
|
TensorProtosDBInput<Context>::TensorProtosDBInput(
|
const OperatorDef& operator_def,
|
Workspace* ws)
|
: PrefetchOperator<Context>(operator_def, ws),
|
prefetched_blobs_(operator_def.output_size()),
|
batch_size_(
|
this->template GetSingleArgument<int>("batch_size", 0)) {}
|
|
template <class Context>
|
bool TensorProtosDBInput<Context>::Prefetch() {
|
const db::DBReader& reader = this->template Input<db::DBReader>(0);
|
TensorDeserializer deserializer;
|
if (batch_size_ == 0) {
|
// We do not need to construct a batch. As a result, we will simply
|
// deserialize everything into the target prefetched blob.
|
reader.Read(&key_, &value_);
|
TensorProtos protos;
|
CAFFE_ENFORCE(protos.ParseFromString(value_));
|
CAFFE_ENFORCE(protos.protos_size() == OutputSize());
|
for (int i = 0; i < protos.protos_size(); ++i) {
|
if (protos.protos(i).has_device_detail()) {
|
protos.mutable_protos(i)->clear_device_detail();
|
}
|
BlobSetTensor(
|
&prefetched_blobs_[i], deserializer.Deserialize(protos.protos(i)));
|
// deserializer.Deserialize(
|
// protos.protos(i), BlobGetMutableTensor(&prefetched_blobs_[i],
|
// CPU));
|
}
|
} else {
|
for (int item_id = 0; item_id < batch_size_; ++item_id) {
|
reader.Read(&key_, &value_);
|
TensorProtos protos;
|
CAFFE_ENFORCE(protos.ParseFromString(value_));
|
CAFFE_ENFORCE(protos.protos_size() == OutputSize());
|
// Note: shape_inferred_ is ignored, we'll always get dimensions from
|
// proto
|
for (int i = 0; i < protos.protos_size(); ++i) {
|
vector<int64_t> dims(
|
protos.protos(i).dims().begin(), protos.protos(i).dims().end());
|
dims.insert(dims.begin(), batch_size_);
|
if (protos.protos(i).has_device_detail()) {
|
protos.mutable_protos(i)->clear_device_detail();
|
}
|
Tensor src = deserializer.Deserialize(protos.protos(i));
|
Tensor* dst = BlobGetMutableTensor(
|
&prefetched_blobs_[i], dims, at::dtype(src.dtype()).device(CPU));
|
DCHECK_EQ(src.numel() * batch_size_, dst->numel());
|
this->context_.CopyItemsSameDevice(
|
src.dtype(),
|
src.numel(),
|
src.raw_data(),
|
static_cast<char*>(dst->raw_mutable_data(src.dtype())) +
|
src.nbytes() * item_id);
|
}
|
}
|
}
|
return true;
|
}
|
|
template <class Context>
|
bool TensorProtosDBInput<Context>::CopyPrefetched() {
|
for (int i = 0; i < OutputSize(); ++i) {
|
OperatorBase::template Output<Tensor>(i, Context::GetDeviceType())
|
->CopyFrom(
|
prefetched_blobs_[i].template Get<TensorCPU>(), /* async */ true);
|
}
|
return true;
|
}
|
|
} // namespace caffe2
|
|
#endif // CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_
|