reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
#pragma once
 
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/TensorUtils.h>
#include <THC/THCAtomics.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <c10/macros/Macros.h>
#include <ATen/LegacyTHFunctionsCUDA.h>
 
#include <math.h>
 
//
// This file contains pointwise operation functions and kernels that
// work on both contiguous and non-contiguous tensor arguments of
// arbitrary (up to MAX_CUTORCH_DIMS) dimensioned arguments without
// copying or temporary storage.
//
 
/*
  NOTE [ CUDA_tensor_applyN helpers ]
 
  The following CUDA_tensor_applyN (where N currently can be 1, 2, 3, or 4)
  functions apply a pointwise operator to N tensor(s).
 
  The calling convention is
 
  1. The template arguments should be, sequentially,
    - First N typename args specify the scalar types of each of the N tensors.
    - (Optional) `int step` arg specifies the number of elements processed
      together at the same time.
      Default is 1.
    - A usually omitted (i.e., inferred) typename arg specifies the type of the
      function/functor applied on `N * step` values  in each iteration of each
      CUDA thread.
  2. The arguments should be, sequentially,
    - N tensors
    - op: a function/functor that processes `N * step` values at the same time.
      - If `step == 1`, it must have signature
        `void(*)(scalar1_t&, scalar2_t&, ..., scalarN_t&)`, where
        `scalar*_t`s are the first N typename template args, and the inputs
        are the `N` values from the `N` tensors retrieved at a common index.
      - Otherwise, it must must have signature
          void(*)(int n, scalar1_t&, scalar1_t&, ..., scalar1_t&,  // repeat `step` times
                         scalar2_t&, scalar2_t&, ..., scalar2_t&,  // repeat `step` times
                         ...,
                         scalarN_t&, scalarN_t&, ..., scalarN_t&)  // repeat `step` times
        Different from `step == 1` case, it processes `N * step` values taken
        from `step` common indices. Moreover, the first input `n` represents the
        number of valid indices (it will always have `0 < n <= step`). It will
        almost always be `step`, but at the boundary we may not have full `step`
        elements and `n` can be a lesser value.
 
        E.g., if `step == 4` and `N == 2`, `op` could be
 
          [](int n, scalar1_t &u1, scalar1_t &u2, scalar1_t &u3, scalar1_t &u4,
                    scalar2_t &v1, scalar2_t &v2, scalar2_t &v3, scalar2_t &v4) {
            // Only process u1, ..., un and v1, ..., vn.
            // So if `n == 3`, `u4` and `v4` need not to be considered.
          }
 
      In both cases, the references can actually be const, but at least one of
      them should be non-const in order to write the output.
    - (Optional, but recommended) N TensorArgType args that specify for each
      tensor whether `op` reads AND writes ] (i.e., TensorArgType::ReadWrite),
      or only reads (i.e., TensorArgType::ReadOnly).
      Default is TensorArgType::ReadWrite for first Tensor, and
                 TensorArgType::ReadOnly  for the rest.
 
  E.g.,
 
  to compute a = b^2 for a and b of same dtype, we can call
 
  CUDA_tensor_apply2<scalar, scalar>(
    a, b,
    [] __device__ (scalar &a_val, const scalar &b_val) { a_val = b_val * b_val; }
  );
 
  to work on 2 values at the same time, we can call
 
  CUDA_tensor_apply2<scalar1, scalar2, 2>(
    a, b,
    [] __device__ (int n, scalar1 &a_val1, scalar1 &a_val2,
                          const scalar2 &b_val1, const scalar2 &b_val2) {
      // call special vectorized op here, or just do elementwise and enjoy unrolling...
      // if n == 1, only process a_val1 and b_val1
    }
  );
*/
 
