How to build your own PyTorch Compiler

November 15, 2024

With the introduction of PyTorch 2.0, TorchDynamo has been introduced to sort of replace TorchScript!

TorchDynamo dynamically traces through PyTorch code during runtime and creates a computational graph of intermediate representations (torch.fx). Then a compiler (such as Inductor) can take that extracted graph and modify it for optimization. However, there isn't a lot of documentation for how to do this, which is understandable because the average MLE isn't gonna write their own compiler to speed up their training, but I wanted to see how its done :D

(P.S Check out the Hidet compiler! My coworkers at CentML worked really hard on it and it's open-source)

How to use torch.compile?

First lets start with the basics, how torch.compile is used in general.

Here is a basic example of calling torch.compile on a pytorch model

class Model(torch.nn.module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(10, 10)
    def forward(self, x): return self.linear(x)

compiled_model = torch.compile(model, backend="inductor")
input_tensor = torch.randn(1, 10)
output = model(input_tensor)
# you can call the compiled model normally, torch.compile return a callable
output_compiled = compiled_model(input_tensor)

You should just be able to simply give torch.compile a callable (usually a nn.module, but a function can work too), and it should return a more optimized compiled callable for the model.

Here is the API for torch.compile

torch.compile(model: Callable[[_InputT], _RetT], *, fullgraph: bool = False, dynamic: Optional[bool] = None, backend: Union[str, Callable] = 'inductor', mode: Optional[str] = None, options: Optional[Dict[str, Union[str, int, bool]]] = None, disable: bool = False) ā†’ Callable[[_InputT], _RetT]


Useful links: