#ifndef CAFFE2_OPERATORS_GRU_UNIT_OP_H_ #define CAFFE2_OPERATORS_GRU_UNIT_OP_H_ #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { namespace detail { template inline T sigmoid(T x) { return 1.0f / (1.0f + exp(-x)); } template inline T host_tanh(T x) { return 2.0f * sigmoid(2.0f * x) - 1.0f; } template void GRUUnit( int N, int D, int t, const T* H_prev, const T* X, const int32_t* seqLengths, bool drop_states, T* H, Context* /*context*/) { for (int n = 0; n < N; ++n) { const bool valid = seqLengths == nullptr || t < seqLengths[n]; for (int d = 0; d < D; ++d) { if (!valid) { if (drop_states) { H[d] = 0; } else { H[d] = H_prev[d]; } } else { const T update = X[1 * D + d]; const T output = X[2 * D + d]; T sigmoid_update = sigmoid(update); H[d] = H_prev[d] * sigmoid_update + host_tanh(output) * (1.0f - sigmoid_update); } } H_prev += D; X += 3 * D; H += D; } } template void GRUUnitGradient( int N, int D, int t, const T* H_prev, const T* X, const int32_t* seqLengths, const T* H, const T* H_diff, bool drop_states, T* H_prev_diff, T* X_diff, Context* /*context*/) { for (int n = 0; n < N; ++n) { const bool valid = seqLengths == nullptr || t < seqLengths[n]; for (int d = 0; d < D; ++d) { T* h_prev_diff = H_prev_diff + d; T* reset_diff = X_diff + 0 * D + d; T* update_diff = X_diff + 1 * D + d; T* output_diff = X_diff + 2 * D + d; if (!valid) { if (drop_states) { *h_prev_diff = 0; } else { *h_prev_diff = H_diff[d]; } *reset_diff = 0; *update_diff = 0; *output_diff = 0; } else { // Calculate Gate Outputs const T u = sigmoid(X[1 * D + d]); const T o = host_tanh(X[2 * D + d]); *h_prev_diff = H_diff[d] * u; *reset_diff = 0; // 0 contribution to gradient from this operation *update_diff = (H_diff[d] * H_prev[d] - H_diff[d] * o) * u * (1.0f - u); *output_diff = H_diff[d] * (1.0f - u) * (1.0f - o * o); } } H_prev += D; X += 3 * D; H += D; H_diff += D; X_diff += 3 * D; H_prev_diff += D; } } } // namespace detail template class GRUUnitOp : public Operator { public: template explicit GRUUnitOp(Args&&... args) : Operator(std::forward(args)...), drop_states_( this->template GetSingleArgument("drop_states", false)), sequence_lengths_( this->template GetSingleArgument("sequence_lengths", true)) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override { // handle potentially-missing sequence lengths input const size_t TIMESTEP = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0); // Extract N const auto N = Input(HIDDEN_T_M_1).size(1); // Gates: 1xNxG const auto G = Input(GATES).size(2); const auto D = Input(HIDDEN_T_M_1).size(2); CAFFE_ENFORCE_EQ(3 * D, G); const auto* H_prev = Input(HIDDEN_T_M_1).template data(); const auto* X = Input(GATES).template data(); const int32_t* seqLengths = nullptr; if (sequence_lengths_) { CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).numel(), N); seqLengths = Input(SEQ_LENGTHS).template data(); } const auto t = static_cast(this) ->Input(TIMESTEP, CPU) .template data()[0]; Output(HIDDEN_T)->ResizeLike(Input(HIDDEN_T_M_1)); auto* H = Output(HIDDEN_T)->template mutable_data(); detail::GRUUnit( N, D, t, H_prev, X, seqLengths, drop_states_, H, &context_); return true; } protected: INPUT_TAGS(HIDDEN_T_M_1, GATES, SEQ_LENGTHS); // additional input tags are determined dynamically based on whether // sequence_lengths is present. OUTPUT_TAGS(HIDDEN_T); private: bool drop_states_; bool sequence_lengths_; }; template class GRUUnitGradientOp : public Operator { public: template explicit GRUUnitGradientOp(Args&&... args) : Operator(std::forward(args)...), drop_states_( this->template GetSingleArgument("drop_states", false)), sequence_lengths_( this->template GetSingleArgument("sequence_lengths", true)) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override { // handle potentially-missing sequence lengths input const size_t inputOffset = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0); const size_t TIMESTEP = inputOffset; const size_t HIDDEN_T = inputOffset + 1; const size_t HIDDEN_T_GRAD = inputOffset + 2; // Extract N const auto N = Input(HIDDEN_T_M_1).size(1); // Gates: 1xNxG const auto G = Input(GATES).size(2); const auto D = Input(HIDDEN_T_M_1).size(2); CAFFE_ENFORCE_EQ(3 * D, G); const auto* H_prev = Input(HIDDEN_T_M_1).template data(); const auto* X = Input(GATES).template data(); const auto t = static_cast(this) ->Input(TIMESTEP, CPU) .template data()[0]; const auto* H = Input(HIDDEN_T).template data(); const auto* H_diff = Input(HIDDEN_T_GRAD).template data(); const int32_t* seqLengths = nullptr; if (sequence_lengths_) { CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).numel(), N); seqLengths = Input(SEQ_LENGTHS).template data(); } Output(HIDDEN_T_M_1_GRAD)->ResizeLike(Input(HIDDEN_T_M_1)); auto* H_prev_diff = Output(HIDDEN_T_M_1_GRAD)->template mutable_data(); Output(GATES_GRAD)->ResizeLike(Input(GATES)); auto* X_diff = Output(GATES_GRAD)->template mutable_data(); detail::GRUUnitGradient( N, D, t, H_prev, X, seqLengths, H, H_diff, drop_states_, H_prev_diff, X_diff, &context_); return true; } protected: INPUT_TAGS(HIDDEN_T_M_1, GATES, SEQ_LENGTHS); OUTPUT_TAGS(HIDDEN_T_M_1_GRAD, GATES_GRAD); private: bool drop_states_; bool sequence_lengths_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_GRU_UNIT_OP_H_