#pragma once #include #include #include #include #include #include #include #include namespace torch { namespace jit { using ::c10::TensorTypePtr; struct ProfilingRecord { // N.B. ProfilingRecord's copy and move c-tor are disabled, so we won't // end up accidentally copying or moving ProfilingRecords whose addresses // are captured in callbacks_ ProfilingRecord(const ProfilingRecord&) = delete; ProfilingRecord(ProfilingRecord&&) noexcept = delete; static TensorTypePtr toTensorTypePtr(const IValue& ival); TORCH_API static std::unique_ptr instrumentGraph( const std::shared_ptr& graph); std::shared_ptr profiled_graph_; std::mutex mutex_; size_t profiling_count_; bool ready() const { return profiling_count_ == 0; } std::shared_ptr graph() const { return profiled_graph_; } private: ProfileOp* createProfileNode( const std::function& fp, at::ArrayRef inputs); void instrumentBlock(Block* block); ProfilingRecord(std::shared_ptr g); }; } // namespace jit } // namespace torch