#pragma once #include "caffe2/core/operator.h" #include "caffe2/utils/eigen_utils.h" #include "caffe2/utils/math.h" namespace caffe2 { template class EnsureClippedOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit EnsureClippedOp(Args&&... args) : Operator(std::forward(args)...), min_(std::numeric_limits::lowest()), max_(std::numeric_limits::max()) { if (HasArgument("min")) { min_ = static_cast(this->template GetSingleArgument("min", 0)); } if (HasArgument("max")) { max_ = static_cast(this->template GetSingleArgument("max", 0)); } } bool RunOnDevice() override { if (InputSize() > INDICES) { // spares gradient, selective checking clipping CAFFE_ENFORCE_EQ( Input(PARAM).size_from_dim(1), Input(GRAD).size_from_dim(Input(INDICES).dim())); return DispatchHelper>::call( this, Input(INDICES)); } else { auto& X = Input(PARAM); auto* Y = Output(OUTPUT_PARAM, X.sizes(), at::dtype()); EigenVectorMap(Y->template mutable_data(), Y->numel()) = ConstEigenVectorMap(X.template data(), X.numel()) .cwiseMax(min_) .cwiseMin(max_); return true; } } template bool DoRunWithType(); protected: T min_; T max_; INPUT_TAGS(PARAM, INDICES, GRAD); OUTPUT_TAGS(OUTPUT_PARAM); }; } // namespace caffe2