With the introduction of PyTorch 2.0, TorchDynamo has been introduced to sort of replace TorchScript!
TorchDynamo dynamically traces through PyTorch code during runtime and creates a computational graph of intermediate representations (torch.fx). Then a compiler (such as Inductor) can take that extracted graph and modify it for optimization. However, there isn't a lot of documentation for how to do this, which is understandable because the average MLE isn't gonna write their own compiler to speed up their training, but I wanted to see how its done :D
(P.S Check out the Hidet compiler! My coworkers at CentML worked really hard on it and it's open-source)
How to use torch.compile?
First lets start with the basics, how torch.compile is used in general.
Here is a basic example of calling torch.compile on a pytorch model
class Model(torch.nn.module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x): return self.linear(x)
compiled_model = torch.compile(model, backend="inductor")
input_tensor = torch.randn(1, 10)
output = model(input_tensor)
# you can call the compiled model normally, torch.compile return a callable
output_compiled = compiled_model(input_tensor)
You should just be able to simply give torch.compile a callable (usually a nn.module, but a function can work too), and it should return a more optimized compiled callable for the model.
Here is the API for torch.compile
torch.compile(model: Callable[[_InputT], _RetT], *, fullgraph: bool = False, dynamic: Optional[bool] = None, backend: Union[str, Callable] = 'inductor', mode: Optional[str] = None, options: Optional[Dict[str, Union[str, int, bool]]] = None, disable: bool = False) ā Callable[[_InputT], _RetT]
Useful links:
- Dynamo Overview
- Dynamo deepdive
- ezyang's deep dive vid (This one is especially funny, he's fixing bugs while explaining Dynamo)