1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
| #pragma once
|
| #include <ATen/cuda/CUDAContext.h>
|
| namespace at { namespace cuda {
|
| // Check if every tensor in a list of tensors matches the current
| // device.
| inline bool check_device(ArrayRef<Tensor> ts) {
| if (ts.empty()) {
| return true;
| }
| Device curDevice = Device(kCUDA, current_device());
| for (const Tensor& t : ts) {
| if (t.device() != curDevice) return false;
| }
| return true;
| }
|
| }} // namespace at::cuda
|
|