#ifndef GATHER_OP_H_ #define GATHER_OP_H_ #include "caffe2/core/context.h" #include "caffe2/core/operator.h" namespace caffe2 { // This maintains index-mapping functions shared by Gather and BatchGather ops. namespace gather_helper { // New shape is concatenation: // [data dims before axis] + [indices dims] + [data dims after axis] template static vector calc_output_shape_vector( const DataDimsVec& data_dims, const IndexDimsVec& indices_dims, int axis, bool match_outer) { vector shape; // If the dimension we are indexing is empty, just use data_dims as shape. // This replicates behavior in (https://github.com/pytorch/pytorch/pull/13781) // needed to allow workflows with empty batch to succeed. if (data_dims[axis] == 0) { shape.insert(shape.end(), data_dims.begin(), data_dims.end()); } else { shape.insert(shape.end(), data_dims.begin(), data_dims.begin() + axis); if (match_outer) { shape.insert( shape.end(), indices_dims.begin() + axis, indices_dims.end()); } else { shape.insert(shape.end(), indices_dims.begin(), indices_dims.end()); } shape.insert(shape.end(), data_dims.begin() + axis + 1, data_dims.end()); } return shape; } // Check that indices fall within dimension array size with CAFFE_ENFORCE. template static void check_indexarray_range( const IndexType* indices, int64_t n, IndexType indexing_axis_dim, bool wrap_indices) { // for (auto i = 0; i < n; ++i) { auto idx = indices[i]; if (wrap_indices && idx < 0) { idx = idx + indexing_axis_dim; } CAFFE_ENFORCE( 0 <= idx && idx < indexing_axis_dim, "INDICES element is out of DATA bounds, id=", idx, " axis_dim=", indexing_axis_dim); } } // Actual gather implementation - resizes output and copies indexed data. template static bool gather_impl( Operator* op, int dataIdx, int indicesIdx, int outputIdx, int axis, bool wrap_indices, bool match_outer) { // If we endup using it on GPU doing O(N) memcpy is probably not best :) // TODO: implement prefetching if it starts mattering (TF does it) const Tensor& data = op->Input(dataIdx); const Tensor& indices = op->Input(indicesIdx); const TypeMeta dataType = data.dtype(); size_t item_bytesize = dataType.itemsize(); // ONNX allows negative axis to index from the back, valid range: [-r, r]. if (axis < 0) { axis = data.dim() + axis; } CAFFE_ENFORCE_GE(data.dim(), axis + 1, "DATA should be at least [axis+1]-D"); CAFFE_ENFORCE_GE(axis, 0, "Axis should be non-negative"); CAFFE_ENFORCE_LT(axis, data.dim(), "Axis out of range"); // New shape: // [data dims before axis] + [indices dims] + [data dims after axis] vector shape = calc_output_shape_vector( data.sizes(), indices.sizes(), axis, match_outer); Tensor* output = op->Output(outputIdx, shape, at::dtype(dataType)); auto out = static_cast(output->raw_mutable_data(dataType)); // Succeed if size of output is zero, which can happen for empty batch which // would have data dimension size of 0. // This *must* be done AFTER output->raw_mutable_data() above as that has // important allocation side effect that we must see. if (output->numel() == 0) { return true; } const Index* idxs = indices.template data(); auto src_base = static_cast(data.raw_data()); auto outer_dims_product = data.size_to_dim(axis); auto block_size = data.size_from_dim(axis + 1); auto block_bytesize = block_size * item_bytesize; auto src_indexing_axis_dim = data.size(axis); auto src_batch_bytesize = data.size_from_dim(axis) * item_bytesize; // Treat indices as a single block even if they have multiple dimensions. // The "gathered batch" is a cumulative result combining indexed blocks. auto idx_inner_dims_product = indices.size_from_dim(axis); auto N = indices.numel(); if (match_outer) { CAFFE_ENFORCE_GE(axis, 1, "Axis should be at least 1"); for (auto i = 0; i < axis; i++) { CAFFE_ENFORCE_EQ( data.size(i), indices.size(i), "INDICES must have the same outer dims as DATA (before dim AXIS)"); } N = idx_inner_dims_product; } auto gathered_batch_bytesize = N * block_size * item_bytesize; check_indexarray_range(idxs, N, src_indexing_axis_dim, wrap_indices); // Special-case single-float copy for efficiency if (data.template IsType() && block_size == 1) { for (auto batch = 0; batch < outer_dims_product; ++batch) { const float* src_floats = (const float*)(src_base + batch * src_batch_bytesize); float* dst_floats = (float*)(out + batch * gathered_batch_bytesize); for (auto i = 0; i < N; ++i) { auto idx = idxs[i]; if (match_outer) { idx = idxs[batch * idx_inner_dims_product + i]; } if (wrap_indices && idx < 0) { idx = idx + src_indexing_axis_dim; } dst_floats[i] = src_floats[idx]; } } } else { // outer_dims_product specifies how many times we repeat inner dimensions, // so we just iterate over it to cover all outer dimensions. for (auto batch = 0; batch < outer_dims_product; ++batch) { for (auto i = 0; i < N; ++i) { auto idx = idxs[i]; if (match_outer) { idx = idxs[batch * idx_inner_dims_product + i]; } if (wrap_indices && idx < 0) { idx = idx + src_indexing_axis_dim; } auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize; auto dst = out + batch * gathered_batch_bytesize + i * block_bytesize; op->getContext()->CopyItemsSameDevice(dataType, block_size, src, dst); } } } return true; } } // namespace gather_helper template class GatherOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit GatherOp(Args&&... args) : Operator(std::forward(args)...), OP_SINGLE_ARG(int, "axis", axis_, 0), OP_SINGLE_ARG(bool, "match_outer", match_outer_, false) { // TBD: We may want to fix the old index wrap behaviour once we have // operator versioning, to only apply it when needed as otherwise its likely // an error. // Right now, we apply index wrapping by default only to axis == 0, // since we have ONNX conversion code that uses it. For other ops it // needs to be speified explicitly with argument or you don't get it. if (OperatorBase::HasArgument("wrap_indices")) { wrap_indices_ = Operator::template GetSingleArgument( "wrap_indices", (false)); } else { wrap_indices_ = (axis_ == 0) ? true : false; } } virtual ~GatherOp() noexcept {} bool RunOnDevice() override { return DispatchHelper>::call( this, this->template Input(INDICES, CPU)); } template bool DoRunWithType() { return gather_helper::gather_impl( this, DATA, INDICES, 0, axis_, wrap_indices_, match_outer_); } INPUT_TAGS(DATA, INDICES); protected: int axis_; bool wrap_indices_; bool match_outer_; }; } // namespace caffe2 #endif // GATHER_OP_H_