Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import thunder.extend as extend
from thunder.extend import Executor, add_default_executor
import thunder.transforms as transforms
from thunder.dynamo.utils import log_trace_or_graphmodule_to_torch_trace, TORCH_COMPILE_COMPILE_ID_KEY

# NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this
import torch as pytorch
Expand Down Expand Up @@ -349,7 +350,6 @@ def jit(

debug_options: optional :class:`thunder.DebugOptions` instance. See the doc string of :class:`DebugOptions` for supported debug options. Default: ``None``
"""

if "executors_list" in compile_options:
warnings.warn("outdated argument executors_list= in call, please use executors=")
if executors is None:
Expand Down Expand Up @@ -457,6 +457,15 @@ def acquire_initial_trace(fn, args, kwargs, cd, cs, ad_hoc_executor):
last_interpreter_log = jit_results.interpreter_log
cs.last_interpreter_log = last_interpreter_log
cs.last_interpreted_instructions = (i for i in last_interpreter_log if isinstance(i, dis.Instruction))

for name, trace in (
("prologue", prologue_trc),
("computation", computation_trc),
("epilogue", epilogue_trc),
):
log_trace_or_graphmodule_to_torch_trace(
name=name, m=trace, compile_id=compile_options.get(TORCH_COMPILE_COMPILE_ID_KEY, None)
)
return prologue_trc, computation_trc, epilogue_trc

def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, computation_trc, epilogue_trc):
Expand Down Expand Up @@ -521,10 +530,22 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com
new_computation_trc,
new_epilogue_trc,
)
transform_name = type(transform).__name__
list_of_name_and_trace = [
(f"prologue_after_{transform_name}", prologue_trc),
(f"computation_after_{transform_name}", computation_trc),
]

prologue_traces.append(prologue_trc)
computation_traces.append(computation_trc)
if epilogue_trc is not None:
epilogue_traces.append(epilogue_trc)
list_of_name_and_trace.append((f"epilogue_after_{transform_name}", epilogue_trc))

for name, trc in list_of_name_and_trace:
log_trace_or_graphmodule_to_torch_trace(
name=name, m=trc, compile_id=compile_options.get(TORCH_COMPILE_COMPILE_ID_KEY, None)
)

prologue_traces += transform_for_execution(
prologue_trc,
Expand Down Expand Up @@ -583,6 +604,14 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com

computation_trc, backward_trc = split_into_forward_and_backward(computation_trc)

for name, trc in (
("computation_after_fwd_bwd_split", computation_trc),
("initial_backward", backward_trc),
):
log_trace_or_graphmodule_to_torch_trace(
name=name, m=trc, compile_id=compile_options.get(TORCH_COMPILE_COMPILE_ID_KEY, None)
)

computation_trc = thunder.executors.passes.del_last_used(computation_trc)
computation_traces.append(computation_trc)
if backward_trc is not None:
Expand All @@ -594,23 +623,39 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com
computation_traces.append(computation_trc)

for transform in transforms:
transform_name = type(transform).__name__
# NOTE: `backward_trc` could be None.
new_computation_trc = transform.transform_trace_post_optimization(
computation_trc, executors_list=cd.executors_list
)
if new_computation_trc is not computation_trc:
computation_trc = new_computation_trc
computation_traces.append(computation_trc)
log_trace_or_graphmodule_to_torch_trace(
name=f"computation_after_{transform_name}",
m=computation_trc,
compile_id=compile_options.get(TORCH_COMPILE_COMPILE_ID_KEY, None),
)
if backward_trc is not None:
new_backward_trc = transform.transform_trace_post_optimization(
backward_trc, executors_list=cd.executors_list
)
if new_backward_trc is not backward_trc:
backward_trc = new_backward_trc
backward_traces.append(backward_trc)
log_trace_or_graphmodule_to_torch_trace(
name=f"backward_after_{transform_name}",
m=backward_trc,
compile_id=compile_options.get(TORCH_COMPILE_COMPILE_ID_KEY, None),
)

if backward_trc is not None:
backward_fn = backward_trc.python_callable()
log_trace_or_graphmodule_to_torch_trace(
name="ex_backward",
m=backward_trc,
compile_id=compile_options.get(TORCH_COMPILE_COMPILE_ID_KEY, None),
)
else:
backward_fn = None
# We do not have to return auxiliary tensors, which will only be useful in backward pass
Expand All @@ -619,6 +664,11 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com

computation_trc = transform_to_torch_types(computation_trc)
comp = computation_trc.python_callable()
log_trace_or_graphmodule_to_torch_trace(
name="ex_computation",
m=computation_trc,
compile_id=compile_options.get(TORCH_COMPILE_COMPILE_ID_KEY, None),
)

# TODO RC1 Update the cache
cache_entry = CacheEntry(
Expand Down Expand Up @@ -786,6 +836,19 @@ def get_computation_and_inputs(*args, **kwargs):
with compile_data_and_stats(cd, cs):
inps, pro_to_epi = cache_entry.prologue_fn(*args, **kwargs)

log_trace_or_graphmodule_to_torch_trace(
name="compile_data",
payload_fn=lambda compile_data=cd: {
"cache_option": str(compile_data.cache_option),
"sharp_edges": str(compile_data.sharp_edges),
"disable_torch_autograd_support": compile_data.disable_torch_autograd_support,
"executors_list": [str(ex) for ex in compile_data.executors_list],
"compile_options": {k: str(v) for k, v in compile_options.items()},
},
encoding="json",
compile_id=compile_options.get(TORCH_COMPILE_COMPILE_ID_KEY, None),
)

return cache_entry, inps, pro_to_epi

def host_execution_timer(fn):
Expand Down
18 changes: 18 additions & 0 deletions thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from thunder.dynamo.splitter import _splitter
from thunder.dynamo.benchmark_utils import ThunderCompileSpecification
from thunder.dynamo.utils import log_trace_or_graphmodule_to_torch_trace, TORCH_COMPILE_COMPILE_ID_KEY
from thunder.transforms.extraction_only_prologue_transform import ExtractionOnlyPrologueTransform

if TYPE_CHECKING:
Expand Down Expand Up @@ -144,12 +145,29 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor
current_thunder_options=self.thunder_options,
is_torch_compile_without_dynamic=is_in_torch_compile() and (not is_dynamic_inputs(sample_args)),
)
torch_compile_compile_id = torch._guards.CompileContext.current_compile_id()
thunder_options = {**thunder_options, TORCH_COMPILE_COMPILE_ID_KEY: torch_compile_compile_id}
Comment thread
crcrpar marked this conversation as resolved.
split_module, subgraph_info = _splitter(
gm,
partial(jit, **thunder_options),
thunder_options,
**compile_options,
)
log_trace_or_graphmodule_to_torch_trace(
name="graphmodule_after_splitter",
m=split_module,
compile_id=torch_compile_compile_id,
)
if subgraph_info.split_reasons:
log_trace_or_graphmodule_to_torch_trace(
name="graph_split_reasons",
payload_fn=lambda: [
{"reason_type": reason.reason_type.name, "info": reason.info, "exception": reason.exception}
for reason in subgraph_info.split_reasons
],
encoding="json",
Comment thread
crcrpar marked this conversation as resolved.
compile_id=torch_compile_compile_id,
)
self.subgraph_infos.append(subgraph_info)
return split_module

Expand Down
5 changes: 4 additions & 1 deletion thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_get_example_inputs_from_placeholder,
_ThunderSplitGraphModule,
translate_dtensor_ops,
TORCH_COMPILE_COMPILE_ID_KEY,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -224,9 +225,11 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
graph_module = getattr(split_gm, node.name)

fake_mode = torch._guards.detect_fake_mode()
# Extract compile_id from thunder_options to propagate it to inductor
compile_id = thunder_options.get(TORCH_COMPILE_COMPILE_ID_KEY) if thunder_options else None
# Delay Inductor compilation until invocation with real tensors,
# because we do not know the strides of tensors that Thunder-compiled submodules return.
jit_fn = LazyInductorModule(graph_module, fake_mode, **compile_options)
jit_fn = LazyInductorModule(graph_module, fake_mode, compile_id, **compile_options)

# Update the node name from "submod_*" to "inductor_*" for more user-friendly names
update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn)
Expand Down
116 changes: 92 additions & 24 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Callable, Sequence
from contextlib import contextmanager
from enum import Enum, auto
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, overload
import dataclasses
import inspect
import itertools
Expand All @@ -15,10 +15,12 @@
from looseversion import LooseVersion

import torch
from torch.fx.graph_module import GraphModule
from torch.nn.modules.module import _addindent
from torch.utils.weak import TensorWeakRef
from torch._guards import tracing, TracingContext
from torch._guards import tracing, TracingContext, compile_context, CompileContext
from torch._subclasses.fake_tensor import DynamicOutputShapeException
from torch._logging._internal import trace_structured_artifact

from torch._inductor import list_mode_options

Expand All @@ -38,10 +40,20 @@
from thunder.core.symbol import Symbol
from typing import Any
from collections.abc import Sequence
from torch._guards import CompileId
from thunder.core.trace import TraceCtx

auto_register_ops = set(itertools.chain(*torch_auto_registered_ops.values()))


if LooseVersion(torch.__version__) >= LooseVersion("2.9.0"):
wrapped_trace_structured_artifact = trace_structured_artifact
else:

def wrapped_trace_structured_artifact(*args, compile_id, **kwargs):
return trace_structured_artifact(*args, **kwargs)
Comment thread
shino16 marked this conversation as resolved.


# Currently, thunder has mapping for these torch function but they
# just raise a warning (or don't support the exact behaviour)
# Previously used for `torch._C._set_grad_enabled` when it just raised a warning.
Expand Down Expand Up @@ -162,12 +174,13 @@ def is_thunder_supported_partition(self, node: torch.fx.Node) -> bool:


class LazyInductorModule(torch.nn.Module):
def __init__(self, graph_module, fake_mode, **compile_options):
def __init__(self, graph_module, fake_mode, compile_id=None, **compile_options):
super().__init__()
self.graph_module = graph_module
self.compiled_fn = None
self.fake_mode = fake_mode
self.compile_options = compile_options
self.compile_id = compile_id

# For ease of debugging, we add graph attribute so GraphModule.print_readable will print it
self.graph = graph_module.graph
Expand Down Expand Up @@ -199,27 +212,30 @@ def fake_increment_toplevel(*args, **kwargs):
def forward(self, *args):
if self.compiled_fn is None:
with self._maybe_patch_increment_toplevel():
# Inductor needs fake_mode, particularly its shape_env, to handle SymInts
with tracing(TracingContext(fake_mode=self.fake_mode)):
try:
original_graph = copy.deepcopy(self.graph_module.graph)
# Extract and merge options from compile_options
options = self.compile_options.get("options", {}).copy()
mode = self.compile_options.get("mode")
if mode:
mode_options = list_mode_options().get(mode, {})
options.update(mode_options)

self.compiled_fn = torch._inductor.compile(self.graph_module, args, options=options)
except DynamicOutputShapeException as e:
# This exception is meant to be handled by Dynamo, which is responsible for graph break
# TODO: Use torch.compile for fallback. Ensure its correctness.
warnings.warn(f"Dynamic output shape operator encountered: {e}. Falling back to eager.")
# NOTE: torch._inductor.compile alters the output to always be a tuple.
# Restore original single-element return, if needed.
self.graph_module.graph = original_graph
self.graph_module.recompile()
self.compiled_fn = self.graph_module
# Restore the compile context so that torch._inductor.compile can log traces
# with the correct compile_id
with compile_context(CompileContext(self.compile_id) if self.compile_id is not None else None):
# Inductor needs fake_mode, particularly its shape_env, to handle SymInts
with tracing(TracingContext(fake_mode=self.fake_mode)):
try:
original_graph = copy.deepcopy(self.graph_module.graph)
# Extract and merge options from compile_options
options = self.compile_options.get("options", {}).copy()
mode = self.compile_options.get("mode")
if mode:
mode_options = list_mode_options().get(mode, {})
options.update(mode_options)

self.compiled_fn = torch._inductor.compile(self.graph_module, args, options=options)
except DynamicOutputShapeException as e:
# This exception is meant to be handled by Dynamo, which is responsible for graph break
# TODO: Use torch.compile for fallback. Ensure its correctness.
warnings.warn(f"Dynamic output shape operator encountered: {e}. Falling back to eager.")
# NOTE: torch._inductor.compile alters the output to always be a tuple.
# Restore original single-element return, if needed.
self.graph_module.graph = original_graph
self.graph_module.recompile()
self.compiled_fn = self.graph_module

return self.compiled_fn(*args)

Expand Down Expand Up @@ -1222,3 +1238,55 @@ def dtensor_to_local_prim_wrapper(x):
node.target = dtensor_to_local_prim_wrapper
except Exception:
pass


@overload
def log_trace_or_graphmodule_to_torch_trace(
*,
name: str,
compile_id: CompileId | None = None,
payload_fn: Callable[[], str],
Comment thread
crcrpar marked this conversation as resolved.
encoding: str,
) -> None: ...


@overload
def log_trace_or_graphmodule_to_torch_trace(
*,
name: str,
m: TraceCtx | GraphModule,
compile_id: CompileId | None = None,
) -> None: ...


def log_trace_or_graphmodule_to_torch_trace(
*,
name: str,
m: TraceCtx | GraphModule | None = None,
compile_id: CompileId | None = None,
payload_fn: Callable[[], str] | None = None,
Comment thread
crcrpar marked this conversation as resolved.
encoding: str | None = None,
) -> None:
if payload_fn is None:
if isinstance(m, GraphModule):
payload_fn = lambda graph_module=m: graph_module.print_readable(
print_output=False,
include_stride=True,
include_device=True,
)
else:
# includes m being TraceCtx
payload_fn = lambda mod=m: f"{mod}\n"
encoding = "string"
else:
check(encoding is not None, lambda: "`encoding` needs to be set")

wrapped_trace_structured_artifact(
name=name,
encoding=encoding,
payload_fn=payload_fn,
compile_id=compile_id,
)


TORCH_COMPILE_COMPILE_ID_KEY: str = "_torch_compile_compile_id_key"
Loading
Loading