#ifndef CAFFE2_OPERATORS_DATASET_OPS_H_ #define CAFFE2_OPERATORS_DATASET_OPS_H_ #include #include #include #include #include "caffe2/core/blob.h" #include "caffe2/core/blob_serialization.h" #include "caffe2/core/tensor.h" namespace caffe2 { namespace dataset_ops { // used for lengths tensors in the dataset using TLength = int32_t; // used for all internal dataset operations (offsets, sizes to read, etc.) using TOffset = int64_t; /** * Provides functionality to iterate across a list of tensors where some * of those tensors represent lengths in a hierarchical structure. */ class TreeIterator { public: struct FieldDesc { int id; int lengthFieldId = -1; std::string name; }; explicit TreeIterator(const std::vector& fields); void advance( const std::vector& lengths, std::vector& offsets, std::vector& sizes, std::vector& limits, TOffset num); // Corresponds to the number of fields that have "length" as its last name int numLengthFields() const { return lengthFieldIds_.size(); } // Corresponds to the number of length fields + 1 (for the top-level domain) int numOffsetFields() const { return numLengthFields() + 1; } // Get lengthField description for the given field const FieldDesc* lengthFieldFor(const FieldDesc& desc) { return (desc.lengthFieldId == -1) ? nullptr : &fields_.at(lengthFieldIds_.at(desc.lengthFieldId)); } // Get lengthField description for the given lengthFieldId, where // 0 <= lengthFieldId < numLengthFields() const FieldDesc& lengthField(int lengthFieldId) { return fields_.at(lengthFieldIds_.at(lengthFieldId)); } // Returns the index into the 'offset' vector for the given field. int offsetFieldIdFor(const FieldDesc& fieldDesc) { return fieldDesc.lengthFieldId + 1; } // Returns the field description for all fields. const std::vector& fields() { return fields_; } const std::vector& lengthFieldIds() const { return lengthFieldIds_; } private: // Description of each field std::vector fields_; // Index into fields_ above for the fields that are lengths. std::vector lengthFieldIds_; }; class TreeCursor { public: explicit TreeCursor(const TreeIterator& iterator) : it(iterator) {} std::vector offsets; std::mutex mutex_; TreeIterator it; }; /** * Simple wrapper class allowing an easy traversal of the tensors representing * the hirerarchical structure. */ class TreeWalker { public: TreeWalker(const vector& inputs, TreeCursor& cursor); // Returns the number of records in a dataset inline TOffset size() const { return limits_.at(0); } void advance(); private: inline const TensorCPU& input(int32_t idx) const { return inputs_[idx]->Get(); } // TODO: Change to fieldDesc inline const TreeIterator::FieldDesc& field(int idx) const { return cursor_.it.fields().at(idx); } inline int lengthIdx(int fieldId) const { return field(fieldId).lengthFieldId + 1; } inline TOffset offset(int fieldId) const { return prevOffsets_[lengthIdx(fieldId)]; } std::vector fieldDim(int fieldId) const; void* fieldPtr(int fieldId) const; public: // Simple Proxy class to expose nicer API for field access class Field { public: Field(TreeWalker& walker, int fieldId) : walker_(walker), fieldId_(fieldId) {} inline std::vector dim() const { return walker_.fieldDim(fieldId_); } inline int64_t size() const { int64_t size = 1; for (const auto d : dim()) { size *= d; } return size; } inline const TypeMeta& meta() const { return walker_.input(fieldId_).dtype(); } inline void* ptr() const { return walker_.fieldPtr(fieldId_); } int fieldId() const { return fieldId_; } inline TOffset offset() const { return walker_.offset(fieldId_); } private: const TreeWalker& walker_; const int fieldId_; }; // Notice that a reference is returned. If advance() is called the fields will // be updated to represent the new state. inline const std::vector& fields() const { return fields_; } private: void gatherLengthData(); void gatherSizeLimits(); const vector& inputs_; TreeCursor& cursor_; std::vector fields_; std::vector lengths_; std::vector limits_; std::vector sizes_; std::vector offsets_; std::vector prevOffsets_; }; using SharedTensorVectorPtr = std::shared_ptr>; using TensorVectorPtr = std::unique_ptr>; class SharedTensorVectorPtrSerializer : public BlobSerializerBase { public: void Serialize( const void* pointer, TypeMeta typeMeta, const string& name, BlobSerializerBase::SerializationAcceptor acceptor) override; }; class SharedTensorVectorPtrDeserializer : public BlobDeserializerBase { public: void Deserialize(const BlobProto& proto, Blob* blob) override; }; } // namespace dataset_ops } // namespace caffe2 #endif // CAFFE2_OPERATORS_DATASET_OPS_H_