diff --git a/docs/source/examples/getting-started.ipynb b/docs/source/examples/getting-started.ipynb index 071ca765..2d22ddeb 100644 --- a/docs/source/examples/getting-started.ipynb +++ b/docs/source/examples/getting-started.ipynb @@ -224,6 +224,56 @@ "print(\"Voltage = \", voltage)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A mistyped column will raise an error and suggest close matches if available:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "raises-exception" + ] + }, + "outputs": [], + "source": [ + "current, voltage = (\n", + " cell.procedure[\"Sample\"]\n", + " .experiment(\"Break-in Cycles\")\n", + " .charge(0)\n", + " .get(\"Crrent [A]\", \"Voltge [V]\")\n", + ")\n", + "print(\"Current [A] = \", current)\n", + "print(\"Voltage [V]= \", voltage)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If the column is completely mistyped an error will be thrown and all available columns will be listed:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "raises-exception" + ] + }, + "outputs": [], + "source": [ + "voltage = (\n", + " cell.procedure[\"Sample\"].experiment(\"Break-in Cycles\").charge(0).get(\"valoolashaka\")\n", + ")\n", + "print(type(voltage), voltage)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -332,6 +382,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -342,7 +397,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.8" + "version": "3.11.13" } }, "nbformat": 4, diff --git a/pyprobe/result.py b/pyprobe/result.py index f725b49e..7481617f 100644 --- a/pyprobe/result.py +++ b/pyprobe/result.py @@ -1,5 +1,6 @@ """A module for the Result class.""" +import difflib import re from collections.abc import Callable from functools import wraps @@ -332,8 +333,7 @@ def __getitem__(self, *column_names: str) -> "Result": ) def get( - self, - *column_names: str, + self, *column_names: str ) -> NDArray[np.float64] | tuple[NDArray[np.float64], ...]: """Return one or more columns of the data as separate 1D numpy arrays. @@ -341,22 +341,45 @@ def get( column_names (str): The column name(s) to return. Returns: - Union[NDArray[np.float64], Tuple[NDArray[np.float64], ...]]: + Union[NDArray[np.float64], tuple[NDArray[np.float64],...]]: The column(s) as numpy array(s). Raises: - ValueError: If no column names are provided. - ValueError: If a column name is not in the data. + ValueError: If no column names are provided + ValueError: If a column is not in the data. Includes suggested close matches + if available. """ - array = self.data_with_columns(*column_names).to_numpy() if len(column_names) == 0: error_msg = "At least one column name must be provided." logger.error(error_msg) raise ValueError(error_msg) - elif len(column_names) == 1: - return array.T[0] - else: - return tuple(array.T) + + try: + return ( + self.data_with_columns(*column_names).to_numpy().T[0] + if len(column_names) == 1 + else tuple(self.data_with_columns(*column_names).to_numpy().T) + ) + except ValueError: + error_msgs = [] + for name in column_names: + matches = difflib.get_close_matches( + name, self.column_list, n=1, cutoff=0.5 + ) + if matches: + error_msg = ( + f'Column "{name}" not found. Did you mean "{matches[0]}"?' + ) + logger.error(error_msg) + error_msgs.append(error_msg) + else: + error_msg = ( + f'Column "{name}" not found and no close match found. ' + f"Available columns: {', '.join(self.column_list)}" + ) + logger.error(error_msg) + error_msgs.append(error_msg) + raise ValueError("\n" + "\n".join(f"- {msg}" for msg in error_msgs)) @property def contains_lazyframe(self) -> bool: diff --git a/tests/test_result.py b/tests/test_result.py index 4bc9f383..01b2a31b 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -189,6 +189,13 @@ def test_get(Result_fixture): voltage, Result_fixture.data["Voltage [V]"].to_numpy(), ) + # Test with a mistyped column + with pytest.raises(ValueError): + current = Result_fixture.get("Crrent [A]") + np_testing.assert_array_equal( + current, + Result_fixture.data["Current [A]"].to_numpy(), + ) def test_get_only(Result_fixture): @@ -703,3 +710,6 @@ def test_from_polars_io_python_object(): assert isinstance(result.base_dataframe, pl.DataFrame) assert result.info == info pl_testing.assert_frame_equal(result.data, test_df, check_column_order=False) + + +Result_fixture.get("Voltage [V]") # Ensure Result_fixture is used