#pragma once #include #include #include #ifdef _WIN32 #define WIN32_LEAN_AND_MEAN #endif #include "tbb/tbb.h" #define INTRA_OP_PARALLEL namespace at { template inline void parallel_for( const int64_t begin, const int64_t end, const int64_t grain_size, const F& f) { TORCH_CHECK(grain_size >= 0); if (begin >= end) { return; } if ((end - begin) < grain_size || get_num_threads() == 1) { f(begin, end); return; } std::atomic_flag err_flag = ATOMIC_FLAG_INIT; std::exception_ptr eptr; tbb::parallel_for(tbb::blocked_range(begin, end, grain_size), [&eptr, &err_flag, f](const tbb::blocked_range& r) { try { f(r.begin(), r.end()); } catch (...) { if (!err_flag.test_and_set()) { eptr = std::current_exception(); } } }); if (eptr) { std::rethrow_exception(eptr); } } template inline scalar_t parallel_reduce( const int64_t begin, const int64_t end, const int64_t grain_size, const scalar_t ident, const F& f, const SF& sf) { TORCH_CHECK(grain_size >= 0); if (begin >= end) { return ident; } if ((end - begin) < grain_size || get_num_threads() == 1) { return f(begin, end, ident); } scalar_t result; std::atomic_flag err_flag = ATOMIC_FLAG_INIT; std::exception_ptr eptr; result = tbb::parallel_reduce( tbb::blocked_range(begin, end, grain_size), ident, [&eptr, &err_flag, f, ident] (const tbb::blocked_range& r, scalar_t ident) { try { return f(r.begin(), r.end(), ident); } catch (...) { if (!err_flag.test_and_set()) { eptr = std::current_exception(); } return ident; } }, sf ); if (eptr) { std::rethrow_exception(eptr); } return result; } template void intraop_invoke(const F0& f0, const F1& f1) { tbb::parallel_invoke(f0, f1); } } // namespace at