Notes on Optimizing Torch Models

ML researchers are from Mars and the ML engineers responsible for deploying models are from Venus. The two have vastly different motivations. The ML researcher’s job, given a dataset and some compute, is to find the lowest possible loss on a task. In this pursuit, no engineering cost is too high. No tech debt is too large. Worse still, if they get published, they must include their code in their paper. Reproducible research only means that everyone be able to reproduce benchmarks and charts from the original paper. It usually has very little to do with production.

On the other hand, the world isn’t too kind to backend engineers whose job is to productionize models. If a paper claims to have developed a multimodal architecture that can identify a particular dog based on no more than its bark and a picture of its snout, then nobody would think twice before expecting an ML engineer to turn the paper into a mobile app. So now the engineer looks up the repository, hoping to find some structure in the code. But they’re first met with a convoluted shell script and an argument parser that only seems to show how to train the model alone. If they look harder, they may find the URLs to the checkpoints. If they’re particularly unfortunate, they may find modules upon modules of hacked up code which serves only to introduce custom layers in the architecture, necessitated only by the original authors not having read torch’s documentation. The whole repository gives off a patronizing vibe, as if it’s written exclusively for people who are novice programmers. In all likelihood, it’s written by them as well. For instance, the overuse of argparsers indicates that people are not comfortable with writing well-documented functions.


There’s a host of optimization options available within torch, referred to variously as compiling, tracing, scripting and so on. But it all comes down to tracing, in that if a torch graph can be successfully traced, then it’s already likely to be optimized for computation. The process of tracing the graph of a neural network is like extracting a deterministic function from it. And it is actually a function in the strict mathematical sense - different inputs may map to the same output, but the same input never maps to different outputs. The torch.compile documentation even says that a graph breaking is a lost optimization opportunity.

I’ve seen many such repositories over the years. I now have a recipe to deal with them. And now that I’m actually writing it down, it feels like there’s hardly any esoteric, domain-specific knowledge here. It’s just good software development practices. These are broadly the three steps involved in optimizing a model for production. They’re written for torch, but the ideas are applicable elsewhere too. Everything that needs to be done to make a graph traceable falls into one of these three categories.

1. Refactor Modules

All inference should be the result of a single instance of torch.nn.Module which accepts tensors and returns tensors. Move all pre- and post-processing outside this module. Often, a single entry-point module is hard to spot in a codebase because it’s buried under layers of helper functions and argparsers. You can probably tell that I absolutely abhor argparsers (especially in code that’s written for programmers by programmers). But I’ll admit that it’s quite fun to chip away at a codebase until a single torch.nn.Module is left standing in main.py, and the remaining modules are comfortably imported.

Bulk of this step is digging out the entry point module from the original codebase, and then further looking up its submodules in the project tree. Thankfully, LLMs are great at this kind of work.

2. Vectorize Loops

This might sound like a no-brainer, but it needs to be said. I feel kids these days jump straight to torch tensors by overstepping NumPy arrays. To be fair, people don’t write loops all that often, but they do unpack one tensor in order to vectors to manipulate individual rows and columns of a tensor. A common example is normalizing bounding boxes in object detection models. Let’s say you have a 2-D tensor with four columns in the $(x_min, x_max, y_min, y_max)$ format. You might want to normalize the matrix so that the coordinates are fractions of the height and width of the image. An innocent way to do that is as follows:

height, width = batch.shape[-2:]

bbox[:, 0] /= width
bbox[:, 1] /= width
bbox[:, 2] /= height
bbox[:, 3] /= height

# Or, better
bbox[:, :2] /= width
bbox[:, 2:] /= height

This isn’t wrong at all, but the tracer will warn you at the very first line that because unpacking a vector counts as an iteration. Worse, it creates standalone Python integers in the middle of the graph. So the correct, but perhaps less readable way of doing this is,

hw = torch.tensor(batch.shape[-2:]).flip(0)  # slicing is fine, unpacking is not!
bbox /= hw.repeat_interleave(2)

Now, the slight problem here is that we’re creating a tensor right in the middle of the forward pass, which the tracer doesn’t like. Unfortunately there are tons of operations where we need to depend on not just the values but the dimensions of the tensor. Thankfully there’s a private method torch._shape_as_tensor which does exactly what we need. And the tracer allows this. God only knows why this is a private method.

3. Bypass Conditional Flows

The tracer needs a single, deterministic path through the graph. Any if whose condition depends on tensor values or shapes is a potential graph break. The warning reads: “Converting a tensor to a Python boolean might cause the trace to be incorrect.” In practice, most of these conditionals often fall into one of three categories, and all three have the same solution: remove the if.

Guards that are never false. Consider a padding operation:

pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
    x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))

The author was trying to be efficient by skipping the pad when it’s unnecessary. But F.pad with zero padding is a no-op anyway, so the guard buys you nothing (except at training times) and costs you a graph break. The condition is redundant.

Dead booleans. These are runtime flags threaded through function signatures to control branches that are irrelevant during inference. In one model I worked on, a repeat_image boolean was passed through three levels of function calls to control this:

if repeat_image:
    src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
else:
    src = image_embeddings

The flag was always False during inference for a wide variety of test inputs. The fix was to delete the parameter, the branch, and every call site that passed it. This is also a good example of why step 1 matters — once you’ve isolated the inference module, you can be aggressive about ripping out training-only code paths.

Assertions. Statements like assert image_embeddings.shape[0] == tokens.shape[0] also convert tensors to Python booleans. They’re a Python-level concern, not a graph-level one. If they were actually failing, there’s a much bigger problem at hand, and tracing would be the least of your concerns. Remove them.

Conditions that are trivially true for real inputs. In another model, a branch checked whether the pooling kernel was larger than the image itself:

if (self.kernel_size >= hw).all(): 
    return F.adaptive_avg_pool2d(x, 1)

There’s no way I’m using an image that’s smaller than a pooling kernel. The condition existed only to handle a degenerate edge case that doesn’t occur in production. And if it does, you have bigger problems.

The common thread is that these conditionals are almost always vestigial. They’re left over from training loops, defensive coding habits, or edge cases that the original author encountered once during development. The tracer forces you to confront them, and in doing so, you usually find that the code is simpler without them.


Let’s not forget that ruthlessly optimizing a model has business implications. Ignoring edge-cases because they break the graph is a risky move, and should not be a unilateral decision. Unfortunately, the cost you have to pay for optimization is prioritizing the norms over the exceptions.

If you’re interested in more real examples, take a look at the git log in this repository. It reads like a debugging diary, and contains many examples of this process of eliminating tracer warnings.

comments powered by Disqus