#pragma once #include #include #include #include namespace at { struct CAFFE2_API CPUGenerator : public Generator { // Constructors CPUGenerator(uint64_t seed_in = default_rng_seed_val); ~CPUGenerator() = default; // CPUGenerator methods std::shared_ptr clone() const; void set_current_seed(uint64_t seed) override; uint64_t current_seed() const override; uint64_t seed() override; static DeviceType device_type(); uint32_t random(); uint64_t random64(); c10::optional next_float_normal_sample(); c10::optional next_double_normal_sample(); void set_next_float_normal_sample(c10::optional randn); void set_next_double_normal_sample(c10::optional randn); at::mt19937 engine(); void set_engine(at::mt19937 engine); private: CPUGenerator* clone_impl() const override; at::mt19937 engine_; c10::optional next_float_normal_sample_; c10::optional next_double_normal_sample_; }; namespace detail { CAFFE2_API CPUGenerator* getDefaultCPUGenerator(); CAFFE2_API std::shared_ptr createCPUGenerator(uint64_t seed_val = default_rng_seed_val); } // namespace detail }