@@ -57,12 +57,12 @@ namespace {
5757
5858using stream_set = std::unordered_set<cuda::CUDAStream>;
5959
60- constexpr size_t kMinBlockSize = 512 ; // all sizes are rounded to at least 512 bytes
61- constexpr size_t kSmallSize = 1048576 ; // largest "small" allocation is 1 MiB
62- constexpr size_t kSmallBuffer = 2097152 ; // "small" allocations are packed in 2 MiB blocks
63- constexpr size_t kLargeBuffer = 20971520 ; // "large" allocations may be packed in 20 MiB blocks
60+ constexpr size_t kMinBlockSize = 512 ; // all sizes are rounded to at least 512 bytes
61+ constexpr size_t kSmallSize = 1048576 ; // largest "small" allocation is 1 MiB
62+ constexpr size_t kSmallBuffer = 2097152 ; // "small" allocations are packed in 2 MiB blocks
63+ constexpr size_t kLargeBuffer = 20971520 ; // "large" allocations may be packed in 20 MiB blocks
6464constexpr size_t kMinLargeAlloc = 10485760 ; // allocations between 1 and 10 MiB may use kLargeBuffer
65- constexpr size_t kRoundLarge = 2097152 ; // round up large allocations to 2 MiB
65+ constexpr size_t kRoundLarge = 2097152 ; // round up large allocations to 2 MiB
6666
6767typedef std::bitset<static_cast <size_t >(StatType::NUM_TYPES)> StatTypes;
6868
@@ -242,56 +242,57 @@ class DeviceCachingAllocator {
242242 // Free all non-split cached blocks and retry alloc.
243243 || (free_cached_blocks () && alloc_block (params, true ));
244244
245- TORCH_INTERNAL_ASSERT ((!block_found && params.err != cudaSuccess) || params.block );
246245 if (!block_found) {
247- if (params.err == cudaErrorMemoryAllocation) {
248- size_t device_free;
249- size_t device_total;
250- C10_CUDA_CHECK (cudaMemGetInfo (&device_free, &device_total));
251- std::string allowed_info;
252-
253- if (set_fraction) {
254- allowed_info = format_size (allowed_memory_maximum) + " allowed; " ;
255- }
246+ // For any error code other than cudaErrorMemoryAllocation,
247+ // alloc_block should have thrown an exception already.
248+ TORCH_INTERNAL_ASSERT (params.err == cudaErrorMemoryAllocation);
256249
257- stats.num_ooms += 1 ;
258-
259- // "total capacity": total global memory on GPU
260- // "allowed": memory is allowed to use, which set by fraction.
261- // "already allocated": memory allocated by the program using the
262- // caching allocator
263- // "free": free memory as reported by the CUDA API
264- // "cached": memory held by the allocator but not used by the program
265- //
266- // The "allocated" amount does not include memory allocated outside
267- // of the caching allocator, such as memory allocated by other programs
268- // or memory held by the driver.
269- //
270- // The sum of "allocated" + "free" + "cached" may be less than the
271- // total capacity due to memory held by the driver and usage by other
272- // programs.
273- //
274- // Note that at this point free_cached_blocks has already returned all
275- // possible "cached" memory to the driver. The only remaining "cached"
276- // memory is split from a larger block that is partially in-use.
277- TORCH_CHECK_WITH (CUDAOutOfMemoryError, false ,
278- " CUDA out of memory. Tried to allocate " , format_size (alloc_size),
279- " (GPU " , device, " ; " ,
280- format_size (device_total), " total capacity; " ,
281- format_size (stats.allocated_bytes [static_cast <size_t >(StatType::AGGREGATE)].current ),
282- " already allocated; " ,
283- format_size (device_free), " free; " ,
284- allowed_info,
285- format_size (stats.reserved_bytes [static_cast <size_t >(StatType::AGGREGATE)].current ),
286- " reserved in total by PyTorch)" );
287- } else {
288- C10_CUDA_CHECK (params.err );
250+ size_t device_free;
251+ size_t device_total;
252+ C10_CUDA_CHECK (cudaMemGetInfo (&device_free, &device_total));
253+ std::string allowed_info;
254+
255+ if (set_fraction) {
256+ allowed_info = format_size (allowed_memory_maximum) + " allowed; " ;
289257 }
258+
259+ stats.num_ooms += 1 ;
260+
261+ // "total capacity": total global memory on GPU
262+ // "allowed": memory is allowed to use, which set by fraction.
263+ // "already allocated": memory allocated by the program using the
264+ // caching allocator
265+ // "free": free memory as reported by the CUDA API
266+ // "cached": memory held by the allocator but not used by the program
267+ //
268+ // The "allocated" amount does not include memory allocated outside
269+ // of the caching allocator, such as memory allocated by other programs
270+ // or memory held by the driver.
271+ //
272+ // The sum of "allocated" + "free" + "cached" may be less than the
273+ // total capacity due to memory held by the driver and usage by other
274+ // programs.
275+ //
276+ // Note that at this point free_cached_blocks has already returned all
277+ // possible "cached" memory to the driver. The only remaining "cached"
278+ // memory is split from a larger block that is partially in-use.
279+ TORCH_CHECK_WITH (CUDAOutOfMemoryError, false ,
280+ " CUDA out of memory. Tried to allocate " , format_size (alloc_size),
281+ " (GPU " , device, " ; " ,
282+ format_size (device_total), " total capacity; " ,
283+ format_size (stats.allocated_bytes [static_cast <size_t >(StatType::AGGREGATE)].current ),
284+ " already allocated; " ,
285+ format_size (device_free), " free; " ,
286+ allowed_info,
287+ format_size (stats.reserved_bytes [static_cast <size_t >(StatType::AGGREGATE)].current ),
288+ " reserved in total by PyTorch)" );
290289 }
291290
291+ TORCH_INTERNAL_ASSERT (params.err == cudaSuccess &&
292+ params.block != nullptr &&
293+ params.block ->ptr != nullptr );
292294 Block* block = params.block ;
293295 Block* remaining = nullptr ;
294- TORCH_INTERNAL_ASSERT (block);
295296
296297 const bool already_split = block->is_split ();
297298 if (should_split (block, size)) {
@@ -647,30 +648,46 @@ class DeviceCachingAllocator {
647648 }
648649
649650 bool alloc_block (AllocParams& p, bool isRetry) {
651+ // Defensively checks for preexisting CUDA error state.
652+ C10_CUDA_CHECK (cudaGetLastError ());
653+
650654 size_t size = p.alloc_size ;
651655 void * ptr;
652656
653657 if (isRetry) {
654658 stats.num_alloc_retries += 1 ;
655659 }
660+
656661 if (set_fraction && total_allocated_memory + size > allowed_memory_maximum) {
657662 p.err = cudaErrorMemoryAllocation;
663+ return false ;
658664 } else {
659665 p.err = cudaMalloc (&ptr, size);
660- }
661-
662- if (p.err != cudaSuccess) {
663- if (!isRetry || p.err == cudaErrorMemoryAllocation)
664- cudaGetLastError (); // clear CUDA error
665- return false ;
666+ if (p.err != cudaSuccess) {
667+ if (p.err == cudaErrorMemoryAllocation) {
668+ // If this is the first attempt (!isRetry), we can forgive and clear CUDA's
669+ // internal error state.
670+ // If this is the second attempt (isRetry), malloc's TORCH_CHECK_WITH will take
671+ // over to throw a helpful exception. The user can choose to catch the exception,
672+ // free some stuff in their script, and attempt their allocation again.
673+ // In this case, we can also forgive and clear CUDA's internal error state.
674+ cudaGetLastError ();
675+ } else {
676+ // If the error's unrelated to memory allocation, we should throw immediately.
677+ C10_CUDA_CHECK (p.err );
678+ }
679+ return false ;
680+ }
666681 }
667682
668683 total_allocated_memory += size;
669684 p.block = new Block (p.device (), p.stream (), size, p.pool , (char *)ptr);
670685 update_stat_array (stats.segment , 1 , p.stat_types );
671686 update_stat_array (stats.reserved_bytes , size, p.stat_types );
672687
673- return (p.block != nullptr );
688+ // p.block came from new, not cudaMalloc. It should not be nullptr here.
689+ TORCH_INTERNAL_ASSERT (p.block != nullptr && p.block ->ptr != nullptr );
690+ return true ;
674691 }
675692
676693 bool free_cached_blocks ()
0 commit comments