#pragma once
|
|
#include "caffe2/core/logging.h"
|
#include "caffe2/opt/shape_info.h"
|
#include "caffe2/proto/caffe2_pb.h"
|
|
#include <sstream>
|
#include <string>
|
#include <unordered_map>
|
#include <unordered_set>
|
|
namespace caffe2 {
|
// This struct stores the max bound size for batch in the general sense. We have
|
// the conventioal batch size and the look-up sequence, which is also batch in a
|
// sense.
|
struct CAFFE2_API BoundShapeSpec {
|
explicit BoundShapeSpec(int64_t b, int64_t q)
|
: max_batch_size(b), max_seq_size(q) {}
|
int64_t max_batch_size;
|
int64_t max_seq_size;
|
};
|
|
/// \class A class that does bound shape inference given a C2 net. Depending on
|
/// its type, each op have a maximum shape that it accepts. We define some
|
/// initial bound for certain dimension, for example max batch size or max
|
/// sequnce lookup size. And the inference will first infer the input size and
|
/// then propagates the bound shape down the network. For now the variable part
|
/// (bound part) is the first dimension of the shape, which usually corresponds
|
/// to the batch size or sequence lookup size.
|
class BoundShapeInferencerBase {
|
public:
|
explicit BoundShapeInferencerBase(const BoundShapeSpec& spec) : spec_(spec) {
|
CAFFE_ENFORCE_GE(spec_.max_batch_size, 0);
|
CAFFE_ENFORCE_GE(spec_.max_seq_size, 0);
|
}
|
|
virtual ~BoundShapeInferencerBase() {}
|
|
virtual void InferBoundShapeAndType(
|
const NetDef& net,
|
const std::unordered_map<std::string, ShapeInfo>& info,
|
caffe2::Workspace* ws) = 0;
|
|
const ShapeInfoMap& shape_info() const {
|
return shape_info_;
|
}
|
|
/// Print out all the shape info
|
std::string PrintShapeInfo() const {
|
std::stringstream ss;
|
for (const auto& kv : shape_info_) {
|
const auto& s = kv.second;
|
ss << s.shape.name() << ": dim_type: " << s.dim_type << ", dims: [";
|
for (const auto d : s.shape.dims()) {
|
ss << d << ", ";
|
}
|
ss << "], dtype: " << s.shape.data_type() << "\n";
|
}
|
return ss.str();
|
}
|
|
protected:
|
const BoundShapeSpec spec_;
|
std::unordered_map<std::string, ShapeInfo> shape_info_;
|
};
|
|
class CAFFE2_API BoundShapeInferencer : public BoundShapeInferencerBase {
|
public:
|
explicit BoundShapeInferencer(const BoundShapeSpec& spec)
|
: BoundShapeInferencerBase(spec) {}
|
|
virtual ~BoundShapeInferencer() override {}
|
void InferBoundShapeAndType(
|
const NetDef& net,
|
const std::unordered_map<std::string, ShapeInfo>& info,
|
caffe2::Workspace* ws) override;
|
|
protected:
|
TensorShape& CheckAndSetTensorShapeAndType(
|
const std::string& name,
|
ShapeInfo::DimType t,
|
std::vector<int64_t> bound_dims,
|
TensorProto::DataType type,
|
bool is_quantized,
|
bool allow_existing_shape = false);
|
|
TensorShape& SetTensorShapeAndTypeIfNotExist(
|
const std::string& name,
|
ShapeInfo::DimType t,
|
std::vector<int64_t> bound_dims,
|
TensorProto::DataType type,
|
bool is_quantized);
|
|
virtual void InferOps(const OperatorDef& op, caffe2::Workspace* ws);
|
|
void InferConcatInputs(const OperatorDef& op);
|
|
void InferGivenTensorFill(const OperatorDef& op);
|
void InferSparseLengthsSum(const OperatorDef& op);
|
void InferFC(const OperatorDef& op);
|
void InferConcat(const OperatorDef& op);
|
void InferShape(const OperatorDef& op);
|
void InferReshape(const OperatorDef& op);
|
void InferLengthsRangeFill(const OperatorDef& op);
|
|
// Standard shape/type inference using op schema registered shape inference
|
// function
|
void InferCommonOp(const OperatorDef& op);
|
|
void EnsureShapeNames(std::unordered_map<std::string, ShapeInfo>* info) const;
|
|
ShapeInfo::DimType current_dim_type_{ShapeInfo::DimType::BATCH};
|
int64_t current_max_batch_size_{0};
|
};
|
|
CAFFE2_API std::shared_ptr<BoundShapeInferencerBase> getBoundShapeInferencer(
|
const BoundShapeSpec& spec);
|
|
C10_DECLARE_SHARED_REGISTRY(
|
BoundShapeInferencerRegistry,
|
BoundShapeInferencerBase,
|
const BoundShapeSpec&);
|
|
} // namespace caffe2
|