Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions docs/running_programmatically.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,25 @@ plot it directly without saving to a file first using
plot_config = torax.import_module('plotting/configs/default_plot_config.py')['PLOT_CONFIG']

# Plot directly from the in-memory data_tree returned by run_simulation.
fig = torax.plot_run_from_data_tree(plot_config, data_tree)
fig = torax.plot_run_from_data_tree(plot_config, {"TORAX": data_tree})

To compare two in-memory runs:
To compare multiple in-memory runs, pass a dictionary mapping labels to
DataTrees. The labels appear in the plot legends:

.. code-block:: python

fig = torax.plot_run_from_data_tree(plot_config, data_tree, data_tree2)
fig = torax.plot_run_from_data_tree(
plot_config,
{
"TORAX": torax_data_tree,
"JINTRAC": jintrac_data_tree,
"EXPERIMENT": experimental_data_tree,
},
)

If you have saved the output to a ``.nc`` file and want to plot from disk,
use ``torax.plot_run`` instead:

.. code-block:: python

fig = torax.plot_run(plot_config, PATH_TO_LOCAL_NC_FILE)
fig = torax.plot_run(plot_config, {"TORAX": PATH_TO_LOCAL_NC_FILE})
121 changes: 84 additions & 37 deletions torax/_src/plotting/plotruns_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
"""Utilities for plotting outputs of Torax runs.

Public API:
plot_run: Main entry point. Loads data from file and returns a plotly Figure.
plot_run_from_data_tree: Plots from an in-memory xr.DataTree.
plot_run: Main entry point. Loads data from files and returns a plotly Figure.
plot_run_from_data_tree: Plots from in-memory xr.DataTrees. Accepts a
mapping of named DataTrees, enabling comparison of arbitrarily many runs
with user-specified legend labels.
PlotData: Data container exposing all output variables as attributes.
FigureProperties: Configuration for the overall figure layout.
PlotProperties: Configuration for an individual subplot.
Expand Down Expand Up @@ -393,61 +395,82 @@ def _get_file_path(outfile: str) -> str:
raise ValueError(f'Could not find {outfile}. Tried {possible_paths}.')


def _get_title_from_paths(path1: str, path2: str | None) -> str:
def _get_title_from_paths(paths: Mapping[str, str]) -> str:
"""Gets the title for the plot."""
names = [f'(1) {path.basename(path1)}']
if path2:
names.append(f'(2) {path.basename(path2)}')
names = [
f'({i+1}) {path.basename(filepath)}'
for i, filepath in enumerate(paths.values())
]
return ' & '.join(names)


def plot_run(
plot_config: FigureProperties,
outfile: str,
outfile2: str | None = None,
outfiles: Mapping[str, str],
interactive: bool = True,
) -> go.Figure:
"""Plots a single run or comparison of two runs from output files."""
fig_title = plot_config.figure_title or _get_title_from_paths(
outfile, outfile2
)
"""Plots one or more runs from output files.

Args:
plot_config: Configuration for the figure layout and subplots.
outfiles: A mapping ``{label: filepath}`` where *label* is the string
that will appear in the plot legend and *filepath* is the path to
the output ``.nc`` file.
interactive: If True, calls ``fig.show()`` before returning.

outfile = _get_file_path(outfile)
outfile2 = _get_file_path(outfile2) if outfile2 else None
Returns:
A plotly ``go.Figure``.
"""
fig_title = plot_config.figure_title or _get_title_from_paths(outfiles)

data_tree = output.load_state_file(outfile)
data_tree2 = output.load_state_file(outfile2) if outfile2 else None
named_data_trees: dict[str, xr.DataTree] = {
name: output.load_state_file(_get_file_path(filepath))
for name, filepath in outfiles.items()
}
return plot_run_from_data_tree(
plot_config, data_tree, data_tree2, interactive, fig_title
plot_config,
named_data_trees,
interactive=interactive,
fig_title=fig_title,
)


def plot_run_from_data_tree(
plot_config: FigureProperties,
data_tree: xr.DataTree,
data_tree2: xr.DataTree | None = None,
data_trees: Mapping[str, xr.DataTree],
interactive: bool = True,
fig_title: str = 'Torax Simulation Results',
) -> go.Figure:
"""Plots a single run or comparison of two runs from in-memory DataTrees."""
plotdata1 = _data_tree_to_plot_data(data_tree)
plotdata2 = _data_tree_to_plot_data(data_tree2) if data_tree2 else None
"""Plots one or more runs from in-memory DataTrees.

datasets_to_check = [(plotdata1, 'data_tree')]
if plotdata2 is not None:
datasets_to_check.append((plotdata2, 'data_tree2'))
Args:
plot_config: Configuration for the figure layout and subplots.
data_trees: A mapping ``{label: xr.DataTree}`` where *label* is the
string that will appear in the plot legend for that dataset.
interactive: If True, calls ``fig.show()`` before returning.
fig_title: Title of the figure.

for plotdata, filename in datasets_to_check:
Returns:
A plotly ``go.Figure``.
"""
named_plot_data: dict[str, PlotData] = {
name: _data_tree_to_plot_data(dt)
for name, dt in data_trees.items()
}

for name, plotdata in named_plot_data.items():
available_vars = plotdata.available_variables()
for cfg in plot_config.axes:
for attr in cfg.attrs:
if attr not in available_vars:
raise ValueError(
f"Attribute '{attr}' in plot_config was not found in "
f'output file: {filename}'
f'output file: {name}'
)

fig = create_plotly_figure(plot_config, plotdata1, plotdata2, fig_title)
fig = create_plotly_figure(
plot_config, named_plot_data, title=fig_title
)
if interactive:
fig.show()

Expand Down Expand Up @@ -598,15 +621,27 @@ def _get_y_limits(
return lower_bound, upper_bound


# Dash patterns cycled when plotting multiple datasets.
_DASH_PATTERNS: Final[tuple[str, ...]] = (
'solid',
'dash',
'dot',
'dashdot',
'longdash',
'longdashdot',
)


def _add_traces_and_update_axes(
fig: go.Figure,
plot_config: FigureProperties,
datasets: list[PlotData],
named_datasets: Mapping[str, PlotData],
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""Adds traces to the figure and updates axes."""
spatial_traces_info = []
timestamp_line_info = []
trace_count = 0
datasets = list(named_datasets.values())

