#ifndef CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_ #define CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_ #include "caffe2/core/common_omp.h" #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/operators/elementwise_ops.h" #include namespace caffe2 { template class WhereOp final : public Operator { public: USE_OPERATOR_FUNCTIONS(Context); USE_DISPATCH_HELPER; template explicit WhereOp(Args&&... args) : Operator(std::forward(args)...), OP_SINGLE_ARG(bool, "broadcast_on_rows", enable_broadcast_, 0) {} bool RunOnDevice() override { return DispatchHelper< TensorTypes>:: call(this, Input(1)); } template bool DoRunWithType() { auto& select = Input(0); auto& left = Input(1); auto& right = Input(2); if (enable_broadcast_) { CAFFE_ENFORCE_EQ(select.dim(), 1); CAFFE_ENFORCE_EQ(select.size(0), right.size(0)); CAFFE_ENFORCE_EQ(left.sizes(), right.sizes()); } else { CAFFE_ENFORCE_EQ(select.sizes(), left.sizes()); CAFFE_ENFORCE_EQ(select.sizes(), right.sizes()); } auto* output = Output(0, left.sizes(), at::dtype()); const bool* select_data = select.template data(); const T* left_data = left.template data(); const T* right_data = right.template data(); T* output_data = output->template mutable_data(); if (enable_broadcast_) { size_t block_size = left.size_from_dim(1); for (int i = 0; i < select.numel(); i++) { size_t offset = i * block_size; if (select_data[i]) { context_.CopyItemsSameDevice( output->dtype(), block_size, left_data + offset, output_data + offset); } else { context_.CopyItemsSameDevice( output->dtype(), block_size, right_data + offset, output_data + offset); } } } else { for (int i = 0; i < select.numel(); ++i) { output_data[i] = select_data[i] ? left_data[i] : right_data[i]; } } return true; } private: bool enable_broadcast_; }; class IsMemberOfValueHolder { std::unordered_set int32_values_; std::unordered_set int64_values_; std::unordered_set bool_values_; std::unordered_set string_values_; bool has_values_ = false; public: template std::unordered_set& get(); template void set(const std::vector& args) { has_values_ = true; auto& values = get(); values.insert(args.begin(), args.end()); } bool has_values() { return has_values_; } }; template class IsMemberOfOp final : public Operator { USE_OPERATOR_CONTEXT_FUNCTIONS; USE_DISPATCH_HELPER; static constexpr const char* VALUE_TAG = "value"; public: using TestableTypes = TensorTypes; template explicit IsMemberOfOp(Args&&... args) : Operator(std::forward(args)...) { auto dtype = static_cast(this->template GetSingleArgument( "dtype", TensorProto_DataType_UNDEFINED)); switch (dtype) { case TensorProto_DataType_INT32: values_.set(this->template GetRepeatedArgument(VALUE_TAG)); break; case TensorProto_DataType_INT64: values_.set(this->template GetRepeatedArgument(VALUE_TAG)); break; case TensorProto_DataType_BOOL: values_.set(this->template GetRepeatedArgument(VALUE_TAG)); break; case TensorProto_DataType_STRING: values_.set(this->template GetRepeatedArgument(VALUE_TAG)); break; case TensorProto_DataType_UNDEFINED: // If dtype is not provided, values_ will be filled the first time that // DoRunWithType is called. break; default: CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype); } } virtual ~IsMemberOfOp() noexcept {} bool RunOnDevice() override { return DispatchHelper< TensorTypes>::call(this, Input(0)); } template bool DoRunWithType() { auto& input = Input(0); auto* output = Output(0, input.sizes(), at::dtype()); if (!values_.has_values()) { values_.set(this->template GetRepeatedArgument(VALUE_TAG)); } const auto& values = values_.get(); const T* input_data = input.template data(); bool* output_data = output->template mutable_data(); for (int i = 0; i < input.numel(); ++i) { output_data[i] = values.find(input_data[i]) != values.end(); } return true; } protected: IsMemberOfValueHolder values_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_