namespace at {
namespace cuda {
 
// TODO: combine with TensorArg?  So far that's been for debugging, and this is functional...
enum class TensorArgType { ReadWrite, ReadOnly };
 
namespace {
 
// Rearrange dimensions for pointwise operations so that strides are in
// decreasing order as much as possible, so that kernels have better memory
// access patterns.
//
// For example, consider a binary operation on two "transposed" 2-dim tensors:
//    sizes:          256 512
//    aInfo->strides:   1 256
//    bInfo->strides:   1 256
//
// Given this, each concurrent memory access inside kernelPointwiseApply2() is
// exactly 256 elements apart, resulting in poor performance.
//
// This function exchanges dimensions so that memory access is contiguous:
//    sizes:          512 256
//    aInfo->strides: 256   1
//    bInfo->strides: 256   1
//
// (Actually, it becomes even better because now collapseDims() can turn each
// input into one contiguous array.)
//
// In general, given M (<=4) TensorInfo's with N dimensions, we can view each
// strides[i] (0 <= i < N) as an M-tuple.  Given each pair i < j, we exchange
// strides[i] and [j] if
//    (1) strides[i][k] < strides[j][k] for some k (0 <= k < M)
//        (exchanging them will benefit input #k), and
//    (2) strides[i][k] <= strieds[j][k] for all k
//        (exchanging them will not make any input worse).
template <typename T1, typename IndexType,
          typename T2 = void, typename T3 = void, typename T4 = void>
inline void rearrangeDims(detail::TensorInfo<T1, IndexType>* aInfo,
                          detail::TensorInfo<T2, IndexType>* bInfo = nullptr,
                          detail::TensorInfo<T3, IndexType>* cInfo = nullptr,
                          detail::TensorInfo<T4, IndexType>* dInfo = nullptr) {
  int numInfos = 1;
  int dims = aInfo->dims;
  IndexType *sizes[4] = { aInfo->sizes, };
  IndexType *strides[4] = { aInfo->strides, };
 
  if (bInfo != nullptr) {
    ++numInfos;
    if (bInfo->dims != dims) return;
    sizes[1] = bInfo->sizes;
    strides[1] = bInfo->strides;
  }
 
  if (cInfo != nullptr) {
    ++numInfos;
    if (cInfo->dims != dims) return;
    sizes[2] = cInfo->sizes;
    strides[2] = cInfo->strides;
  }
 
  if (dInfo != nullptr) {
    ++numInfos;
    if (dInfo->dims != dims) return;
    sizes[3] = dInfo->sizes;
    strides[3] = dInfo->strides;
  }
 
  // Bail out if sizes do not match: we are using "deprecated pointwise
  // behavior" among tensors of different shapes but same number of elements.
  for (int i = 1; i < numInfos; ++i) {
    for (int j = 0; j < dims; ++j) {
      if (sizes[i][j] != sizes[0][j]) return;
    }
  }
 
  for (int i = 0; i < dims - 1; ++i) {
    // No need to consider dimensions of size 1.
    if (sizes[0][i] == 1) continue;
 
    for (int j = i + 1; j < dims; ++j) {
      if (sizes[0][j] == 1) continue;
 
      // Compare the relative sizes of strides between dim #i and dim #j.
      bool hasIncreasingStrides = false;
      bool hasDecreasingStrides = false;
 
      for (int k = 0; k < numInfos; k++) {
        IndexType stride_i = strides[k][i];
        IndexType stride_j = strides[k][j];
        if (stride_i < stride_j) {
          hasIncreasingStrides = true;
        } else if (stride_i > stride_j) {
          hasDecreasingStrides = true;
        }
      }
 
      if (hasIncreasingStrides && !hasDecreasingStrides) {
        for (int k = 0; k < numInfos; k++) {
          IndexType size = sizes[k][i];
          sizes[k][i] = sizes[k][j];
          sizes[k][j] = size;
 
          IndexType stride = strides[k][i];
          strides[k][i] = strides[k][j];
          strides[k][j] = stride;
        }
      }
    }
  }
}
 
// Threads per block for our apply kernel
// FIXME: use occupancy calculator instead
constexpr uint32_t AT_APPLY_THREADS_PER_BLOCK = 512;
constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4;
 
// The `remaining_steps` argument is used to support Op that operates on
// multiple elements at the same time. Generally, the strategy of ApplyOpN is to
//  1. Initialize `remaining_steps = step`, where `step` is the template arg of
//     CUDA_tensor_applyN helpers. The input arg `n` to `apply()` represents the
//     number of elements in bound for this call. It will almost always equal to
//     `step` except at boundaries.
//  2. If `remaining_steps > 0` convert the current linearIndex to offset (if in
//     bound), and recursively call `ApplyOpN` with `remaining_steps - 1`.
//  3. At `remaining_steps = 0`,
//       if `step = 1`, call `op(tensor1_val, tensor2_val, ...)`;
//       if `step > 1`, call `op(n, tensor1_val1, tensor1_val2, ..., tesor1_valstep,
//                                  tensor2_val1, tensor2_val2, ..., tesor2_valstep,
//                                       ...
//                                  tensorN_val1, tensorN_val2, ..., tesorN_valstep);`
//
// See NOTE [ CUDA_tensor_applyN helpers ] above for how Op may look like.
 
template <typename Op,
          typename scalar,
          typename IndexType,
          int ADims,
          int remaining_steps,
          typename... Offsets>
struct ApplyOp1 {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op, int n,
                  IndexType linearIndex, Offsets... aOffsets) {
  // Convert `linearIndex` into an offset of `a`
  const IndexType aOffset = sizeof...(Offsets) < n ?
    detail::IndexToOffset<scalar, IndexType, ADims>::get(linearIndex, a) : 0;
 
  ApplyOp1<Op, scalar, IndexType, ADims, remaining_steps - 1, const IndexType, Offsets...>::apply(
    a, op, n, linearIndex + 1, aOffsets..., aOffset
  );
}
};
 
// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
// We don't need to pass in how many elements need to processed in this case.
template <typename Op,
          typename scalar,
          typename IndexType,
          int ADims,
          typename Offset>
struct ApplyOp1<Op, scalar, IndexType, ADims, 0, Offset> {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op,
                  int n, IndexType linearIndex, Offset offset) {
  op(a.data[offset]);
}
};
 
template <typename Op,
          typename scalar,
          typename IndexType,
          int ADims,
          typename... Offsets>
struct ApplyOp1<Op, scalar, IndexType, ADims, 0, Offsets...> {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op, int n,
                 IndexType linearIndex, Offsets... offsets) {
  op(n, a.data[offsets]...);
}
};
 
template <typename Op,
          typename scalar,
          typename IndexType,
          int ADims,
          int step>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
#endif
__global__ void kernelPointwiseApply1(detail::TensorInfo<scalar, IndexType> a,
                                      IndexType totalElements, const Op op) {
  for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
       linearIndex < totalElements;
       linearIndex += gridDim.x * blockDim.x * step) {
    ApplyOp1<Op, scalar, IndexType, ADims, step>::apply(
      a, op, ::min(step, static_cast<int>(totalElements - linearIndex)), linearIndex);
  }
}
 
 
template <typename Op,
          typename scalar1,
          typename scalar2,
          typename IndexType,
          int ADims,
          int BDims,
          int remaining_steps,
          typename... Offsets>
struct ApplyOp2 {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
                  detail::TensorInfo<scalar2, IndexType> &b,
                  const Op &op, int n, IndexType linearIndex,
                  Offsets... aOffsets, Offsets... bOffsets) {
  // Convert `linearIndex` into an offset of `a`
  const IndexType aOffset = sizeof...(Offsets) < n ?
    detail::IndexToOffset<scalar1, IndexType, ADims>::get(linearIndex, a) : 0;
 
  // Convert `linearIndex` into an offset of `b`
  const IndexType bOffset = sizeof...(Offsets) < n ?
    detail::IndexToOffset<scalar2, IndexType, BDims>::get(linearIndex, b) : 0;
 
  ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, remaining_steps - 1, const IndexType, Offsets...>::apply(
    a, b, op, n, linearIndex + 1, aOffsets..., aOffset, bOffsets..., bOffset
  );
}
};
 
// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
// We don't need to pass in how many elements need to processed in this case.
template <typename Op,
          typename scalar1,
          typename scalar2,
          typename IndexType,
          int ADims,
          int BDims,
          typename Offset>
struct ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, 0, Offset> {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
                  detail::TensorInfo<scalar2, IndexType> &b,
                  const Op &op, int n, IndexType linearIndex,
                  Offset aOffset, Offset bOffset) {
  op(a.data[aOffset], b.data[bOffset]);
}
};
 
