#ifndef CAFFE2_OPERATORS_DISTANCE_OP_H_ #define CAFFE2_OPERATORS_DISTANCE_OP_H_ #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { template class SquaredL2DistanceOp : public Operator { public: template explicit SquaredL2DistanceOp(Args&&... args) : Operator(std::forward(args)...) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; protected: // Input: X, Y; Output: Distance }; template class SquaredL2DistanceGradientOp final : public Operator { public: template explicit SquaredL2DistanceGradientOp(Args&&... args) : Operator(std::forward(args)...) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override { auto& X = Input(0); auto& Y = Input(1); auto& dDistance = Input(2); int N = X.dim() > 0 ? X.dim32(0) : 1; int D = N > 0 ? X.numel() / N : 0; CAFFE_ENFORCE(X.dim() == Y.dim()); for (int i = 0; i < X.dim(); ++i) { CAFFE_ENFORCE(X.dim32(i) == Y.dim32(i)); } CAFFE_ENFORCE(dDistance.dim() == 1); CAFFE_ENFORCE(dDistance.dim32(0) == N); auto* dX = Output(0, X.sizes(), at::dtype()); auto* dY = Output(1, Y.sizes(), at::dtype()); math::Sub( X.numel(), X.template data(), Y.template data(), dX->template mutable_data(), &context_); for (int i = 0; i < N; ++i) { math::Scale( D, dDistance.template data() + i, dX->template data() + i * D, dX->template mutable_data() + i * D, &context_); } // The gradient of the other side is basically the negative. math::Scale( X.numel(), -1, dX->template data(), dY->template mutable_data(), &context_); return true; } protected: // Input: X, Y, dDistance; Output: dX, dY }; template class L1DistanceOp : public Operator { public: template explicit L1DistanceOp(Args&&... args) : Operator(std::forward(args)...) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; protected: // Input: X, Y; Output: Distance }; template class L1DistanceGradientOp : public Operator { public: template explicit L1DistanceGradientOp(Args&&... args) : Operator(std::forward(args)...) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; protected: // Input: X, Y, dDistance; Output: dX, dY }; template class DotProductOp : public Operator { public: template explicit DotProductOp(Args&&... args) : Operator(std::forward(args)...) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; protected: INPUT_TAGS(X_IN, Y_IN); OUTPUT_TAGS(DOT_OUT); }; template class DotProductGradientOp final : public Operator { public: template explicit DotProductGradientOp(Args&&... args) : Operator(std::forward(args)...) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; protected: INPUT_TAGS(X_IN, Y_IN, DER_DOT_IN); OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT); }; template class DotProductWithPaddingOp : public Operator { public: template explicit DotProductWithPaddingOp(Args&&... args) : Operator(std::forward(args)...), pad_value_(this->template GetSingleArgument("pad_value", 0.0)), replicate_(this->template GetSingleArgument("replicate", false)) { } USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; protected: float pad_value_; bool replicate_; INPUT_TAGS(X_IN, Y_IN); OUTPUT_TAGS(DOT_OUT); }; template class CosineSimilarityOp : public Operator { public: template explicit CosineSimilarityOp(Args&&... args) : Operator(std::forward(args)...) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; protected: INPUT_TAGS(X_IN, Y_IN); OUTPUT_TAGS(COS_OUT); private: Tensor aux_; }; template class CosineSimilarityGradientOp final : public Operator { public: template explicit CosineSimilarityGradientOp(Args&&... args) : Operator(std::forward(args)...) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; protected: INPUT_TAGS(X_IN, Y_IN, DER_COS_IN); OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT); private: Tensor aux_; }; template class DotProductWithPaddingGradientOp final : public Operator { public: template explicit DotProductWithPaddingGradientOp(Args&&... args) : Operator(std::forward(args)...), pad_value_(this->template GetSingleArgument("pad_value", 0.0)), replicate_(this->template GetSingleArgument("replicate", false)) { } USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override { auto& X = Input(X_IN); auto& Y = Input(Y_IN); auto& dDot = Input(DER_DOT_IN); int N, D, DX, DY, restD; if (X.numel() > 0) { N = X.dim() > 0 ? X.dim32(0) : 1; DX = X.numel() / N; DY = Y.numel() / N; } else { N = 0; DX = 0; DY = 0; } CAFFE_ENFORCE(!replicate_ || DX % DY == 0 || DY % DX == 0); D = std::min(DX, DY); restD = std::max(DX, DY) - D; CAFFE_ENFORCE_EQ(X.dim(), Y.dim()); CAFFE_ENFORCE_EQ(X.dim32(0), Y.dim32(0)); CAFFE_ENFORCE_EQ(dDot.dim(), 1); CAFFE_ENFORCE_EQ(dDot.dim32(0), N); auto* dX = Output(DER_X_OUT, X.sizes(), at::dtype()); auto* dY = Output(DER_Y_OUT, Y.sizes(), at::dtype()); const auto* X_data = X.template data(); const auto* Y_data = Y.template data(); const auto* dDot_data = dDot.template data(); auto* dX_data = dX->template mutable_data(); auto* dY_data = dY->template mutable_data(); for (int i = 0; i < N; ++i) { // TODO: multithreading auto offsetX = i * DX; auto offsetY = i * DY; if (replicate_) { // L_ for longer vector and S_ for shorter vector const T *L_data, *S_data; T *dL_data, *dS_data; int DL, DS; if (DX > DY) { L_data = X_data + offsetX; S_data = Y_data + offsetY; dL_data = dX_data + offsetX; dS_data = dY_data + offsetY; DL = DX; DS = DY; } else { L_data = Y_data + offsetY; S_data = X_data + offsetX; dL_data = dY_data + offsetY; dS_data = dX_data + offsetX; DL = DY; DS = DX; } // TODO: get rid of temp memory use std::vector tmp_data(DS); math::Set(DS, 0.0, dS_data, &context_); for (int j = 0; j < DL / DS; j++) { math::Scale( DS, dDot_data[i], S_data, dL_data + j * DS, &context_); math::Scale( DS, dDot_data[i], L_data + j * DS, tmp_data.data(), &context_); math::Axpy( DS, 1.0, tmp_data.data(), dS_data, &context_); } } else { math::Scale( D, dDot_data[i], X_data + offsetX, dY_data + offsetY, &context_); math::Scale( D, dDot_data[i], Y_data + offsetY, dX_data + offsetX, &context_); } if (!replicate_ && DX != DY) { T* rest_data; if (DX > DY) { rest_data = dX_data + offsetX + D; } else { rest_data = dY_data + offsetY + D; } auto pad_gradient = dDot_data[i] * pad_value_; math::Set(restD, pad_gradient, rest_data, &context_); } } return true; } protected: float pad_value_; bool replicate_; INPUT_TAGS(X_IN, Y_IN, DER_DOT_IN); OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT); }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_DISTANCE_OP_H_