#pragma once
|
#include <ATen/ATen.h>
|
#include <string>
|
#include <vector>
|
|
#include <ATen/core/interned_strings.h>
|
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
|
namespace torch {
|
namespace jit {
|
|
using ::c10::Symbol;
|
|
constexpr int max_tensor_display_size = 10;
|
|
enum class AttributeKind { f, fs, i, is, s, ss, t, ts, g, gs };
|
static inline const char* toString(AttributeKind kind) {
|
static const char* names[] = {
|
"f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs"};
|
AT_ASSERT(size_t(kind) < sizeof(names) / sizeof(AttributeKind));
|
return names[int(kind)];
|
}
|
|
struct AttributeValue {
|
AttributeValue(Symbol name) : name(name) {}
|
using Ptr = std::unique_ptr<AttributeValue>;
|
Symbol name;
|
virtual AttributeKind kind() const = 0;
|
virtual Ptr clone() const = 0;
|
virtual ~AttributeValue() = default;
|
};
|
|
template <typename T, AttributeKind Kind>
|
struct ScalarAttributeValue : public AttributeValue {
|
using ConstructorType = T;
|
using ValueType = T;
|
ScalarAttributeValue(Symbol name, ConstructorType value_)
|
: AttributeValue(name), value_(std::move(value_)) {}
|
ValueType& value() {
|
return value_;
|
}
|
Ptr clone() const override {
|
return Ptr(new ScalarAttributeValue(name, value_));
|
}
|
AttributeKind kind() const override {
|
return Kind;
|
}
|
|
private:
|
ValueType value_;
|
};
|
|
template <typename T, AttributeKind Kind>
|
struct VectorAttributeValue : public AttributeValue {
|
using ConstructorType = std::vector<T>;
|
using ValueType = std::vector<T>;
|
VectorAttributeValue(Symbol name, ConstructorType value_)
|
: AttributeValue(name), value_(std::move(value_)) {}
|
ValueType& value() {
|
return value_;
|
}
|
AttributeKind kind() const override {
|
return Kind;
|
}
|
std::unique_ptr<AttributeValue> clone() const override {
|
auto copy = value_;
|
return Ptr(new VectorAttributeValue(name, std::move(copy)));
|
}
|
|
private:
|
ValueType value_;
|
};
|
|
using FloatAttr = ScalarAttributeValue<double, AttributeKind::f>;
|
using FloatsAttr = VectorAttributeValue<double, AttributeKind::fs>;
|
using IntAttr = ScalarAttributeValue<int64_t, AttributeKind::i>;
|
using IntsAttr = VectorAttributeValue<int64_t, AttributeKind::is>;
|
using StringAttr = ScalarAttributeValue<std::string, AttributeKind::s>;
|
using StringsAttr = VectorAttributeValue<std::string, AttributeKind::ss>;
|
using TensorAttr = ScalarAttributeValue<at::Tensor, AttributeKind::t>;
|
using TensorsAttr = VectorAttributeValue<at::Tensor, AttributeKind::ts>;
|
struct Graph;
|
|
// We special case Graph attributes like this because we want to ensure that
|
// Graph::copy() is called when we clone() these attributes.
|
struct TORCH_API GraphAttr : public AttributeValue {
|
using ConstructorType = std::shared_ptr<Graph>;
|
using ValueType = std::shared_ptr<Graph>;
|
GraphAttr(Symbol name, ConstructorType value_)
|
: AttributeValue(name), value_(value_) {}
|
ValueType& value() {
|
return value_;
|
}
|
Ptr clone() const override;
|
AttributeKind kind() const override {
|
return AttributeKind::g;
|
}
|
|
private:
|
std::shared_ptr<Graph> value_;
|
};
|
|
struct TORCH_API GraphsAttr : public AttributeValue {
|
using ConstructorType = std::vector<std::shared_ptr<Graph>>;
|
using ValueType = std::vector<std::shared_ptr<Graph>>;
|
GraphsAttr(Symbol name, ConstructorType value_)
|
: AttributeValue(name), value_(std::move(value_)) {}
|
ValueType& value() {
|
return value_;
|
}
|
AttributeKind kind() const override {
|
return AttributeKind::gs;
|
}
|
std::unique_ptr<AttributeValue> clone() const override;
|
|
private:
|
ValueType value_;
|
};
|
|
struct AttributeError : public std::exception {
|
AttributeError(Symbol name, bool defined) {
|
std::stringstream ss;
|
if (!defined) {
|
ss << "required keyword attribute '" << name.toUnqualString()
|
<< "' is undefined";
|
} else {
|
ss << "required keyword attribute '" << name.toUnqualString()
|
<< "' has the wrong type";
|
}
|
msg = ss.str();
|
}
|
const char* what() const noexcept override {
|
return msg.c_str();
|
}
|
|
private:
|
std::string msg;
|
};
|
} // namespace jit
|
} // namespace torch
|