#pragma once
|
|
#include "caffe2/core/operator.h"
|
|
namespace caffe2 {
|
|
struct CAFFE2_API QShapeInfo {
|
QShapeInfo(float o = 0, float s = 1, uint32_t a = 1) {
|
offset.clear();
|
scale.clear();
|
offset.push_back(o);
|
scale.push_back(s);
|
axis = a;
|
}
|
|
uint32_t axis;
|
vector<float> offset;
|
vector<float> scale;
|
};
|
|
struct CAFFE2_API ShapeInfo {
|
enum DimType : int8_t { UNKNOWN = 0, CONSTANT = 1, BATCH = 2, SEQ = 3 };
|
ShapeInfo(bool q = false) : is_quantized(q) {}
|
ShapeInfo(DimType t, TensorShape&& s, bool q = false)
|
: dim_type(t), shape(std::move(s)), is_quantized(q) {}
|
ShapeInfo(DimType t, const TensorShape& s, bool q = false)
|
: dim_type(t), shape(s), is_quantized(q) {}
|
|
ShapeInfo(bool q, const QShapeInfo& info) : is_quantized(q), q_info(info) {}
|
ShapeInfo(DimType t, TensorShape&& s, bool q, const QShapeInfo& info)
|
: dim_type(t), shape(std::move(s)), is_quantized(q), q_info(info) {}
|
ShapeInfo(DimType t, const TensorShape& s, bool q, const QShapeInfo& info)
|
: dim_type(t), shape(s), is_quantized(q), q_info(info) {}
|
|
// type of the shape according its first dim
|
DimType dim_type{DimType::UNKNOWN};
|
TensorShape shape;
|
|
// quantization related information
|
bool is_quantized;
|
QShapeInfo q_info;
|
};
|
|
using ShapeInfoMap = std::unordered_map<std::string, ShapeInfo>;
|
|
// Generates ShapeInfo from Blob.
|
ShapeInfo getShapeInfoFromBlob(const Blob* blob);
|
|
bool operator==(const ShapeInfo& lhs, const ShapeInfo& rhs);
|
|
} // namespace caffe2
|