#ifndef QUANT_DECODE_OP_H_
|
#define QUANT_DECODE_OP_H_
|
|
#include "caffe2/core/context.h"
|
#include "caffe2/core/operator.h"
|
#include "caffe2/core/tensor.h"
|
#include <c10/util/typeid.h>
|
|
namespace caffe2 {
|
|
namespace {
|
|
template <class CodebookT, class CodeT>
|
void Decode(
|
const Tensor& codebook,
|
const Tensor& codes,
|
/* optional */ const Tensor* const decoded_grad,
|
Tensor* const output,
|
bool resizeOnly) {
|
CAFFE_ENFORCE(codebook.IsType<CodebookT>());
|
|
auto* cb_ptr = codebook.data<CodebookT>();
|
int cb_size = codebook.numel();
|
|
CAFFE_ENFORCE(codes.IsType<CodeT>());
|
auto* code_ptr = codes.data<CodeT>();
|
|
if (decoded_grad == nullptr) {
|
// Forward pass: decode and store codebook values in output.
|
output->ResizeLike(codes);
|
auto* out_ptr = output->template mutable_data<CodebookT>();
|
if (resizeOnly) {
|
return;
|
}
|
|
int sz = output->numel();
|
for (int i = 0; i < sz; i++) {
|
DCHECK_LE(*code_ptr, cb_size);
|
*out_ptr++ = cb_ptr[*code_ptr++];
|
}
|
} else {
|
// Backward pass: decode and accumulate gradient w.r.t. codebook values.
|
CAFFE_ENFORCE_EQ(codes.numel(), decoded_grad->numel());
|
auto* gradient_ptr = decoded_grad->data<CodebookT>();
|
auto* const gradient_end = gradient_ptr + decoded_grad->numel();
|
|
CAFFE_ENFORCE_EQ(cb_size, output->numel());
|
auto* out_ptr = output->template mutable_data<CodebookT>();
|
while (gradient_ptr < gradient_end) {
|
DCHECK_LE(*code_ptr, cb_size);
|
out_ptr[*code_ptr++] += *gradient_ptr++;
|
}
|
}
|
}
|
|
#define REGISTER_DECODER(codebookType, codesType) \
|
{ \
|
{TypeMeta::Id<codebookType>(), TypeMeta::Id<codesType>()}, \
|
[](const Tensor& codebook_, \
|
const Tensor& codes_, \
|
const Tensor* gradient_, \
|
Tensor* outDecoded_, \
|
bool resizeOnly_) { \
|
Decode<codebookType, codesType>( \
|
codebook_, codes_, gradient_, outDecoded_, resizeOnly_); \
|
} \
|
}
|
|
inline void DecodeGeneral(
|
const Tensor& codebook,
|
const Tensor& codes,
|
const Tensor* gradient,
|
Tensor* outDecoded,
|
bool resizeOnly) {
|
const static std::map<
|
std::pair<TypeIdentifier, TypeIdentifier>,
|
std::function<void(
|
const Tensor& codebook,
|
const Tensor& codes,
|
const Tensor* gradient,
|
Tensor* outDecoded,
|
bool resizeOnly)>>
|
gDecoderMapper = {REGISTER_DECODER(float, uint8_t),
|
REGISTER_DECODER(float, uint16_t),
|
REGISTER_DECODER(float, int32_t)};
|
|
gDecoderMapper.at({codebook.dtype().id(), codes.dtype().id()})(
|
codebook, codes, gradient, outDecoded, resizeOnly);
|
}
|
|
} // namespace
|
|
// Decode tensors based on given codebook,
|
// The codebook is generated by model_quantize.py
|
|
enum class QuantDecodeRunTy {
|
RUN_ALWAYS,
|
RUN_ONCE,
|
};
|
|
template <QuantDecodeRunTy QuantDecodeRun>
|
class QuantDecodeOp final : public Operator<CPUContext> {
|
public:
|
USE_OPERATOR_FUNCTIONS(CPUContext);
|
template <class... Args>
|
explicit QuantDecodeOp(Args&&... args)
|
: Operator<CPUContext>(std::forward<Args>(args)...) {}
|
|
~QuantDecodeOp() {}
|
|
bool RunOnDevice() override {
|
CAFFE_ENFORCE_GT(InputSize(), 1);
|
// first input is the codebook
|
CAFFE_ENFORCE_EQ(InputSize(), OutputSize() + 1);
|
|
const auto& codebook = Input(0);
|
CAFFE_ENFORCE(codebook.template IsType<float>(), codebook.dtype().name());
|
|
for (int i = 0; i < OutputSize(); i++) {
|
auto& ci = Input(i + 1);
|
auto* co = Output(i);
|
|
DecodeGeneral(
|
codebook,
|
ci,
|
nullptr,
|
co,
|
/*resizeOnly=*/QuantDecodeRun == QuantDecodeRunTy::RUN_ONCE &&
|
hasRun_);
|
}
|
hasRun_ = true;
|
return true;
|
}
|
|
private:
|
bool hasRun_{false};
|
};
|
|
class QuantDecodeGradientOp final : public Operator<CPUContext> {
|
public:
|
USE_OPERATOR_FUNCTIONS(CPUContext);
|
template <class... Args>
|
explicit QuantDecodeGradientOp(Args&&... args)
|
: Operator<CPUContext>(std::forward<Args>(args)...) {}
|
~QuantDecodeGradientOp() {}
|
|
bool RunOnDevice() override {
|
// Inputs: 1 codebook, n tensors of codes, and n corresponding gradients.
|
CAFFE_ENFORCE(InputSize() >= 3 && InputSize() % 2 == 1);
|
const int num_code_tensors = (InputSize() - 1) / 2;
|
CAFFE_ENFORCE_EQ(OutputSize(), 1);
|
|
const auto& codebook = Input(0);
|
CAFFE_ENFORCE(codebook.template IsType<float>(), codebook.dtype().name());
|
|
auto* gradient = Output(0, codebook.sizes(), at::dtype<float>());
|
auto* gradient_ptr = gradient->template mutable_data<float>();
|
std::fill(gradient_ptr, gradient_ptr + gradient->numel(), 0);
|
|
for (int i = 0; i < num_code_tensors; i++) {
|
auto& codes_i = Input(i + 1);
|
auto& output_gradient_i = Input(i + num_code_tensors + 1);
|
DecodeGeneral(codebook, codes_i, &output_gradient_i, gradient, false);
|
}
|
return true;
|
}
|
};
|
|
} // namespace caffe2
|
#endif // QUANT_DECODE_OP_H_
|