Skip to content

Commit

Permalink
Patched docs for torch_compile_tutorial (#2936)
Browse files Browse the repository at this point in the history
* Patched docs for torch_compile_tutorial

---------

Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
  • Loading branch information
2 people authored and c-p-i-o committed Sep 6, 2024
1 parent 5465f9b commit f45ddc2
Showing 1 changed file with 100 additions and 2 deletions.
102 changes: 100 additions & 2 deletions intermediate_source/torch_compile_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,21 @@ def foo(x, y):

######################################################################
# Alternatively, we can decorate the function.
t1 = torch.randn(10, 10)
t2 = torch.randn(10, 10)

@torch.compile
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10)))
print(opt_foo2(t1, t2))

######################################################################
# We can also optimize ``torch.nn.Module`` instances.

t = torch.randn(10, 100)

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -94,7 +98,101 @@ def forward(self, x):

mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(torch.randn(10, 100)))
print(opt_mod(t))

######################################################################
# torch.compile and Nested Calls
# ------------------------------
# Nested function calls within the decorated function will also be compiled.

def nested_function(x):
return torch.sin(x)

@torch.compile
def outer_function(x, y):
a = nested_function(x)
b = torch.cos(y)
return a + b

print(outer_function(t1, t2))

######################################################################
# In the same fashion, when compiling a module all sub-modules and methods
# within it, that are not in a skip list, are also compiled.

class OuterModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.inner_module = MyModule()
self.outer_lin = torch.nn.Linear(10, 2)

def forward(self, x):
x = self.inner_module(x)
return torch.nn.functional.relu(self.outer_lin(x))

outer_mod = OuterModule()
opt_outer_mod = torch.compile(outer_mod)
print(opt_outer_mod(t))

######################################################################
# We can also disable some functions from being compiled by using
# ``torch.compiler.disable``. Suppose you want to disable the tracing on just
# the ``complex_function`` function, but want to continue the tracing back in
# ``complex_conjugate``. In this case, you can use
# ``torch.compiler.disable(recursive=False)`` option. Otherwise, the default is
# ``recursive=True``.

def complex_conjugate(z):
return torch.conj(z)

@torch.compiler.disable(recursive=False)
def complex_function(real, imag):
# Assuming this function cause problems in the compilation
z = torch.complex(real, imag)
return complex_conjugate(z)

def outer_function():
real = torch.tensor([2, 3], dtype=torch.float32)
imag = torch.tensor([4, 5], dtype=torch.float32)
z = complex_function(real, imag)
return torch.abs(z)

# Try to compile the outer_function
try:
opt_outer_function = torch.compile(outer_function)
print(opt_outer_function())
except Exception as e:
print("Compilation of outer_function failed:", e)

######################################################################
# Best Practices and Recommendations
# ----------------------------------
#
# Behavior of ``torch.compile`` with Nested Modules and Function Calls
#
# When you use ``torch.compile``, the compiler will try to recursively compile
# every function call inside the target function or module inside the target
# function or module that is not in a skip list (such as built-ins, some functions in
# the torch.* namespace).
#
# **Best Practices:**
#
# 1. **Top-Level Compilation:** One approach is to compile at the highest level
# possible (i.e., when the top-level module is initialized/called) and
# selectively disable compilation when encountering excessive graph breaks or
# errors. If there are still many compile issues, compile individual
# subcomponents instead.
#
# 2. **Modular Testing:** Test individual functions and modules with ``torch.compile``
# before integrating them into larger models to isolate potential issues.
#
# 3. **Disable Compilation Selectively:** If certain functions or sub-modules
# cannot be handled by `torch.compile`, use the `torch.compiler.disable` context
# managers to recursively exclude them from compilation.
#
# 4. **Compile Leaf Functions First:** In complex models with multiple nested
# functions and modules, start by compiling the leaf functions or modules first.
# For more information see `TorchDynamo APIs for fine-grained tracing <https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html>`__.

######################################################################
# Demonstrating Speedups
Expand Down

0 comments on commit f45ddc2

Please sign in to comment.