@@ -204,7 +204,7 @@ cudaError_t GetDeviceCount(int* dev_count) {
204204// x = torch.empty(1, device=“cuda:1”) # no CUDA context on cuda:0 after this
205205// call y = torch.empty(1, device=“cuda”) # CUDA context is created on cuda:0
206206// ```
207- #if CUDA_VERSION >= 11000
207+ #if CUDA_VERSION >= 12000
208208thread_local int targetDeviceIndex = -1 ;
209209
210210cudaError_t GetDevice (int * device) {
@@ -223,9 +223,7 @@ cudaError_t SetDevice(int device) {
223223 if (device == cur_device) {
224224 return cudaSuccess;
225225 }
226- cudaError_t err = cudaSetDevice (device);
227- C10_CUDA_CHECK (cudaFree (0 ));
228- return err;
226+ return cudaSetDevice (device);
229227}
230228
231229cudaError_t MaybeSetDevice (int device) {
@@ -236,6 +234,8 @@ cudaError_t MaybeSetDevice(int device) {
236234 return cudaSuccess;
237235}
238236
237+ // This function always initializes the CUDA context
238+ // on to_device
239239int ExchangeDevice (int to_device) {
240240 int cur_device = targetDeviceIndex;
241241 targetDeviceIndex = -1 ;
@@ -246,10 +246,11 @@ int ExchangeDevice(int to_device) {
246246 }
247247 }
248248 C10_CUDA_CHECK (cudaSetDevice (to_device));
249- C10_CUDA_CHECK (cudaFree (0 ));
250249 return cur_device;
251250}
252251
252+ // This function does not initialize the CUDA context
253+ // on to_device if it does not already exist
253254int MaybeExchangeDevice (int to_device) {
254255 int cur_device = -1 ;
255256 C10_CUDA_CHECK (cudaGetDevice (&cur_device));
@@ -258,7 +259,6 @@ int MaybeExchangeDevice(int to_device) {
258259 }
259260 if (hasPrimaryContext (to_device)) {
260261 C10_CUDA_CHECK (cudaSetDevice (to_device));
261- C10_CUDA_CHECK (cudaFree (0 ));
262262 } else {
263263 targetDeviceIndex = to_device;
264264 }
0 commit comments