#pragma once #include #include #include #include #include #include #include #include #include #include #include namespace caffe2 { class Event; } // namespace caffe2 namespace at { class BaseContext; /** * Virtual interface for the Context class in Caffe2. * * A Context defines all the necessities to run an operator on a specific * device. Specific Context classes needs to implement all the pure virtual * functions in the BaseContext class. * TODO: add docs after this is finalized. */ class CAFFE2_API BaseContext { public: virtual ~BaseContext() noexcept {} virtual Device device() const = 0; /* Sorry for the naming, will get rid of this in future diff */ virtual DeviceType device_type() const = 0; virtual void SwitchToDevice(int /*stream_id*/) = 0; inline void SwitchToDevice() { SwitchToDevice(0); } virtual void WaitEvent(const caffe2::Event& ev) = 0; virtual void Record(caffe2::Event* ev, const char* err_msg = nullptr) const = 0; virtual void FinishDeviceComputation() = 0; // This used to be arbitrary cross-device copy, but it turns out everyone // did direct CPU-X copy, so we just make three functions for it (to avoid // double dispatch). This will get obsoleted by C10. where copies // will be proper operators (and get to rely on multiple dispatch there.) virtual void CopyBytesSameDevice( size_t nbytes, const void* src, void* dst) = 0; virtual void CopyBytesFromCPU(size_t nbytes, const void* src, void* dst) = 0; virtual void CopyBytesToCPU(size_t nbytes, const void* src, void* dst) = 0; template inline void CopySameDevice(size_t n, const T* src, T* dst) { static_assert( std::is_fundamental::value, "CopySameDevice requires fundamental types"); CopyBytesSameDevice( n * sizeof(T), static_cast(src), static_cast(dst)); } template inline void CopyFromCPU(size_t n, const T* src, T* dst) { static_assert( std::is_fundamental::value, "CopyFromCPU requires fundamental types"); CopyBytesFromCPU( n * sizeof(T), static_cast(src), static_cast(dst)); } template inline void CopyToCPU(size_t n, const T* src, T* dst) { static_assert( std::is_fundamental::value, "CopyToCPU requires fundamental types"); CopyBytesToCPU( n * sizeof(T), static_cast(src), static_cast(dst)); } virtual bool SupportsNonFundamentalTypes() const { return false; } inline void EnforceMetaCopyOK() { AT_ASSERTM( SupportsNonFundamentalTypes(), "Context requires fundamental types"); } void CopyItemsSameDevice( const caffe2::TypeMeta& meta, size_t n, const void* src, void* dst) { if (meta.copy()) { EnforceMetaCopyOK(); meta.copy()(src, dst, n); } else { CopyBytesSameDevice(n * meta.itemsize(), src, dst); } } void CopyItemsFromCPU( const caffe2::TypeMeta& meta, size_t n, const void* src, void* dst) { if (meta.copy()) { EnforceMetaCopyOK(); meta.copy()(src, dst, n); } else { CopyBytesFromCPU(n * meta.itemsize(), src, dst); } } void CopyItemsToCPU( const caffe2::TypeMeta& meta, size_t n, const void* src, void* dst) { if (meta.copy()) { EnforceMetaCopyOK(); meta.copy()(src, dst, n); } else { CopyBytesToCPU(n * meta.itemsize(), src, dst); } } }; // Context constructor registry C10_DECLARE_TYPED_REGISTRY( ContextRegistry, at::DeviceType, at::BaseContext, std::unique_ptr, at::Device); #define REGISTER_CONTEXT(type, ...) \ C10_REGISTER_TYPED_CLASS(ContextRegistry, type, __VA_ARGS__) inline std::unique_ptr CreateContext( const at::Device& device) { return at::ContextRegistry()->Create(device.type(), device); } } // namespace at namespace caffe2 { using at::BaseContext; using at::CreateContext; } // namespace caffe2