template <typename Op,
          typename scalar1,
          typename scalar2,
          typename IndexType,
          int ADims,
          int BDims,
          typename... Offsets>
struct ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, 0, Offsets...> {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
                  detail::TensorInfo<scalar2, IndexType> &b,
                  const Op &op, int n, IndexType linearIndex,
                  Offsets... aOffsets, Offsets... bOffsets) {
  op(n, a.data[aOffsets]..., b.data[bOffsets]...);
}
};
 
template <typename Op,
          typename scalar1,
          typename scalar2,
          typename IndexType,
          int ADims, int BDims,
          int step>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
#endif
__global__ void
kernelPointwiseApply2(detail::TensorInfo<scalar1, IndexType> a,
                      detail::TensorInfo<scalar2, IndexType> b,
                      IndexType totalElements,
                      const Op op) {
  for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
       linearIndex < totalElements;
       linearIndex += gridDim.x * blockDim.x * step) {
    ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, step>::apply(
      a, b, op, ::min(step, static_cast<int>(totalElements - linearIndex)),
      linearIndex);
  }
}
 
 
template <typename Op,
          typename scalar1,
          typename scalar2,
          typename scalar3,
          typename IndexType,
          int ADims,
          int BDims,
          int CDims,
          int remaining_steps,
          typename... Offsets>
struct ApplyOp3 {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
                  detail::TensorInfo<scalar2, IndexType> &b,
                  detail::TensorInfo<scalar3, IndexType> &c,
                  const Op &op, int n, IndexType linearIndex,
                  Offsets... aOffsets, Offsets... bOffsets,
                  Offsets... cOffsets) {
  // Convert `linearIndex` into an offset of `a`
  const IndexType aOffset = sizeof...(Offsets) < n ?
    detail::IndexToOffset<scalar1, IndexType, ADims>::get(linearIndex, a) : 0;
 
  // Convert `linearIndex` into an offset of `b`
  const IndexType bOffset = sizeof...(Offsets) < n ?
    detail::IndexToOffset<scalar2, IndexType, BDims>::get(linearIndex, b) : 0;
 
  // Convert `linearIndex` into an offset of `c`
  const IndexType cOffset = sizeof...(Offsets) < n ?
    detail::IndexToOffset<scalar3, IndexType, CDims>::get(linearIndex, c) : 0;
 
  ApplyOp3<Op, scalar1, scalar2, scalar3, IndexType, ADims, BDims, CDims,
           remaining_steps - 1, const IndexType, Offsets...>::apply(
    a, b, c, op, n, linearIndex + 1, aOffsets..., aOffset, bOffsets..., bOffset,
    cOffsets..., cOffset
  );
}
};
 
// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
// We don't need to pass in how many elements need to processed in this case.
template <typename Op,
          typename scalar1,
          typename scalar2,
          typename scalar3,
          typename IndexType,
          int ADims,
          int BDims,
          int CDims,
          typename Offset>
struct ApplyOp3<Op, scalar1, scalar2, scalar3, IndexType,
                ADims, BDims, CDims, 0, Offset> {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
                  detail::TensorInfo<scalar2, IndexType> &b,
                  detail::TensorInfo<scalar3, IndexType> &c,
                  const Op &op, int n, IndexType linearIndex,
                  Offset aOffset, Offset bOffset, Offset cOffset) {
  op(a.data[aOffset], b.data[bOffset], c.data[cOffset]);
}
};
 
template <typename Op,
          typename scalar1,
          typename scalar2,
          typename scalar3,
          typename IndexType,
          int ADims,
          int BDims,
          int CDims,
          typename... Offsets>
struct ApplyOp3<Op, scalar1, scalar2, scalar3, IndexType,
                ADims, BDims, CDims, 0, Offsets...> {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
                  detail::TensorInfo<scalar2, IndexType> &b,
                  detail::TensorInfo<scalar3, IndexType> &c,
                  const Op &op, int n, IndexType linearIndex,
                  Offsets... aOffsets, Offsets... bOffsets,
                  Offsets... cOffsets) {
  op(n, a.data[aOffsets]..., b.data[bOffsets]..., c.data[cOffsets]...);
}
};
 
 
template <typename Op,
          typename scalar1,
          typename scalar2,
          typename scalar3,
          typename IndexType,
          int ADims, int BDims, int CDims,
          int step>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
#endif
__global__ void
kernelPointwiseApply3(detail::TensorInfo<scalar1, IndexType> a,
                      detail::TensorInfo<scalar2, IndexType> b,
                      detail::TensorInfo<scalar3, IndexType> c,
                      IndexType totalElements,
                      const Op op) {
  for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
       linearIndex < totalElements;
       linearIndex += gridDim.x * blockDim.x * step) {
    ApplyOp3<Op, scalar1, scalar2, scalar3, IndexType, ADims, BDims, CDims, step>::apply(
      a, b, c, op, ::min(step, static_cast<int>(totalElements - linearIndex)), linearIndex);
  }
}
 
 
template <typename Op,
          typename scalar1,
          typename scalar2,
          typename scalar3,
          typename scalar4,
          typename IndexType,
          int ADims,
          int BDims,
          int CDims,
          int DDims,
          int remaining_steps,
          typename... Offsets>
struct ApplyOp4 {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
                  detail::TensorInfo<scalar2, IndexType> &b,
                  detail::TensorInfo<scalar3, IndexType> &c,
                  detail::TensorInfo<scalar4, IndexType> &d,
                  const Op &op, int n, IndexType linearIndex,
                  Offsets... aOffsets, Offsets... bOffsets,
                  Offsets... cOffsets, Offsets... dOffsets) {
  // Convert `linearIndex` into an offset of `a`
  const IndexType aOffset = sizeof...(Offsets) < n ?
    detail::IndexToOffset<scalar1, IndexType, ADims>::get(linearIndex, a) : 0;
 
  // Convert `linearIndex` into an offset of `b`
  const IndexType bOffset = sizeof...(Offsets) < n ?
    detail::IndexToOffset<scalar2, IndexType, BDims>::get(linearIndex, b) : 0;
 
  // Convert `linearIndex` into an offset of `c`
  const IndexType cOffset = sizeof...(Offsets) < n ?
    detail::IndexToOffset<scalar3, IndexType, CDims>::get(linearIndex, c) : 0;
 
  // Convert `linearIndex` into an offset of `d`
  const IndexType dOffset = sizeof...(Offsets) < n ?
    detail::IndexToOffset<scalar4, IndexType, DDims>::get(linearIndex, d) : 0;
 
  ApplyOp4<Op, scalar1, scalar2, scalar3, scalar4, IndexType,
           ADims, BDims, CDims, DDims, remaining_steps - 1, const IndexType, Offsets...>::apply(
    a, b, c, d, op, n, linearIndex + 1, aOffsets..., aOffset, bOffsets..., bOffset,
    cOffsets..., cOffset, dOffsets..., dOffset
  );
}
};
 
// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
// We don't need to pass in how many elements need to processed in this case.
template <typename Op,
          typename scalar1,
          typename scalar2,
          typename scalar3,
          typename scalar4,
          typename IndexType,
          int ADims,
          int BDims,
          int CDims,
          int DDims,
          typename Offset>
struct ApplyOp4<Op, scalar1, scalar2, scalar3, scalar4, IndexType,
                ADims, BDims, CDims, DDims, 0, Offset> {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
                  detail::TensorInfo<scalar2, IndexType> &b,
                  detail::TensorInfo<scalar3, IndexType> &c,
                  detail::TensorInfo<scalar4, IndexType> &d,
                  const Op &op, int n, IndexType linearIndex,
                  Offset aOffset, Offset bOffset,
                  Offset cOffset, Offset dOffset) {
  op(a.data[aOffset], b.data[bOffset], c.data[cOffset], d.data[dOffset]);
}
};
 
template <typename Op,
          typename scalar1,
          typename scalar2,
          typename scalar3,
          typename scalar4,
          typename IndexType,
          int ADims,
          int BDims,
          int CDims,
          int DDims,
          typename... Offsets>
struct ApplyOp4<Op, scalar1, scalar2, scalar3, scalar4, IndexType,
                ADims, BDims, CDims, DDims, 0, Offsets...> {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
                  detail::TensorInfo<scalar2, IndexType> &b,
                  detail::TensorInfo<scalar3, IndexType> &c,
                  detail::TensorInfo<scalar4, IndexType> &d,
                  const Op &op, int n, IndexType linearIndex,
                  Offsets... aOffsets, Offsets... bOffsets,
                  Offsets... cOffsets, Offsets... dOffsets) {
  op(n, a.data[aOffsets]..., b.data[bOffsets]..., c.data[cOffsets]..., d.data[dOffsets]...);
}
};
 
template <typename Op,
          typename scalar1,
          typename scalar2,
          typename scalar3,
          typename scalar4,
          typename IndexType,
          int ADims, int BDims, int CDims, int DDims,
          int step>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
#endif
__global__ void
kernelPointwiseApply4(detail::TensorInfo<scalar1, IndexType> a,
                      detail::TensorInfo<scalar2, IndexType> b,
                      detail::TensorInfo<scalar3, IndexType> c,
                      detail::TensorInfo<scalar4, IndexType> d,
                      IndexType totalElements,
                      const Op op) {
  for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
       linearIndex < totalElements;
       linearIndex += gridDim.x * blockDim.x * step) {
    ApplyOp4<Op, scalar1, scalar2, scalar3, scalar4, IndexType,
             ADims, BDims, CDims, DDims, step>::apply(
      a, b, c, d, op, ::min(step, static_cast<int>(totalElements - linearIndex)), linearIndex);
  }
}
 
} // namespace
 
/**
   Computes ceil(a / b)
*/
template <typename T>
__host__ __device__ __forceinline__ T ATenCeilDiv(T a, T b) {
  return (a + b - 1) / b;
}
 
template <int step = 1>
inline bool getApplyGrid(uint64_t totalElements, dim3& grid, int64_t curDevice) {
  if (curDevice == -1) return false;
  uint64_t numel_per_thread = static_cast<uint64_t>(AT_APPLY_THREADS_PER_BLOCK) * static_cast<uint64_t>(step);
  uint64_t numBlocks = ATenCeilDiv(totalElements, numel_per_thread);
  uint64_t maxGridX = at::cuda::getDeviceProperties(curDevice)->maxGridSize[0];
  if (numBlocks > maxGridX)
      numBlocks = maxGridX;
  grid = dim3(numBlocks);
  return true;
}
 
inline dim3 getApplyBlock() {
  return dim3(AT_APPLY_THREADS_PER_BLOCK);
}
 
 
template <typename scalar, int step, typename Op>
inline bool CUDA_tensor_apply1(at::Tensor a,
                               const Op op,
                               TensorArgType aType = TensorArgType::ReadWrite) {
  checkBackend("CUDA_tensor_apply1", {a}, Backend::CUDA);
  auto dim = a.dim();
 
  /*
  Since this is a unary op, we can easily first check for expanded dimensions
  (with stride 0), and remove them, to avoid calling .contiguous() in such
  case when detail::maybeOverlappingIndices(a) returns true.
  */
  std::vector<int64_t> collapsed_shape;
  std::vector<int64_t> collapsed_strides;
  collapsed_shape.reserve(dim);
  collapsed_strides.reserve(dim);
  for (int64_t i = 0; i < dim; i++) {
    if (a.stride(i) != 0) {
      collapsed_shape.push_back(a.size(i));
      collapsed_strides.push_back(a.stride(i));
    }
  }
  if (collapsed_shape.size() != dim) {
    a = a.as_strided(collapsed_shape, collapsed_strides);
  }
 
  int64_t totalElements = a.numel();
 
  if (dim > MAX_TENSORINFO_DIMS) {
    return false;
  }
 
  if (totalElements == 0) {
    // Empty tensor; do nothing
    return true;
  }
  const dim3 block = getApplyBlock();
 
  dim3 grid;
  int64_t curDevice = current_device();
  if (curDevice == -1) return false;
  if (!getApplyGrid<step>(totalElements, grid, curDevice)) {
    return false;
  }
 
  /*
  Expands readable/writable tensors whose indices may be "overlapped."
  This ensures that each element of the tensor is operated on once and only
  once.
  */
  Tensor oldA;
 
  if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) {
    // Must perform in contiguous space
    oldA = a;
    a = a.contiguous();
  }
 
  // It is possible that the tensor dimensions are able to be collapsed,
  // and thus we can reduce the actual code complexity of the copy by
  // exploiting this knowledge statically, since the div/mod is the
  // most expensive part of the operation, more so than memory accesses.
  // For instance, when copying a non-contiguous to a contiguous tensor
  // (or vice versa), the contiguous tensor can be collapsed to one
  // dimension, and the loop to translate the linear index to the array
  // index can be similarly collapsed. That is what this unrolling is for.
 
#define HANDLE_CASE(TYPE, A)                                           \
  kernelPointwiseApply1<Op,                                            \
                        scalar,                                        \
                        TYPE, A, step>                                 \
   <<<grid, block, 0, at::cuda::getCurrentCUDAStream(curDevice)>>>(    \
       aInfo, static_cast<TYPE>(totalElements), op);
 
#define HANDLE_A_CASE(TYPE, A) {            \
  switch (A) {                              \
    case 1:                                 \
      HANDLE_CASE(TYPE, 1);                 \
      break;                                \
    case 2:                                 \
      HANDLE_CASE(TYPE, 2);                 \
      break;                                \
    default:                                \
      HANDLE_CASE(TYPE, -1);                \
      break;                                \
  }                                         \
}
 
  if (detail::canUse32BitIndexMath(a)) {
    detail::TensorInfo<scalar, unsigned int> aInfo =
      detail::getTensorInfo<scalar, unsigned int>(a);
 
    rearrangeDims(&aInfo);
    aInfo.collapseDims();
 
    HANDLE_A_CASE(unsigned int, aInfo.dims);
  } else {
    detail::TensorInfo<scalar, uint64_t> aInfo =
      detail::getTensorInfo<scalar, uint64_t>(a);
 
    rearrangeDims(&aInfo);
    aInfo.collapseDims();
 
    /*
    Only instantiates the all 1D special case and the fallback all nD case for
    large (64-bit indexed) tensors to reduce compilation time.
    */
    if (aInfo.dims == 1) {
      HANDLE_CASE(uint64_t, 1);
    } else {
      HANDLE_CASE(uint64_t, -1);
    }
  }
#undef HANDLE_CASE
#undef HANDLE_A_CASE
 
  if (oldA.defined()) {
    // Ignore overlaps when copying back; if we use copy
    // instead, it will recursively try and invoke ourselves to make
    // oldA contiguous.
    at::native::legacy::cuda::_th_copy_ignoring_overlaps_(oldA, a);
  }
 
  return true;
}
 
/* Provides default step = 1 to CUDA_tensor_apply1. */
template <typename scalar, typename Op>
inline bool CUDA_tensor_apply1(at::Tensor a,
                               const Op op,
                               TensorArgType aType = TensorArgType::ReadWrite) {
  return CUDA_tensor_apply1<scalar, 1, Op>(a, op, aType);
}
 
 
template <typename scalar1, typename scalar2, int step, typename Op>
inline bool CUDA_tensor_apply2(at::Tensor a,
                               at::Tensor b,
                               const Op op,
                               TensorArgType aType = TensorArgType::ReadWrite,
                               TensorArgType bType = TensorArgType::ReadOnly) {
  checkBackend("CUDA_tensor_apply2", {a, b}, Backend::CUDA);
  int64_t totalElements = a.numel();
 
  if (totalElements != b.numel()) {
    return false;
  }
 
  if (a.dim() > MAX_TENSORINFO_DIMS ||
      b.dim() > MAX_TENSORINFO_DIMS) {
    return false;
  }
 
  if (a.numel() == 0) {
    // Empty tensor; do nothing
    return true;
  }
  const dim3 block = getApplyBlock();
 
  dim3 grid;
  int64_t curDevice = current_device();
  if (curDevice == -1) return false;
  if (!getApplyGrid<step>(totalElements, grid, curDevice)) {
    return false;
  }
 
  /*
  Expands readable/writable tensors whose indices may be "overlapped."
  This ensures that each element of the tensor is operated on once and only
  once.
  */
  Tensor oldA;
  Tensor oldB;
 
  if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) {
    // Must perform in contiguous space
    oldA = a;
    a = a.contiguous();
  }
  if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) {
    // Must perform in contiguous space
    oldB = b;
    b = b.contiguous();
  }
 
  // It is possible that the tensor dimensions are able to be collapsed,
  // and thus we can reduce the actual code complexity of the copy by
  // exploiting this knowledge statically, since the div/mod is the
  // most expensive part of the operation, more so than memory accesses.
  // For instance, when copying a non-contiguous to a contiguous tensor
  // (or vice versa), the contiguous tensor can be collapsed to one
  // dimension, and the loop to translate the linear index to the array
  // index can be similarly collapsed. That is what this unrolling is for.
 
#define HANDLE_CASE(TYPE, A, B)                                        \
  kernelPointwiseApply2<Op,                                            \
                        scalar1,                                       \
                        scalar2,                                       \
                        TYPE, A, B, step>                              \
   <<<grid, block, 0, at::cuda::getCurrentCUDAStream(curDevice)>>>(    \
       aInfo, bInfo, static_cast<TYPE>(totalElements), op);
 
#define HANDLE_B_CASE(TYPE, A, B) {         \
  switch (B) {                              \
    case 1:                                 \
      HANDLE_CASE(TYPE, A, 1);              \
      break;                                \
    case 2:                                 \
      HANDLE_CASE(TYPE, A, 2);              \
      break;                                \
    default:                                \
      HANDLE_CASE(TYPE, A, -1);             \
      break;                                \
  }                                         \
}
 
