LeetGPU-2: Matrix Multiplication

LeetGPU
Notes and solutions in PyTorch, Triton, and CUDA. Runtime shown for T4 GPU.
Author

Md Saidul Hoque Anik

Published

December 2025

Note: This is part of a learning series on CUDA and Triton, focusing on correctness-first implementations rather than performance optimization.

Problem Statement

Write a program that multiplies two matrices of 32-bit floating point numbers on a GPU. Given matrix A of dimensions \(M \times N\) and matrix B of dimensions N x K, compute the product matrix C, which will have dimensions MxK. All matrices are stored in row-major format.

Constraints

  • 1 ≤ M, N, K ≤ 8192

  • Performance is measured with M = 8192, N = 6144, K = 4096

Solution

PyTorch

Note
  • The solution is straightforward. But can we do blocking?

Solution-1

import torch

# A, B, C are tensors on the GPU
def solve(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, M: int, N: int, K: int):
    torch.matmul(A, B, out=C)

Runtime: 97.48ms

TODO: block matmul

Triton

Note

Solution

Correct but slow. Will come back to this later in sha Allah!

import torch
import triton
import triton.language as tl


@triton.jit
def matrix_multiplication_kernel(
    a, b, c, M, N, K, stride_am, stride_an, stride_bn, stride_bk, stride_cm, stride_ck
):
    row = tl.program_id(0)
    col = tl.program_id(1)

    # BLOCK_SIZE: tl.constexpr = 16
    acc = 0.0
    for i in range(0, N):
        a_val = tl.load(a + row * N + i)
        b_val = tl.load(b + col + i * K)
        acc += a_val * b_val

    tl.store(c + K * row + col, acc)


# a, b, c are tensors on the GPU
def solve(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, M: int, N: int, K: int):
    stride_am, stride_an = N, 1
    stride_bn, stride_bk = K, 1
    stride_cm, stride_ck = K, 1

    grid = (M, K)
    matrix_multiplication_kernel[grid](
        a, b, c, M, N, K, stride_am, stride_an, stride_bn, stride_bk, stride_cm, stride_ck
    )

Runtime: 0.00ms

CUDA

Note

Solution

#include <cuda_runtime.h>

__global__ void matrix_multiplication_kernel(const float* A, const float* B, float* C, int M, int N, int K) 
{
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;

    if (row >= M or col >= K) {
        return;
    }

    //now take the ith row from A and jth col from B and do a dot product
    float sum = 0.0f;
    for (int k = 0; k < N; k++)
    {
        sum += A[row * N + k] * B[k * K + col];
    }
    C[row * K + col] = sum;

}

// A, B, C are device pointers (i.e. pointers to memory on the GPU)
extern "C" void solve(const float* A, const float* B, float* C, int M, int N, int K) {
    dim3 threadsPerBlock(16, 16);
    dim3 blocksPerGrid((K + threadsPerBlock.x - 1) / threadsPerBlock.x,
                       (M + threadsPerBlock.y - 1) / threadsPerBlock.y);

    matrix_multiplication_kernel<<<blocksPerGrid, threadsPerBlock>>>(A, B, C, M, N, K);
    cudaDeviceSynchronize();
}

Note: Timeout when C[row * K + col] is used as accumulator!

Runtime: 953.89 ms

Reference