#pragma once #include #include #include #include #include #include #include namespace at { class Tensor; } namespace c10 { struct IValue; template class List; struct Type; using TypePtr = std::shared_ptr; namespace detail { template struct ListImpl final : public c10::intrusive_ptr_target { using list_type = std::vector; explicit ListImpl(list_type list_, TypePtr elementType_) : list(std::move(list_)) , elementType(std::move(elementType_)) {} list_type list; TypePtr elementType; intrusive_ptr copy() const { return make_intrusive(list, elementType); } }; } namespace impl { template class ListIterator; template class ListElementReference; template void swap(ListElementReference&& lhs, ListElementReference&& rhs); template class ListElementReference final { public: operator T() const; ListElementReference& operator=(T&& new_value) &&; ListElementReference& operator=(const T& new_value) &&; // assigning another ref to this assigns the underlying value ListElementReference& operator=(ListElementReference&& rhs) &&; friend void swap(ListElementReference&& lhs, ListElementReference&& rhs); private: ListElementReference(Iterator iter) : iterator_(iter) {} ListElementReference(const ListElementReference&) = delete; ListElementReference& operator=(const ListElementReference&) = delete; // allow moving, but only our friends (i.e. the List class) can move us ListElementReference(ListElementReference&&) noexcept = default; ListElementReference& operator=(ListElementReference&& rhs) & noexcept { iterator_ = std::move(rhs.iterator_); return *this; } friend class List; friend class ListIterator; Iterator iterator_; }; // this wraps vector::iterator to make sure user code can't rely // on it being the type of the underlying vector. template class ListIterator final : public std::iterator { public: explicit ListIterator() = default; ~ListIterator() = default; ListIterator(const ListIterator&) = default; ListIterator(ListIterator&&) noexcept = default; ListIterator& operator=(const ListIterator&) = default; ListIterator& operator=(ListIterator&&) = default; ListIterator& operator++() { ++iterator_; return *this; } ListIterator operator++(int) { ListIterator copy(*this); ++*this; return copy; } ListIterator& operator--() { --iterator_; return *this; } ListIterator operator--(int) { ListIterator copy(*this); --*this; return copy; } ListIterator& operator+=(typename List::size_type offset) { iterator_ += offset; return *this; } ListIterator& operator-=(typename List::size_type offset) { iterator_ -= offset; return *this; } ListIterator operator+(typename List::size_type offset) const { return ListIterator{iterator_ + offset}; } ListIterator operator-(typename List::size_type offset) const { return ListIterator{iterator_ - offset}; } friend typename std::iterator::difference_type operator-(const ListIterator& lhs, const ListIterator& rhs) { return lhs.iterator_ - rhs.iterator_; } ListElementReference operator*() const { return {iterator_}; } private: explicit ListIterator(Iterator iterator): iterator_(std::move(iterator)) {} Iterator iterator_; friend bool operator==(const ListIterator& lhs, const ListIterator& rhs) { return lhs.iterator_ == rhs.iterator_; } friend bool operator!=(const ListIterator& lhs, const ListIterator& rhs) { return !(lhs == rhs); } friend bool operator<(const ListIterator& lhs, const ListIterator& rhs) { return lhs.iterator_ < rhs.iterator_; } friend bool operator<=(const ListIterator& lhs, const ListIterator& rhs) { return lhs.iterator_ <= rhs.iterator_; } friend bool operator>(const ListIterator& lhs, const ListIterator& rhs) { return lhs.iterator_ > rhs.iterator_; } friend bool operator>=(const ListIterator& lhs, const ListIterator& rhs) { return lhs.iterator_ >= rhs.iterator_; } friend class ListIterator::list_type::iterator, StorageT>; friend class List; }; template List toTypedList(List list); template List toGenericList(List list); const IValue* ptr_to_first_element(const List& list); template List toList(std::vector list); template const std::vector& toVector(const List& list); } template bool list_is_equal(const List& lhs, const List& rhs); /** * An object of this class stores a list of values of type T. * * This is a pointer type. After a copy, both Lists * will share the same storage: * * > List a; * > List b = a; * > b.push_back("three"); * > ASSERT("three" == a.get(0)); * * We use this class in the PyTorch kernel API instead of * std::vector, because that allows us to do optimizations * and switch out the underlying list implementation without * breaking backwards compatibility for the kernel API. */ template class List final { private: // List of types that don't use IValue based lists using types_with_direct_list_implementation = guts::typelist::typelist< int64_t, double, bool, at::Tensor >; using StorageT = guts::conditional_t< guts::typelist::contains::value, T, // The types listed in types_with_direct_list_implementation store the list as std::vector IValue // All other types store the list as std::vector >; // This is an intrusive_ptr because List is a pointer type. // Invariant: This will never be a nullptr, there will always be a valid // ListImpl. c10::intrusive_ptr> impl_; using internal_reference_type = impl::ListElementReference::StorageT>::list_type::iterator, typename List::StorageT>; public: using value_type = T; using size_type = typename detail::ListImpl::list_type::size_type; using iterator = impl::ListIterator::list_type::iterator, StorageT>; using reverse_iterator = impl::ListIterator::list_type::reverse_iterator, StorageT>; using internal_value_type_test_only = StorageT; /** * Constructs an empty list. */ explicit List(); /** * Constructs a list with some initial values. * Example: * List a({2, 3, 4}); */ explicit List(std::initializer_list initial_values); explicit List(ArrayRef initial_values); /** * Create a generic list with runtime type information. * This only works for c10::impl::GenericList and is not part of the public API * but only supposed to be used internally by PyTorch. */ explicit List(TypePtr elementType); List(const List&) = default; List& operator=(const List&) = default; List(List&&) noexcept; List& operator=(List&&) noexcept; /** * Create a new List pointing to a deep copy of the same data. * The List returned is a new list with separate storage. * Changes in it are not reflected in the original list or vice versa. */ List copy() const; /** * Returns the element at specified location pos, with bounds checking. * If pos is not within the range of the container, an exception of type std::out_of_range is thrown. */ value_type get(size_type pos) const; /** * Moves out the element at the specified location pos and returns it, with bounds checking. * If pos is not within the range of the container, an exception of type std::out_of_range is thrown. * The list contains an invalid element at position pos afterwards. Any operations * on it before re-setting it are invalid. */ value_type extract(size_type pos) const; /** * Returns a reference to the element at specified location pos, with bounds checking. * If pos is not within the range of the container, an exception of type std::out_of_range is thrown. * * You cannot store the reference, but you can read it and assign new values to it: * * List list = ...; * list[2] = 5; * int64_t v = list[1]; */ internal_reference_type operator[](size_type pos) const; /** * Assigns a new value to the element at location pos. */ void set(size_type pos, const value_type& value) const; /** * Assigns a new value to the element at location pos. */ void set(size_type pos, value_type&& value) const; /** * Returns an iterator to the first element of the container. * If the container is empty, the returned iterator will be equal to end(). */ iterator begin() const; /** * Returns an iterator to the element following the last element of the container. * This element acts as a placeholder; attempting to access it results in undefined behavior. */ iterator end() const; /** * Checks if the container has no elements. */ bool empty() const; /** * Returns the number of elements in the container */ size_type size() const; /** * Increase the capacity of the vector to a value that's greater or equal to new_cap. */ void reserve(size_type new_cap) const; /** * Erases all elements from the container. After this call, size() returns zero. * Invalidates any references, pointers, or iterators referring to contained elements. Any past-the-end iterators are also invalidated. */ void clear() const; /** * Inserts value before pos. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ iterator insert(iterator pos, const T& value) const; /** * Inserts value before pos. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ iterator insert(iterator pos, T&& value) const; /** * Inserts a new element into the container directly before pos. * The new element is constructed with the given arguments. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ template iterator emplace(iterator pos, Args&&... value) const; /** * Appends the given element value to the end of the container. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ void push_back(const T& value) const; /** * Appends the given element value to the end of the container. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ void push_back(T&& value) const; /** * Appends the given list to the end of the container. Uses at most one memory allocation. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ void append(List lst) const; /** * Appends the given element value to the end of the container. * The new element is constructed with the given arguments. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ template void emplace_back(Args&&... args) const; /** * Removes the element at pos. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ iterator erase(iterator pos) const; /** * Removes the elements in the range [first, last). * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ iterator erase(iterator first, iterator last) const; /** * Removes the last element of the container. * Calling pop_back on an empty container is undefined. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ void pop_back() const; /** * Resizes the container to contain count elements. * If the current size is less than count, additional default-inserted elements are appended. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ void resize(size_type count) const; /** * Resizes the container to contain count elements. * If the current size is less than count, additional copies of value are appended. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ void resize(size_type count, const T& value) const; /** * Compares two lists for equality. Two lists are equal if they have the * same number of elements and for each list position the elements at * that position are equal. */ friend bool list_is_equal(const List& lhs, const List& rhs); /** * Returns the number of Lists currently pointing to this same list. * If this is the only instance pointing to this list, returns 1. */ // TODO Test use_count size_t use_count() const; TypePtr elementType() const; // See [unsafe set type] for why this exists. void unsafeSetElementType(TypePtr t); private: explicit List(c10::intrusive_ptr>&& elements); friend struct IValue; template friend List impl::toTypedList(List); template friend List impl::toGenericList(List); friend const IValue* impl::ptr_to_first_element(const List& list); template friend List impl::toList(std::vector list); template friend const std::vector& impl::toVector(const List& list); }; namespace impl { // GenericList is how IValue stores lists. It is, however, not part of the // public API. Kernels should use Lists with concrete types instead // (maybe except for some internal prim ops). using GenericList = List; inline const IValue* ptr_to_first_element(const GenericList& list) { return &list.impl_->list[0]; } template const std::vector& toVector(const List& list) { static_assert(std::is_same::value || std::is_same::StorageT>::value, "toVector only works for lists that store their elements as std::vector. You tried to call it for a list that stores its elements as std::vector."); return list.impl_->list; } template List toList(std::vector list) { static_assert(std::is_same::value || std::is_same::StorageT>::value, "toList only works for lists that store their elements as std::vector. You tried to call it for a list that stores its elements as std::vector."); List result; result.impl_->list = std::move(list); return result; } } } namespace torch { template using List = c10::List; } #include