#define HANDLE_A_CASE(TYPE, A, B) {         \
  switch (A) {                              \
    case 1:                                 \
      HANDLE_B_CASE(TYPE, 1, B);            \
      break;                                \
    case 2:                                 \
      HANDLE_B_CASE(TYPE, 2, B);            \
      break;                                \
    default:                                \
      HANDLE_B_CASE(TYPE, -1, B);           \
      break;                                \
  }                                         \
}
 
  if (detail::canUse32BitIndexMath(a) &&
      detail::canUse32BitIndexMath(b)) {
    detail::TensorInfo<scalar1, unsigned int> aInfo =
      detail::getTensorInfo<scalar1, unsigned int>(a);
 
    detail::TensorInfo<scalar2, unsigned int> bInfo =
      detail::getTensorInfo<scalar2, unsigned int>(b);
    rearrangeDims(&aInfo, &bInfo);
    aInfo.collapseDims();
    bInfo.collapseDims();
 
    HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims);
  } else {
    detail::TensorInfo<scalar1, uint64_t> aInfo =
      detail::getTensorInfo<scalar1, uint64_t>(a);
 
    detail::TensorInfo<scalar2, uint64_t> bInfo =
      detail::getTensorInfo<scalar2, uint64_t>(b);
    rearrangeDims(&aInfo, &bInfo);
    aInfo.collapseDims();
    bInfo.collapseDims();
 
    /*
    Only instantiates the all 1D special case and the fallback all nD case for
    large (64-bit indexed) tensors to reduce compilation time.
    */
    if (aInfo.dims == 1 && bInfo.dims == 1) {
      HANDLE_CASE(uint64_t, 1, 1);
    } else {
      HANDLE_CASE(uint64_t, -1, -1);
    }
  }
#undef HANDLE_CASE
#undef HANDLE_B_CASE
#undef HANDLE_A_CASE
 
  if (oldA.defined()) {
    // Ignore overlaps when copying back; if we use copy
    // instead, it will recursively try and invoke ourselves to make
    // oldA contiguous.
    at::native::legacy::cuda::_th_copy_ignoring_overlaps_(oldA, a);
  }
 
  if (oldB.defined()) {
    // Ignore overlaps when copying back; if we use copy
    // instead, it will recursively try and invoke ourselves to make
    // oldB contiguous.
    at::native::legacy::cuda::_th_copy_ignoring_overlaps_(oldB, b);
  }
 
  return true;
}
 
/* Provides default step = 1 to CUDA_tensor_apply2. */
template <typename scalar1, typename scalar2, typename Op>
inline bool CUDA_tensor_apply2(at::Tensor a,
                               at::Tensor b,
                               const Op op,
                               TensorArgType aType = TensorArgType::ReadWrite,
                               TensorArgType bType = TensorArgType::ReadOnly) {
  return CUDA_tensor_apply2<scalar1, scalar2, 1, Op>(a, b, op, aType, bType);
}
 
 
template <typename scalar1, typename scalar2, typename scalar3, int step, typename Op>
inline bool CUDA_tensor_apply3(at::Tensor a,
                               at::Tensor b,
                               at::Tensor c,
                               const Op op,
                               TensorArgType aType = TensorArgType::ReadWrite,
                               TensorArgType bType = TensorArgType::ReadOnly,
                               TensorArgType cType = TensorArgType::ReadOnly) {
  checkBackend("CUDA_tensor_apply3", {a, b, c}, Backend::CUDA);
  int64_t totalElements = a.numel();
 
  if (totalElements != b.numel() ||
      totalElements != c.numel()) {
    return false;
  }
 
  if (a.dim() > MAX_TENSORINFO_DIMS ||
      b.dim() > MAX_TENSORINFO_DIMS ||
      c.dim() > MAX_TENSORINFO_DIMS) {
    return false;
  }
 
  if (a.numel() == 0) {
    // Empty tensor; do nothing
    return true;
  }
 
  const dim3 block = getApplyBlock();
 
  dim3 grid;
  int64_t curDevice = current_device();
  if (curDevice == -1) return false;
  if (!getApplyGrid<step>(totalElements, grid, curDevice)) {
    return false;
  }
 
  /*
  Expands readable/writable tensors whose indices may be "overlapped."
  This ensures that each element of the tensor is operated on once and only
  once.
  */
  Tensor oldA;
  Tensor oldB;
  Tensor oldC;
 
  if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) {
    // Must perform in contiguous space
    oldA = a;
    a = a.contiguous();
  }
  if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) {
    // Must perform in contiguous space
    oldB = b;
    b = b.contiguous();
  }
  if (cType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(c)) {
    // Must perform in contiguous space
    oldC = c;
    c = c.contiguous();
  }
 
#define HANDLE_CASE(TYPE, A, B, C)                                     \
  kernelPointwiseApply3<Op,                                            \
                        scalar1,                                       \
                        scalar2,                                       \
                        scalar3,                                       \
                        TYPE, A, B, C, step>                           \
    <<<grid, block, 0, at::cuda::getCurrentCUDAStream(curDevice)>>>(   \
      aInfo, bInfo, cInfo, static_cast<TYPE>(totalElements), op);
 
#define HANDLE_C_CASE(TYPE, A, B, C) {      \
  switch (C) {                              \
    case 1:                                 \
      HANDLE_CASE(TYPE, A, B, 1);           \
      break;                                \
    case 2:                                 \
      HANDLE_CASE(TYPE, A, B, 2);           \
      break;                                \
    default:                                \
      HANDLE_CASE(TYPE, A, B, -1);          \
      break;                                \
  }                                         \
}
 
#define HANDLE_B_CASE(TYPE, A, B, C) {      \
  switch (B) {                              \
    case 1:                                 \
      HANDLE_C_CASE(TYPE, A, 1, C);         \
      break;                                \
    case 2:                                 \
      HANDLE_C_CASE(TYPE, A, 2, C);         \
      break;                                \
    default:                                \
      HANDLE_C_CASE(TYPE, A, -1, C);        \
      break;                                \
  }                                         \
}
 
