Efficient AI Computing,
Transforming the Future.

Patch Conv: Patch Convolution to Avoid Large GPU Memory Usage of Conv2D

TL;DR

In this blog, we introduce Patch Conv to reduce memory footprint when generating high-resolution images. PatchConv significantly cuts down the memory usage by over 2.4× compared to existing PyTorch implementation. Code: https://github.com/mit-han-lab/patch_conv

Background

In current generative models, we usually apply convolution over large-size activations to generate high-resolution content. However, PyTorch tends to use excessive memory for these operations, potentially leading to memory shortages even on 80GB A100 GPUs.

As shown in the above figure, memory demands for standard PyTorch convolutions drastically increase when the input size reaches 1B parameters (channel×height×width). Notably, with a kernel size of 7×7, the 80GB A100 GPUs would trigger Out of Memory (OOM) errors. Inputs exceeding 2B parameters can further cause 3×3 convolutions exhaust all the memory and that’s just for one layer! This memory bottleneck prevents users and the community from scaling up the models to produce high-quality images.

Patch Conv

To bypass this issue and reduce memory consumption, we propose a simple and effective solution -- Patch Conv. As shown in the above figure, similar to SIGE, Patch Conv first divides the input into several smaller patches along the height dimension while keeping some overlap between them. These patches are then reorganized into the batch dimension and fed into the original convolution to produce output patches, which are then concatenated together to form the final output.

Performance

As shown in the above figure, we compare the memory usage of vanilla convolution and Patch Conv across various kernel sizes (3, 5, and 7) for inputs at 2048×2048 and 4096×4096 resolutions. Remarkably, the 3×3 convolution workloads we show are used in the SDXL decoder to generate a 4096×4096 image. Patch Conv significantly reduces memory usage by 2.4~4.4× while the computation results are numerically identical.

We also report the latency of vanilla convolution and Patch Conv. On NVIDIA A100, PatchConv is slightly slower than the standard vanilla convolution. The overhead mainly comes from additional memory operations of the patchifying. We are continuing to improve the performance to close the gap.

Usage

You can simply install Patch Conv with pip install patch_conv. You only need to add a single line of code to wrap the nn.Conv2d to our Patch Conv with convert_model. The basic usage is like:

Python Code Display with Syntax Highlighting

from patch_conv import convert_model

model = Model(...)  # Your PyTorch model
model = convert_model(model, splits=4)  # The only modification you need to make

with torch.no_grad():
    model(...)  # Run the model in the original way