diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 25fc08b775..99a99a27ab 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1160,6 +1160,9 @@ def replace(self, **changes): def type_string(self): value_str = f"{self.value}" if self.value is not None else "?" type_str = "int" if self.python_type is int else "bool" + if not self.is_static_constrained(): + # For non-static values, only show the type + return f"symbolic {type_str}" return f"{type_str} {value_str}" def __repr__(self): @@ -1209,6 +1212,9 @@ def replace(self, **changes): def type_string(self): value_str = f"{self.value}" if self.value is not None else "?" + if not self.is_static_constrained(): + # For non-static values, only show the type + return "symbolic float" return f"float {value_str}" def __repr__(self): diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 3db72a1ec3..5f78d2c491 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -1630,6 +1630,44 @@ def test_cache_symbolic_values_nn_parameter_static_shape(): assert isinstance(bsym.output.shape[1], thunder.core.proxies.IntegerProxy) +def test_cache_symbolic_values_int_float_inputs(): + def foo(a, b): + return a + b + + jfoo = thunder_jit(foo, cache="symbolic values") + + a = 1 + b = 2.0 + actual = jfoo(a, b) + expected = foo(a, b) + + assert_close(actual, expected) + assert thunder.cache_misses(jfoo) == 1 + assert thunder.cache_hits(jfoo) == 0 + + a = 2 + b = 3.0 + actual = jfoo(a, b) + expected = foo(a, b) + + assert_close(actual, expected) + assert thunder.cache_misses(jfoo) == 1 + assert thunder.cache_hits(jfoo) == 1 + + trc = thunder.last_traces(jfoo)[-1] + for bsym in trc.bound_symbols: + if bsym.sym.name == prims.PrimIDs.UNPACK_TRIVIAL: + assert isinstance(bsym.output, (IntegerProxy, FloatProxy)) + + trc_str = str(trc) + # Verify that symbolic inputs are not baked in as constants in the trace string + assert 'a: "int 1"' not in trc_str + assert 'b: "float 2.0"' not in trc_str + + assert 'a: "symbolic int"' in trc_str + assert 'b: "symbolic float"' in trc_str + + def test_specific_dataclass_returns(): import transformers