#ifndef CAFFE2_OPERATORS_H_SOFTMAX_OP_H_ #define CAFFE2_OPERATORS_H_SOFTMAX_OP_H_ #include #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/proto/hsm.pb.h" #include "caffe2/utils/math.h" namespace caffe2 { template class HSoftmaxOpBase : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit HSoftmaxOpBase(Args&&... args) : Operator(std::forward(args)...) { HierarchyProto hierarchy; CAFFE_ENFORCE(hierarchy.ParseFromString( this->template GetSingleArgument("hierarchy", ""))); for (const auto& path : hierarchy.paths()) { hierarchy_all_map_.emplace(path.word_id(), path); } } protected: std::unordered_map hierarchy_all_map_; c10::optional scale_; c10::optional sum_multiplier_; c10::optional bias_multiplier_; static constexpr T kLOG_THRESHOLD() { return 1e-20f; } static std::unordered_map getHierarchyForLabels( int M, const int* labels, const std::unordered_map& hierarchy_all_map) { std::unordered_map hierarchy_map; std::set label_set = std::set(labels, labels + M); for (const auto& label : label_set) { auto search = hierarchy_all_map.find(label); CAFFE_ENFORCE(search != hierarchy_all_map.end(), "incorrect label."); hierarchy_map.emplace(search->first, search->second); } return hierarchy_map; } int getIntermediateOutputSize( const int* labels, int M, std::unordered_map& hierarchy) const { int size = 0; for (int label = 0; label < M; ++label) { int word_id = labels[label]; const auto& path = hierarchy[word_id]; size += std::accumulate( path.path_nodes().begin(), path.path_nodes().end(), 0, // Output of FC + Output of Softmax [](int sz, PathNodeProto node) { return sz + 2 * node.length(); }); } return size; } }; template class HSoftmaxOp : public HSoftmaxOpBase { public: USE_OPERATOR_CONTEXT_FUNCTIONS; using HSoftmaxOpBase::HSoftmaxOpBase; bool RunOnDevice() override; protected: float RunForwardSingle( const float* X, const float* W, const float* b, int target, float* output, const float* bias_multiplier, int w_length, int K, int& output_offset); }; template class HSoftmaxGradientOp final : public HSoftmaxOpBase { public: USE_OPERATOR_CONTEXT_FUNCTIONS; using HSoftmaxOpBase::HSoftmaxOpBase; bool RunOnDevice() override; private: void RunBackwardSingle( const float* X, const float* dY, const float* W, int target, const float* int_output, float* dX, float* dW, float* db, float* dOutput, int dim_in, int w_length, int& output_offset); }; template class HSoftmaxSearchOp final : public HSoftmaxOp { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit HSoftmaxSearchOp(Args&&... args) : HSoftmaxOp(std::forward(args)...), top_n_(this->template GetSingleArgument("topN", 5)), beam_(this->template GetSingleArgument("beam", 0.01f)) { CAFFE_ENFORCE(tree_.ParseFromString( this->template GetSingleArgument("tree", ""))); } bool RunOnDevice() override; private: int top_n_; float beam_; TreeProto tree_; bool pruning( const float* X, int sample, int K, const float* W, const float* b, const NodeProto& src_node, NodeProto& dst_node, float parent_score, float beam); bool extractNodes( const NodeProto& node, std::vector>& info); }; template class HuffmanTreeHierarchyOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit HuffmanTreeHierarchyOp(Args&&... args) : Operator(std::forward(args)...), num_classes_(this->template GetSingleArgument("num_classes", -1)) { } bool RunOnDevice() override; private: // Internal huffman tree data. struct Node { Node(T l, int count) : label(l), count(count), left_ch_index(-1), right_ch_index(-1) {} T label; int count; int left_ch_index; int right_ch_index; }; struct NodeComparator { bool operator()(const Node& node_a, const Node& node_b) { return node_a.count > node_b.count; } }; int num_classes_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_SOFTMAX_OP_H_