#define HANDLE_A_CASE(TYPE, A, B, C) {      \
  switch (A) {                              \
    case 1:                                 \
      HANDLE_B_CASE(TYPE, 1, B, C);         \
      break;                                \
    case 2:                                 \
      HANDLE_B_CASE(TYPE, 2, B, C);         \
      break;                                \
    default:                                \
      HANDLE_B_CASE(TYPE, -1, B, C);        \
      break;                                \
  }                                         \
}
 
  if (detail::canUse32BitIndexMath(a) &&
      detail::canUse32BitIndexMath(b) &&
      detail::canUse32BitIndexMath(c)) {
    detail::TensorInfo<scalar1, unsigned int> aInfo =
      detail::getTensorInfo<scalar1, unsigned int>(a);
 
    detail::TensorInfo<scalar2, unsigned int> bInfo =
      detail::getTensorInfo<scalar2, unsigned int>(b);
 
    detail::TensorInfo<scalar3, unsigned int> cInfo =
      detail::getTensorInfo<scalar3, unsigned int>(c);
 
    rearrangeDims(&aInfo, &bInfo, &cInfo);
    aInfo.collapseDims();
    bInfo.collapseDims();
    cInfo.collapseDims();
 
    HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims, cInfo.dims);
  } else {
    detail::TensorInfo<scalar1, uint64_t> aInfo =
      detail::getTensorInfo<scalar1, uint64_t>(a);
 
    detail::TensorInfo<scalar2, uint64_t> bInfo =
      detail::getTensorInfo<scalar2, uint64_t>(b);
 
    detail::TensorInfo<scalar3, uint64_t> cInfo =
      detail::getTensorInfo<scalar3, uint64_t>(c);
 
    rearrangeDims(&aInfo, &bInfo, &cInfo);
    aInfo.collapseDims();
    bInfo.collapseDims();
    cInfo.collapseDims();
 
    /*
    Only instantiates the all 1D special case and the fallback all nD case for
    large (64-bit indexed) tensors to reduce compilation time.
    */
    if (aInfo.dims == 1 && bInfo.dims == 1 && cInfo.dims == 1) {
      HANDLE_CASE(uint64_t, 1, 1, 1);
    } else {
      HANDLE_CASE(uint64_t, -1, -1, -1);
    }
  }
#undef HANDLE_CASE
#undef HANDLE_C_CASE
#undef HANDLE_B_CASE
#undef HANDLE_A_CASE
 
  if (oldA.defined()) {
    // Ignore overlaps when copying back; if we use THCTensor_copy
    // instead, it will recursively try and invoke ourselves to make
    // oldA contiguous.
    at::native::legacy::cuda::_th_copy_ignoring_overlaps_(oldA, a);
    a = oldA;
  }
 
  if (oldB.defined()) {
    // Ignore overlaps when copying back; if we use THCTensor_copy
    // instead, it will recursively try and invoke ourselves to make
    // oldB contiguous.
    at::native::legacy::cuda::_th_copy_ignoring_overlaps_(oldB, b);
    b = oldB;
  }
 
  if (oldC.defined()) {
    // Ignore overlaps when copying back; if we use THCTensor_copy
    // instead, it will recursively try and invoke ourselves to make
    // oldC contiguous.
    at::native::legacy::cuda::_th_copy_ignoring_overlaps_(oldC, c);
    c = oldC;
  }
 
  return true;
}
 
/* Provides default step = 1 to CUDA_tensor_apply3. */
template <typename scalar1, typename scalar2, typename scalar3, typename Op>
inline bool CUDA_tensor_apply3(at::Tensor a,
                               at::Tensor b,
                               at::Tensor c,
                               const Op op,
                               TensorArgType aType = TensorArgType::ReadWrite,
                               TensorArgType bType = TensorArgType::ReadOnly,
                               TensorArgType cType = TensorArgType::ReadOnly) {
  return CUDA_tensor_apply3<scalar1, scalar2, scalar3, 1, Op>(
    a, b, c, op, aType, bType, cType);
}
 
 
template <typename scalar1, typename scalar2, typename scalar3, typename scalar4,
          int step, typename Op>
inline bool CUDA_tensor_apply4(at::Tensor a,
                               at::Tensor b,
                               at::Tensor c,
                               at::Tensor d,
                               const Op op,
                               TensorArgType aType = TensorArgType::ReadWrite,
                               TensorArgType bType = TensorArgType::ReadOnly,
                               TensorArgType cType = TensorArgType::ReadOnly,
                               TensorArgType dType = TensorArgType::ReadOnly) {
  checkBackend("CUDA_tensor_apply4", {a, b, c, d}, Backend::CUDA);
  int64_t totalElements = a.numel();
 
  if (totalElements != b.numel() ||
      totalElements != c.numel() ||
      totalElements != d.numel()) {
    return false;
  }
 
  if (a.dim() > MAX_TENSORINFO_DIMS ||
      b.dim() > MAX_TENSORINFO_DIMS ||
      c.dim() > MAX_TENSORINFO_DIMS ||
      d.dim() > MAX_TENSORINFO_DIMS) {
    return false;
  }
 
  if (a.numel() == 0) {
    // Empty tensor; do nothing
    return true;
  }
 
  const dim3 block = getApplyBlock();
 
  dim3 grid;
  int64_t curDevice = current_device();
  if (curDevice == -1) return false;
  if (!getApplyGrid<step>(totalElements, grid, curDevice)) {
    return false;
  }
 
  /*
  Expands readable/writable tensors whose indices may be "overlapped."
  This ensures that each element of the tensor is operated on once and only
  once.
  */
  Tensor oldA;
  Tensor oldB;
  Tensor oldC;
  Tensor oldD;
 
  if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) {
    // Must perform in contiguous space
    oldA = a;
    a = a.contiguous();
  }
  if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) {
    // Must perform in contiguous space
    oldB = b;
    b = b.contiguous();
  }
  if (cType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(c)) {
    // Must perform in contiguous space
    oldC = c;
    c = c.contiguous();
  }
  if (dType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(c)) {
    // Must perform in contiguous space
    oldD = d;
    d = d.contiguous();
  }
 
#define HANDLE_CASE(TYPE, A, B, C, D)                                  \
  kernelPointwiseApply4<Op,                                            \
                        scalar1,                                       \
                        scalar2,                                       \
                        scalar3,                                       \
                        scalar4,                                       \
                        TYPE, A, B, C, D, step>                        \
    <<<grid, block, 0, at::cuda::getCurrentCUDAStream(curDevice)>>>(   \
    aInfo, bInfo, cInfo, dInfo, static_cast<TYPE>(totalElements), op);
 
