LeetGPU-12: Simple Inference

LeetGPU
Notes and solutions in PyTorch and JAX. Runtime shown for T4 GPU.
Author

Md Saidul Hoque Anik

Published

December 2025

Note: This is part of a learning series on CUDA and Triton, focusing on correctness-first implementations rather than performance optimization.

Problem Statement

Run inference on a PyTorch model. Given an input tensor and a trained torch.nn.Linear model, compute the forward pass and store the result in the output tensor.

The model performs a linear transformation: output = input @ weight.T + bias where weight has shape [output_size, input_size] and bias has shape [output_size].

Implementation Requirements Use PyTorch’s built-in functions and operations The solve function signature must remain unchanged The final result must be stored in the output tensor The model is already loaded and ready for inference

Example

 Input:  input = [[1.0, 2.0]]  (batch_size=1, input_size=2)
 model: Linear layer with weight=[[0.5, 1.0], [1.5, 0.5]], bias=[0.1, 0.2]
 Output: output = [[2.6, 2.7]]  (batch_size=1, output_size=2)

Constraints

  • 1 ≤ batch_size ≤ 1,000

  • 1 ≤ input_size ≤ 1,000

  • 1 ≤ output_size ≤ 1,000

  • -10.0 ≤ input values ≤ 10.0

Solution

PyTorch

Note

Very straightforward. A few options:

  • Use set_ instead of copy_

  • Use model.forward()

Solution

import torch
import torch.nn as nn


# input, model, and output are on the GPU
def solve(input: torch.Tensor, model: nn.Module, output: torch.Tensor):
    output.copy_(model(input))

Runtime: 0.19ms

JAX

Note
  • Need to copy to numpy to create JAX array. Slower

  • TODO: use dlp to perform zero copy and gpu-cpu-gpu transfer

Solution

import jax
import jax.numpy as jnp
import numpy as np
import torch


# input and model are on the GPU
def solve(input: jax.Array, model) -> jax.Array:
    np_array = np.asarray(input)
    torch_tensor = torch.from_numpy(np_array).to(next(model.parameters()).device)
    output = model(torch_tensor)
    return jnp.asarray(output.detach().cpu().numpy())

Runtime: 3.04ms