#pragma once #include #include #include #include namespace torch { namespace throughput_benchmark { namespace detail { template BenchmarkExecutionStats BenchmarkHelper::benchmark( const BenchmarkConfig& config) const { CHECK(initialized_); TORCH_CHECK( config.num_worker_threads == 1, "Only parallelization by callers is supported"); // We pre-generate inputs here for each of the threads. This allows us to // safely move inputs out for each of the threads independently and thus avoid // overhead from the benchmark runner itself std::vector> thread_inputs(config.num_calling_threads); std::vector input_iters(config.num_calling_threads); { std::random_device seeder; std::mt19937 engine(seeder()); TORCH_CHECK( !inputs_.empty(), "Please provide benchmark inptus." "Did you forget to call add_input()? "); std::uniform_int_distribution dist(0, inputs_.size() - 1); for (int thread_id = 0; thread_id < config.num_calling_threads; ++thread_id) { // Just in case we generate num_iters inputs for each of the threads // This was if one thread does all the work we will be fine for (int i = 0; i < config.num_iters + config.num_warmup_iters; ++i) { thread_inputs[thread_id].push_back(cloneInput(inputs_[dist(engine)])); } input_iters[thread_id] = 0; } } std::mutex m; std::condition_variable worker_main_cv; std::condition_variable main_worker_cv; // TODO: add GUARDED_BY once it is available int64_t initialized{0}; int64_t finished{0}; bool start{false}; std::atomic num_attempted_iters{0}; std::vector callers; for (auto thread_id = 0; thread_id < config.num_calling_threads; ++thread_id) { callers.emplace_back([&, thread_id]() { // We use conditional variable as a barrier to make sure each thread // performs required warmeup iterations before we start measuring for (auto j = 0; j < config.num_warmup_iters; ++j) { runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]])); ++input_iters[thread_id]; } { std::unique_lock lock(m); ++initialized; worker_main_cv.notify_one(); while (!start) { main_worker_cv.wait(lock); } } LOG(INFO) << "Starting forward thread " << thread_id; while (num_attempted_iters.fetch_add(1) < config.num_iters) { runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]])); ++input_iters[thread_id]; } { std::unique_lock lock(m); ++finished; worker_main_cv.notify_one(); LOG(INFO) << "Shutting down forward thread " << thread_id << ". Total number of finished threads: " << finished; } }); } using Clock = std::chrono::high_resolution_clock; using TimePoint = std::chrono::time_point; TimePoint start_time; { std::unique_lock lock(m); while (initialized != config.num_calling_threads) { worker_main_cv.wait(lock); } LOG(INFO) << "Starting threads"; start = true; start_time = Clock::now(); } main_worker_cv.notify_all(); { std::unique_lock lock(m); worker_main_cv.wait( lock, [&]() { return finished == config.num_calling_threads; }); } auto end_time = std::chrono::high_resolution_clock::now(); LOG(INFO) << "Finished benchmark"; BenchmarkExecutionStats stats; float total_time_ms = std::chrono::duration_cast( end_time - start_time) .count() / 1000.0 / 1000.0; // We use config.num_iters instead of num_attempted_iters as it is // repsesatative of the real work done. Last attempted iteration on each // calling threads doesn't represent the real work (i.e. running the model) stats.latency_avg_ms = total_time_ms * config.num_calling_threads / config.num_iters; stats.num_iters = config.num_iters; for (auto& t : callers) { t.join(); } return stats; } } // namespace detail } // namespace throughput_benchmark } // namespace torch