LeetGPU-12: Simple Inference
Note: This is part of a learning series on CUDA and Triton, focusing on correctness-first implementations rather than performance optimization.
Problem Statement
Run inference on a PyTorch model. Given an input tensor and a trained torch.nn.Linear model, compute the forward pass and store the result in the output tensor.
The model performs a linear transformation: output = input @ weight.T + bias where weight has shape [output_size, input_size] and bias has shape [output_size].
Implementation Requirements Use PyTorch’s built-in functions and operations The solve function signature must remain unchanged The final result must be stored in the output tensor The model is already loaded and ready for inference
Example
Input: input = [[1.0, 2.0]] (batch_size=1, input_size=2)
model: Linear layer with weight=[[0.5, 1.0], [1.5, 0.5]], bias=[0.1, 0.2]
Output: output = [[2.6, 2.7]] (batch_size=1, output_size=2)
Constraints
1 ≤
batch_size≤ 1,0001 ≤
input_size≤ 1,0001 ≤
output_size≤ 1,000-10.0 ≤ input values ≤ 10.0
Solution
PyTorch
Very straightforward. A few options:
Use
set_instead ofcopy_Use
model.forward()
Solution
import torch
import torch.nn as nn
# input, model, and output are on the GPU
def solve(input: torch.Tensor, model: nn.Module, output: torch.Tensor):
output.copy_(model(input))Runtime: 0.19ms
JAX
Need to copy to numpy to create JAX array. Slower
TODO: use dlp to perform zero copy and gpu-cpu-gpu transfer
Solution
import jax
import jax.numpy as jnp
import numpy as np
import torch
# input and model are on the GPU
def solve(input: jax.Array, model) -> jax.Array:
np_array = np.asarray(input)
torch_tensor = torch.from_numpy(np_array).to(next(model.parameters()).device)
output = model(torch_tensor)
return jnp.asarray(output.detach().cpu().numpy())Runtime: 3.04ms