#ifndef CAFFE2_OPERATORS_DATASET_OPS_H_
|
#define CAFFE2_OPERATORS_DATASET_OPS_H_
|
|
#include <memory>
|
#include <mutex>
|
#include <string>
|
#include <vector>
|
#include "caffe2/core/blob.h"
|
#include "caffe2/core/blob_serialization.h"
|
#include "caffe2/core/tensor.h"
|
|
namespace caffe2 {
|
namespace dataset_ops {
|
|
// used for lengths tensors in the dataset
|
using TLength = int32_t;
|
// used for all internal dataset operations (offsets, sizes to read, etc.)
|
using TOffset = int64_t;
|
|
/**
|
* Provides functionality to iterate across a list of tensors where some
|
* of those tensors represent lengths in a hierarchical structure.
|
*/
|
class TreeIterator {
|
public:
|
struct FieldDesc {
|
int id;
|
int lengthFieldId = -1;
|
std::string name;
|
};
|
|
explicit TreeIterator(const std::vector<std::string>& fields);
|
|
void advance(
|
const std::vector<const TLength*>& lengths,
|
std::vector<TOffset>& offsets,
|
std::vector<TOffset>& sizes,
|
std::vector<TOffset>& limits,
|
TOffset num);
|
|
// Corresponds to the number of fields that have "length" as its last name
|
int numLengthFields() const {
|
return lengthFieldIds_.size();
|
}
|
|
// Corresponds to the number of length fields + 1 (for the top-level domain)
|
int numOffsetFields() const {
|
return numLengthFields() + 1;
|
}
|
|
// Get lengthField description for the given field
|
const FieldDesc* lengthFieldFor(const FieldDesc& desc) {
|
return (desc.lengthFieldId == -1)
|
? nullptr
|
: &fields_.at(lengthFieldIds_.at(desc.lengthFieldId));
|
}
|
|
// Get lengthField description for the given lengthFieldId, where
|
// 0 <= lengthFieldId < numLengthFields()
|
const FieldDesc& lengthField(int lengthFieldId) {
|
return fields_.at(lengthFieldIds_.at(lengthFieldId));
|
}
|
|
// Returns the index into the 'offset' vector for the given field.
|
int offsetFieldIdFor(const FieldDesc& fieldDesc) {
|
return fieldDesc.lengthFieldId + 1;
|
}
|
|
// Returns the field description for all fields.
|
const std::vector<FieldDesc>& fields() {
|
return fields_;
|
}
|
|
const std::vector<int>& lengthFieldIds() const {
|
return lengthFieldIds_;
|
}
|
|
private:
|
// Description of each field
|
std::vector<FieldDesc> fields_;
|
// Index into fields_ above for the fields that are lengths.
|
std::vector<int> lengthFieldIds_;
|
};
|
|
class TreeCursor {
|
public:
|
explicit TreeCursor(const TreeIterator& iterator) : it(iterator) {}
|
std::vector<TOffset> offsets;
|
std::mutex mutex_;
|
TreeIterator it;
|
};
|
|
/**
|
* Simple wrapper class allowing an easy traversal of the tensors representing
|
* the hirerarchical structure.
|
*/
|
class TreeWalker {
|
public:
|
TreeWalker(const vector<const Blob*>& inputs, TreeCursor& cursor);
|
|
// Returns the number of records in a dataset
|
inline TOffset size() const {
|
return limits_.at(0);
|
}
|
|
void advance();
|
|
private:
|
inline const TensorCPU& input(int32_t idx) const {
|
return inputs_[idx]->Get<TensorCPU>();
|
}
|
|
// TODO: Change to fieldDesc
|
inline const TreeIterator::FieldDesc& field(int idx) const {
|
return cursor_.it.fields().at(idx);
|
}
|
|
inline int lengthIdx(int fieldId) const {
|
return field(fieldId).lengthFieldId + 1;
|
}
|
|
inline TOffset offset(int fieldId) const {
|
return prevOffsets_[lengthIdx(fieldId)];
|
}
|
|
std::vector<int64_t> fieldDim(int fieldId) const;
|
|
void* fieldPtr(int fieldId) const;
|
|
public:
|
// Simple Proxy class to expose nicer API for field access
|
class Field {
|
public:
|
Field(TreeWalker& walker, int fieldId)
|
: walker_(walker), fieldId_(fieldId) {}
|
|
inline std::vector<int64_t> dim() const {
|
return walker_.fieldDim(fieldId_);
|
}
|
|
inline int64_t size() const {
|
int64_t size = 1;
|
for (const auto d : dim()) {
|
size *= d;
|
}
|
return size;
|
}
|
|
inline const TypeMeta& meta() const {
|
return walker_.input(fieldId_).dtype();
|
}
|
|
inline void* ptr() const {
|
return walker_.fieldPtr(fieldId_);
|
}
|
|
int fieldId() const {
|
return fieldId_;
|
}
|
|
inline TOffset offset() const {
|
return walker_.offset(fieldId_);
|
}
|
|
private:
|
const TreeWalker& walker_;
|
const int fieldId_;
|
};
|
|
// Notice that a reference is returned. If advance() is called the fields will
|
// be updated to represent the new state.
|
inline const std::vector<Field>& fields() const {
|
return fields_;
|
}
|
|
private:
|
void gatherLengthData();
|
|
void gatherSizeLimits();
|
|
const vector<const Blob*>& inputs_;
|
TreeCursor& cursor_;
|
std::vector<Field> fields_;
|
|
std::vector<const TLength*> lengths_;
|
std::vector<TOffset> limits_;
|
std::vector<TOffset> sizes_;
|
std::vector<TOffset> offsets_;
|
std::vector<TOffset> prevOffsets_;
|
};
|
|
using SharedTensorVectorPtr = std::shared_ptr<std::vector<TensorCPU>>;
|
|
using TensorVectorPtr = std::unique_ptr<std::vector<Tensor>>;
|
|
class SharedTensorVectorPtrSerializer : public BlobSerializerBase {
|
public:
|
void Serialize(
|
const void* pointer,
|
TypeMeta typeMeta,
|
const string& name,
|
BlobSerializerBase::SerializationAcceptor acceptor) override;
|
};
|
|
class SharedTensorVectorPtrDeserializer : public BlobDeserializerBase {
|
public:
|
void Deserialize(const BlobProto& proto, Blob* blob) override;
|
};
|
|
} // namespace dataset_ops
|
} // namespace caffe2
|
|
#endif // CAFFE2_OPERATORS_DATASET_OPS_H_
|