Notes: GPU Memory Performance Considerations
Access Patterns and Shared Memory
Continuing my exploration of GPU performance I decided to turn my focus into the card’s memory. For a reminder on how the memory hierarchy looks like from the hardware perspective you can look at one of my previous posts.
Theoretical Peak Bandwidth
Before jumping into experiments it’s good to know what we could theoretically achieve given our hardware.
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, DEVICE_ID));
printf("Device %d: %s\n", DEVICE_ID, prop.name);
printf("\tMemory Clock Rate: %d kHz\n", prop.memoryClockRate);
printf("\tMemory Bus Width : %d bits\n", prop.memoryBusWidth);
// - 2.0 multiplier: HBM is DDR (Double Data Rate) which means that
// it can transfer data twice per clock cycle (e.g.
// on both the rises and the edges of the clock
// signal).
double bandwidthGBs = 2.0 * prop.memoryClockRate *
(prop.memoryBusWidth/8) / 1.0e6;
double bandwidthGiBs = bandwidthGBs / 1.073741824;
printf("\tTheoretical Peak Bandwidth: %.2f GB/s (%.2f GiB/s)\n",
bandwidthGBs, bandwidthGiBs);Running this on my machine I get:
Device 0: NVIDIA H100 80GB HBM3
Memory Clock Rate: 2619000 kHz
Memory Bus Width : 5120 bits
Theoretical Peak Bandwidth: 3352.32 GB/s (3122.09 GiB/s)Now real-world bandwidth/throughput is always lower due, but not limited, to the following reasons:
PCIe or NVLink bottlenecks
Contention from other processes/kernels
ECC Overhead
Access patterns
In this post I’ll be looking into access patterns as this is the factor that we have the most direct control of as programmers.
Aside: Bandwidth vs Throughput
Given that both use the same units I figured I’d look into how they differ. Basically:
Bandwidth → the maximum possible rate the hardware could achieve (theoretical peak)
Throughput → the actual rate achieved in practice (real-world performance)
Another way to think about it that I found out through one of the bots (I think it was ChatGPT) is to think about a highway:
Bandwidth = how many cars could fit on all lanes at top speed
Throughput = how many cars are actually moving during rush hour
Sequential Access / Unit-Stride Performance
Similar to CPUs, GPUs perform better when accessing sequential memory locations. Here is an example kernel that just copies elements from an input array to an output array:
// Ideal case:
// Each thread copies one element from input to output at the same index
__global__ void
coalesced_copy(const float* input, float* output, size_t n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = input[idx];
}
}Now if we introduce a stride to make the access more sparse/pseudo-random we can write something like this:
__global__ void
strided_copy(const float* input, float* output, size_t n, int stride) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = input[idx * stride];
}
}Letting the above two race and trying different strides for the latter one I get the following results:
coalesced_copy -> Bandwidth: 2230.25 GiB/s | Duration: 0.00022419s
strided_copy stride=2 -> Bandwidth: 1744.55 GiB/s|Duration: 0.000286607s
strided_copy stride=4 -> Bandwidth: 1157.11 GiB/s|Duration: 0.000432112s
strided_copy stride=8 -> Bandwidth: 645.33 GiB/s |Duration: 0.000774798sThe sequential/coalesced code achieves 2230.25/3122.09 → ~71% of peak bandwidth, and from what my googling says that’s pretty good.
Now, just like CPUs, non-sequential access result in performance degradation. Some of the hardware details for why is that though are slightly different because of how GPU cores are organized in warps.
When threads in a warp read consecutive memory addresses, the GPU hardware merges those requests into a single memory transaction. So for a warp where threads 0-31 read addresses A[0] … A[31], the hardware combines these accesses into a single transaction - a 128-byte load from global memory. When we have each thread within a warp access memory randomly, each can trigger a separate memory transaction, and the more transactions we have the lower the performance gets.
Now if we look at our results above we see that starting from stride = 2, every time we double the stride our bandwidth almost halves. This is because we almost double our memory transactions every time we double our stride. Here is a more visual way of looking at this:
# unit-stride memory layout & accesses:
---------------------------------------------------------
| A[0] | A[1] | A[2] | A[3] | A[4] | A[5] | A[6] | A[7] | Transaction 0
---------------------------------------------------------
↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
<warp reads all its elements in one transaction>
# stride=2 memory layout & accesses
---------------------------------------------------------
| A[0] | A[1] | A[2] | A[3] | A[4] | A[5] | A[6] | A[7] | Transaction 0
---------------------------------------------------------
↑ ↑ ↑ ↑
---------------------------------------------------------
| A[8] | A[9] | A[10] | A[11] | A[12] |A[13]|A[14]|A[15]| Transaction 1
---------------------------------------------------------
↑ ↑ ↑ ↑
<same warp needs 2 transactions now>
# stride=2 memory layout & accesses
---------------------------------------------------------
| A[0] | A[1] | A[2] | A[3] | A[4] | A[5] | A[6] | A[7] | Transaction 0
---------------------------------------------------------
↑ ↑
---------------------------------------------------------
| A[8] | A[9] | A[10] | A[11] | A[12] |A[13]|A[14]|A[15]| Transaction 1
---------------------------------------------------------
↑ ↑
---------------------------------------------------------
|A[16]|A[17]| A[18]| A[19] | A[20] | A[21] | A[22]|A[23]| Transaction 2
---------------------------------------------------------
↑ ↑
---------------------------------------------------------
|A[24]|A[25]|A[26]| A[27] | A[28] | A[29] | A[30] |A[31]| Transaction 3
---------------------------------------------------------
↑ ↑
<same warp needs 4 transactions now>So going back to comparison with CPUs: Just like CPUs, GPUs perform better with sequential accesses due to caching and can benefit from spatial/temporal locality. Unlike CPUs though the penalty of random accesses is a lot more devastating for performance because the hardware cannot amortize latency over coalesced transactions. CPUs have deeper cache hierarchies that are also smarter (e.g. they are equipped with sophisticated branch prediction, out-of-order execution, etc..). GPUs don’t have that and in order to hide their latency they rely heavily in memory coalescing to enable massive parallelism (see race car vs freight-truck metaphor in my previous post).
Shared Memory & Tilling
Non-coalesced accessed patterns can be hard to avoid. There are some scenarios though where shared memory can be of help. If you recall from one of my previous posts, there’s a piece of L1 Data cache memory shared between all the warps in the same SM, and can be actually partitioned explicitly by the programmer (i.e. it is a user-managed cache). CUDA programmers can leverage this shared memory with a pattern commonly referred to as Tilling. The idea is that we can break a large problem (like a big matrix) into smaller chunks/tiles that fit into shared memory and do our work there. By loading tiles into shared memory and then computing on them efficiently we essentially convert non-coalesced global memory accesses to coalesced global memory accesses. Additionally, Even if shared memory accesses aren’t perfectly coalesced they are still way faster than scattered global memory access.
The most common example used for demonstrating Tilling is transposing a matrix. Here is a naive implementation:
__global__ void
naive_transpose(const float* input, float* output,
int width, int height) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
if (x < width && y < height) {
output[x * height + y] = input[y * width + x];
}
}The accesses from input[] are coalesced - here’s an example visual assuming a 4x4 matrix:
Thread x y Accesses input[y * width + x]
====== = = =============================
0 0 0 input[0]
1 1 0 input[1]
2 2 0 input[2]
3 3 0 input[3]
4 0 1 input[4]
...etc...On the other hand for the same threads the output[] has a 4-element stride:
Thread x y Accesses input[x * height + y]
====== = = =============================
0 0 0 output[0]
1 1 0 output[4]
2 2 0 output[8]
3 3 0 output[12]
4 0 1 output[16]
...etc...So the reads are coalesced but the writes are not. Now if we were to use tilling, the same transposition functions would be written like so:
#define TILE_DIM 32 // Tile dimension
__global__ void
shared_mem_transpose(const float* input, float* output, int width, int height) {
__shared__ float tile[TILE_DIM][TILE_DIM];
int x = blockIdx.x * TILE_DIM + threadIdx.x;
int y = blockIdx.y * TILE_DIM + threadIdx.y;
int tile_x = threadIdx.x;
int tile_y = threadIdx.y;
// Load tile from global to shared memory (coalesced)
if (x < width && y < height) {
tile[tile_y][tile_x] = input[y * width + x];
}
// Ensure that tile is fully loaded before proceeding
__syncthreads();
// Update indices for transposed write
x = blockIdx.y * TILE_DIM + threadIdx.x;
y = blockIdx.x * TILE_DIM + threadIdx.y;
// Write tile from shared to global memory (coalesced)
if (x < height && y < width) {
output[y * height + x] = tile[tile_x][tile_y];
}
}and in order to call it we also supply the size that we want the shared memory to be:
shared_mem_transpose<<<gridSize, blockSize, sharedMemBytes>>>(
d_input,
d_output,
width,
height);Now let’s unpack what’s happening in the function above:
We first declare the tile/working-set matrix in shared memory - declared with the
__shared__keyword.We then declare our global coordinates
xandy, and our tile coordinatestile_xandtile_y.We then perform our coalesced reads from global memory - like we did in the naive implementation.
Then we use
__syncthreads(), which is a barrier for all the threads in the same warp. The idea is that we want all the threads to have finished the writes to the tile because the tile coordinate that one thread populated another thread will read in the next part of the function.Then we update our indices and essentially do the transpose by swapping
blockIdx.xandblockIdx.y.We then write our transposed tile to global memory through coalesced writes.
Running both versions we see how the introduction of the shared memory complexity is definitely worth it for the outcome. Throughput has doubled:
naive -> Bandwidth: 424.814 GiB/s | Avg.Duration: 7.35616e-05s
shared_mem -> Bandwidth: 1026.75 GiB/s | Avg.Duration: 3.04358e-05sBank Conflicts
As I showed above shared memory can give a boost to CUDA applications when used. Just like every piece of memory though that’s accessed by many threads, it needs to be used effectively. A problem with my code above is that it contains bank conflicts.
Shared memory on NVIDIA GPUs is divided into banks. Each bank can serve one address per clock cycle. If multiple threads in a warp (typically 32) access different addresses in the same bank then a bank conflict occurs and the accesses are serialized which hurts performance. In Hopper and other modern NVIDIA GPUS there are 32 banks in each shared memory unit. So we can find the bank of an address X in shared memory by computing X % 32 [32 x sizeof(float) to be exact]and get the bank numbers.
In our code above the conflict happens in the following part:
int tile_x = threadIdx.x;
int tile_y = threadIdx.y;
..snip..
output[y * height + x] = tile[tile_x][tile_y];Each thread in a warp has same tile_x and different tile_y. So within the first warp for example (threadIdx: 0-31, threadIdx.y = 0) all are accessing tile[0][0], tile[1][0], ..., tile[31][0] which is the same column and in our case the same bank (because the size of each row is 32 elements)
Tile Access Index In Memory Layout Bank
----------- ---------------------- ------------
tile[0][0] (0 x 32) + 0 = 0 0 % 32 = 0
tile[1][0] (1 x 32) + 0 = 32 32 % 32 = 0
...
tile[31][0] (31 x 32) + 0 = 992 992 % 32 = 0To get rid of these conflicts all we need to do is change one line:
__global__ void
shared_mem_i_transpose(const float* input, float* output, int width, int height) {
__shared__ float tile[TILE_DIM][TILE_DIM + 1]; // Avoid bank conflicts
..etc..
}By padding rows by 1 element (total of 33 elements in a row) we shard threads in a warp across banks. In our previous example with the first warp:
Tile Access Index In Memory Layout Bank
----------- ---------------------- -------------
tile[0][0] (0 x 33) + 0 = 0 0 % 32 = 0
tile[1][0] (1 x 33) + 0 = 33 33 % 32 = 1
...
tile[31][0] (31 x 33) + 0 = 1023 1023 % 32 = 31Now the tradeoff here is that we do waste some memory for this - specifically 4 bytes (size of float) x 32 (elements of padding) → 128 bytes which is a ~3% increase in our initial 4KiB matrix, and pretty small compared to the total memory of the shared memory unit in an SM which is generally in the high tenths to low hundreds of KiB per SM. On the other hand we save ourselves from a 32-way conflict turns a single-cycle access into 32 cycles of serialized accessed.
Here’s the trade-off in action.
naive -> Bandwidth: 424.814 GiB/s | Avg.Duration: 7.35616e-05s
shared_mem -> Bandwidth: 1026.75 GiB/s | Avg.Duration: 3.04358e-05s
shared_mem_i -> Bandwidth: 1967.92 GiB/s | Avg.Duration: 1.58797e-05sDefinitely worth the trade-off!

