#pragma once #include #include namespace c10 { template TypePtr getTypePtr(); std::string toString(TypePtr typePtr); namespace impl { inline bool shallowEquals(const IValue& lhs, const IValue& rhs) { if (lhs.isNone()) { return rhs.isNone(); } else if (lhs.isInt()) { return rhs.isInt() && lhs.toInt() == rhs.toInt(); } else if (lhs.isString()) { return rhs.isString() && lhs.toStringRef() == rhs.toStringRef(); } else if (lhs.isDouble()) { return rhs.isDouble() && lhs.toDouble() == rhs.toDouble(); } else if (lhs.isBool()) { return rhs.isBool() && lhs.toBool() == rhs.toBool(); } else if (lhs.isIntList()) { return rhs.isIntList() && lhs.toIntListRef() == rhs.toIntListRef(); } else if (lhs.isTensor()) { return lhs.toTensor().is_same(rhs.toTensor()); } else { AT_ERROR("shallowEquals(IValue, IValue) not implemented for type ", lhs.tagKind()); } } template Dict toTypedDict(GenericDict dict) { TORCH_INTERNAL_ASSERT(*getTypePtr() == *dict.impl_->elementTypes.keyType, "Tried to cast a Dict<", toString(dict.impl_->elementTypes.keyType), ", ", toString(dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(getTypePtr()), ", ", toString(getTypePtr()), ">. Key types mismatch."); TORCH_INTERNAL_ASSERT(*getTypePtr() == *dict.impl_->elementTypes.valueType, "Tried to cast a Dict<", toString(dict.impl_->elementTypes.keyType), ", ", toString(dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(getTypePtr()), ", ", toString(getTypePtr()), ">. Value types mismatch."); return Dict(std::move(dict.impl_)); } template GenericDict toGenericDict(Dict dict) { return GenericDict(std::move(dict.impl_)); } } namespace detail { inline size_t DictKeyHash::operator()(const IValue& ivalue) const { if (ivalue.isInt()) { return std::hash()(ivalue.toInt()); } else if (ivalue.isString()) { return std::hash()(ivalue.toStringRef()); } else if (ivalue.isDouble()) { return std::hash()(ivalue.toDouble()); } else if (ivalue.isBool()) { return std::hash()(ivalue.toBool()); } else if (ivalue.isTensor()) { return std::hash()(ivalue.toTensor().unsafeGetTensorImpl()); } else { throw std::runtime_error("Can't hash IValues with this tag"); } } inline intrusive_ptr DictImpl::copy() const { return make_intrusive(dict, elementTypes); } } template Dict::Dict() :Dict(make_intrusive( detail::DictImpl::dict_map_type(), detail::DictImpl::DictElementTypes{getTypePtr(), getTypePtr()})) { static_assert(!std::is_same::value, "This constructor is not valid for Dict. Please use c10::impl::GenericDict(keyType, valueType) instead, or if you absolutely have to, use c10::impl::GenericDict(c10::impl::deprecatedUntypedDict())."); static_assert(!std::is_same::value, "This constructor is not valid for Dict<_, IValue>. Please use c10::impl::GenericDict(keyType, valueType) instead, or if you absolutely have to, use c10::impl::GenericDict(c10::impl::deprecatedUntypedDict())."); } template Dict::Dict(TypePtr keyType, TypePtr valueType) : Dict(make_intrusive( detail::DictImpl::dict_map_type(), detail::DictImpl::DictElementTypes {std::move(keyType), std::move(valueType)})) { static_assert(std::is_same::value, "This constructor is only valid for c10::impl::GenericDict."); static_assert(std::is_same::value, "This constructor is only valid for c10::impl::GenericDict."); } template Dict::Dict(Dict&& rhs) noexcept: impl_(std::move(rhs.impl_)) { rhs.impl_ = make_intrusive(detail::DictImpl::dict_map_type(), impl_->elementTypes); } template Dict::Dict(c10::intrusive_ptr&& impl): impl_(std::move(impl)) {} template Dict& Dict::operator=(Dict&& rhs) noexcept { impl_ = std::move(rhs.impl_); rhs.impl_ = make_intrusive(detail::DictImpl::dict_map_type(), impl_->elementTypes); return *this; } template Dict Dict::copy() const { return Dict(impl_->copy()); } template typename Dict::iterator Dict::begin() const { return iterator{impl_->dict.begin()}; } template typename Dict::iterator Dict::end() const { return iterator{impl_->dict.end()}; } template bool Dict::empty() const { return impl_->dict.empty(); } template typename Dict::size_type Dict::size() const { return impl_->dict.size(); } template void Dict::clear() const { impl_->dict.clear(); } template template std::pair::iterator, bool> Dict::insert(Key_&& key, Value_&& value) const { static_assert(std::is_constructible::value, "Wrong type for the key argument of Dict::insert"); static_assert(std::is_constructible::value, "Wrong type for the value argument of Dict::insert"); auto inserted = impl_->dict.insert(std::pair{ Key(std::forward(key)), Value(std::forward(value))}); return {iterator{inserted.first}, inserted.second}; } template template std::pair::iterator, bool> Dict::insert_or_assign(Key_&& key, Value_&& value) const { static_assert(std::is_constructible::value, "Wrong type for the key argument of Dict::insert_or_assign"); static_assert(std::is_constructible::value, "Wrong type for the value argument of Dict::insert_or_assign"); auto inserted = impl_->dict.insert_or_assign( Key(std::forward(key)), Value(std::forward(value))); return {iterator{inserted.first}, inserted.second}; } template void Dict::erase(iterator iter) const { impl_->dict.erase(iter.entryRef_.iterator_); } template C10_NODISCARD size_t Dict::erase(const Key& key) const { return impl_->dict.erase(key); } template Value Dict::at(const Key& key) const { return impl_->dict.at(key).template to(); } template typename Dict::iterator Dict::find(const Key& key) const { return iterator{impl_->dict.find(key)}; } template bool Dict::contains(const Key& key) const { return end() != find(key); } template void Dict::reserve(size_type count) const { impl_->dict.reserve(count); } template TypePtr Dict::keyType() const { return impl_->elementTypes.keyType; } template TypePtr Dict::valueType() const { return impl_->elementTypes.valueType; } template void Dict::unsafeSetKeyType(TypePtr t) { impl_->elementTypes.keyType = std::move(t); } template void Dict::unsafeSetValueType(TypePtr t) { impl_->elementTypes.valueType = std::move(t); } }