diff --git a/thunder/__init__.py b/thunder/__init__.py index c0483f8c3a..16ca84c24d 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -60,6 +60,7 @@ import thunder.extend as extend from thunder.extend import Executor, add_default_executor import thunder.transforms as transforms +from thunder.dynamo.utils import log_trace_or_graphmodule_to_torch_trace, TORCH_COMPILE_COMPILE_ID_KEY # NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this import torch as pytorch @@ -349,7 +350,6 @@ def jit( debug_options: optional :class:`thunder.DebugOptions` instance. See the doc string of :class:`DebugOptions` for supported debug options. Default: ``None`` """ - if "executors_list" in compile_options: warnings.warn("outdated argument executors_list= in call, please use executors=") if executors is None: @@ -457,6 +457,15 @@ def acquire_initial_trace(fn, args, kwargs, cd, cs, ad_hoc_executor): last_interpreter_log = jit_results.interpreter_log cs.last_interpreter_log = last_interpreter_log cs.last_interpreted_instructions = (i for i in last_interpreter_log if isinstance(i, dis.Instruction)) + + for name, trace in ( + ("prologue", prologue_trc), + ("computation", computation_trc), + ("epilogue", epilogue_trc), + ): + log_trace_or_graphmodule_to_torch_trace( + name=name, m=trace, compile_id=compile_options.get(TORCH_COMPILE_COMPILE_ID_KEY, None) + ) return prologue_trc, computation_trc, epilogue_trc def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, computation_trc, epilogue_trc): @@ -521,10 +530,22 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com new_computation_trc, new_epilogue_trc, ) + transform_name = type(transform).__name__ + list_of_name_and_trace = [ + (f"prologue_after_{transform_name}", prologue_trc), + (f"computation_after_{transform_name}", computation_trc), + ] + prologue_traces.append(prologue_trc) computation_traces.append(computation_trc) if epilogue_trc is not None: epilogue_traces.append(epilogue_trc) + list_of_name_and_trace.append((f"epilogue_after_{transform_name}", epilogue_trc)) + + for name, trc in list_of_name_and_trace: + log_trace_or_graphmodule_to_torch_trace( + name=name, m=trc, compile_id=compile_options.get(TORCH_COMPILE_COMPILE_ID_KEY, None) + ) prologue_traces += transform_for_execution( prologue_trc, @@ -583,6 +604,14 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com computation_trc, backward_trc = split_into_forward_and_backward(computation_trc) + for name, trc in ( + ("computation_after_fwd_bwd_split", computation_trc), + ("initial_backward", backward_trc), + ): + log_trace_or_graphmodule_to_torch_trace( + name=name, m=trc, compile_id=compile_options.get(TORCH_COMPILE_COMPILE_ID_KEY, None) + ) + computation_trc = thunder.executors.passes.del_last_used(computation_trc) computation_traces.append(computation_trc) if backward_trc is not None: @@ -594,6 +623,7 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com computation_traces.append(computation_trc) for transform in transforms: + transform_name = type(transform).__name__ # NOTE: `backward_trc` could be None. new_computation_trc = transform.transform_trace_post_optimization( computation_trc, executors_list=cd.executors_list @@ -601,6 +631,11 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com if new_computation_trc is not computation_trc: computation_trc = new_computation_trc computation_traces.append(computation_trc) + log_trace_or_graphmodule_to_torch_trace( + name=f"computation_after_{transform_name}", + m=computation_trc, + compile_id=compile_options.get(TORCH_COMPILE_COMPILE_ID_KEY, None), + ) if backward_trc is not None: new_backward_trc = transform.transform_trace_post_optimization( backward_trc, executors_list=cd.executors_list @@ -608,9 +643,19 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com if new_backward_trc is not backward_trc: backward_trc = new_backward_trc backward_traces.append(backward_trc) + log_trace_or_graphmodule_to_torch_trace( + name=f"backward_after_{transform_name}", + m=backward_trc, + compile_id=compile_options.get(TORCH_COMPILE_COMPILE_ID_KEY, None), + ) if backward_trc is not None: backward_fn = backward_trc.python_callable() + log_trace_or_graphmodule_to_torch_trace( + name="ex_backward", + m=backward_trc, + compile_id=compile_options.get(TORCH_COMPILE_COMPILE_ID_KEY, None), + ) else: backward_fn = None # We do not have to return auxiliary tensors, which will only be useful in backward pass @@ -619,6 +664,11 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com computation_trc = transform_to_torch_types(computation_trc) comp = computation_trc.python_callable() + log_trace_or_graphmodule_to_torch_trace( + name="ex_computation", + m=computation_trc, + compile_id=compile_options.get(TORCH_COMPILE_COMPILE_ID_KEY, None), + ) # TODO RC1 Update the cache cache_entry = CacheEntry( @@ -786,6 +836,19 @@ def get_computation_and_inputs(*args, **kwargs): with compile_data_and_stats(cd, cs): inps, pro_to_epi = cache_entry.prologue_fn(*args, **kwargs) + log_trace_or_graphmodule_to_torch_trace( + name="compile_data", + payload_fn=lambda compile_data=cd: { + "cache_option": str(compile_data.cache_option), + "sharp_edges": str(compile_data.sharp_edges), + "disable_torch_autograd_support": compile_data.disable_torch_autograd_support, + "executors_list": [str(ex) for ex in compile_data.executors_list], + "compile_options": {k: str(v) for k, v in compile_options.items()}, + }, + encoding="json", + compile_id=compile_options.get(TORCH_COMPILE_COMPILE_ID_KEY, None), + ) + return cache_entry, inps, pro_to_epi def host_execution_timer(fn): diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index fa60fbabbc..61fea5de2b 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -24,6 +24,7 @@ ) from thunder.dynamo.splitter import _splitter from thunder.dynamo.benchmark_utils import ThunderCompileSpecification +from thunder.dynamo.utils import log_trace_or_graphmodule_to_torch_trace, TORCH_COMPILE_COMPILE_ID_KEY from thunder.transforms.extraction_only_prologue_transform import ExtractionOnlyPrologueTransform if TYPE_CHECKING: @@ -144,12 +145,29 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor current_thunder_options=self.thunder_options, is_torch_compile_without_dynamic=is_in_torch_compile() and (not is_dynamic_inputs(sample_args)), ) + torch_compile_compile_id = torch._guards.CompileContext.current_compile_id() + thunder_options = {**thunder_options, TORCH_COMPILE_COMPILE_ID_KEY: torch_compile_compile_id} split_module, subgraph_info = _splitter( gm, partial(jit, **thunder_options), thunder_options, **compile_options, ) + log_trace_or_graphmodule_to_torch_trace( + name="graphmodule_after_splitter", + m=split_module, + compile_id=torch_compile_compile_id, + ) + if subgraph_info.split_reasons: + log_trace_or_graphmodule_to_torch_trace( + name="graph_split_reasons", + payload_fn=lambda: [ + {"reason_type": reason.reason_type.name, "info": reason.info, "exception": reason.exception} + for reason in subgraph_info.split_reasons + ], + encoding="json", + compile_id=torch_compile_compile_id, + ) self.subgraph_infos.append(subgraph_info) return split_module diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index c5d784a783..1febf2a968 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -23,6 +23,7 @@ _get_example_inputs_from_placeholder, _ThunderSplitGraphModule, translate_dtensor_ops, + TORCH_COMPILE_COMPILE_ID_KEY, ) if TYPE_CHECKING: @@ -224,9 +225,11 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: graph_module = getattr(split_gm, node.name) fake_mode = torch._guards.detect_fake_mode() + # Extract compile_id from thunder_options to propagate it to inductor + compile_id = thunder_options.get(TORCH_COMPILE_COMPILE_ID_KEY) if thunder_options else None # Delay Inductor compilation until invocation with real tensors, # because we do not know the strides of tensors that Thunder-compiled submodules return. - jit_fn = LazyInductorModule(graph_module, fake_mode, **compile_options) + jit_fn = LazyInductorModule(graph_module, fake_mode, compile_id, **compile_options) # Update the node name from "submod_*" to "inductor_*" for more user-friendly names update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn) diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index e0d0610fcc..a8db974886 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -2,7 +2,7 @@ from collections.abc import Callable, Sequence from contextlib import contextmanager from enum import Enum, auto -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, overload import dataclasses import inspect import itertools @@ -15,10 +15,12 @@ from looseversion import LooseVersion import torch +from torch.fx.graph_module import GraphModule from torch.nn.modules.module import _addindent from torch.utils.weak import TensorWeakRef -from torch._guards import tracing, TracingContext +from torch._guards import tracing, TracingContext, compile_context, CompileContext from torch._subclasses.fake_tensor import DynamicOutputShapeException +from torch._logging._internal import trace_structured_artifact from torch._inductor import list_mode_options @@ -38,10 +40,20 @@ from thunder.core.symbol import Symbol from typing import Any from collections.abc import Sequence + from torch._guards import CompileId + from thunder.core.trace import TraceCtx auto_register_ops = set(itertools.chain(*torch_auto_registered_ops.values())) +if LooseVersion(torch.__version__) >= LooseVersion("2.9.0"): + wrapped_trace_structured_artifact = trace_structured_artifact +else: + + def wrapped_trace_structured_artifact(*args, compile_id, **kwargs): + return trace_structured_artifact(*args, **kwargs) + + # Currently, thunder has mapping for these torch function but they # just raise a warning (or don't support the exact behaviour) # Previously used for `torch._C._set_grad_enabled` when it just raised a warning. @@ -162,12 +174,13 @@ def is_thunder_supported_partition(self, node: torch.fx.Node) -> bool: class LazyInductorModule(torch.nn.Module): - def __init__(self, graph_module, fake_mode, **compile_options): + def __init__(self, graph_module, fake_mode, compile_id=None, **compile_options): super().__init__() self.graph_module = graph_module self.compiled_fn = None self.fake_mode = fake_mode self.compile_options = compile_options + self.compile_id = compile_id # For ease of debugging, we add graph attribute so GraphModule.print_readable will print it self.graph = graph_module.graph @@ -199,27 +212,30 @@ def fake_increment_toplevel(*args, **kwargs): def forward(self, *args): if self.compiled_fn is None: with self._maybe_patch_increment_toplevel(): - # Inductor needs fake_mode, particularly its shape_env, to handle SymInts - with tracing(TracingContext(fake_mode=self.fake_mode)): - try: - original_graph = copy.deepcopy(self.graph_module.graph) - # Extract and merge options from compile_options - options = self.compile_options.get("options", {}).copy() - mode = self.compile_options.get("mode") - if mode: - mode_options = list_mode_options().get(mode, {}) - options.update(mode_options) - - self.compiled_fn = torch._inductor.compile(self.graph_module, args, options=options) - except DynamicOutputShapeException as e: - # This exception is meant to be handled by Dynamo, which is responsible for graph break - # TODO: Use torch.compile for fallback. Ensure its correctness. - warnings.warn(f"Dynamic output shape operator encountered: {e}. Falling back to eager.") - # NOTE: torch._inductor.compile alters the output to always be a tuple. - # Restore original single-element return, if needed. - self.graph_module.graph = original_graph - self.graph_module.recompile() - self.compiled_fn = self.graph_module + # Restore the compile context so that torch._inductor.compile can log traces + # with the correct compile_id + with compile_context(CompileContext(self.compile_id) if self.compile_id is not None else None): + # Inductor needs fake_mode, particularly its shape_env, to handle SymInts + with tracing(TracingContext(fake_mode=self.fake_mode)): + try: + original_graph = copy.deepcopy(self.graph_module.graph) + # Extract and merge options from compile_options + options = self.compile_options.get("options", {}).copy() + mode = self.compile_options.get("mode") + if mode: + mode_options = list_mode_options().get(mode, {}) + options.update(mode_options) + + self.compiled_fn = torch._inductor.compile(self.graph_module, args, options=options) + except DynamicOutputShapeException as e: + # This exception is meant to be handled by Dynamo, which is responsible for graph break + # TODO: Use torch.compile for fallback. Ensure its correctness. + warnings.warn(f"Dynamic output shape operator encountered: {e}. Falling back to eager.") + # NOTE: torch._inductor.compile alters the output to always be a tuple. + # Restore original single-element return, if needed. + self.graph_module.graph = original_graph + self.graph_module.recompile() + self.compiled_fn = self.graph_module return self.compiled_fn(*args) @@ -1222,3 +1238,55 @@ def dtensor_to_local_prim_wrapper(x): node.target = dtensor_to_local_prim_wrapper except Exception: pass + + +@overload +def log_trace_or_graphmodule_to_torch_trace( + *, + name: str, + compile_id: CompileId | None = None, + payload_fn: Callable[[], str], + encoding: str, +) -> None: ... + + +@overload +def log_trace_or_graphmodule_to_torch_trace( + *, + name: str, + m: TraceCtx | GraphModule, + compile_id: CompileId | None = None, +) -> None: ... + + +def log_trace_or_graphmodule_to_torch_trace( + *, + name: str, + m: TraceCtx | GraphModule | None = None, + compile_id: CompileId | None = None, + payload_fn: Callable[[], str] | None = None, + encoding: str | None = None, +) -> None: + if payload_fn is None: + if isinstance(m, GraphModule): + payload_fn = lambda graph_module=m: graph_module.print_readable( + print_output=False, + include_stride=True, + include_device=True, + ) + else: + # includes m being TraceCtx + payload_fn = lambda mod=m: f"{mod}\n" + encoding = "string" + else: + check(encoding is not None, lambda: "`encoding` needs to be set") + + wrapped_trace_structured_artifact( + name=name, + encoding=encoding, + payload_fn=payload_fn, + compile_id=compile_id, + ) + + +TORCH_COMPILE_COMPILE_ID_KEY: str = "_torch_compile_compile_id_key" diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 90b750e712..1f5bb212ba 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -17,6 +17,7 @@ from hypothesis import HealthCheck import copy from functools import partial +import tempfile import thunder from thunder import dtypes @@ -80,6 +81,7 @@ def run_script(file_name, cmd): dtypes=NOTHING, executors=[DynamoThunderExecutor], decorators=( + pytest.mark.parametrize("run_with_torch_trace", (False, True), ids=("no_torch_trace", "with_torch_trace")), pytest.mark.parametrize("dynamic", (True, False, None), ids=("dynamic", "static", "auto")), pytest.mark.skipif( condition=IS_WINDOWS, @@ -87,32 +89,48 @@ def run_script(file_name, cmd): ), ), ) -def test_basic(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): - x = torch.ones(2, dtype=dtype, device=device, requires_grad=True) - - def func(x): - x = torch.sin(x) - if x.sum() > 0: - return x + 1 - else: - return x - 1 - - compiled = thunderfx(func, dynamic=dynamic) - out = compiled(x) - - # out should have grad_fn and its name should be ThunderFunctionBackward - assert out.grad_fn is not None - assert out.grad_fn.name() == "ThunderFunctionBackward" - - # We record the GraphModules that was compiled by ThunderCompiler - backend = compiled._backend - assert len(backend.subgraph_infos) == 2 # 2 due to data-dependent flow - - for subgraph_info in backend.subgraph_infos: - assert isinstance(subgraph_info.original_graph_module, torch.fx.GraphModule) - assert len(subgraph_info.thunder_compiled_fns) # There was atleast one function compiled with thunder. - for thunder_fn in subgraph_info.thunder_compiled_fns: - assert last_traces(thunder_fn) # Verify that we can fetch last_traces +def test_basic(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None, run_with_torch_trace: bool): + # Set up TORCH_TRACE environment variable if needed using mock.patch + if run_with_torch_trace: + tmp_path = tempfile.mkdtemp() + env_patch = {"TORCH_TRACE": tmp_path} + else: + tmp_path = None + env_patch = {} + + try: + with patch.dict(os.environ, env_patch): + x = torch.ones(2, dtype=dtype, device=device, requires_grad=True) + + def func(x): + x = torch.sin(x) + if x.sum() > 0: + return x + 1 + else: + return x - 1 + + compiled = thunderfx(func, dynamic=dynamic) + out = compiled(x) + + # out should have grad_fn and its name should be ThunderFunctionBackward + assert out.grad_fn is not None + assert out.grad_fn.name() == "ThunderFunctionBackward" + + # We record the GraphModules that was compiled by ThunderCompiler + backend = compiled._backend + assert len(backend.subgraph_infos) == 2 # 2 due to data-dependent flow + + for subgraph_info in backend.subgraph_infos: + assert isinstance(subgraph_info.original_graph_module, torch.fx.GraphModule) + assert len(subgraph_info.thunder_compiled_fns) # There was at least one function compiled with thunder. + for thunder_fn in subgraph_info.thunder_compiled_fns: + assert last_traces(thunder_fn) # Verify that we can fetch last_traces + finally: + # Clean up temporary directory + if tmp_path is not None: + import shutil + + shutil.rmtree(tmp_path, ignore_errors=True) @instantiate(