#ifndef CAFFE2_OPERATORS_LSTM_UNIT_OP_H_
|
#define CAFFE2_OPERATORS_LSTM_UNIT_OP_H_
|
|
#include "caffe2/core/context.h"
|
#include "caffe2/core/operator.h"
|
#include "caffe2/utils/conversions.h"
|
|
namespace caffe2 {
|
namespace detail {
|
template <typename T>
|
inline T sigmoid(T x) {
|
return 1. / (1. + exp(-x));
|
}
|
|
template <typename T>
|
inline T host_tanh(T x) {
|
return 2. * sigmoid(2. * x) - 1.;
|
}
|
|
template <typename T, typename Context>
|
void LSTMUnit(
|
int N,
|
int D,
|
int t,
|
const T* H_prev,
|
const T* C_prev,
|
const T* X,
|
const int32_t* seqLengths,
|
bool drop_states,
|
T* C,
|
T* H,
|
const float forget_bias,
|
Context* /*context*/) {
|
for (int n = 0; n < N; ++n) {
|
const bool valid = seqLengths == nullptr || t < seqLengths[n];
|
|
for (int d = 0; d < D; ++d) {
|
if (!valid) {
|
if (drop_states) {
|
H[d] = 0;
|
C[d] = 0;
|
} else {
|
H[d] = H_prev[d];
|
C[d] = C_prev[d];
|
}
|
} else {
|
const T i = sigmoid(X[d]);
|
const T f = sigmoid(X[1 * D + d] + convert::To<float, T>(forget_bias));
|
const T o = sigmoid(X[2 * D + d]);
|
const T g = host_tanh(X[3 * D + d]);
|
const T c_prev = C_prev[d];
|
const T c = f * c_prev + i * g;
|
C[d] = c;
|
const T host_tanh_c = host_tanh(c);
|
H[d] = o * host_tanh_c;
|
}
|
}
|
H_prev += D;
|
C_prev += D;
|
X += 4 * D;
|
C += D;
|
H += D;
|
}
|
}
|
|
template <typename T, typename Context>
|
void LSTMUnitGradient(
|
int N,
|
int D,
|
int t,
|
const T* C_prev,
|
const T* X,
|
const int32_t* seqLengths,
|
const T* C,
|
const T* H,
|
const T* C_diff,
|
const T* H_diff,
|
bool drop_states,
|
T* H_prev_diff,
|
T* C_prev_diff,
|
T* X_diff,
|
const float forget_bias,
|
Context* /*context*/) {
|
for (int n = 0; n < N; ++n) {
|
const bool valid = seqLengths == nullptr || t < seqLengths[n];
|
|
for (int d = 0; d < D; ++d) {
|
T* c_prev_diff = C_prev_diff + d;
|
T* h_prev_diff = H_prev_diff + d;
|
T* i_diff = X_diff + d;
|
T* f_diff = X_diff + 1 * D + d;
|
T* o_diff = X_diff + 2 * D + d;
|
T* g_diff = X_diff + 3 * D + d;
|
|
if (!valid) {
|
if (drop_states) {
|
*h_prev_diff = 0;
|
*c_prev_diff = 0;
|
} else {
|
*h_prev_diff = H_diff[d];
|
*c_prev_diff = C_diff[d];
|
}
|
*i_diff = 0;
|
*f_diff = 0;
|
*o_diff = 0;
|
*g_diff = 0;
|
} else {
|
const T i = sigmoid(X[d]);
|
const T f = sigmoid(X[1 * D + d] + convert::To<float, T>(forget_bias));
|
const T o = sigmoid(X[2 * D + d]);
|
const T g = host_tanh(X[3 * D + d]);
|
const T c_prev = C_prev[d];
|
const T c = C[d];
|
const T host_tanh_c = host_tanh(c);
|
const T c_term_diff = C_diff[d] + H_diff[d] * o * (1 - host_tanh_c * host_tanh_c);
|
*c_prev_diff = c_term_diff * f;
|
*h_prev_diff = 0; // not used in 'valid' case
|
*i_diff = c_term_diff * g * i * (1 - i);
|
*f_diff = c_term_diff * c_prev * f * (1 - f);
|
*o_diff = H_diff[d] * host_tanh_c * o * (1 - o);
|
*g_diff = c_term_diff * i * (1 - g * g);
|
}
|
}
|
C_prev += D;
|
X += 4 * D;
|
C += D;
|
H += D;
|
C_diff += D;
|
H_diff += D;
|
X_diff += 4 * D;
|
H_prev_diff += D;
|
C_prev_diff += D;
|
}
|
}
|
} // namespace detail
|
|
template <typename Context>
|
class LSTMUnitOp : public Operator<Context> {
|
public:
|
explicit LSTMUnitOp(const OperatorDef& operator_def, Workspace* ws)
|
: Operator<Context>(operator_def, ws),
|
forget_bias_(static_cast<float>(
|
this->template GetSingleArgument<float>("forget_bias", 0.0))),
|
sequence_lengths_(
|
this->template GetSingleArgument<bool>("sequence_lengths", true)),
|
drop_states_(
|
this->template GetSingleArgument<bool>("drop_states", false)) {}
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
using Operator<Context>::Operator;
|
|
template <typename T>
|
bool DoRunWithType() {
|
// handle potentially-missing sequence lengths input
|
const size_t TIMESTEP = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
|
|
// Extract N
|
const auto N = Input(CELL_T_M_1).size(1);
|
|
// Gates: 1xNxG
|
const auto G = Input(GATES).size(2);
|
const auto D = Input(CELL_T_M_1).size(2);
|
|
CAFFE_ENFORCE_EQ(4 * D, G);
|
const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
|
const auto* C_prev = Input(CELL_T_M_1).template data<T>();
|
const auto* X = Input(GATES).template data<T>();
|
|
const int32_t* seqLengths = nullptr;
|
if (sequence_lengths_) {
|
CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).numel(), N);
|
seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
|
}
|
|
const auto t = static_cast<OperatorBase*>(this)
|
->Input<Tensor>(TIMESTEP, CPU)
|
.template data<int32_t>()[0];
|
Output(CELL_T)->ResizeLike(Input(CELL_T_M_1));
|
auto* C = Output(CELL_T)->template mutable_data<T>();
|
Output(HIDDEN_T)->ResizeLike(Input(CELL_T_M_1));
|
auto* H = Output(HIDDEN_T)->template mutable_data<T>();
|
detail::LSTMUnit<T, Context>(
|
N,
|
D,
|
t,
|
H_prev,
|
C_prev,
|
X,
|
seqLengths,
|
drop_states_,
|
C,
|
H,
|
forget_bias_,
|
&context_);
|
return true;
|
}
|
|
bool RunOnDevice() override {
|
return DoRunWithType<float>();
|
}
|
|
protected:
|
INPUT_TAGS(HIDDEN_T_M_1, CELL_T_M_1, GATES, SEQ_LENGTHS);
|
// additional input tags are determined dynamically based on whether
|
// sequence_lengths is present.
|
OUTPUT_TAGS(HIDDEN_T, CELL_T);
|
|
float forget_bias_;
|
bool sequence_lengths_;
|
|
private:
|
bool drop_states_;
|
};
|
|
template <typename Context>
|
class LSTMUnitGradientOp : public Operator<Context> {
|
public:
|
template <class... Args>
|
explicit LSTMUnitGradientOp(Args&&... args)
|
: Operator<Context>(std::forward<Args>(args)...),
|
forget_bias_(static_cast<float>(
|
this->template GetSingleArgument<float>("forget_bias", 0.0))),
|
sequence_lengths_(
|
this->template GetSingleArgument<bool>("sequence_lengths", true)),
|
drop_states_(
|
this->template GetSingleArgument<bool>("drop_states", false)) {}
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
template <typename T>
|
bool DoRunWithType() {
|
// handle potentially-missing sequence lengths input
|
const size_t inputOffset = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
|
const size_t TIMESTEP = inputOffset;
|
const size_t HIDDEN_T = inputOffset + 1;
|
const size_t CELL_T = inputOffset + 2;
|
const size_t HIDDEN_T_GRAD = inputOffset + 3;
|
const size_t CELL_T_GRAD = inputOffset + 4;
|
|
// Extract N
|
const auto N = Input(CELL_T_M_1).size(1);
|
|
// Gates: 1xNxG
|
const auto G = Input(GATES).size(2);
|
const auto D = Input(CELL_T_M_1).size(2);
|
|
CAFFE_ENFORCE_EQ(4 * D, G);
|
const auto* C_prev = Input(CELL_T_M_1).template data<T>();
|
const auto* X = Input(GATES).template data<T>();
|
const auto t = static_cast<OperatorBase*>(this)
|
->Input<Tensor>(TIMESTEP, CPU)
|
.template data<int32_t>()[0];
|
const auto* C = Input(CELL_T).template data<T>();
|
const auto* H = Input(HIDDEN_T).template data<T>();
|
const auto* C_diff = Input(CELL_T_GRAD).template data<T>();
|
const auto* H_diff = Input(HIDDEN_T_GRAD).template data<T>();
|
|
const int32_t* seqLengths = nullptr;
|
if (sequence_lengths_) {
|
CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).numel(), N);
|
seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
|
}
|
|
Output(HIDDEN_T_M_1_GRAD)->ResizeLike(Input(HIDDEN_T_M_1));
|
auto* H_prev_diff = Output(HIDDEN_T_M_1_GRAD)->template mutable_data<T>();
|
Output(CELL_T_M_1_GRAD)->ResizeLike(Input(CELL_T_M_1));
|
auto* C_prev_diff = Output(CELL_T_M_1_GRAD)->template mutable_data<T>();
|
Output(GATES_GRAD)->ResizeLike(Input(GATES));
|
auto* X_diff = Output(GATES_GRAD)->template mutable_data<T>();
|
|
detail::LSTMUnitGradient<T, Context>(
|
N,
|
D,
|
t,
|
C_prev,
|
X,
|
seqLengths,
|
C,
|
H,
|
C_diff,
|
H_diff,
|
drop_states_,
|
H_prev_diff,
|
C_prev_diff,
|
X_diff,
|
forget_bias_,
|
&context_);
|
return true;
|
}
|
|
bool RunOnDevice() override {
|
return DoRunWithType<float>();
|
}
|
|
protected:
|
INPUT_TAGS(HIDDEN_T_M_1, CELL_T_M_1, GATES, SEQ_LENGTHS);
|
// additional input tags are determined dynamically based on whether
|
// sequence_lengths is present.
|
OUTPUT_TAGS(HIDDEN_T_M_1_GRAD, CELL_T_M_1_GRAD, GATES_GRAD);
|
|
float forget_bias_;
|
bool sequence_lengths_;
|
|
private:
|
bool drop_states_;
|
};
|
} // namespace caffe2
|
|
#endif // CAFFE2_OPERATORS_LSTM_UNIT_OP_H_
|