#pragma once #include #include #include #include #include #include template inline std::vector makeStack(Inputs&&... inputs) { return {std::forward(inputs)...}; } inline at::Tensor dummyTensor(c10::TensorTypeId dispatch_key) { auto* allocator = c10::GetCPUAllocator(); int64_t nelements = 1; auto dtype = caffe2::TypeMeta::Make(); auto storage_impl = c10::make_intrusive( dtype, nelements, allocator->allocate(nelements * dtype.itemsize()), allocator, /*resizable=*/true); return at::detail::make_tensor(storage_impl, dispatch_key); } template inline std::vector callOp(const c10::OperatorHandle& op, Args... args) { auto stack = makeStack(std::forward(args)...); c10::Dispatcher::singleton().callBoxed(op, &stack); return stack; } template inline Result callOpUnboxed(const c10::OperatorHandle& op, c10::TensorTypeId dispatchKey, Args... args) { return c10::Dispatcher::singleton() .template callUnboxed(op, dispatchKey, std::forward(args)...); } inline void expectDoesntFindKernel(const char* op_name, c10::TensorTypeId dispatch_key) { auto op = c10::Dispatcher::singleton().findSchema({op_name, ""}); EXPECT_ANY_THROW( callOp(*op, dummyTensor(dispatch_key), 5); ); } inline void expectDoesntFindOperator(const char* op_name) { auto op = c10::Dispatcher::singleton().findSchema({op_name, ""}); EXPECT_FALSE(op.has_value()); } template inline void expectThrows(Functor&& functor, const char* expectMessageContains) { try { std::forward(functor)(); } catch (const Exception& e) { EXPECT_THAT(e.what(), testing::HasSubstr(expectMessageContains)); return; } ADD_FAILURE() << "Expected to throw exception containing \"" << expectMessageContains << "\" but didn't throw"; } template void expectListEquals(c10::ArrayRef expected, c10::List actual) { EXPECT_EQ(expected.size(), actual.size()); for (size_t i = 0; i < expected.size(); ++i) { EXPECT_EQ(expected[i], actual.get(i)); } } template void expectListEquals(c10::ArrayRef expected, std::vector actual) { EXPECT_EQ(expected.size(), actual.size()); for (size_t i = 0; i < expected.size(); ++i) { EXPECT_EQ(expected[i], actual[i]); } } // NB: This is not really sound, but all of the type sets constructed here // are singletons so it's fine static inline c10::TensorTypeId extractTypeId(const at::Tensor& t) { return legacyExtractTypeId(t.type_set()); }