#pragma once #include #include #include #include #include #include namespace torch { /// An ordered dictionary implementation, akin to Python's `OrderedDict`. template class OrderedDict { public: /// A (key, value) pair. class Item; // The lifetime of an iterator is bound to the lifetime of the `OrderedDict`. // Further, any `insert()` operation may invalidate all iterators // pointing into the vector. using Iterator = typename std::vector::iterator; using ConstIterator = typename std::vector::const_iterator; /// Constructs the `OrderedDict` with a short description of the kinds of keys /// stored in the `OrderedDict`. This description is used in error messages /// thrown by the `OrderedDict`. explicit OrderedDict(std::string key_description = "Key"); /// Copy constructs this `OrderedDict` from `other`. OrderedDict(const OrderedDict& other); /// Assigns items from `other` to this `OrderedDict`. OrderedDict& operator=(const OrderedDict& other); // NB: Move works by default, because you can move-construct vectors of const // values. I tried to make this noexcept (conditional on the move constructors // of index_ and items_ being noexcept) but the obvious spelling didn't // compile on Windows. OrderedDict(OrderedDict&& other) = default; OrderedDict& operator=(OrderedDict&& other) = default; ~OrderedDict() = default; /// Constructs a new `OrderedDict` and pre-populates it with the given /// `Item`s. /*implicit */ OrderedDict(std::initializer_list initializer_list); /// Returns the key description string the `OrderedDict` was constructed with. const std::string& key_description() const noexcept; // Element Access /// Returns the very first item in the `OrderedDict` and throws an exception /// if it is empty. Item& front(); /// Returns the very first item in the `OrderedDict` and throws an exception /// if it is empty. const Item& front() const; /// Returns the very last item in the `OrderedDict` and throws an exception /// if it is empty. Item& back(); /// Returns the very last item in the `OrderedDict` and throws an exception /// if it is empty. const Item& back() const; /// Returns the item at the `index`-th position in the `OrderedDict`. Throws /// an exception if the index is out of bounds. Item& operator[](size_t index); /// Returns the item at the `index`-th position in the `OrderedDict`. Throws /// an exception if the index is out of bounds. const Item& operator[](size_t index) const; /// Returns the value associated with the given `key`. Throws an exception if /// no such key is stored in the `OrderedDict`. Use `find()` for a /// non-throwing way of accessing a value if it is present. Value& operator[](const Key& key); /// Returns the value associated with the given `key`. Throws an exception if /// no such key is stored in the `OrderedDict`. Use `find()` for a /// non-throwing way of accessing a value if it is present. const Value& operator[](const Key& key) const; // Lookup /// Returns a pointer to the value associated with the given key, or a /// `nullptr` if no such key is stored in the `OrderedDict`. Value* find(const Key& key) noexcept; /// Returns a pointer to the value associated with the given key, or a /// `nullptr` if no such key is stored in the `OrderedDict`. const Value* find(const Key& key) const noexcept; /// Returns true if the key is present in the `OrderedDict`. bool contains(const Key& key) const noexcept; // Iterators /// Returns an iterator to the first item in the `OrderedDict`. Iteration is /// ordered. Iterator begin(); /// Returns an iterator to the first item in the `OrderedDict`. Iteration is /// ordered. ConstIterator begin() const; /// Returns an iterator one past the last item in the `OrderedDict`. Iterator end(); /// Returns an iterator one past the last item in the `OrderedDict`. ConstIterator end() const; // Capacity /// Returns the number of items currently stored in the `OrderedDict`. size_t size() const noexcept; /// Returns true if the `OrderedDict` contains no elements. bool is_empty() const noexcept; /// Resizes internal storage to fit at least `requested_capacity` items /// without requiring reallocation. void reserve(size_t requested_capacity); // Modifiers /// Inserts a new `(key, value)` pair into the `OrderedDict`. Throws an /// exception if the key is already present. If insertion is succesful, /// immediately returns a reference to the inserted value. template Value& insert(K&& key, V&& value); /// Inserts a new `(key, value)` pair into the `OrderedDict`. Throws an /// exception if the key is already present. If insertion is succesful, /// immediately returns a reference to the inserted value. Value& insert(Key key, Value&& value); /// Inserts all items from `other` into this `OrderedDict`. If any key from /// `other` is already present in this `OrderedDict`, an exception is thrown. void update(OrderedDict&& other); /// Inserts all items from `other` into this `OrderedDict`. If any key from /// `other` is already present in this `OrderedDict`, an exception is thrown. void update(const OrderedDict& other); /// Removes the item that has `key` from this `OrderedDict` if exists and if /// it doesn't an exception is thrown. void erase(const Key& key); /// Removes all items from this `OrderedDict`. void clear(); // Observers /// Returns the items stored in the `OrderedDict`. const std::vector& items() const noexcept; /// Returns a newly allocated vector and copies all keys from this /// `OrderedDict` into the vector. ::std::vector keys() const; /// Returns a newly allocated vector and copies all values from this /// `OrderedDict` into the vector. ::std::vector values() const; /// Returns a newly allocated vector and copies all keys and values from this /// `OrderedDict` into a vector of `std::pair`. ::std::vector> pairs() const; private: /// A mapping from a key to an index into the `items_` vector. ::std::unordered_map index_; /// The items stored in the `OrderedDict`. ::std::vector items_; /// A description of the keys stored in the `OrderedDict`. ::std::string key_description_{"Key"}; }; // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ OrderedDict::Item ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template class OrderedDict::Item { public: /// Constructs a new item. Item(Key key, Value value) : pair_(std::move(key), std::move(value)) {} /// Returns a reference to the value. Value& operator*() { return value(); } /// Returns a reference to the value. const Value& operator*() const { return value(); } /// Allows access to the value using the arrow operator. Value* operator->() { return &value(); } /// Allows access to the value using the arrow operator. const Value* operator->() const { return &value(); } /// Returns a reference to the key. const Key& key() const noexcept { return pair_.first; } /// Returns a reference to the value. Value& value() noexcept { return pair_.second; } /// Returns a reference to the value. const Value& value() const noexcept { return pair_.second; } /// Returns a `(key, value)` pair. const std::pair& pair() const noexcept { return pair_; } private: /// This is stored as an std::pair because it will make Python binding a lot, /// lot easier. ::std::pair pair_; }; // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ OrderedDict ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template OrderedDict::OrderedDict(std::string key_description) : key_description_(std::move(key_description)) {} template OrderedDict::OrderedDict(const OrderedDict& other) : index_(other.index_), key_description_(other.key_description_) { // Copy we have to do ourselves, because items' keys are const, so we have to // re-insert the items. for (const auto& item : other.items_) { items_.push_back(item); } } template OrderedDict& OrderedDict::operator=( const OrderedDict& other) { index_ = other.index_; items_.clear(); for (auto& item : other.items_) { items_.push_back(item); } key_description_ = other.key_description_; return *this; } template OrderedDict::OrderedDict( std::initializer_list initializer_list) : OrderedDict("Key") { items_.reserve(initializer_list.size()); for (auto& item : initializer_list) { // Copy the key here and move it into the index. items_.emplace_back(item.key(), std::move(item.value())); index_.emplace(std::move(item.key()), size() - 1); } } template typename OrderedDict::Iterator OrderedDict::begin() { return items_.begin(); } template typename OrderedDict::ConstIterator OrderedDict::begin() const { return items_.begin(); } template typename OrderedDict::Iterator OrderedDict::end() { return items_.end(); } template typename OrderedDict::ConstIterator OrderedDict::end() const { return items_.end(); } template typename OrderedDict::Item& OrderedDict::front() { TORCH_CHECK(!items_.empty(), "Called front() on an empty OrderedDict"); return items_.front(); } template const typename OrderedDict::Item& OrderedDict::front() const { TORCH_CHECK(!items_.empty(), "Called front() on an empty OrderedDict"); return items_.front(); } template typename OrderedDict::Item& OrderedDict::back() { TORCH_CHECK(!items_.empty(), "Called back() on an empty OrderedDict"); return items_.back(); } template const typename OrderedDict::Item& OrderedDict::back() const { TORCH_CHECK(!items_.empty(), "Called back() on an empty OrderedDict"); return items_.back(); } template typename OrderedDict::Item& OrderedDict::operator[]( size_t index) { TORCH_CHECK(index < items_.size(), "Index ", index, " is out of bounds"); return items_[index]; } template const typename OrderedDict:: Item& OrderedDict::operator[](size_t index) const { TORCH_CHECK(index < items_.size(), "Index ", index, " is out of bounds"); return items_[index]; } template Value& OrderedDict::operator[](const Key& key) { if (auto* value = find(key)) { return *value; } AT_ERROR(key_description_, " '", key, "' is not defined"); } template const Value& OrderedDict::operator[](const Key& key) const { if (auto* value = find(key)) { return *value; } AT_ERROR(key_description_, " '", key, "' is not defined"); } template template Value& OrderedDict::insert(K&& key, V&& value) { TORCH_CHECK( index_.count(key) == 0, key_description_, " '", key, "' already defined"); // Copy `key` here and move it into the index. items_.emplace_back(key, std::forward(value)); index_.emplace(std::forward(key), size() - 1); return items_.back().value(); } template Value& OrderedDict::insert(Key key, Value&& value) { return insert(std::move(key), std::move(value)); } template void OrderedDict::update(OrderedDict&& other) { reserve(size() + other.size()); for (auto& item : other) { // We want to call `insert()` to prevent duplicate keys. insert(std::move(item.key()), std::move(item.value())); } } template void OrderedDict::update(const OrderedDict& other) { reserve(size() + other.size()); for (auto& item : other) { // We want to call `insert()` to prevent duplicate keys. insert(item.key(), item.value()); } } template Value* OrderedDict::find(const Key& key) noexcept { auto iterator = index_.find(key); if (iterator == index_.end()) { return nullptr; } return &items_[iterator->second].value(); } template const Value* OrderedDict::find(const Key& key) const noexcept { auto iterator = index_.find(key); if (iterator == index_.end()) { return nullptr; } return &items_[iterator->second].value(); } template void OrderedDict::erase(const Key& key) { auto it = index_.find(key); TORCH_CHECK(it != index_.end(), "Key '", key, "' doesn't exist"); auto index = it->second; index_.erase(it); items_.erase(items_.begin() + index); for (auto& pair : index_) if (pair.second > index) --pair.second; } template bool OrderedDict::contains(const Key& key) const noexcept { return find(key) != nullptr; } template void OrderedDict::clear() { index_.clear(); items_.clear(); } template size_t OrderedDict::size() const noexcept { return items_.size(); } template bool OrderedDict::is_empty() const noexcept { return items_.empty(); } template const std::string& OrderedDict::key_description() const noexcept { return key_description_; } template const std::vector::Item>& OrderedDict< Key, Value>::items() const noexcept { return items_; } template ::std::vector OrderedDict::keys() const { std::vector keys; keys.reserve(size()); for (const auto& item : items_) { keys.push_back(item.key()); } return keys; } template ::std::vector OrderedDict::values() const { std::vector values; values.reserve(size()); for (const auto& item : items_) { values.push_back(item.value()); } return values; } template ::std::vector> OrderedDict::pairs() const { std::vector> values; values.reserve(size()); for (const auto& item : items_) { values.push_back(item.pair()); } return values; } template void OrderedDict::reserve(size_t requested_capacity) { index_.reserve(requested_capacity); items_.reserve(requested_capacity); } } // namespace torch