1 #ifdef THRUST_DEVICE_LOWER_BOUND_WORKS 2 #include <thrust/binary_search.h> 3 #include <thrust/device_vector.h> 4 #include <thrust/execution_policy.h> 5 #include <thrust/functional.h> 6 #endif 7 namespace c10::cuda { 8 #ifdef THRUST_DEVICE_LOWER_BOUND_WORKS 9 template <typename Iter, typename Scalar> 10 __forceinline__ __device__ Iter lower_bound(Iter start,Iter end,Scalar value)11 lower_bound(Iter start, Iter end, Scalar value) { 12 return thrust::lower_bound(thrust::device, start, end, value); 13 } 14 #else 15 // thrust::lower_bound is broken on device, see 16 // https://github.com/NVIDIA/thrust/issues/1734 Implementation inspired by 17 // https://github.com/pytorch/pytorch/blob/805120ab572efef66425c9f595d9c6c464383336/aten/src/ATen/native/cuda/Bucketization.cu#L28 18 template <typename Iter, typename Scalar> 19 __device__ Iter lower_bound(Iter start, Iter end, Scalar value) { 20 while (start < end) { 21 auto mid = start + ((end - start) >> 1); 22 if (*mid < value) { 23 start = mid + 1; 24 } else { 25 end = mid; 26 } 27 } 28 return end; 29 } 30 #endif // THRUST_DEVICE_LOWER_BOUND_WORKS 31 } // namespace c10::cuda 32