#define HANDLE_D_CASE(TYPE, A, B, C, D) {       \
  switch (D) {                                  \
    case 1:                                     \
      HANDLE_CASE(TYPE, A, B, C, 1);            \
      break;                                    \
    case 2:                                     \
      HANDLE_CASE(TYPE, A, B, C, 2);            \
      break;                                    \
    default:                                    \
      HANDLE_CASE(TYPE, A, B, C, -1);           \
      break;                                    \
  }                                             \
}
 
#define HANDLE_C_CASE(TYPE, A, B, C, D) {       \
  switch (C) {                                  \
    case 1:                                     \
      HANDLE_D_CASE(TYPE, A, B, 1, D);          \
      break;                                    \
    case 2:                                     \
      HANDLE_D_CASE(TYPE, A, B, 2, D);          \
      break;                                    \
    default:                                    \
      HANDLE_D_CASE(TYPE, A, B, -1, D);         \
      break;                                    \
  }                                             \
}
 
#define HANDLE_B_CASE(TYPE, A, B, C, D) {       \
  switch (B) {                                  \
    case 1:                                     \
      HANDLE_C_CASE(TYPE, A, 1, C, D);          \
      break;                                    \
    case 2:                                     \
      HANDLE_C_CASE(TYPE, A, 2, C, D);          \
      break;                                    \
    default:                                    \
      HANDLE_C_CASE(TYPE, A, -1, C, D);         \
      break;                                    \
  }                                             \
}
 
#define HANDLE_A_CASE(TYPE, A, B, C, D) {       \
  switch (A) {                                  \
    case 1:                                     \
      HANDLE_B_CASE(TYPE, 1, B, C, D);          \
      break;                                    \
    case 2:                                     \
      HANDLE_B_CASE(TYPE, 2, B, C, D);          \
      break;                                    \
    default:                                    \
      HANDLE_B_CASE(TYPE, -1, B, C, D);         \
      break;                                    \
  }                                             \
}
 
  if (detail::canUse32BitIndexMath(a) &&
      detail::canUse32BitIndexMath(b) &&
      detail::canUse32BitIndexMath(c) &&
      detail::canUse32BitIndexMath(d)) {
    detail::TensorInfo<scalar1, unsigned int> aInfo =
      detail::getTensorInfo<scalar1, unsigned int>(a);
 
    detail::TensorInfo<scalar2, unsigned int> bInfo =
      detail::getTensorInfo<scalar2, unsigned int>(b);
 
    detail::TensorInfo<scalar3, unsigned int> cInfo =
      detail::getTensorInfo<scalar3, unsigned int>(c);
 
    detail::TensorInfo<scalar4, unsigned int> dInfo =
      detail::getTensorInfo<scalar4, unsigned int>(d);
 
    rearrangeDims(&aInfo, &bInfo, &cInfo, &dInfo);
    aInfo.collapseDims();
    bInfo.collapseDims();
    cInfo.collapseDims();
    dInfo.collapseDims();
 
    HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims, cInfo.dims, dInfo.dims);
  } else {
    detail::TensorInfo<scalar1, uint64_t> aInfo =
      detail::getTensorInfo<scalar1, uint64_t>(a);
 
    detail::TensorInfo<scalar2, uint64_t> bInfo =
      detail::getTensorInfo<scalar2, uint64_t>(b);
 
    detail::TensorInfo<scalar3, uint64_t> cInfo =
      detail::getTensorInfo<scalar3, uint64_t>(c);
 
    detail::TensorInfo<scalar4, uint64_t> dInfo =
      detail::getTensorInfo<scalar4, uint64_t>(d);
 
    rearrangeDims(&aInfo, &bInfo, &cInfo, &dInfo);
    aInfo.collapseDims();
    bInfo.collapseDims();
    cInfo.collapseDims();
    dInfo.collapseDims();
 
    /*
    Only instantiates the all 1D special case and the fallback all nD case for
    large (64-bit indexed) tensors to reduce compilation time.
    */
    if (aInfo.dims == 1 && bInfo.dims == 1 && cInfo.dims == 1 && dInfo.dims == 1) {
      HANDLE_CASE(uint64_t, 1, 1, 1, 1);
    } else {
      HANDLE_CASE(uint64_t, -1, -1, -1, -1);
    }
  }
#undef HANDLE_CASE
#undef HANDLE_D_CASE
#undef HANDLE_C_CASE
#undef HANDLE_B_CASE
#undef HANDLE_A_CASE
 
  if (oldA.defined()) {
    // Ignore overlaps when copying back; if we use THCTensor_copy
    // instead, it will recursively try and invoke ourselves to make
    // oldA contiguous.
    at::native::legacy::cuda::_th_copy_ignoring_overlaps_(oldA, a);
  }
 
  if (oldB.defined()) {
    // Ignore overlaps when copying back; if we use THCTensor_copy
    // instead, it will recursively try and invoke ourselves to make
    // oldB contiguous.
    at::native::legacy::cuda::_th_copy_ignoring_overlaps_(oldB, b);
  }
 
  if (oldC.defined()) {
    // Ignore overlaps when copying back; if we use THCTensor_copy
    // instead, it will recursively try and invoke ourselves to make
    // oldC contiguous.
    at::native::legacy::cuda::_th_copy_ignoring_overlaps_(oldC, c);
  }
 
  if (oldD.defined()) {
    // Ignore overlaps when copying back; if we use THCTensor_copy
    // instead, it will recursively try and invoke ourselves to make
    // oldC contiguous.
    at::native::legacy::cuda::_th_copy_ignoring_overlaps_(oldD, c);
  }
 
  return true;
}
 
/* Provides default step = 1 to CUDA_tensor_apply4. */
template <typename scalar1, typename scalar2, typename scalar3, typename scalar4,
          typename Op>
inline bool CUDA_tensor_apply4(at::Tensor a,
                               at::Tensor b,
                               at::Tensor c,
                               at::Tensor d,
                               const Op op,
                               TensorArgType aType = TensorArgType::ReadWrite,
                               TensorArgType bType = TensorArgType::ReadOnly,
                               TensorArgType cType = TensorArgType::ReadOnly,
                               TensorArgType dType = TensorArgType::ReadOnly) {
  return CUDA_tensor_apply4<scalar1, scalar2, scalar3, scalar4, 1, Op>(
    a, b, c, d, op, aType, bType, cType);
}
 
} // cuda
} // at