#include #include #include #include #ifdef USE_CUDA // NB: It's a list of *optional* CUDAStream; when nullopt, that means to use // whatever the current stream of the device the input is associated with was. std::vector> THPUtils_PySequence_to_CUDAStreamList(PyObject* obj) { if (!PySequence_Check(obj)) { throw std::runtime_error( "Expected a sequence in THPUtils_PySequence_to_CUDAStreamList"); } THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, nullptr)); if (seq.get() == nullptr) { throw std::runtime_error( "expected PySequence, but got " + std::string(THPUtils_typename(obj))); } std::vector> streams; Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get()); for (Py_ssize_t i = 0; i < length; i++) { PyObject* stream = PySequence_Fast_GET_ITEM(seq.get(), i); if (PyObject_IsInstance(stream, THCPStreamClass)) { // Spicy hot reinterpret cast!! streams.emplace_back(at::cuda::CUDAStream::unpack3( (reinterpret_cast(stream))->stream_id, (reinterpret_cast(stream))->device_index, static_cast( (reinterpret_cast(stream))->device_type))); } else if (stream == Py_None) { streams.emplace_back(); } else { // NOLINTNEXTLINE(bugprone-throw-keyword-missing) std::runtime_error( "Unknown data type found in stream list. Need torch.cuda.Stream or None"); } } return streams; } #endif