How are 2D and 3D thread blocks linearized into warps in CUDA?
Introduction
In CUDA, the kernel and thread hierarchy are two fundamental and related concepts. A kernel is a function that defines the work done by a single CUDA thread, while the thread hierarchy defines how threads are grouped together for execution, and how they relate to each other.
For example, to do matrix multiplication, we could write a kernel that computes one element of an output matrix by computing the dot product between one row from the left matrix, and one column of the right matrix. We could then allocate enough threads to “cover” all elements of the output matrix in order to parallelize the computation.
In CUDA, threads are grouped into thread blocks, which contain one or more threads. Thread blocks are then grouped into a grid, which contains one or more thread blocks. 1 This grouping has a relationship with the conceptual architecture of the GPU/device exposed through the CUDA APIs:
- A thread runs on a CUDA core.
- A thread block is executed by a streaming multiprocessor (SM), which contains multiple CUDA cores.
- The entire grid is collectively executed by the GPU/device, which contains many SMs.
For example, the RTX 4090 contains 128 SMs, each with 128 CUDA cores, giving us 16384 total CUDA cores. Note that this is many orders of magnitude higher than even a high-end CPU, though each CUDA core is far less powerful than a typical CPU core.
The CUDA execution model is fundamentally different than a CPU. In CUDA, the fundamental execution unit is not a single thread, but instead a group of threads known as a warp. A warp (currently) consists of 32 threads from the same block which all execute the same instruction2 at the same time. This is referred to as Single Instruction Multiple Thread (SIMT).
A consequence of this execution model is that in order to efficiently access memory, you want the threads in a warp to access consecutive memory locations. If this is done, then the accesses by the 32 threads in a warp can be coalesced into fewer than 32 accesses.3 This optimization is crucial to ensuring a kernel runs efficiently on the GPU/device, since often the throughput is constrained by memory accesses.
But how are threads grouped/linearized into warps when using 2D or 3D thread blocks?
Note: The remainder of this article assumes some basic understanding of CUDA, namely the thread execution model and memory hierarchy. I’m by no means an expert, and still continuing to learn in this area.
2D, 3D thread blocks and matrix multiplication
In CUDA, the thread block (and grid) size can be defined not just as a single number, but also 2D or 3D size. For example, you could define the thread block size to be dim3 threadsPerBlock(32, 32);
, which would be a 2D block of 32 x 32 threads (or 1024 threads total). In this example, threadIdx.x
would vary between $[0, 31]$, and threadIdx.y
would also vary between $[0, 31]$. The question we want to answer is: How are these threads grouped into warps of 32 threads?
As far as I can tell, this happens in a row-major order. That is, threads with consecutive threadIdx.x
values would be placed together, e.g. threadIdx.x
would change the fastest. In this example, it would likely result in warps being built like this:
Similarly a 3D thread block would be linearized by having threadIdx.x
varying fastest, then threadIdx.y
, followed by threadIdx.z
.
While this makes sense, I was unable to find a definitive answer in the official programming guide, and this seemed to be an implementation detail. However, many third-party sources claim this, and it’s even mentioned in this GPU Teaching Kit produced by NVIDIA and the University of Illinois. (A different version of which is found here)
Matrix multiplication example
We can test this out by writing a simple matrix multiplication kernel which assumes a row-major ordering to the threads in a warp and structures the memory accesses in such a way as to achieve good coalescing, and then writing a “bad” matrix multiplication kernel that ignores this. We can then profile each kernel to see what, if any, performance difference exists, and if so, why. (Full code here)
Here’s the code for a straightforward matrix multiplication kernel (matMul()
) which assumes a row-major ordering of threadIdx
:
Here’s the “bad” version (matMulBad
), which just reverses the assignment of row
to threadIdx.x
and col
to threadIdx.y
: (This kernel needs to be called slightly differently, see the full code for more details)
Here are the results from using ncu
to profile the program when multiplying a $3000 \times 4000$ matrix with a $4000 \times 3000$ matrix to produce a $3000 \times 3000$ matrix.
First matMul()
:
And here are the results for matMulBad()
:
Observations:
- The straightforward
matMul()
kernel does fairly well, especially considering it hasn’t been optimized to use techniques like explicit tiling. Max Bandwidth
is reduced from 98.43% to 11.86%.- The
Source Counters
section formatMulBad()
gives the reason why:
Explanation
In our matMul()
(good) code, we index into the row of left
and column of right
as follows:
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
All of the matrices (left
, right
, and out
) are all stored in row-major order. The output is then computed as:
if (row < a && col < c) { // Bounds check
float sum = 0.0;
for (int i = 0; i < b; i++) {
sum += left[row * b + i] * right[i * c + col];
}
out[row * c + col] = sum;
}
If threads are linearized (grouped) into warps in row-major order, then it’s likely all the threads in the warp will have consecutive values of threadIdx.x
. Conversely, the value of threadIdx.y
will likely not be changing. This means across threads in a warp, row
will likely be fixed and col
will have consecutive values.
This leads to a decent memory access pattern: 4
- Read row from
left
: In each loop iteration, each thread will read the same element from the same row inleft
. Across all iterations in the loop, this results in the same row being read by all threads in the warp. Although this isn’t efficient, it does mean that all threads in the warp will coalesce into a single read per iteration. Additionally, because the memory controller accesses at least 32 bytes at a time, this may result in the first read caching subsequent elements in L1 (e.g. not just a single float is read at a time), reducing global accesses. - Read columns from
right
: In each loop iteration, the 32 threads of the warp read 32 consecutive elements from a single row inright
. Across all iterations in the loop, this results in 32 columns being read, one by each thread in the warp. Because each iteration results in 32 consecutive elements being read (matrix is row-major order), this access pattern results in proper coalescing of the accesses into a single global access. (Read 32 x 4-byte floats in a single 128-byte transaction) - Write to
out
: Each thread in the warp writes to the same row inout
. Specifically, becausecol
is varying with the threads butrow
is likely not, this results in the 32 threads in the warp writing to 32 consecutive positions in the same row ofout
. Again, this memory access can be coalesced.
(2) and (3) result in the “Neighboring Columns” access pattern, depicted in this slide from this University of Illinois lecture: (Diagram from David Kirk/NVIDIA and Wen-mei Hwu)
Contrast this behaviour with matMulBad()
, which merely reverses the row
/col
mapping as:
int row = blockIdx.x * blockDim.x + threadIdx.x;
int col = blockIdx.y * blockDim.y + threadIdx.y;
Now row
is associated with threadIdx.x
and hence changing across the threads in the warp, and col
is associated with threadIdx.y
and is likely fixed within the warp. This changes the memory access pattern across the threads in the warp:
- Read rows from
left
: Becauserow
is varying across each thread, on each loop iteration, the 32 threads of the warp read one element from across 32 different rows. That is, they read 32 elements from the same column on each iteration. This results in 32 rows being read, one by each thread in the warp. This is not an efficient memory access, because the 32 elements are spread across multiple rows and thus not adjacent to each other in a row-major ordered matrix. Hence, the accesses cannot be coalesced, and results in multiple memory accesses. - Read column from
right
: Becausecol
is likely not varying across threads in the warp, each thread reads the same element in the column in each loop iteration. This results in each thread reading the same column fromright
. So, each iteration will result in just a single read. - Write to
out
: Each thread writes to a different row inout
, becauserow
is now varying with each thread in the warp. Becausecol
is likely not varying, this means the 32 threads in the warp write to 32 positions in the same column. Again, because the access pattern spans rows, this is an inefficient write which cannot be coalesced.
Reasons (1) and (3) above are likely responsible for the majority of the performance decrease noticed with the matMulBad()
kernel. These both utilize the “Neighboring Rows” access pattern, also depicted in this University of Illinois lecture: (Diagram from David Kirk/NVIDIA and Wen-mei Hwu)
Conclusion
Threads in a 2D or 3D thread block are linearized in row-major order. Although this may be an implementation detail, many third-party sources mention it, and probably a lot of code depends on it. Intuitively, this makes sense, since matrices are typically stored in row-major order by convention in C/C++ code, so having thread blocks linearized in this order makes things “match up”. Take this into account when laying out data in memory and determining the access pattern from a CUDA kernel.
Note that this does not mean we should assume any ordering to how warps themselves are executed. While threads within a warp are executed in lock-step, warps themselves (or the enclosing thread blocks) could be executed in any order, and so you should not rely on certain warps or thread blocks finishing before or after others.
-
I left out details of thread block clusters, which are an optional level between thread blocks and the grid. ↩︎
-
Ignoring details of warp divergence for the purpose of this article. ↩︎
-
There are also memory alignment requirements, but again I will skip over those for the purpose of this article. ↩︎
-
The
matMul()
kernel is a straightforward implementation (just for illustration purposes) and could likely be improved by utilizing techniques like tiling which take advantage of a thread block’s shared memory. ↩︎