#pragma once #include #include #include #include #include namespace torch { namespace jit { using ::c10::Symbol; constexpr int max_tensor_display_size = 10; enum class AttributeKind { f, fs, i, is, s, ss, t, ts, g, gs }; static inline const char* toString(AttributeKind kind) { static const char* names[] = { "f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs"}; AT_ASSERT(size_t(kind) < sizeof(names) / sizeof(AttributeKind)); return names[int(kind)]; } struct AttributeValue { AttributeValue(Symbol name) : name(name) {} using Ptr = std::unique_ptr; Symbol name; virtual AttributeKind kind() const = 0; virtual Ptr clone() const = 0; virtual ~AttributeValue() = default; }; template struct ScalarAttributeValue : public AttributeValue { using ConstructorType = T; using ValueType = T; ScalarAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(std::move(value_)) {} ValueType& value() { return value_; } Ptr clone() const override { return Ptr(new ScalarAttributeValue(name, value_)); } AttributeKind kind() const override { return Kind; } private: ValueType value_; }; template struct VectorAttributeValue : public AttributeValue { using ConstructorType = std::vector; using ValueType = std::vector; VectorAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(std::move(value_)) {} ValueType& value() { return value_; } AttributeKind kind() const override { return Kind; } std::unique_ptr clone() const override { auto copy = value_; return Ptr(new VectorAttributeValue(name, std::move(copy))); } private: ValueType value_; }; using FloatAttr = ScalarAttributeValue; using FloatsAttr = VectorAttributeValue; using IntAttr = ScalarAttributeValue; using IntsAttr = VectorAttributeValue; using StringAttr = ScalarAttributeValue; using StringsAttr = VectorAttributeValue; using TensorAttr = ScalarAttributeValue; using TensorsAttr = VectorAttributeValue; struct Graph; // We special case Graph attributes like this because we want to ensure that // Graph::copy() is called when we clone() these attributes. struct TORCH_API GraphAttr : public AttributeValue { using ConstructorType = std::shared_ptr; using ValueType = std::shared_ptr; GraphAttr(Symbol name, ConstructorType value_) : AttributeValue(name), value_(value_) {} ValueType& value() { return value_; } Ptr clone() const override; AttributeKind kind() const override { return AttributeKind::g; } private: std::shared_ptr value_; }; struct TORCH_API GraphsAttr : public AttributeValue { using ConstructorType = std::vector>; using ValueType = std::vector>; GraphsAttr(Symbol name, ConstructorType value_) : AttributeValue(name), value_(std::move(value_)) {} ValueType& value() { return value_; } AttributeKind kind() const override { return AttributeKind::gs; } std::unique_ptr clone() const override; private: ValueType value_; }; struct AttributeError : public std::exception { AttributeError(Symbol name, bool defined) { std::stringstream ss; if (!defined) { ss << "required keyword attribute '" << name.toUnqualString() << "' is undefined"; } else { ss << "required keyword attribute '" << name.toUnqualString() << "' has the wrong type"; } msg = ss.str(); } const char* what() const noexcept override { return msg.c_str(); } private: std::string msg; }; } // namespace jit } // namespace torch