#pragma once #include #include #include #include #ifdef BUILD_NAMEDTENSOR namespace at { // XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen. // Due to the c10/ATen library split, TensorImpl cannot depend on Dimname, // so we have a couple of workarounds. // // In the long term, we'll move Dimname to c10 and everything in this file // can be refactored out. The main blocker for that is that "c10::Symbol" // actually exists outside of c10 and needs to be moved in. // TensorImpl has a unique_ptr field. // XXX: Ideally we would just put optional> into TensorImpl. struct CAFFE2_API NamedTensorMeta : public c10::NamedTensorMetaInterface { explicit NamedTensorMeta(int64_t num_names) : names_(std::vector(num_names, Dimname::wildcard())) {} explicit NamedTensorMeta(DimnameList names) : names_(names.vec()) {} explicit NamedTensorMeta(std::vector&& names) : names_(std::move(names)) {} std::unique_ptr clone() const override { return c10::guts::make_unique(names_); } bool has_names() const; DimnameList names() const { return names_; } // Used for an assertion in TensorImpl.h int64_t slow_dim() const override { return names_.size(); } void set_names(DimnameList new_names) { TORCH_INTERNAL_ASSERT(new_names.size() == names_.size()); std::copy(new_names.begin(), new_names.end(), names_.begin()); } void set_names(std::vector&& new_names) { TORCH_INTERNAL_ASSERT(new_names.size() == names_.size()); names_ = std::move(new_names); } private: std::vector names_; }; // When NamesMode is disabled, then all operations ignore tensors' names fields. // Concretely speaking, all tensors are treated as having nullopt names. struct CAFFE2_API NamesMode { static bool is_enabled(); static void set_enabled(bool enabled); }; // A RAII, thread local (!) guard that enables or disables names upon // construction, and sets it back to the original value upon destruction. struct CAFFE2_API NoNamesGuard { NoNamesGuard() : prev_mode(NamesMode::is_enabled()) { NamesMode::set_enabled(false); } ~NoNamesGuard() { NamesMode::set_enabled(prev_mode); } private: bool prev_mode; }; void check_names_valid_for(const Tensor& tensor, DimnameList names); // Sets the names of `tensor` to be `names`. CAFFE2_API Tensor& internal_set_names_inplace(Tensor& tensor, optional names); CAFFE2_API Tensor& internal_set_names_inplace(Tensor& tensor, std::vector&& names, bool validate_names); constexpr size_t kMaxNamedTensorDim = 64; DimnameList default_names(size_t len); namespace impl { // Some helper functions on TensorImpl. Useful for working with names in TH. // XXX: Ideally these would exist as methods on TensorImpl CAFFE2_API void internal_set_names_inplace(TensorImpl* impl, optional names); CAFFE2_API void internal_set_names_inplace(TensorImpl* impl, std::vector&& names, bool validate_names); void check_names_valid_for(TensorImpl* impl, DimnameList names); // Returns true if the tensor's names exist and are not all 'None'. // Returns false if the tensor's names don't exist (were not allocated), // or if all names are 'None'. // We treat not-allocated-names the same as allocated names that are all 'None'. CAFFE2_API bool has_names(const TensorImpl* impl); // Returns the names of the tensor's dimensions. // Unnamed tensors are treated as having 'None' in all dimension; this method // would return a DimnameList of all 'None's for an unnamed tensor. CAFFE2_API DimnameList get_names(const TensorImpl* impl); // This is more of an implementation detail; one should use impl::get_names / // Tensor::names() whenever possible because it provides a cleaner API. // Returns the names of the tensor if they have been allocated; returns nullopt // instead if the haven't been. The names of a tensor are not allocated if a // tensor is constructed with names=None. CAFFE2_API optional get_opt_names(const TensorImpl* impl); } // namespace impl } // namespace at #endif