#ifndef QUANT_DECODE_OP_H_ #define QUANT_DECODE_OP_H_ #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/core/tensor.h" #include namespace caffe2 { namespace { template void Decode( const Tensor& codebook, const Tensor& codes, /* optional */ const Tensor* const decoded_grad, Tensor* const output, bool resizeOnly) { CAFFE_ENFORCE(codebook.IsType()); auto* cb_ptr = codebook.data(); int cb_size = codebook.numel(); CAFFE_ENFORCE(codes.IsType()); auto* code_ptr = codes.data(); if (decoded_grad == nullptr) { // Forward pass: decode and store codebook values in output. output->ResizeLike(codes); auto* out_ptr = output->template mutable_data(); if (resizeOnly) { return; } int sz = output->numel(); for (int i = 0; i < sz; i++) { DCHECK_LE(*code_ptr, cb_size); *out_ptr++ = cb_ptr[*code_ptr++]; } } else { // Backward pass: decode and accumulate gradient w.r.t. codebook values. CAFFE_ENFORCE_EQ(codes.numel(), decoded_grad->numel()); auto* gradient_ptr = decoded_grad->data(); auto* const gradient_end = gradient_ptr + decoded_grad->numel(); CAFFE_ENFORCE_EQ(cb_size, output->numel()); auto* out_ptr = output->template mutable_data(); while (gradient_ptr < gradient_end) { DCHECK_LE(*code_ptr, cb_size); out_ptr[*code_ptr++] += *gradient_ptr++; } } } #define REGISTER_DECODER(codebookType, codesType) \ { \ {TypeMeta::Id(), TypeMeta::Id()}, \ [](const Tensor& codebook_, \ const Tensor& codes_, \ const Tensor* gradient_, \ Tensor* outDecoded_, \ bool resizeOnly_) { \ Decode( \ codebook_, codes_, gradient_, outDecoded_, resizeOnly_); \ } \ } inline void DecodeGeneral( const Tensor& codebook, const Tensor& codes, const Tensor* gradient, Tensor* outDecoded, bool resizeOnly) { const static std::map< std::pair, std::function> gDecoderMapper = {REGISTER_DECODER(float, uint8_t), REGISTER_DECODER(float, uint16_t), REGISTER_DECODER(float, int32_t)}; gDecoderMapper.at({codebook.dtype().id(), codes.dtype().id()})( codebook, codes, gradient, outDecoded, resizeOnly); } } // namespace // Decode tensors based on given codebook, // The codebook is generated by model_quantize.py enum class QuantDecodeRunTy { RUN_ALWAYS, RUN_ONCE, }; template class QuantDecodeOp final : public Operator { public: USE_OPERATOR_FUNCTIONS(CPUContext); template explicit QuantDecodeOp(Args&&... args) : Operator(std::forward(args)...) {} ~QuantDecodeOp() {} bool RunOnDevice() override { CAFFE_ENFORCE_GT(InputSize(), 1); // first input is the codebook CAFFE_ENFORCE_EQ(InputSize(), OutputSize() + 1); const auto& codebook = Input(0); CAFFE_ENFORCE(codebook.template IsType(), codebook.dtype().name()); for (int i = 0; i < OutputSize(); i++) { auto& ci = Input(i + 1); auto* co = Output(i); DecodeGeneral( codebook, ci, nullptr, co, /*resizeOnly=*/QuantDecodeRun == QuantDecodeRunTy::RUN_ONCE && hasRun_); } hasRun_ = true; return true; } private: bool hasRun_{false}; }; class QuantDecodeGradientOp final : public Operator { public: USE_OPERATOR_FUNCTIONS(CPUContext); template explicit QuantDecodeGradientOp(Args&&... args) : Operator(std::forward(args)...) {} ~QuantDecodeGradientOp() {} bool RunOnDevice() override { // Inputs: 1 codebook, n tensors of codes, and n corresponding gradients. CAFFE_ENFORCE(InputSize() >= 3 && InputSize() % 2 == 1); const int num_code_tensors = (InputSize() - 1) / 2; CAFFE_ENFORCE_EQ(OutputSize(), 1); const auto& codebook = Input(0); CAFFE_ENFORCE(codebook.template IsType(), codebook.dtype().name()); auto* gradient = Output(0, codebook.sizes(), at::dtype()); auto* gradient_ptr = gradient->template mutable_data(); std::fill(gradient_ptr, gradient_ptr + gradient->numel(), 0); for (int i = 0; i < num_code_tensors; i++) { auto& codes_i = Input(i + 1); auto& output_gradient_i = Input(i + num_code_tensors + 1); DecodeGeneral(codebook, codes_i, &output_gradient_i, gradient, false); } return true; } }; } // namespace caffe2 #endif // QUANT_DECODE_OP_H_