#pragma once #include #include #include #include #include #ifdef BUILD_NAMEDTENSOR namespace at { using NameVector = SmallVector; inline bool has_names(TensorList tensors) { return std::any_of( tensors.begin(), tensors.end(), [](const Tensor& t) { return t.has_names(); }); } // Converts dim to an positional index. Errors if `dim` cannot be used to // refer to any dimension of tensor. CAFFE2_API int64_t dimname_to_position(const Tensor& tensor, Dimname dim); CAFFE2_API std::vector dimnames_to_positions(const Tensor& tensor, DimnameList dims); // Unifies two DimnameList to produce a third. This is useful for implementing // the named inference rule for binary broadcasting operations like add. // // There are three main constraints: // 1) Check matching: Names must match positionally from the right. // 2) Check misaligned: If a name `n` is in `names`, then it must appear at // the same index from the right in other. // 3) The output names are obtained by unifying the names individually from the right. CAFFE2_API std::vector unify_from_right(DimnameList names, DimnameList other, const char* action = "broadcast"); [[noreturn]] inline void reportNYIDimnameOverload(const char* op_name) { TORCH_CHECK( false, op_name, ": You passed a dimname (string) to this op in place of a dimension " "index but it does not yet support this behavior. Please pass a dimension " "index to work around this."); } namespace namedinference { // Names get propagated via the following rules: // 1) If result does not have names, then `names` get propagated. // 2) If result has names, then `names` must be equal to result.names void propagate_names(Tensor& result, optional names); void propagate_names(Tensor& result, std::vector&& names, bool validate_names); CAFFE2_API void propagate_names(Tensor& result, optional>&& maybe_names, bool validate_names); void propagate_names(TensorImpl* result, optional names); void propagate_names(TensorImpl* result, std::vector&& names, bool validate_names); void propagate_names(TensorImpl* result, optional>&& maybe_names, bool validate_names); // Propagates all names from src to result. CAFFE2_API void propagate_names(Tensor& result, const Tensor& src); void propagate_names(TensorImpl* result, /*const */TensorImpl* src); // Propagates all names except for those at the excluded_idxs. void propagate_names_except(Tensor& result, const Tensor& src, IntArrayRef excluded_idxs); // Used for reduction ops that have a `keepdim` arg. void propagate_names_for_reduction(Tensor& result, const Tensor& src, IntArrayRef excluded_idxs, bool keepdim); // result = m1 @ m2 + bias void propagate_names_for_addmm( TensorImpl* result, /*const*/TensorImpl* m1, /*const*/TensorImpl* m2, /*const*/TensorImpl* bias); void propagate_names_for_addmv( TensorImpl* result, TensorImpl* mat, TensorImpl* vec, TensorImpl* bias); void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2); void propagate_names_for_expand(Tensor& result, const Tensor& self); optional> compute_cat_outnames(TensorList tensors); optional> compute_broadcast_outnames( const Tensor& self, const Tensor& other); optional> broadcast_to_outnames( const Tensor& tensor, const Tensor& reference_tensor, const char* op_name); optional> compute_baddbmm_outnames( TensorImpl* result, TensorImpl* self, TensorImpl* other, TensorImpl* bias); optional> compute_matmul_outnames(const Tensor& self, const Tensor& other); optional> compute_bmm_outnames( Tensor& result, const Tensor& self, const Tensor& other); optional> compute_squeeze_outnames(const Tensor& tensor); } // namespace namedinference } // namespace at #endif