LeetGPU-2: Matrix Multiplication
LeetGPU
Notes and solutions in PyTorch, Triton, and CUDA. Runtime shown for T4 GPU.
Note: This is part of a learning series on CUDA and Triton, focusing on correctness-first implementations rather than performance optimization.
Problem Statement
Write a program that multiplies two matrices of 32-bit floating point numbers on a GPU. Given matrix A of dimensions \(M \times N\) and matrix B of dimensions N x K, compute the product matrix C, which will have dimensions MxK. All matrices are stored in row-major format.
Constraints
1 ≤
M,N,K≤ 8192Performance is measured with
M= 8192,N= 6144,K= 4096
Solution
PyTorch
Note
- The solution is straightforward. But can we do blocking?
Solution-1
import torch
# A, B, C are tensors on the GPU
def solve(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, M: int, N: int, K: int):
torch.matmul(A, B, out=C)Runtime: 97.48ms
TODO: block matmul
Triton
Note
Solution
Correct but slow. Will come back to this later in sha Allah!
import torch
import triton
import triton.language as tl
@triton.jit
def matrix_multiplication_kernel(
a, b, c, M, N, K, stride_am, stride_an, stride_bn, stride_bk, stride_cm, stride_ck
):
row = tl.program_id(0)
col = tl.program_id(1)
# BLOCK_SIZE: tl.constexpr = 16
acc = 0.0
for i in range(0, N):
a_val = tl.load(a + row * N + i)
b_val = tl.load(b + col + i * K)
acc += a_val * b_val
tl.store(c + K * row + col, acc)
# a, b, c are tensors on the GPU
def solve(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, M: int, N: int, K: int):
stride_am, stride_an = N, 1
stride_bn, stride_bk = K, 1
stride_cm, stride_ck = K, 1
grid = (M, K)
matrix_multiplication_kernel[grid](
a, b, c, M, N, K, stride_am, stride_an, stride_bn, stride_bk, stride_cm, stride_ck
)Runtime: 0.00ms
CUDA
Note
Solution
#include <cuda_runtime.h>
__global__ void matrix_multiplication_kernel(const float* A, const float* B, float* C, int M, int N, int K)
{
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= M or col >= K) {
return;
}
//now take the ith row from A and jth col from B and do a dot product
float sum = 0.0f;
for (int k = 0; k < N; k++)
{
sum += A[row * N + k] * B[k * K + col];
}
C[row * K + col] = sum;
}
// A, B, C are device pointers (i.e. pointers to memory on the GPU)
extern "C" void solve(const float* A, const float* B, float* C, int M, int N, int K) {
dim3 threadsPerBlock(16, 16);
dim3 blocksPerGrid((K + threadsPerBlock.x - 1) / threadsPerBlock.x,
(M + threadsPerBlock.y - 1) / threadsPerBlock.y);
matrix_multiplication_kernel<<<blocksPerGrid, threadsPerBlock>>>(A, B, C, M, N, K);
cudaDeviceSynchronize();
}Note: Timeout when
C[row * K + col]is used as accumulator!
Runtime: 953.89 ms