diff --git a/docs/running_programmatically.rst b/docs/running_programmatically.rst index 95248f522..abfb63919 100644 --- a/docs/running_programmatically.rst +++ b/docs/running_programmatically.rst @@ -46,17 +46,25 @@ plot it directly without saving to a file first using plot_config = torax.import_module('plotting/configs/default_plot_config.py')['PLOT_CONFIG'] # Plot directly from the in-memory data_tree returned by run_simulation. - fig = torax.plot_run_from_data_tree(plot_config, data_tree) + fig = torax.plot_run_from_data_tree(plot_config, {"TORAX": data_tree}) -To compare two in-memory runs: +To compare multiple in-memory runs, pass a dictionary mapping labels to +DataTrees. The labels appear in the plot legends: .. code-block:: python - fig = torax.plot_run_from_data_tree(plot_config, data_tree, data_tree2) + fig = torax.plot_run_from_data_tree( + plot_config, + { + "TORAX": torax_data_tree, + "JINTRAC": jintrac_data_tree, + "EXPERIMENT": experimental_data_tree, + }, + ) If you have saved the output to a ``.nc`` file and want to plot from disk, use ``torax.plot_run`` instead: .. code-block:: python - fig = torax.plot_run(plot_config, PATH_TO_LOCAL_NC_FILE) + fig = torax.plot_run(plot_config, {"TORAX": PATH_TO_LOCAL_NC_FILE}) diff --git a/torax/_src/plotting/plotruns_lib.py b/torax/_src/plotting/plotruns_lib.py index 98b3e994a..fa7c56aae 100644 --- a/torax/_src/plotting/plotruns_lib.py +++ b/torax/_src/plotting/plotruns_lib.py @@ -15,8 +15,10 @@ """Utilities for plotting outputs of Torax runs. Public API: - plot_run: Main entry point. Loads data from file and returns a plotly Figure. - plot_run_from_data_tree: Plots from an in-memory xr.DataTree. + plot_run: Main entry point. Loads data from files and returns a plotly Figure. + plot_run_from_data_tree: Plots from in-memory xr.DataTrees. Accepts a + mapping of named DataTrees, enabling comparison of arbitrarily many runs + with user-specified legend labels. PlotData: Data container exposing all output variables as attributes. FigureProperties: Configuration for the overall figure layout. PlotProperties: Configuration for an individual subplot. @@ -393,61 +395,82 @@ def _get_file_path(outfile: str) -> str: raise ValueError(f'Could not find {outfile}. Tried {possible_paths}.') -def _get_title_from_paths(path1: str, path2: str | None) -> str: +def _get_title_from_paths(paths: Mapping[str, str]) -> str: """Gets the title for the plot.""" - names = [f'(1) {path.basename(path1)}'] - if path2: - names.append(f'(2) {path.basename(path2)}') + names = [ + f'({i+1}) {path.basename(filepath)}' + for i, filepath in enumerate(paths.values()) + ] return ' & '.join(names) def plot_run( plot_config: FigureProperties, - outfile: str, - outfile2: str | None = None, + outfiles: Mapping[str, str], interactive: bool = True, ) -> go.Figure: - """Plots a single run or comparison of two runs from output files.""" - fig_title = plot_config.figure_title or _get_title_from_paths( - outfile, outfile2 - ) + """Plots one or more runs from output files. + + Args: + plot_config: Configuration for the figure layout and subplots. + outfiles: A mapping ``{label: filepath}`` where *label* is the string + that will appear in the plot legend and *filepath* is the path to + the output ``.nc`` file. + interactive: If True, calls ``fig.show()`` before returning. - outfile = _get_file_path(outfile) - outfile2 = _get_file_path(outfile2) if outfile2 else None + Returns: + A plotly ``go.Figure``. + """ + fig_title = plot_config.figure_title or _get_title_from_paths(outfiles) - data_tree = output.load_state_file(outfile) - data_tree2 = output.load_state_file(outfile2) if outfile2 else None + named_data_trees: dict[str, xr.DataTree] = { + name: output.load_state_file(_get_file_path(filepath)) + for name, filepath in outfiles.items() + } return plot_run_from_data_tree( - plot_config, data_tree, data_tree2, interactive, fig_title + plot_config, + named_data_trees, + interactive=interactive, + fig_title=fig_title, ) def plot_run_from_data_tree( plot_config: FigureProperties, - data_tree: xr.DataTree, - data_tree2: xr.DataTree | None = None, + data_trees: Mapping[str, xr.DataTree], interactive: bool = True, fig_title: str = 'Torax Simulation Results', ) -> go.Figure: - """Plots a single run or comparison of two runs from in-memory DataTrees.""" - plotdata1 = _data_tree_to_plot_data(data_tree) - plotdata2 = _data_tree_to_plot_data(data_tree2) if data_tree2 else None + """Plots one or more runs from in-memory DataTrees. - datasets_to_check = [(plotdata1, 'data_tree')] - if plotdata2 is not None: - datasets_to_check.append((plotdata2, 'data_tree2')) + Args: + plot_config: Configuration for the figure layout and subplots. + data_trees: A mapping ``{label: xr.DataTree}`` where *label* is the + string that will appear in the plot legend for that dataset. + interactive: If True, calls ``fig.show()`` before returning. + fig_title: Title of the figure. - for plotdata, filename in datasets_to_check: + Returns: + A plotly ``go.Figure``. + """ + named_plot_data: dict[str, PlotData] = { + name: _data_tree_to_plot_data(dt) + for name, dt in data_trees.items() + } + + for name, plotdata in named_plot_data.items(): available_vars = plotdata.available_variables() for cfg in plot_config.axes: for attr in cfg.attrs: if attr not in available_vars: raise ValueError( f"Attribute '{attr}' in plot_config was not found in " - f'output file: {filename}' + f'output file: {name}' ) - fig = create_plotly_figure(plot_config, plotdata1, plotdata2, fig_title) + fig = create_plotly_figure( + plot_config, named_plot_data, title=fig_title + ) if interactive: fig.show() @@ -598,15 +621,27 @@ def _get_y_limits( return lower_bound, upper_bound +# Dash patterns cycled when plotting multiple datasets. +_DASH_PATTERNS: Final[tuple[str, ...]] = ( + 'solid', + 'dash', + 'dot', + 'dashdot', + 'longdash', + 'longdashdot', +) + + def _add_traces_and_update_axes( fig: go.Figure, plot_config: FigureProperties, - datasets: list[PlotData], + named_datasets: Mapping[str, PlotData], ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """Adds traces to the figure and updates axes.""" spatial_traces_info = [] timestamp_line_info = [] trace_count = 0 + datasets = list(named_datasets.values()) for i, axis_config in enumerate(plot_config.axes): row, col = (i // plot_config.cols) + 1, (i % plot_config.cols) + 1 @@ -617,7 +652,7 @@ def _add_traces_and_update_axes( for attr, label in zip(axis_config.attrs, axis_config.labels): color = next(colors) - for idx, dataset in enumerate(datasets): + for idx, (ds_name, dataset) in enumerate(named_datasets.items()): if axis_config.suppress_zero_values and np.all( getattr(dataset, attr) == 0 ): @@ -639,7 +674,10 @@ def _add_traces_and_update_axes( x = dataset.t y = getattr(dataset, attr) - label_html = f"{prefix} {_transform_string(f'{label} (Data {idx+1})')}" + label_html = ( + f"{prefix} {_transform_string(f'{label} ({ds_name})')}" + ) + dash = _DASH_PATTERNS[idx % len(_DASH_PATTERNS)] fig.add_trace( go.Scatter( x=x, @@ -648,7 +686,7 @@ def _add_traces_and_update_axes( name=label_html, showlegend=True, legendgroup=prefix, - line=dict(color=color, dash='dash' if idx > 0 else 'solid'), + line=dict(color=color, dash=dash), ), row=row, col=col, @@ -786,20 +824,29 @@ def _update_global_layout( def create_plotly_figure( plot_config: FigureProperties, - data1: PlotData, - data2: PlotData | None = None, + datasets: Mapping[str, PlotData], title: str = 'Torax Simulation Results', ) -> go.Figure: - """Create a plotly figure.""" + """Create a plotly figure. + + Args: + plot_config: Configuration for the figure layout and subplots. + datasets: A mapping ``{label: PlotData}`` where *label* is the string + that will appear in the plot legend for that dataset. + title: Title of the figure. + Returns: + A plotly ``go.Figure``. + """ fig = _setup_subplots(plot_config) - datasets = [d for d in [data1, data2] if d is not None] spatial_traces_info, timestamp_line_info = _add_traces_and_update_axes( fig, plot_config, datasets ) - _build_slider(fig, data1, spatial_traces_info, timestamp_line_info) + # Use the first dataset's time as the master clock for the slider. + first_dataset = next(iter(datasets.values())) + _build_slider(fig, first_dataset, spatial_traces_info, timestamp_line_info) _update_global_layout(fig, plot_config, title) return fig diff --git a/torax/_src/plotting/plotruns_lib_test.py b/torax/_src/plotting/plotruns_lib_test.py index 830ecdd9f..5bd1723c0 100644 --- a/torax/_src/plotting/plotruns_lib_test.py +++ b/torax/_src/plotting/plotruns_lib_test.py @@ -71,7 +71,7 @@ def test_plot_config_all(self, config_name: str, data_file: str): plot_config = config_loader.import_module(config_path)['PLOT_CONFIG'] test_data_path = test_data_dir / data_file fig = plotruns_lib.plot_run( - plot_config, str(test_data_path), interactive=False + plot_config, {'Data 1': str(test_data_path)}, interactive=False ) self.assertIsInstance( fig, go.Figure, msg=f'Plotting of {test_data_path.name} failed' diff --git a/torax/plotting/plotruns.py b/torax/plotting/plotruns.py index ee56f9e37..b37a8c809 100644 --- a/torax/plotting/plotruns.py +++ b/torax/plotting/plotruns.py @@ -58,10 +58,10 @@ def main(args): 'Error loading plot config: %s: %s', plot_config_module_path, e ) raise - if len(args.outfile) == 1: - plotruns_lib.plot_run(plot_config, args.outfile[0]) - else: - plotruns_lib.plot_run(plot_config, args.outfile[0], args.outfile[1]) + outfiles = { + f'Data {i + 1}': f for i, f in enumerate(args.outfile) + } + plotruns_lib.plot_run(plot_config, outfiles) # Method used by the `plot_torax` binary. diff --git a/torax/run_simulation_main.py b/torax/run_simulation_main.py index e802d21a5..32d6ce783 100644 --- a/torax/run_simulation_main.py +++ b/torax/run_simulation_main.py @@ -329,16 +329,21 @@ def _post_run_plotting( return match input_text: case '0': - return plotruns_lib.plot_run(plot_config, output_files[-1]) + return plotruns_lib.plot_run( + plot_config, {'Run 1': output_files[-1]} + ) case '1': if len(output_files) == 1: simulation_app.log_to_stdout( 'Only one output run file found, only plotting the last run.', color=simulation_app.AnsiColors.RED, ) - return plotruns_lib.plot_run(plot_config, output_files[-1]) + return plotruns_lib.plot_run( + plot_config, {'Run 1': output_files[-1]} + ) return plotruns_lib.plot_run( - plot_config, output_files[-1], output_files[-2] + plot_config, + {'Run 1': output_files[-1], 'Run 2': output_files[-2]}, ) case '2': reference_run = _REFERENCE_RUN.value @@ -347,7 +352,10 @@ def _post_run_plotting( 'No reference run provided, only plotting the last run.', color=simulation_app.AnsiColors.RED, ) - return plotruns_lib.plot_run(plot_config, output_files[-1], reference_run) + outfiles = {'Run 1': output_files[-1]} + if reference_run is not None: + outfiles['Reference Run'] = reference_run + return plotruns_lib.plot_run(plot_config, outfiles) case _: raise ValueError('Unknown command')