for i, axis_config in enumerate(plot_config.axes):
row, col = (i // plot_config.cols) + 1, (i % plot_config.cols) + 1
Expand All @@ -617,7 +652,7 @@ def _add_traces_and_update_axes(
for attr, label in zip(axis_config.attrs, axis_config.labels):

color = next(colors)
for idx, dataset in enumerate(datasets):
for idx, (ds_name, dataset) in enumerate(named_datasets.items()):
if axis_config.suppress_zero_values and np.all(
getattr(dataset, attr) == 0
):
Expand All @@ -639,7 +674,10 @@ def _add_traces_and_update_axes(
x = dataset.t
y = getattr(dataset, attr)

label_html = f"{prefix} {_transform_string(f'{label} (Data {idx+1})')}"
label_html = (
f"{prefix} {_transform_string(f'{label} ({ds_name})')}"
)
dash = _DASH_PATTERNS[idx % len(_DASH_PATTERNS)]
fig.add_trace(
go.Scatter(
x=x,
Expand All @@ -648,7 +686,7 @@ def _add_traces_and_update_axes(
name=label_html,
showlegend=True,
legendgroup=prefix,
line=dict(color=color, dash='dash' if idx > 0 else 'solid'),
line=dict(color=color, dash=dash),
),
row=row,
col=col,
Expand Down Expand Up @@ -786,20 +824,29 @@ def _update_global_layout(

def create_plotly_figure(
plot_config: FigureProperties,
data1: PlotData,
data2: PlotData | None = None,
datasets: Mapping[str, PlotData],
title: str = 'Torax Simulation Results',
) -> go.Figure:
"""Create a plotly figure."""
"""Create a plotly figure.

Args:
plot_config: Configuration for the figure layout and subplots.
datasets: A mapping ``{label: PlotData}`` where *label* is the string
that will appear in the plot legend for that dataset.
title: Title of the figure.

Returns:
A plotly ``go.Figure``.
"""
fig = _setup_subplots(plot_config)
datasets = [d for d in [data1, data2] if d is not None]

spatial_traces_info, timestamp_line_info = _add_traces_and_update_axes(
fig, plot_config, datasets
)

_build_slider(fig, data1, spatial_traces_info, timestamp_line_info)
# Use the first dataset's time as the master clock for the slider.
first_dataset = next(iter(datasets.values()))
_build_slider(fig, first_dataset, spatial_traces_info, timestamp_line_info)
_update_global_layout(fig, plot_config, title)

return fig
Expand Down
2 changes: 1 addition & 1 deletion torax/_src/plotting/plotruns_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_plot_config_all(self, config_name: str, data_file: str):
plot_config = config_loader.import_module(config_path)['PLOT_CONFIG']
test_data_path = test_data_dir / data_file
fig = plotruns_lib.plot_run(
plot_config, str(test_data_path), interactive=False
plot_config, {'Data 1': str(test_data_path)}, interactive=False
)
self.assertIsInstance(
fig, go.Figure, msg=f'Plotting of {test_data_path.name} failed'
Expand Down
8 changes: 4 additions & 4 deletions torax/plotting/plotruns.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def main(args):
'Error loading plot config: %s: %s', plot_config_module_path, e
)
raise
if len(args.outfile) == 1:
plotruns_lib.plot_run(plot_config, args.outfile[0])
else:
plotruns_lib.plot_run(plot_config, args.outfile[0], args.outfile[1])
outfiles = {
f'Data {i + 1}': f for i, f in enumerate(args.outfile)
}
plotruns_lib.plot_run(plot_config, outfiles)


# Method used by the `plot_torax` binary.
Expand Down
16 changes: 12 additions & 4 deletions torax/run_simulation_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,16 +329,21 @@ def _post_run_plotting(
return
match input_text:
case '0':
return plotruns_lib.plot_run(plot_config, output_files[-1])
return plotruns_lib.plot_run(
plot_config, {'Run 1': output_files[-1]}
)
case '1':
if len(output_files) == 1:
simulation_app.log_to_stdout(
'Only one output run file found, only plotting the last run.',
color=simulation_app.AnsiColors.RED,
)
return plotruns_lib.plot_run(plot_config, output_files[-1])
return plotruns_lib.plot_run(
plot_config, {'Run 1': output_files[-1]}
)
return plotruns_lib.plot_run(
plot_config, output_files[-1], output_files[-2]
plot_config,
{'Run 1': output_files[-1], 'Run 2': output_files[-2]},
)
case '2':
reference_run = _REFERENCE_RUN.value
Expand All @@ -347,7 +352,10 @@ def _post_run_plotting(
'No reference run provided, only plotting the last run.',
color=simulation_app.AnsiColors.RED,
)
return plotruns_lib.plot_run(plot_config, output_files[-1], reference_run)
outfiles = {'Run 1': output_files[-1]}
if reference_run is not None:
outfiles['Reference Run'] = reference_run
return plotruns_lib.plot_run(plot_config, outfiles)
case _:
raise ValueError('Unknown command')

Expand Down
Loading