#ifndef CAFFE2_CORE_OPERATOR_GRADIENT_H_ #define CAFFE2_CORE_OPERATOR_GRADIENT_H_ #include "c10/util/Registry.h" #include "caffe2/core/operator_schema.h" #include "caffe2/proto/caffe2_pb.h" #include "caffe2/utils/proto_utils.h" namespace caffe2 { /* @brief A struct that abstracts on top of dense and sparse blobs. * * For a dense blob, its gradient name should be written into dense_, and for * a sparse blob, its gradient name should be written into indice_ for * the sparse indices and value_ for the values. */ struct CAFFE2_API GradientWrapper { string dense_; string indices_; string values_; inline bool IsDense() const { return (dense_.size() != 0); } inline bool IsSparse() const { return (indices_.size() != 0 || values_.size() != 0); } inline bool IsEmpty() const { return (!IsDense() && !IsSparse()); } }; /** * A struct that holds the gradient operators and related gradient maps. */ struct CAFFE2_API GradientOpsMeta { vector ops_; vector g_input_; GradientOpsMeta() {} GradientOpsMeta( const vector& ops, const vector& v) : ops_(ops), g_input_(v) {} }; class CAFFE2_API GradientMakerBase { public: GradientMakerBase( const OperatorDef& def, const vector& g_output) : def_(def), g_output_(g_output), g_input_(def.input_size()){}; virtual ~GradientMakerBase() {} virtual bool CopyDeviceOption() const { return true; } virtual bool CopyEngine() const { return true; } virtual bool CopyArguments() const { return true; } virtual void VerifyOp() const { auto* schema = OpSchemaRegistry::Schema(def_.type()); if (schema) { CAFFE_ENFORCE( schema->Verify(def_), "(GradientMaker) Operator def did not pass schema checking: ", ProtoDebugString(def_)); } } /** * @brief Returns the gradient ops meta. * * If your gradient op generator only use standard input and output * manipulations, you can simply implement GetGradientDefs() that * returns vector. In that, you can call GI, GI_V and GI_I * that will automatically create the gradient registration for you. * * If you need to do custom gradient name registration, overload this * function directly. */ virtual GradientOpsMeta Get() { VerifyOp(); vector new_defs = GetGradientDefs(); for (auto& opdef : new_defs) { opdef.set_is_gradient_op(true); } return GradientOpsMeta(new_defs, g_input_); }; const OperatorDef& Def() const { return def_; } protected: virtual vector GetGradientDefs() { CAFFE_NOT_IMPLEMENTED; } // Helper functions to return names for the gradient computation. // I(idx), O(idx): return the input and output names. // GO(idx): return the name of the gradient for output idx. // GI(idx), GI_I(idx), GI_V(idx): return the name of the gradient for // input idx, and also registers that name into the gradient // registry to be returned. string I(const int i) { CAFFE_ENFORCE((i >= 0) && (i < def_.input().size())); return def_.input(i); } string O(const int i) { CAFFE_ENFORCE((i >= 0) && (i < def_.output().size())); return def_.output(i); } string GI(const int i) { CAFFE_ENFORCE( !g_input_.at(i).IsSparse(), "Input ", def_.input(i), " already set to sparse."); g_input_.at(i).dense_ = GradientName(def_.input(i)); return GradientName(def_.input(i)); } string GI_I(const int i) { CAFFE_ENFORCE( !g_input_.at(i).IsDense(), "Input ", def_.input(i), " already set to dense."); g_input_.at(i).indices_ = GradientSliceIndices(def_.input(i)); return GradientSliceIndices(def_.input(i)); } string GI_V(const int i) { CAFFE_ENFORCE( !g_input_.at(i).IsDense(), "Input ", def_.input(i), " already set to dense."); g_input_.at(i).values_ = GradientSliceValues(def_.input(i)); return GradientSliceValues(def_.input(i)); } string GO(const int i) { CAFFE_ENFORCE( g_output_.at(i).IsDense(), "Gradient of output ", def_.output(i), (g_output_.at(i).IsSparse() ? " is sparse (expected dense)." : " is not provided!")); return g_output_.at(i).dense_; } string GO_I(const int i) { CAFFE_ENFORCE( g_output_.at(i).IsSparse(), "Gradient of output ", def_.output(i), (g_output_.at(i).IsDense() ? " is dense (expected sparse)." : " is not provided!")); return g_output_.at(i).indices_; } string GO_V(const int i) { CAFFE_ENFORCE( g_output_.at(i).IsSparse(), "Gradient of output ", def_.output(i), (g_output_.at(i).IsDense() ? " is dense (expected sparse)." : " is not provided!")); return g_output_.at(i).values_; } const GradientWrapper& GradOut(int i) { return g_output_.at(i); } // Function to add a gradient pair to map. void SetDense(const int i, const string& name) { CAFFE_ENFORCE( !g_input_.at(i).IsSparse(), "Input ", def_.input(i), " already set to sparse."); g_input_.at(i).dense_ = name; } void SetSparse(const int i, const string& indices, const string& values) { CAFFE_ENFORCE( !g_input_.at(i).IsDense(), "Input ", def_.input(i), " already set to dense."); g_input_.at(i).indices_ = indices; g_input_.at(i).values_ = values; } /** * @brief a helper function to allow one to create one single operator * def, which is usually the case for many simple operators. */ template inline static vector SingleGradientDef(const Args&... args) { return vector{CreateOperatorDef(args...)}; } public: /** * Returns map that returns the parameters that the gradients are for. */ static CaffeMap MatchGradsToParams(const OperatorDef& op) { // NOTE: how to go beyond string-matching? CaffeMap m; for (auto& out : op.output()) { if (IsGradientBlob(out)) { m[out] = out.substr(0, out.length() - 5); } } return m; } private: // Utility functions for gradient name computation. We don't expose them // in order to discourage the use of such names explicitly. static string GradientName(const string& name) { return name + "_grad"; } static bool IsGradientBlob(const string& name) { return name.length() > 5 && name.find("_grad") == name.length() - 5; } static string GradientNameToParam(const string& name) { CHECK(IsGradientBlob(name)); return name.substr(0, name.length() - 5); } static string GradientSliceIndices(const string& name) { return name + "_grad_indices"; } static string GradientSliceValues(const string& name) { return name + "_grad_values"; } protected: // We make the member variables protected in case someone wants to write // a fully custom Get() function. const OperatorDef& def_; const vector& g_output_; vector g_input_; }; /** * @brief A helper class to indicate that the operator does not need gradient * computation. * * Use the macro NO_GRADIENT to register operators that do not have gradients. * Note that this is different fron SHOULD_NOT_DO_GRADIENT: the latter means * that the gradient computation should not flow through it at all, and throws * an error if it is called. */ class CAFFE2_API NoGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; vector GetGradientDefs() override { return vector(); } }; /** * @brief A helper class to indicate that the operator should have no gradient. * * This is used when the operator definition is designed to not have a gradient. * Calling a gradient on this operator def will cause Caffe2 to quit. */ struct ThrowInTheTowelIfGradientIsCalled : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; GradientOpsMeta Get() override { CAFFE_THROW("One should not call gradient for operator ", def_.type(), "."); } }; /** * @brief A helper class to indicate that the gradient mechanism is not ready. * * This should only be used sparsely when the gradient does exist, but we have * not implemented it yet and are using this as a lazy excuse. Eventually, a * gradient operator should be implemented. */ struct GradientNotImplementedYet : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; GradientOpsMeta Get() override { CAFFE_THROW( "Operator ", def_.type(), " should have a gradient but is not implemented yet."); } }; C10_DECLARE_REGISTRY( GradientRegistry, GradientMakerBase, const OperatorDef&, const vector&); #ifdef CAFFE2_NO_GRADIENT_OPS #define REGISTER_GRADIENT(name, ...) /* No gradients. */ #define REGISTER_GRADIENT_STR(str_name, ...) /* No gradients. */ #else #define REGISTER_GRADIENT(name, ...) \ C10_REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__) #define REGISTER_GRADIENT_STR(str_name, ...) \ C10_REGISTER_TYPED_CLASS(GradientRegistry, str_name, __VA_ARGS__) #endif // NO_GRADIENT means that the operator does not need any gradient computation. #define NO_GRADIENT(name) REGISTER_GRADIENT(name, NoGradient) // SHOULD_NOT_DO_GRADIENT means that the operator is not designed to have // gradient operators. If you attempt to call the gradient, a log fatal will // occur. #define SHOULD_NOT_DO_GRADIENT(name) \ REGISTER_GRADIENT(name, ThrowInTheTowelIfGradientIsCalled) #define GRADIENT_NOT_IMPLEMENTED_YET(name) \ REGISTER_GRADIENT(name, GradientNotImplementedYet) /** * @brief Gets the GradientOpsMeta for the given operator def. */ CAFFE2_API GradientOpsMeta GetGradientForOp( const OperatorDef& def, const vector& g_output); } // namespace caffe2 #endif // CAFFE2_CORE_OPERATOR_GRADIENT_H_