diff --git a/.github/DISCUSSION_TEMPLATE/questions.yml b/.github/DISCUSSION_TEMPLATE/questions.yml index 92d4d7a143..c4f3256d44 100644 --- a/.github/DISCUSSION_TEMPLATE/questions.yml +++ b/.github/DISCUSSION_TEMPLATE/questions.yml @@ -36,8 +36,6 @@ body: required: true - label: I already read and followed all the tutorials in the docs and didn't find an answer. required: true - - label: I already checked if it is not related to Typer but to [Click](https://github.com/pallets/click). - required: true - type: checkboxes id: help attributes: diff --git a/CITATION.cff b/CITATION.cff index 5fe4aa2d34..b4902c6ab1 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -18,5 +18,6 @@ abstract: >- Typer, build great CLIs. Easy to code. Based on Python type hints. keywords: - typer - - click + - cli + - python license: MIT diff --git a/README.md b/README.md index 80b0b20ebd..43a46c5aa0 100644 --- a/README.md +++ b/README.md @@ -352,11 +352,20 @@ For a more complete example including more features, see the vendored Click version 8.3.1 and adds a layer on top of it. -**Typer** mainly adds a layer on top of Click, making the code simpler and easier to use, with autocompletion everywhere, etc, but providing all the powerful features of Click underneath. +Typer aims to make the code simpler and easier to use, with autocompletion everywhere, etc, while still providing many of the powerful features of Click underneath. As someone pointed out: ["Nice to see it is built on Click but adds the type stuff. Me gusta!"](https://twitter.com/fishnets88/status/1210126833745838080) diff --git a/docs/features.md b/docs/features.md index bd11fa5a9d..25490ec610 100644 --- a/docs/features.md +++ b/docs/features.md @@ -63,14 +63,6 @@ Auto completion works when you create a package (installable with `pip`). Or whe /// -/// tip - -**Typer**'s completion is implemented internally, it uses ideas and components from Click and ideas from `click-completion`, but it doesn't use `click-completion` and re-implements some of the relevant parts of Click. - -Then it extends those ideas with features and bug fixes. For example, **Typer** programs also support modern versions of PowerShell (e.g. in Windows 10) among all the other shells. - -/// - ## Tested * 100% test coverage. diff --git a/docs/index.md b/docs/index.md index 1add7a39dc..fec3eb402e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -360,11 +360,20 @@ For a more complete example including more features, see the = 8.2.1, < 8.4", "shellingham >=1.3.0", "rich >=13.8.0", "annotated-doc >=0.0.2", + "colorama; platform_system == 'Windows'", ] readme = "README.md" @@ -189,7 +189,7 @@ ignore = [ "docs_src/*" = ["TID"] [tool.ruff.lint.isort] -known-third-party = ["typer", "click"] +known-third-party = ["typer"] # For docs_src/subcommands/tutorial003/ known-first-party = ["reigns", "towns", "lands", "items", "users"] diff --git a/tests/assets/completion_argument.py b/tests/assets/completion_argument.py index f91e2b7cfb..e2754c4357 100644 --- a/tests/assets/completion_argument.py +++ b/tests/assets/completion_argument.py @@ -1,10 +1,10 @@ -import click import typer +from typer import _click app = typer.Typer() -def shell_complete(ctx: click.Context, param: click.Parameter, incomplete: str): +def shell_complete(ctx: _click.Context, param: _click.Parameter, incomplete: str): typer.echo(f"ctx: {ctx.info_name}", err=True) typer.echo(f"arg is: {param.name}", err=True) typer.echo(f"incomplete is: {incomplete}", err=True) diff --git a/tests/atomic_write_example.py b/tests/atomic_write_example.py new file mode 100644 index 0000000000..a190cc154f --- /dev/null +++ b/tests/atomic_write_example.py @@ -0,0 +1,63 @@ +import time + +import typer + +app = typer.Typer() + + +@app.command() +def write_atomic( + config: typer.FileText = typer.Option(..., mode="w", atomic=True), + pause: float = typer.Option(0.3), +) -> None: + config.write("atomic-content-1\n") + config.flush() + typer.echo("halfway") + time.sleep(pause) + config.write("atomic-content-2\n") + config.flush() + typer.echo("written atomically") + + +@app.command() +def write_atomic_binary( + config: typer.FileBinaryWrite = typer.Option(..., atomic=True, lazy=False), +) -> None: + config.write(b"\x00\x01binary-atomic\n") + typer.echo("written binary atomically") + + +@app.command() +def api_atomic( + config: typer.FileText = typer.Option(..., mode="w", atomic=True, lazy=False), +) -> None: + typer.echo(f"name={config.name}") + typer.echo(f"repr={repr(config)}") + with config as entered: + typer.echo(f"entered={entered is config}") + entered.write("atomic-api-done\n") + + +@app.command() +def invalid_atomic_append( + config: typer.FileText = typer.Option(..., mode="a", atomic=True, lazy=False), +) -> None: + typer.echo(config.name) # pragma: no cover + + +@app.command() +def invalid_atomic_exclusive( + config: typer.FileText = typer.Option(..., mode="x", atomic=True, lazy=False), +) -> None: + typer.echo(config.name) # pragma: no cover + + +@app.command() +def invalid_atomic_read( + config: typer.FileText = typer.Option(..., mode="r", atomic=True, lazy=False), +) -> None: + typer.echo(config.name) # pragma: no cover + + +if __name__ == "__main__": + app() diff --git a/tests/test_annotated.py b/tests/test_annotated.py index c487eae9d0..af3cc0680b 100644 --- a/tests/test_annotated.py +++ b/tests/test_annotated.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Annotated +import pytest import typer from typer.testing import CliRunner @@ -95,3 +96,14 @@ def custom_parser( result = runner.invoke(app, "/some/quirky/path/implementation") assert result.exit_code == 0 + + +def test_annotated_option_invalid(): + app = typer.Typer() + + @app.command() + def cmd(value: Annotated[str, typer.Option(..., "foo-bar")]): + print(value) # pragma: no cover + + with pytest.raises(ValueError, match="Invalid start character for option"): + runner.invoke(app, ["--help"], catch_exceptions=False) diff --git a/tests/test_atomic_file.py b/tests/test_atomic_file.py new file mode 100644 index 0000000000..010f899034 --- /dev/null +++ b/tests/test_atomic_file.py @@ -0,0 +1,122 @@ +import subprocess +import sys +from pathlib import Path + +import pytest + +from . import atomic_write_example as mod + + +def test_atomic_write(tmp_path: Path) -> None: + original_content = "existing-content\n" + output_file = tmp_path / "atomic-write-target.txt" + output_file.write_text(original_content, encoding="utf-8") + + process = subprocess.Popen( + [ + sys.executable, + "-m", + "coverage", + "run", + mod.__file__, + "write-atomic", + f"--config={output_file}", + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + assert process.stdout is not None + + # Halfway of writing the file, check that the original content is still there + halfway_line = process.stdout.readline().strip() + assert halfway_line == "halfway" + assert output_file.read_text(encoding="utf-8") == original_content + + # Only at the end, the full new content is visible + stdout, stderr = process.communicate(timeout=5) + assert process.returncode == 0, stderr + assert "written atomically" in stdout + assert ( + output_file.read_text(encoding="utf-8") + == "atomic-content-1\natomic-content-2\n" + ) + + +def test_atomic_binary_write(tmp_path: Path) -> None: + output_file = tmp_path / "atomic-binary.bin" + + result = subprocess.run( + [ + sys.executable, + "-m", + "coverage", + "run", + mod.__file__, + "write-atomic-binary", + f"--config={output_file}", + ], + capture_output=True, + encoding="utf-8", + ) + + assert result.returncode == 0, result.stderr + assert "written binary atomically" in result.stdout + assert output_file.read_bytes() == b"\x00\x01binary-atomic\n" + + +def test_atomic_api(tmp_path: Path) -> None: + output_file = tmp_path / "atomic-api.txt" + + result = subprocess.run( + [ + sys.executable, + "-m", + "coverage", + "run", + mod.__file__, + "api-atomic", + f"--config={output_file}", + ], + capture_output=True, + encoding="utf-8", + ) + + assert result.returncode == 0, result.stderr + assert f"name={output_file}" in result.stdout + assert "repr=<_io.TextIOWrapper" in result.stdout + assert "entered=True" in result.stdout + assert output_file.read_text(encoding="utf-8") == "atomic-api-done\n" + + +@pytest.mark.parametrize( + ("command_name", "expected_message"), + [ + ("invalid-atomic-append", "Appending to an existing file is not supported"), + ("invalid-atomic-exclusive", "Use the `overwrite`-parameter instead."), + ("invalid-atomic-read", "Atomic writes only make sense with `w`-mode."), + ], +) +def test_atomic_mode_invalid_options( + tmp_path: Path, command_name: str, expected_message: str +) -> None: + output_file = tmp_path / "atomic-invalid-mode.txt" + output_file.write_text("existing-content\n", encoding="utf-8") + + result = subprocess.run( + [ + sys.executable, + "-m", + "coverage", + "run", + mod.__file__, + command_name, + f"--config={output_file}", + ], + capture_output=True, + encoding="utf-8", + ) + + assert result.returncode != 0 + combined_output = f"{result.stdout}\n{result.stderr}" + assert expected_message in combined_output diff --git a/tests/test_cli/test_help.py b/tests/test_cli/test_help.py index 64e5495c9a..e829c5801b 100644 --- a/tests/test_cli/test_help.py +++ b/tests/test_cli/test_help.py @@ -1,6 +1,11 @@ import subprocess import sys +import typer +from typer.testing import CliRunner + +runner = CliRunner() + def test_script_help(): result = subprocess.run( @@ -36,3 +41,119 @@ def test_not_python(): encoding="utf-8", ) assert "Could not import as Python file" in result.stderr + + +def test_short_help() -> None: + app = typer.Typer( + rich_markup_mode=None, + context_settings={"max_content_width": 50}, + ) + + @app.command(help=" \n\t ") + def empty() -> None: + pass # pragma: no cover + + @app.command(help="\b first sentence.") + def marker() -> None: + pass # pragma: no cover + + # Forcing truncation + @app.command(help=f"{'x' * 30} {'y' * 5} z trailing") + def long() -> None: + pass # pragma: no cover + + result = runner.invoke(app, ["--help"], terminal_width=50) + assert result.exit_code == 0 + assert "empty" in result.output + assert "marker" in result.output + assert "long" in result.output + assert "first sentence." in result.output + assert f"{'x' * 30}..." in result.output + + +def test_help_wrapping() -> None: + app = typer.Typer( + rich_markup_mode=None, + context_settings={"max_content_width": 50}, + ) + + @app.command( + help=( + "Wrapped paragraph has enough words to wrap in help output.\n" + "\n" + "\n" + "\b\n" + "RAW-LINE-ONE stays on one line even with many many many words.\n" + "RAW-LINE-TWO keeps original formatting.\n" + "\n" + "Final paragraph wraps normally as well." + ) + ) + def cmd() -> None: + pass # pragma: no cover + + result = runner.invoke(app, ["cmd", "--help"], terminal_width=50) + assert result.exit_code == 0 + assert "Wrapped paragraph has enough words to wrap" in result.output + assert ( + "RAW-LINE-ONE stays on one line even with many many many words." + in result.output + ) + assert "RAW-LINE-TWO keeps original formatting." in result.output + assert "Final paragraph wraps normally as well." in result.output + + +def test_help_wrapping_long_name() -> None: + app = typer.Typer(rich_markup_mode=None) + + @app.command() + def cmd(value: str) -> None: + pass # pragma: no cover + + result = runner.invoke( + app, + ["cmd", "--help"], + terminal_width=40, + prog_name="very-long-program-name-that-forces-wrap", + ) + assert result.exit_code == 0 + + output_lines = result.output.splitlines() + usage_idx = output_lines.index("Usage: very-long-program-name-that-forces-wrap ") + args_line = output_lines[usage_idx + 1] + assert args_line.lstrip() == "[OPTIONS] VALUE" + assert args_line.startswith(" ") + + +def test_format_long_help_option() -> None: + app = typer.Typer(rich_markup_mode=None) + + @app.command() + def cmd( + very_long: str = typer.Option( + ..., + "--this-is-a-very-very-very-long-option-name", + help="Description is rendered in the next line for long option labels.", + ), + ) -> None: + pass # pragma: no cover + + result = runner.invoke(app, ["cmd", "--help"], terminal_width=80) + assert result.exit_code == 0 + + output_lines = result.output.splitlines() + option_idx = next( + i + for i, line in enumerate(output_lines) + if "--this-is-a-very-very-very-long-option-name" in line + ) + assert "Description is rendered" not in output_lines[option_idx] + first_desc_line = output_lines[option_idx + 1] + assert first_desc_line.lstrip().startswith("Description is rendered") + continuation_block = " ".join( + line.strip() for line in output_lines[option_idx + 1 :] if line.startswith(" ") + ) + assert ( + "Description is rendered in the next line for long option labels." + in continuation_block + ) diff --git a/tests/test_cli/test_parser.py b/tests/test_cli/test_parser.py new file mode 100644 index 0000000000..9d9b641e4b --- /dev/null +++ b/tests/test_cli/test_parser.py @@ -0,0 +1,67 @@ +import subprocess +import sys + +import typer +from typer.testing import CliRunner + +runner = CliRunner() + + +def test_double_dash() -> None: + result = subprocess.run( + [ + sys.executable, + "-m", + "coverage", + "run", + "-m", + "typer", + "tests/assets/cli/sample.py", + "run", + "hello", + "--", + "--name", + "Camila", + ], + capture_output=True, + encoding="utf-8", + ) + assert "Got unexpected extra argument" in result.stderr + assert "--name Camila" in result.stderr + + +def test_unknown_short_option() -> None: + result = subprocess.run( + [ + sys.executable, + "-m", + "coverage", + "run", + "-m", + "typer", + "tests/assets/cli/sample.py", + "run", + "hello", + "-x", + ], + capture_output=True, + encoding="utf-8", + ) + assert "No such option: -x" in result.stderr + + +def test_ignore_unknown_short_option() -> None: + app = typer.Typer( + context_settings={"ignore_unknown_options": True, "allow_extra_args": True} + ) + + @app.command() + def main( + ctx: typer.Context, all_: bool = typer.Option(False, "--all", "-a") + ) -> None: + assert all_ + print(ctx.args) + + result = runner.invoke(app, ["-azq"]) + assert result.exit_code == 0 + assert "['-zq']" in result.output diff --git a/tests/test_cli/test_program_name.py b/tests/test_cli/test_program_name.py new file mode 100644 index 0000000000..4c20942596 --- /dev/null +++ b/tests/test_cli/test_program_name.py @@ -0,0 +1,13 @@ +from typer import _click + + +def test_detect_program_name_submodule_path() -> None: + class MainModule: + __package__ = "example" + + program_name = _click.utils._detect_program_name( + path="/tmp/cli.py", + _main=MainModule(), + ) + + assert program_name == "python -m example.cli" diff --git a/tests/test_completion/choice_case_insensitive_example.py b/tests/test_completion/choice_case_insensitive_example.py new file mode 100644 index 0000000000..1f208dd0b5 --- /dev/null +++ b/tests/test_completion/choice_case_insensitive_example.py @@ -0,0 +1,19 @@ +from enum import Enum + +import typer + +app = typer.Typer() + + +class User(str, Enum): + rick = "rick" + morty = "morty" + + +@app.command() +def main(name: User = typer.Option(User.rick, "--name", case_sensitive=False)): + print(name.value) + + +if __name__ == "__main__": + app() diff --git a/tests/test_completion/choice_example.py b/tests/test_completion/choice_example.py new file mode 100644 index 0000000000..2c47d0fdc4 --- /dev/null +++ b/tests/test_completion/choice_example.py @@ -0,0 +1,19 @@ +from enum import Enum + +import typer + +app = typer.Typer() + + +class User(str, Enum): + rick = "rick" + morty = "morty" + + +@app.command() +def main(name: User = typer.Option(User.rick, "--name")): + print(name.value) + + +if __name__ == "__main__": + app() diff --git a/tests/test_completion/completion_option_then_argument.py b/tests/test_completion/completion_option_then_argument.py new file mode 100644 index 0000000000..ae5abd8d2e --- /dev/null +++ b/tests/test_completion/completion_option_then_argument.py @@ -0,0 +1,23 @@ +import typer + +app = typer.Typer() + + +def complete_name(ctx, args, incomplete): + return ["opt-choice"] # pragma: no cover + + +def complete_target(ctx, args, incomplete): + return ["arg-choice"] + + +@app.command() +def main( + name: str = typer.Option(..., "--name", autocompletion=complete_name), + target: str = typer.Argument(..., autocompletion=complete_target), +): + print(name, target) # pragma: no cover + + +if __name__ == "__main__": + app() diff --git a/tests/test_completion/file_example.py b/tests/test_completion/file_example.py new file mode 100644 index 0000000000..56556d2006 --- /dev/null +++ b/tests/test_completion/file_example.py @@ -0,0 +1,12 @@ +import typer + +app = typer.Typer() + + +@app.command() +def main(config: typer.FileText = typer.Option(...)): + print(config.read()) + + +if __name__ == "__main__": + app() diff --git a/tests/test_completion/test_completion.py b/tests/test_completion/test_completion.py index 049ec4f6af..e9be4e25d1 100644 --- a/tests/test_completion/test_completion.py +++ b/tests/test_completion/test_completion.py @@ -3,9 +3,12 @@ import sys from pathlib import Path +from typer._click.shell_completion import CompletionItem + from docs_src.typer_app import tutorial001_py310 as mod from ..utils import needs_bash, needs_linux, requires_completion_permission +from . import completion_option_then_argument as mod_option_arg @needs_bash @@ -164,3 +167,27 @@ def test_completion_source_pwsh(): "Register-ArgumentCompleter -Native -CommandName tutorial001_py310.py -ScriptBlock $scriptblock" in result.stdout ) + + +def test_completion_option_argument() -> None: + file_name = Path(mod_option_arg.__file__).name + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod_option_arg.__file__, " "], + capture_output=True, + encoding="utf-8", + env={ + **os.environ, + f"_{file_name.upper()}_COMPLETE": "complete_bash", + "COMP_WORDS": f"{file_name} --name chosen ", + "COMP_CWORD": "3", + }, + ) + assert "arg-choice" in result.stdout + assert "opt-choice" not in result.stdout + + +def test_completion_item_getattr() -> None: + item = CompletionItem("demo", source="envvar") + + assert item.source == "envvar" + assert item.missing is None diff --git a/tests/test_completion/test_completion_choice.py b/tests/test_completion/test_completion_choice.py new file mode 100644 index 0000000000..7cc4d81235 --- /dev/null +++ b/tests/test_completion/test_completion_choice.py @@ -0,0 +1,32 @@ +import os +import subprocess +import sys + +from . import choice_example as mod + + +def test_script() -> None: + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, "--name", "rick"], + capture_output=True, + encoding="utf-8", + ) + assert result.returncode == 0 + assert "rick" in result.stdout + + +def test_completion_choice_bash() -> None: + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, " "], + capture_output=True, + encoding="utf-8", + env={ + **os.environ, + "_CHOICE_EXAMPLE.PY_COMPLETE": "complete_bash", + "COMP_WORDS": "choice_example.py --name mo", + "COMP_CWORD": "2", + }, + ) + assert result.returncode == 0 + assert "morty" in result.stdout + assert "rick" not in result.stdout diff --git a/tests/test_completion/test_completion_choice_no_case.py b/tests/test_completion/test_completion_choice_no_case.py new file mode 100644 index 0000000000..f14097761d --- /dev/null +++ b/tests/test_completion/test_completion_choice_no_case.py @@ -0,0 +1,32 @@ +import os +import subprocess +import sys + +from . import choice_case_insensitive_example as mod + + +def test_script() -> None: + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, "--name", "rick"], + capture_output=True, + encoding="utf-8", + ) + assert result.returncode == 0 + assert "rick" in result.stdout + + +def test_completion_choice_bash_case_insensitive() -> None: + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, " "], + capture_output=True, + encoding="utf-8", + env={ + **os.environ, + "_CHOICE_CASE_INSENSITIVE_EXAMPLE.PY_COMPLETE": "complete_bash", + "COMP_WORDS": "choice_case_insensitive_example.py --name MO", + "COMP_CWORD": "2", + }, + ) + assert result.returncode == 0 + assert "morty" in result.stdout + assert "rick" not in result.stdout diff --git a/tests/test_completion/test_completion_file.py b/tests/test_completion/test_completion_file.py new file mode 100644 index 0000000000..782d3a04bc --- /dev/null +++ b/tests/test_completion/test_completion_file.py @@ -0,0 +1,39 @@ +import os +import subprocess +import sys + +from . import file_example as mod + + +def test_script() -> None: + result = subprocess.run( + [ + sys.executable, + "-m", + "coverage", + "run", + mod.__file__, + "--config", + mod.__file__, + ], + capture_output=True, + encoding="utf-8", + ) + assert result.returncode == 0 + assert "def main(config: typer.FileText = typer.Option(...)):" in result.stdout + + +def test_completion_file_bash() -> None: + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, " "], + capture_output=True, + encoding="utf-8", + env={ + **os.environ, + "_FILE_EXAMPLE.PY_COMPLETE": "complete_bash", + "COMP_WORDS": "file_example.py --config file_ex", + "COMP_CWORD": "2", + }, + ) + assert result.returncode == 0 + assert "file_ex" in result.stdout diff --git a/tests/test_completion/test_completion_option_colon.py b/tests/test_completion/test_completion_option_colon.py index 6b65d786e5..8818ee4a1a 100644 --- a/tests/test_completion/test_completion_option_colon.py +++ b/tests/test_completion/test_completion_option_colon.py @@ -177,6 +177,39 @@ def test_completion_colon_powershell_single(): assert "nvidia/cuda:10.0-devel-ubuntu18.04" not in result.stdout +def test_completion_powershell_option_equals_value() -> None: + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, " "], + capture_output=True, + encoding="utf-8", + env={ + **os.environ, + "_COLON_EXAMPLE.PY_COMPLETE": "complete_powershell", + "_TYPER_COMPLETE_ARGS": "colon_example.py --name=alpine", + "_TYPER_COMPLETE_WORD_TO_COMPLETE": "--name=alpine", + }, + ) + assert "alpine:hello" in result.stdout + assert "alpine:latest" in result.stdout + assert "nvidia/cuda:10.0-devel-ubuntu18.04" not in result.stdout + + +def test_completion_powershell_option_equals_only() -> None: + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, " "], + capture_output=True, + encoding="utf-8", + env={ + **os.environ, + "_COLON_EXAMPLE.PY_COMPLETE": "complete_powershell", + "_TYPER_COMPLETE_ARGS": "colon_example.py --name=", + "_TYPER_COMPLETE_WORD_TO_COMPLETE": "=", + }, + ) + assert result.returncode == 0 + assert result.stdout.strip() == "" + + def test_completion_colon_pwsh_all(): result = subprocess.run( [sys.executable, "-m", "coverage", "run", mod.__file__, " "], diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000000..2cb759918b --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,379 @@ +from typing import Annotated + +import pytest +import typer +import typer._completion_shared +import typer.completion +from typer import _click +from typer.core import TyperArgument, TyperCommand, TyperGroup, TyperOption, _split_opt +from typer.testing import CliRunner + +runner = CliRunner() + + +def test_human_readable_name() -> None: + app = typer.Typer() + + @app.command() + def main( + my_arg_1: Annotated[str, typer.Argument()], + my_arg_2: Annotated[str, typer.Argument(metavar="META_ARG")], + my_opt: Annotated[str, typer.Option()], + ): + pass # pragma: no cover + + command = typer.main.get_command(app) + params = {param.name: param for param in command.params} + + assert params["my_arg_1"].human_readable_name == "MY_ARG_1" + assert params["my_arg_2"].human_readable_name == "META_ARG" + assert params["my_opt"].human_readable_name == "my_opt" + + +def test_parameter_metavar() -> None: + app = typer.Typer(rich_markup_mode=None) + + @app.command() + def cmd(name: Annotated[str, typer.Option(metavar="CUSTOM")]) -> None: + pass # pragma: no cover + + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + assert "--name CUSTOM" in result.output + + +def test_parameter_nargs_gt_1() -> None: + param = TyperArgument(param_decls=["value"], type=str, nargs=2) + ctx = _click.Context(TyperCommand(name="cmd")) + + assert param.type_cast_value(ctx, ("one", "two")) == ("one", "two") + + with pytest.raises( + _click.exceptions.BadParameter, match="Takes 2 values but 1 given." + ): + param.type_cast_value(ctx, ("one",)) + + +def test_parameter_constructor() -> None: + # no param_decl and expose_value is False: sets name to None + arg = TyperArgument(param_decls=[], expose_value=False) + assert arg.name is None + assert arg.opts == [] + assert arg.secondary_opts == [] + + # no param_decl and expose_value is True: raises + with pytest.raises(TypeError, match="does not have a name."): + TyperArgument(param_decls=[], expose_value=True) + + # len(param_decl) > 1: raises + with pytest.raises(TypeError, match="take exactly one parameter declaration"): + TyperArgument(param_decls=["first", "second"]) + + # duplicated identifier in option declarations: raises + with pytest.raises(TypeError, match="Name 'name' defined twice"): + TyperOption(param_decls=["name", "name"], required=False) + + # same true/false flag in boolean option declaration: raises + with pytest.raises(ValueError, match="cannot use the same flag for true/false"): + TyperOption(param_decls=["flag", "--flag/--flag"], required=False, is_flag=True) + + # inferred name is not a valid identifier: sets name to None + unnamed_option = TyperOption(param_decls=["--123"], required=False) + assert unnamed_option.name is None + + # no param_decl and prompt=True: raises + with pytest.raises(TypeError, match="'name' is required with 'prompt=True'."): + TyperOption(param_decls=[], expose_value=False, prompt=True, required=False) + + # count works + option = TyperOption( + param_decls=["verbose", "--verbose", "-v"], + type=None, + default=0, + required=False, + count=True, + ) + assert isinstance(option.type, _click.types.IntRange) + assert option.type.min == 0 + + +def test_option_error_hint() -> None: + option = TyperOption( + param_decls=["name", "--name"], + required=False, + show_envvar=True, + envvar="APP_NAME", + ) + hint = option.get_error_hint(_click.Context(TyperCommand(name="cmd"))) + assert "(env var: 'APP_NAME')" in hint + + +def test_group_init() -> None: + group_no_commands = TyperGroup(name="root", commands=None) + assert group_no_commands.commands == {} + + named = TyperCommand(name="named") + unnamed = TyperCommand(name=None) + group_command_sequence = TyperGroup(name="root", commands=[named, unnamed]) + assert group_command_sequence.commands == {"named": named} + + +@pytest.mark.parametrize("with_result_callback", [False, True]) +def test_group_result_callback(with_result_callback: bool) -> None: + called = {"child": False, "result_callback": False} + + def child_callback() -> None: + called["child"] = True + return None + + def result_callback(value, **kwargs): # type: ignore[no-untyped-def] + called["result_callback"] = True + return value + + child = TyperCommand(name="child", callback=child_callback) + group = TyperGroup( + name="root", + commands={"child": child}, + result_callback=result_callback if with_result_callback else None, + ) + ctx = group.make_context("root", ["child"]) + + result = group.invoke(ctx) + + assert result is None + assert called["child"] is True + assert called["result_callback"] is with_result_callback + assert ctx.invoked_subcommand == "child" + + +def test_group_add_command() -> None: + group = TyperGroup(name="root") + unnamed_command = TyperCommand(name=None) + + with pytest.raises(TypeError, match="Command has no name."): + group.add_command(unnamed_command) + + +def test_group_click_resolve_command() -> None: + child = TyperCommand(name="child") + group = TyperGroup(name="root", commands={"child": child}) + ctx = group.make_context("root", ["CHILD"], token_normalize_func=str.lower) + + cmd_name, cmd, remaining = group._click_resolve_command(ctx, ["CHILD"]) + + assert cmd_name == "child" + assert cmd is child + assert remaining == [] + + +@pytest.mark.parametrize( + ("envvar", "auto_prefix", "set_env", "expected"), + [ + ("APP_NAME", None, True, "my-precious"), + (None, "APP", True, "my-precious"), + (None, None, False, None), + ], +) +def test_option_resolve_envvar( + monkeypatch: pytest.MonkeyPatch, + envvar: str | None, + auto_prefix: str | None, + set_env: bool, + expected: str | None, +) -> None: + option = TyperOption( + param_decls=["name", "--name"], + required=False, + envvar=envvar, + ) + if set_env: + monkeypatch.setenv("APP_NAME", "my-precious") + + ctx = _click.Context(TyperCommand(name="cmd"), auto_envvar_prefix=auto_prefix) + assert option.resolve_envvar_value(ctx) == expected + + +def test_option_resolve_envvar_list( + monkeypatch: pytest.MonkeyPatch, +) -> None: + option = TyperOption( + param_decls=["name", "--name"], + required=False, + envvar=["APP_NAME_1", "APP_NAME_2"], + ) + monkeypatch.delenv("APP_NAME_1", raising=False) + monkeypatch.delenv("APP_NAME_2", raising=False) + ctx = _click.Context(TyperCommand(name="cmd")) + + assert option.resolve_envvar_value(ctx) is None + + +def test_context_auto_envvar() -> None: + app = typer.Typer(context_settings={"auto_envvar_prefix": "APP"}) + sub_app = typer.Typer() + + @sub_app.command() + def clone(ctx: typer.Context) -> None: + print(ctx.auto_envvar_prefix) + + app.add_typer(sub_app, name="beth") + + result = runner.invoke(app, ["beth", "clone"]) + assert result.exit_code == 0 + assert "APP_BETH_CLONE" in result.stdout + + +def test_context_with_resource() -> None: + events: list[str] = [] + + class DemoResource: + def __enter__(self) -> str: + events.append("enter") + return "pickle-rick" + + def __exit__(self, *args: object) -> None: + events.append("exit") + + app = typer.Typer() + + @app.command() + def cmd(ctx: typer.Context) -> None: + value = ctx.with_resource(DemoResource()) + assert value == "pickle-rick" + assert events == ["enter"] + print("I'm a pickle") + + result = runner.invoke(app) + + assert result.exit_code == 0 + assert "I'm a pickle" in result.stdout + assert events == ["enter", "exit"] + + +def test_context_find_root() -> None: + app = typer.Typer() + sub_app = typer.Typer() + + @sub_app.command() + def child(ctx: typer.Context) -> None: + root = ctx.find_root() + assert root.parent is None + assert root is ctx.parent.parent + print("ok") + + app.add_typer(sub_app, name="sub") + + result = runner.invoke(app, ["sub", "child"]) + assert result.exit_code == 0 + assert "ok" in result.stdout + + +def test_context_find_object() -> None: + class Marker: + pass + + marker = Marker() + app = typer.Typer() + + @app.callback() + def callback(ctx: typer.Context) -> None: + ctx.obj = marker + + @app.command() + def child(ctx: typer.Context) -> None: + assert ctx.find_object(Marker) is marker + assert ctx.find_object(str) is None + print("ok") + + result = runner.invoke(app, ["child"]) + assert result.exit_code == 0 + assert "ok" in result.stdout + + +def test_context_lookup_default_callable() -> None: + app = typer.Typer() + + @app.command() + def child(ctx: typer.Context) -> None: + ctx.default_map = {"planet": lambda: "Earth"} + assert ctx.lookup_default("planet") == "Earth" + value = ctx.lookup_default("planet", call=False) + assert callable(value) + print("ok") + + result = runner.invoke(app) + assert result.exit_code == 0 + assert "ok" in result.stdout + + +def test_context_abort() -> None: + app = typer.Typer() + + @app.command() + def cmd(ctx: typer.Context) -> None: + ctx.abort() + + result = runner.invoke(app, standalone_mode=False) + assert result.exit_code == 1 + assert isinstance(result.exception, _click.core.Abort) + + +def test_command_help_disabled() -> None: + app = typer.Typer() + + @app.command(add_help_option=False) + def cmd() -> None: + pass # pragma: no cover + + result = runner.invoke(app, ["--help"], standalone_mode=False) + assert result.exit_code == 1 + assert isinstance(result.exception, _click.exceptions.NoSuchOption) + assert result.exception.option_name == "--help" + + +def test_command_help_deprecated() -> None: + app = typer.Typer(rich_markup_mode=None, epilog="Built with love") + + @app.command(short_help="Shorty", help="Regular help text.", deprecated=True) + def one() -> None: + pass # pragma: no cover + + @app.command() + def two() -> None: + pass # pragma: no cover + + result = runner.invoke(app, ["one", "--help"]) + assert result.exit_code == 0 + assert "Regular help text. (DEPRECATED)" in result.output + + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + assert "Built with love" in result.output + assert "oneShorty(DEPRECATED)" in result.output.replace(" ", "") + + +@pytest.mark.parametrize( + ("value", "expected_prefix", "expected_opt"), + [ + ("--verbose", "--", "verbose"), + ("//verbose", "//", "verbose"), + ("-verbose", "-", "verbose"), + ("verbose", "", "verbose"), + ], +) +def test_split_opt(value: str, expected_prefix: str, expected_opt: str) -> None: + prefix, opt = _split_opt(value) + assert prefix == expected_prefix + assert opt == expected_opt + + +def test_nargs_default_map(): + app = typer.Typer() + + @app.command() + def main(names: list[str] = typer.Option(None)): + print(names) # pragma: no cover + + result = runner.invoke(app, [], default_map={"names": "not-a-list"}) + assert result.exit_code == 2 + assert "Invalid value" in result.output diff --git a/tests/test_launch.py b/tests/test_launch.py index c15d6c57da..1a97c197aa 100644 --- a/tests/test_launch.py +++ b/tests/test_launch.py @@ -1,10 +1,11 @@ +import io import subprocess from unittest.mock import patch import pytest import typer -from tests.utils import needs_windows +from tests.utils import needs_linux, needs_macos, needs_windows url = "http://example.com" @@ -51,19 +52,94 @@ def test_launch_url_no_xdg_open(): mock_webbrowser_open.assert_called_once_with(url) -def test_calls_original_launch_when_not_passing_urls(): - with patch("typer.main.click.launch", return_value=0) as launch_mock: - typer.launch("not a url") +@pytest.fixture +def allow_dev_null(monkeypatch): + real_open = open - launch_mock.assert_called_once_with("not a url", wait=False, locate=False) + def fake_open(path, *args, **kwargs): + if path == "/dev/null": + return io.StringIO() + return real_open(path, *args, **kwargs) # pragma: no cover + + monkeypatch.setattr("builtins.open", fake_open) + + +@needs_macos +def test_open_url_macos(monkeypatch, allow_dev_null): + recorded: list[list[str]] = [] + + class Proc: + def wait(self) -> int: + return 42 + + def fake_popen(args, **kwargs): + recorded.append(list(args)) + return Proc() + + monkeypatch.setattr(subprocess, "Popen", fake_popen) + + assert typer.launch("/path/to/file", wait=True, locate=True) == 42 + assert recorded[0][:3] == ["open", "-W", "-R"] + assert recorded[0][-1] == "/path/to/file" + + +@needs_windows +def test_launch_files_windows(monkeypatch): + calls: list[list[str]] = [] + + def fake_call(args): + calls.append(list(args)) + return 0 + + monkeypatch.setattr(subprocess, "call", fake_call) + + assert typer.launch("C:/Tools/readme.txt", wait=True, locate=False) == 0 + assert typer.launch("file:///C:/tmp/a.txt", wait=False, locate=True) == 0 + assert calls.pop(0) == ["start", "/WAIT", "", "C:/Tools/readme.txt"] + assert calls.pop(0) == ["explorer", "/select,/C:/tmp/a.txt"] + + monkeypatch.setattr(subprocess, "call", lambda a: (_ for _ in ()).throw(OSError())) + assert typer.launch("D:/no/such/file.txt", wait=False, locate=False) == 127 + + +@needs_linux +def test_open_url_linux_wait(monkeypatch): + class Proc: + def __init__(self, code: int = 0) -> None: + self._code = code + + def wait(self) -> int: + return self._code + + monkeypatch.setattr(subprocess, "Popen", lambda *a, **k: Proc(7)) + + assert typer.launch("/file", wait=True, locate=False) == 7 + + +@needs_linux +def test_open_url_linux_locate(monkeypatch): + recorded: list[list[str]] = [] + + class Proc: + def wait(self) -> int: + return 0 # pragma: no cover + + def fake_popen(args, **kwargs): + recorded.append(list(args)) + return Proc() + + monkeypatch.setattr(subprocess, "Popen", fake_popen) + + assert typer.launch("/tmp/sub/file.txt", wait=False, locate=True) == 0 + assert recorded[-1] == ["xdg-open", "/tmp/sub"] @needs_windows def test_launch_file(): with ( - patch("click._termui_impl.sys.platform", "win32"), - patch("click._termui_impl.WIN", True), - patch("click._termui_impl.CYGWIN", False), + patch("typer._click._termui_impl.sys.platform", "win32"), + patch("typer._click._termui_impl.WIN", True), + patch("typer._click._termui_impl.CYGWIN", False), patch("subprocess.call", return_value=0) as call_mock, ): result = typer.launch("C:/tmp/file.txt", locate=True) diff --git a/tests/test_others.py b/tests/test_others.py index b389ed353f..bd83e60a07 100644 --- a/tests/test_others.py +++ b/tests/test_others.py @@ -6,12 +6,11 @@ from typing import Annotated from unittest import mock -import click import pytest import typer import typer._completion_shared import typer.completion -from typer.core import _split_opt +from typer import _click from typer.main import solve_typer_info_defaults, solve_typer_info_help from typer.models import ParameterInfo, TyperInfo from typer.testing import CliRunner @@ -37,14 +36,14 @@ def test_too_many_parsers(): def custom_parser(value: str) -> int: return int(value) # pragma: no cover - class CustomClickParser(click.ParamType): + class CustomClickParser(_click.types.ParamType): name = "custom_parser" def convert( self, value: str, - param: click.Parameter | None, - ctx: click.Context | None, + param: _click.Parameter | None, + ctx: _click.Context | None, ) -> typing.Any: return int(value) # pragma: no cover @@ -61,14 +60,14 @@ def test_valid_parser_permutations(): def custom_parser(value: str) -> int: return int(value) # pragma: no cover - class CustomClickParser(click.ParamType): + class CustomClickParser(_click.types.ParamType): name = "custom_parser" def convert( self, value: str, - param: click.Parameter | None, - ctx: click.Context | None, + param: _click.Parameter | None, + ctx: _click.Context | None, ) -> typing.Any: return int(value) # pragma: no cover @@ -104,7 +103,7 @@ def name_callback(ctx, param, val1, val2): def main(name: str = typer.Option(..., callback=name_callback)): pass # pragma: no cover - with pytest.raises(click.ClickException) as exc_info: + with pytest.raises(_click.ClickException) as exc_info: runner.invoke(app, ["--name", "Camila"]) assert ( exc_info.value.message == "Too many CLI parameter callback function parameters" @@ -127,6 +126,135 @@ def main(name: str = typer.Option(..., callback=name_callback)): assert "value is: Camila" in result.stdout +@pytest.mark.parametrize( + ("param_hint", "option_decls", "expected_message"), + [ + ("--name", (), "Invalid value for --name"), + (None, ("--name", "-n"), "Invalid value for '--name' / '-n'"), + ], +) +def test_bad_parameter_callback( + param_hint: str | None, option_decls: tuple[str, ...], expected_message: str +) -> None: + app = typer.Typer() + + def my_bad(value: str) -> str: + kwargs = {"param_hint": param_hint} if param_hint is not None else {} + raise typer.BadParameter("custom validation failed", **kwargs) + + @app.command() + def main(name: str = typer.Option(..., *option_decls, callback=my_bad)) -> None: + typer.echo(name) # pragma: no cover + + result = runner.invoke(app, ["--name", "Camila"]) + assert result.exit_code == 2 + assert expected_message in result.stderr + assert "custom validation failed" in result.stderr + + +def test_bad_parameter_main() -> None: + app = typer.Typer() + + @app.command() + def main() -> None: + raise typer.BadParameter("custom validation failed") + + result = runner.invoke(app, []) + assert result.exit_code == 2 + assert "Invalid value: custom validation failed" in result.stderr + + +@pytest.mark.parametrize( + ("kw", "msg"), + [ + ( + {"param_hint": ["--name", "-n"], "param_type": "parameter"}, + "Missing parameter '--name' / '-n'.", + ), + ({"param_type": "value"}, "Missing value."), + ], +) +def test_missing_parameter_msg(kw: dict[str, object], msg: str) -> None: + app = typer.Typer(rich_markup_mode=None) + + @app.command() + def main() -> None: + raise typer._click.exceptions.MissingParameter(**kw) + + result = runner.invoke(app, []) + assert result.exit_code == 2 + assert msg in result.stderr + + +def test_missing_parameter_callback_msg() -> None: + def my_cb(ctx: typer.Context, param: typer.CallbackParam, value: str) -> str: + raise typer._click.exceptions.MissingParameter( + message="My bad", ctx=ctx, param=param, param_type="parameter" + ) + + app = typer.Typer(rich_markup_mode=None) + + @app.command() + def main( + mode: Annotated[ + typing.Literal["alpha", "beta"], + typer.Option(..., "--mode", callback=my_cb), + ], + ) -> None: + typer.echo(mode) # pragma: no cover + + result = runner.invoke(app, ["--mode", "alpha"]) + assert result.exit_code == 2 + assert "Missing parameter '--mode'." in result.stderr + assert "My bad. Choose from:" in result.stderr + assert "alpha" in result.stderr + assert "beta" in result.stderr + result_msg = runner.invoke(app, ["--mode", "alpha"], standalone_mode=False) + assert isinstance(result_msg.exception, typer._click.exceptions.MissingParameter) + assert str(result_msg.exception) == "My bad" + + +def test_missing_parameter_str() -> None: + def my_cb(ctx: typer.Context, param: typer.CallbackParam, value: str) -> str: + raise typer._click.exceptions.MissingParameter(ctx=ctx, param=param) + + app = typer.Typer() + + @app.command() + def main(mode: str = typer.Option(..., "--mode", callback=my_cb)) -> None: + typer.echo(mode) # pragma: no cover + + result2 = runner.invoke(app, ["--mode", "alpha"], standalone_mode=False) + assert isinstance(result2.exception, typer._click.exceptions.MissingParameter) + assert str(result2.exception) == "Missing parameter: mode" + + +def test_click_exception_show_default_file() -> None: + app = typer.Typer(rich_markup_mode=None) + + @app.command() + def main() -> None: + raise typer._click.ClickException("custom click failure") + + result = runner.invoke(app, []) + assert result.exit_code == 1 + assert "custom click" in result.stderr + assert "failure" in result.stderr + + +def test_no_args_is_help_show() -> None: + app = typer.Typer(rich_markup_mode=None) + + @app.callback(invoke_without_command=True, no_args_is_help=True) + def main() -> None: + return None # pragma: no cover + + result = runner.invoke(app, []) + assert result.exit_code == 2 + assert "Usage:" in result.stderr + assert "Show this message and exit." in result.stderr + + def test_callback_3_untyped_parameters(): app = typer.Typer() @@ -169,6 +297,18 @@ def main( assert "Hello World" in result.stdout +def test_multiple_bool_flags() -> None: + app = typer.Typer() + + @app.command() + def main(choices: list[bool] = typer.Option([], "--accept/--reject")) -> None: + print(choices) + + result = runner.invoke(app, ["--accept", "--reject", "--accept"]) + assert result.exit_code == 0 + assert "[True, False, True]" in result.stdout + + def test_empty_list_default_generator(): def empty_list() -> list[str]: return [] @@ -266,7 +406,7 @@ def name_callback(ctx, args, incomplete, val2): def main(name: str = typer.Option(..., autocompletion=name_callback)): pass # pragma: no cover - with pytest.raises(click.ClickException) as exc_info: + with pytest.raises(_click.ClickException) as exc_info: runner.invoke(app, ["--name", "Camila"]) assert exc_info.value.message == "Invalid autocompletion callback parameters: val2" @@ -303,24 +443,6 @@ def main(name: str): assert "Show this message and exit." in result.stdout -def test_split_opt(): - prefix, opt = _split_opt("--verbose") - assert prefix == "--" - assert opt == "verbose" - - prefix, opt = _split_opt("//verbose") - assert prefix == "//" - assert opt == "verbose" - - prefix, opt = _split_opt("-verbose") - assert prefix == "-" - assert opt == "verbose" - - prefix, opt = _split_opt("verbose") - assert prefix == "" - assert opt == "verbose" - - def test_options_metadata_typer_default(): app = typer.Typer(options_metavar="[options]") diff --git a/tests/test_progress_bar.py b/tests/test_progress_bar.py new file mode 100644 index 0000000000..c684476dec --- /dev/null +++ b/tests/test_progress_bar.py @@ -0,0 +1,397 @@ +""" +Tests for the Progress bar functionality. +Created after vendoring Click to ensure test coverage is back up to 100%. +""" + +import io +import shutil + +import pytest +import typer +from typer import progressbar +from typer._click import _termui_impl +from typer.testing import CliRunner + +runner = CliRunner() + + +def _fake_clock(monkeypatch: pytest.MonkeyPatch) -> list[float]: + clock = [0.0] + monkeypatch.setattr(_termui_impl.time, "time", lambda: clock[0]) + return clock + + +def _pbar(**kw): + return progressbar(file=kw.pop("file", io.StringIO()), **kw) + + +@pytest.mark.parametrize( + ("iterable", "length", "hidden", "label", "expected_count"), + [ + (["a", "b"], None, False, "Processing", 2), + (None, 3, False, "Counting", 3), + (["x", "y"], None, True, "Hidden", 2), + ], +) +def test_progressbar(iterable, length, hidden, label, expected_count): + app = typer.Typer() + + @app.command() + def main(): + bar_out = io.StringIO() + count = 0 + with progressbar( + iterable, length=length, hidden=hidden, label=label, file=bar_out + ) as bar: + for _ in bar: + count += 1 + typer.echo(f"count={count}") + typer.echo(f"bar={bar_out.getvalue()!r}") + + result = runner.invoke(app, []) + assert result.exit_code == 0, result.output + assert f"count={expected_count}" in result.stdout + assert (label in result.stdout) == (not hidden) + + +@pytest.mark.parametrize( + ("label", "pbar_kw", "must_contain", "must_not_contain"), + [ + pytest.param( + "TTY", + { + "show_pos": True, + "show_percent": True, + "item_show_func": lambda item: f"item={item}", + }, + ("TTY", "1/1", "100%", "item=x"), + (), + ), + pytest.param( + "HeurPct", + {}, + ("HeurPct", "100%"), + ("1/1",), + ), + pytest.param( + "HeurPos", + {"show_pos": True}, + ("HeurPos", "1/1"), + ("100%",), + ), + ], +) +def test_progressbar_tty( + monkeypatch, label: str, pbar_kw: dict, must_contain, must_not_contain +): + monkeypatch.setattr(_termui_impl, "isatty", lambda f: True) + _fake_clock(monkeypatch) + + app = typer.Typer() + + @app.command() + def main(): + bar_out = io.StringIO() + with progressbar( + ["x"], + label=label, + file=bar_out, + bar_template="%(label)s %(info)s", + width=1, + **pbar_kw, + ) as bar: + for _ in bar: + pass + typer.echo(bar_out.getvalue()) + + result = runner.invoke(app, []) + assert result.exit_code == 0, result.output + for part in must_contain: + assert part in result.stdout + for part in must_not_contain: + assert part not in result.stdout + + +def test_progressbar_tty_show_eta(monkeypatch): + monkeypatch.setattr(_termui_impl, "isatty", lambda f: True) + clock = _fake_clock(monkeypatch) + clock[0] = 1_000.0 + + app = typer.Typer() + + @app.command() + def main(): + bar_out = io.StringIO() + with progressbar( + ["a", "b"], + label="ETA", + file=bar_out, + show_pos=True, + show_percent=False, + show_eta=True, + bar_template="%(label)s %(info)s", + width=1, + ) as bar: + for i, _ in enumerate(bar): + if i == 0: + clock[0] = 1_001.0 + typer.echo(bar_out.getvalue()) + + result = runner.invoke(app, []) + assert result.exit_code == 0, result.output + for part in ("ETA", "1/2", "00:00:01"): + assert part in result.stdout + + +def test_progressbar_autowidth(monkeypatch): + monkeypatch.setattr(_termui_impl, "isatty", lambda f: True) + call = [0] + real_get_terminal_size = shutil.get_terminal_size + + def fake_get_terminal_size(*args, **kwargs): + # Pytest (and others) call get_terminal_size(fallback=...); only stub no-arg calls + if args or kwargs: + return real_get_terminal_size(*args, **kwargs) + col = 120 if call[0] == 0 else 40 + call[0] += 1 + return type("TS", (), {"columns": col, "lines": 24})() + + monkeypatch.setattr(shutil, "get_terminal_size", fake_get_terminal_size) + + state: dict[str, object] = {} + + app = typer.Typer() + + @app.command() + def main(): + out = io.StringIO() + with progressbar(["a", "b"], width=0, label="AW", file=out) as bar: + state["autowidth"] = bar.autowidth + for _ in bar: + pass + state["call_count"] = call[0] + state["out"] = out.getvalue() + typer.echo("done") + + result = runner.invoke(app, []) + assert result.exit_code == 0, result.output + assert state["autowidth"] is True + assert state["call_count"] >= 2 + out = str(state["out"]) + assert "\r" in out and "AW" in out + assert "0%" in out and "50%" in out and "100%" in out + + +def test_progress_bar_iter(): + not_entered = _pbar(iterable=[1, 2], length=2) + with pytest.raises(RuntimeError, match="with block"): + iter(not_entered) + + entered = _pbar(iterable=[10, 20], length=2) + with entered: + iterator = iter(entered) + assert next(iterator) == 10 + assert next(entered) == 20 + with pytest.raises(StopIteration): + next(iterator) + + +def test_progress_bar_time(monkeypatch): + clock = _fake_clock(monkeypatch) + state: dict[str, object] = {} + clock[0] = 1_000.0 + + app = typer.Typer() + + @app.command() + def main(): + bar = _pbar(iterable=None, length=10) + state["tpi0"] = bar.time_per_iteration + clock[0] = 1_000.5 + bar.make_step(1) + state["avg_after_one"] = list(bar.avg) + state["tpi_after_one"] = bar.time_per_iteration + clock[0] = 1_001.0 + bar.make_step(1) + state["pos2"] = bar.pos + state["avg2"] = list(bar.avg) + state["tpi2"] = bar.time_per_iteration + clock[0] = 1_002.0 + bar.make_step(1) + state["avg3"] = list(bar.avg) + state["tpi3"] = bar.time_per_iteration + typer.echo("ok") + + result = runner.invoke(app, []) + assert result.exit_code == 0, result.output + assert state["tpi0"] == 0.0 + assert state["avg_after_one"] == [] and state["tpi_after_one"] == 0.0 + assert state["pos2"] == 2 + assert state["avg2"] == [(1_001.0 - 1_000.0) / 2.0] + assert state["tpi2"] == pytest.approx(0.5) + assert state["avg3"] == [0.5, (1_002.0 - 1_000.0) / 3.0] + assert state["tpi3"] == pytest.approx(sum(state["avg3"]) / 2.0) # type: ignore[arg-type] + + +def test_progress_bar_time_zero_steps(monkeypatch): + clock = _fake_clock(monkeypatch) + state: dict[str, object] = {} + clock[0] = 2_000.0 + + app = typer.Typer() + + @app.command() + def main(): + bar = _pbar(iterable=None, length=3) + clock[0] = 2_001.0 + bar.make_step(0) + state["pos"] = bar.pos + state["avg"] = list(bar.avg) + state["tpi"] = bar.time_per_iteration + typer.echo("ok") + + result = runner.invoke(app, []) + assert result.exit_code == 0, result.output + assert state["pos"] == 0 + assert state["avg"] == [1.0] + assert state["tpi"] == pytest.approx(1.0) + + +def test_progress_bar_eta(monkeypatch): + state: dict[str, object] = {} + + app = typer.Typer() + + @app.command() + def main(): + state["eta0"] = _pbar(iterable=[1, 2], length=None).eta + + done = _pbar(iterable=None, length=5) + done.pos, done.finished = 2, True + state["eta_done"] = done.eta + + fresh = _pbar(iterable=None, length=5) + state["eta_known_fresh"] = fresh.eta_known + state["fmt_eta_fresh"] = fresh.format_eta() + + clock = _fake_clock(monkeypatch) + clock[0] = 5_000.0 + bar = _pbar(iterable=None, length=10) + clock[0] = 5_001.0 + bar.make_step(3) + state["pos3"] = bar.pos + state["eta_after"] = bar.eta + state["tpi"] = bar.time_per_iteration + + cases_out = [] + for t0, t1, length, n_steps, _expected_fmt, expected_eta_int in ( + (9_000.0, 9_001.0, 10, 1, "00:00:09", None), + (1_000.0, 100_000.0, 2, 1, "1d 03:30:00", 99_000), + ): + clock2 = _fake_clock(monkeypatch) + clock2[0] = t0 + b = _pbar(iterable=None, length=length) + clock2[0] = t1 + b.make_step(n_steps) + cases_out.append( + ( + b.eta_known, + b.format_eta(), + int(b.eta) if expected_eta_int is not None else None, + ) + ) + + state["cases"] = cases_out + + clock3 = _fake_clock(monkeypatch) + clock3[0] = 3_000.0 + b2 = _pbar(iterable=None, length=2) + clock3[0] = 3_001.0 + b2.make_step(1) + state["fmt_before_finish"] = b2.format_eta() + b2.finish() + state["fmt_after_finish"] = b2.format_eta() + typer.echo("ok") + + result = runner.invoke(app, []) + assert result.exit_code == 0, result.output + assert state["eta0"] == 0.0 + assert state["eta_done"] == 0.0 + assert not state["eta_known_fresh"] and state["fmt_eta_fresh"] == "" + assert state["pos3"] == 3 + assert state["eta_after"] == pytest.approx(state["tpi"] * (10 - 3)) # type: ignore[operator] + + (ek1, fmt1, ei1), (ek2, fmt2, ei2) = state["cases"] # type: ignore[misc] + assert ek1 and fmt1 == "00:00:09" and ei1 is None + assert ek2 and fmt2 == "1d 03:30:00" and ei2 == 99_000 + + assert state["fmt_before_finish"] != "" + assert state["fmt_after_finish"] == "" + + +@pytest.mark.parametrize( + ("width", "fill_char", "empty_char", "expected_bar", "finished", "sample_timing"), + [ + pytest.param(4, "X", "-", "XXXX", True, False, id="finished"), + pytest.param(4, "#", "-", "----", False, False, id="no_timing_yet"), + pytest.param(5, "*", ".", None, False, True, id="indeterminate"), + ], +) +def test_progress_bar_unknown_length( + monkeypatch, + width: int, + fill_char: str, + empty_char: str, + expected_bar: str | None, + finished: bool, + sample_timing: bool, +): + clock: list[float] | None = _fake_clock(monkeypatch) if sample_timing else None + if clock is not None: + clock[0] = 100.0 + + state: dict[str, object] = {} + + class _IterableWithoutLength: + def __iter__(self): + return iter((1, 2, 3)) + + app = typer.Typer() + + @app.command() + def main(): + bar = _pbar( + iterable=_IterableWithoutLength(), + length=None, + width=width, + fill_char=fill_char, + empty_char=empty_char, + ) + assert bar.length is None + + if sample_timing: + assert clock is not None + clock[0] = 101.0 + bar.make_step(1) + assert bar.time_per_iteration > 0 + rendered = bar.format_bar() + assert len(rendered) == width + assert rendered.count(fill_char) == 1 + assert rendered.count(empty_char) == width - 1 + state["branch"] = "sample_timing" + elif finished: + bar.finished = True + state["bar"] = bar.format_bar() + state["branch"] = "finished" + else: + assert bar.time_per_iteration == 0.0 + state["bar"] = bar.format_bar() + state["branch"] = "no_timing" + typer.echo("ok") + + result = runner.invoke(app, []) + assert result.exit_code == 0, result.output + if sample_timing: + assert state["branch"] == "sample_timing" + else: + assert state["bar"] == expected_bar diff --git a/tests/test_termui.py b/tests/test_termui.py new file mode 100644 index 0000000000..f6e04199a7 --- /dev/null +++ b/tests/test_termui.py @@ -0,0 +1,458 @@ +""" +Tests for the termui, echo, and CliRunner isolation functionality. +Created after vendoring Click to ensure test coverage is back up to 100%. +""" + +import io +import os +from contextlib import contextmanager +from typing import Literal + +import pytest +import typer +from typer._click import _termui_impl, termui +from typer.testing import CliRunner + +from tests.utils import needs_windows, skip_if_windows + + +def test_raw_terminal(monkeypatch): + runner = CliRunner() + app = typer.Typer() + state = {"entered": 0, "exited": 0} + + @contextmanager + def fake_raw_terminal(): + state["entered"] += 1 + try: + yield 42 + finally: + state["exited"] += 1 + + monkeypatch.setattr(_termui_impl, "raw_terminal", fake_raw_terminal) + + @app.command() + def main(): + with termui.raw_terminal() as fd: + typer.echo(f"fd={fd}") + + result = runner.invoke(app, []) + assert result.exit_code == 0, result.output + assert "fd=42" in result.stdout + assert state["entered"] == 1 + assert state["exited"] == 1 + + +def test_getchar(monkeypatch): + # Cached path: call the existing _getchar directly. + cached_state = {"echo": None} + + def cached_getchar(echo: bool) -> str: + cached_state["echo"] = echo + return "x" + + monkeypatch.setattr(termui, "_getchar", cached_getchar) + assert termui.getchar(echo=True) == "x" + assert cached_state["echo"] is True + + # Lazy-load path: _getchar is None, so import/cache _termui_impl.getchar. + lazy_state = {"calls": 0} + + def lazy_getchar(echo: bool) -> str: + lazy_state["calls"] += 1 + return "y" if not echo else "z" + + monkeypatch.setattr(termui, "_getchar", None) + monkeypatch.setattr(_termui_impl, "getchar", lazy_getchar) + + assert termui.getchar(echo=False) == "y" + assert termui._getchar is lazy_getchar + assert termui.getchar(echo=True) == "z" + assert lazy_state["calls"] == 2 + + +def test_clirunner_getchar(monkeypatch) -> None: + runner = CliRunner() + app = typer.Typer() + + @app.command() + def main() -> None: + first = termui.getchar(echo=False) + second = termui.getchar(echo=True) + typer.echo(f"\nfirst={first};second={second}") + + monkeypatch.setattr(termui, "_getchar", None) + result = runner.invoke(app, [], input="ab") + assert result.exit_code == 0, result.output + assert result.stdout.splitlines() == ["b", "first=a;second=b"] + + +def test_clirunner_env_none(monkeypatch) -> None: + runner = CliRunner() + app = typer.Typer() + env_key = "TYPER_TEST_ENV_REMOVE" + monkeypatch.setenv(env_key, "present") + + @app.command() + def main() -> None: + typer.echo(f"inside={os.environ.get(env_key)}") + + result = runner.invoke(app, [], env={env_key: None}) + assert result.exit_code == 0, result.output + assert "inside=None" in result.stdout + assert os.environ.get(env_key) == "present" + + +@pytest.mark.parametrize( + ("runner_exc", "invoke_exc"), + [ + (False, None), + (True, False), + ], +) +def test_clirunner_invoke_catch_exceptions( + runner_exc: bool, invoke_exc: bool | None +) -> None: + runner = CliRunner(catch_exceptions=runner_exc) + app = typer.Typer() + + @app.command() + def main() -> None: + raise RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + runner.invoke(app, [], catch_exceptions=invoke_exc) + + +@pytest.mark.parametrize( + ("exit_value", "expected_exit_code", "expected_stdout"), + [ + (None, 0, ""), + ("bad-exit", 1, "bad-exit\n"), + ], +) +def test_clirunner_invoke_system_exit_branches( + exit_value: object, + expected_exit_code: int, + expected_stdout: str, +) -> None: + runner = CliRunner() + app = typer.Typer() + + @app.command() + def main() -> None: + raise SystemExit(exit_value) + + result = runner.invoke(app, []) + assert result.exit_code == expected_exit_code + assert result.stdout == expected_stdout + if expected_exit_code: + assert isinstance(result.exception, SystemExit) + else: + assert result.exception is None + + +@needs_windows +def test_termui_impl_windows_raw_terminal(): + with _termui_impl.raw_terminal() as fd: + assert fd == -1 + with termui.raw_terminal() as fd: + assert fd == -1 + + +@needs_windows +def test_termui_impl_windows_getchar(monkeypatch): + monkeypatch.setattr(_termui_impl.msvcrt, "getwch", lambda: "a") + monkeypatch.setattr(_termui_impl.msvcrt, "getwche", lambda: "b") + assert _termui_impl.getchar(echo=False) == "a" + assert _termui_impl.getchar(echo=True) == "b" + + seq_null = iter(["\x00", "K"]) + monkeypatch.setattr(_termui_impl.msvcrt, "getwch", lambda: next(seq_null)) + assert _termui_impl.getchar(echo=False) == "\x00K" + + seq_e0 = iter(["\xe0", "H"]) + monkeypatch.setattr(_termui_impl.msvcrt, "getwch", lambda: next(seq_e0)) + assert _termui_impl.getchar(echo=False) == "\xe0H" + + seq_echo = iter(["\x00", "M"]) + monkeypatch.setattr(_termui_impl.msvcrt, "getwche", lambda: next(seq_echo)) + assert _termui_impl.getchar(echo=True) == "\x00M" + + seq_e0_echo = iter(["\xe0", "Z"]) + monkeypatch.setattr(_termui_impl.msvcrt, "getwche", lambda: next(seq_e0_echo)) + assert _termui_impl.getchar(echo=True) == "\xe0Z" + + monkeypatch.setattr(_termui_impl.msvcrt, "getwch", lambda: "\x03") + with pytest.raises(KeyboardInterrupt): + _termui_impl.getchar(echo=False) + + monkeypatch.setattr(_termui_impl.msvcrt, "getwch", lambda: "\x1a") + with pytest.raises(EOFError): + _termui_impl.getchar(echo=False) + + +@skip_if_windows +@pytest.mark.parametrize("use_stdin_tty", [True, False]) +def test_termui_impl_posix_raw_terminal(monkeypatch, use_stdin_tty: bool): + state: dict[str, object] = {} + flushed: list[None] = [] + fake_tty = None + + if use_stdin_tty: + expected_fd = 14 + old_termios = "old_settings" + monkeypatch.setattr( + _termui_impl, "isatty", lambda s: s is _termui_impl.sys.stdin + ) + monkeypatch.setattr(_termui_impl.sys.stdin, "fileno", lambda: expected_fd) + else: + expected_fd = 27 + old_termios = "old" + monkeypatch.setattr( + _termui_impl, + "isatty", + lambda s: s is not _termui_impl.sys.stdin, + ) + + class FakeTTY: + def __init__(self) -> None: + self.closed = False + + def fileno(self) -> int: + return expected_fd + + def close(self) -> None: + self.closed = True + + fake_tty = FakeTTY() + real_open = open + + def fake_open(path, *args, **kwargs): + if path == "/dev/tty": + return fake_tty + return real_open(path, *args, **kwargs) # pragma: no cover + + monkeypatch.setattr("builtins.open", fake_open) + + def tcgetattr(fd: int) -> str: + state["tcgetattr_fd"] = fd + return old_termios + + def setraw(fd: int) -> None: + state["setraw_fd"] = fd + + def tcsetattr(fd: int, when: int, old: str) -> None: + state["tcsetattr"] = (fd, when, old) + + monkeypatch.setattr(_termui_impl.termios, "tcgetattr", tcgetattr) + monkeypatch.setattr(_termui_impl.tty, "setraw", setraw) + monkeypatch.setattr(_termui_impl.termios, "tcsetattr", tcsetattr) + monkeypatch.setattr( + _termui_impl.sys.stdout, "flush", lambda *a, **k: flushed.append(None) + ) + + with _termui_impl.raw_terminal() as fd: + assert fd == expected_fd + + assert state["tcgetattr_fd"] == expected_fd + assert state["setraw_fd"] == expected_fd + assert state["tcsetattr"] == ( + expected_fd, + _termui_impl.termios.TCSADRAIN, + old_termios, + ) + assert flushed == [None] + if fake_tty is not None: + assert fake_tty.closed is True + + +@skip_if_windows +def test_termui_impl_posix_getchar(monkeypatch): + @contextmanager + def fake_raw(): + yield 7 + + monkeypatch.setattr(_termui_impl, "raw_terminal", fake_raw) + monkeypatch.setattr(_termui_impl.os, "read", lambda fd, n: b"q") + monkeypatch.setattr(_termui_impl, "get_best_encoding", lambda stdin: "utf-8") + monkeypatch.setattr(_termui_impl, "isatty", lambda f: f is _termui_impl.sys.stdout) + written: list[str] = [] + monkeypatch.setattr(_termui_impl.sys.stdout, "write", lambda s: written.append(s)) + + assert _termui_impl.getchar(echo=True) == "q" + assert written == ["q"] + + +@skip_if_windows +def test_termui_impl_posix_getchar_eof(monkeypatch): + @contextmanager + def fake_raw(): + yield 5 + + monkeypatch.setattr(_termui_impl, "raw_terminal", fake_raw) + monkeypatch.setattr(_termui_impl.os, "read", lambda fd, n: b"\x04") + monkeypatch.setattr(_termui_impl, "get_best_encoding", lambda stdin: "utf-8") + monkeypatch.setattr(_termui_impl, "isatty", lambda f: False) + + with pytest.raises(EOFError): + _termui_impl.getchar(echo=False) + + +def test_prompt(): + runner = CliRunner() + app = typer.Typer() + fake_file = io.StringIO("data") + fake_file.name = "demo.txt" + + @app.command() + def main( + accept: bool = typer.Option(True, prompt=True), + name: str = typer.Option(..., prompt=True), + flavor: Literal["a", "b"] = typer.Option(..., prompt=True), + city: str = typer.Option("London", prompt=True), + config: str = typer.Option(fake_file, prompt=True), + password: str = typer.Option( + ..., + prompt=True, + hide_input=True, + confirmation_prompt=True, + ), + ): + typer.echo( + f"accept={accept};name={name};flavor={flavor};city={city};config={config};pass_len={len(password)}" + ) + + result = runner.invoke(app, [], input="\nAda\na\n\ncustom.ini\nsecret\nsecret\n") + assert result.exit_code == 0, result.output + assert ( + "accept=True;name=Ada;flavor=a;city=London;config=custom.ini;pass_len=6" + in result.stdout + ) + assert "(a, b): " in result.stdout + assert "[demo.txt]: " in result.stdout + + +def test_hidden_prompt_func(monkeypatch): + monkeypatch.setattr("getpass.getpass", lambda prompt: "secret") + assert termui.hidden_prompt_func("Password: ") == "secret" + + +def test_echo_stdout_missing(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("sys.stdout", None) + typer.echo("ignored") + + +def test_echo_stringifies() -> None: + stream = io.StringIO() + typer.echo(123, file=stream, nl=False) + assert stream.getvalue() == "123" + + +def test_echo_bytes() -> None: + buffer = io.BytesIO() + stream = io.TextIOWrapper(buffer, encoding="utf-8") + typer.echo(b"abc", file=stream, nl=True) + assert buffer.getvalue() == b"abc\n" + + +def test_echo_empty_output() -> None: + class FlushTrackingTextStream(io.StringIO): + def __init__(self) -> None: + super().__init__() + self.flush_count = 0 + + def flush(self) -> None: + self.flush_count += 1 + super().flush() + + def write(self, s: str) -> int: + raise AssertionError("Empty output") # pragma: no cover + + stream = FlushTrackingTextStream() + typer.echo("", file=stream, nl=False) + assert stream.flush_count == 1 + + +@needs_windows +def test_echo_windows_color_none( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class TtyStream(io.StringIO): + def isatty(self) -> bool: + return True + + stream = TtyStream() + monkeypatch.setattr("typer._click.utils.auto_wrap_for_ansi", None) + typer.echo("\x1b[31mred\x1b[0m", file=stream, nl=False, color=None) + assert stream.getvalue() == "red" + + +@pytest.mark.parametrize( + ("flag", "true_code", "false_code"), + [ + ("bold", "\x1b[1m", "\x1b[22m"), + ("dim", "\x1b[2m", "\x1b[22m"), + ("underline", "\x1b[4m", "\x1b[24m"), + ("overline", "\x1b[53m", "\x1b[55m"), + ("italic", "\x1b[3m", "\x1b[23m"), + ("blink", "\x1b[5m", "\x1b[25m"), + ("reverse", "\x1b[7m", "\x1b[27m"), + ("strikethrough", "\x1b[9m", "\x1b[29m"), + ], +) +def test_style(flag, true_code, false_code): + runner = CliRunner() + app = typer.Typer() + + @app.command() + def main(): + # testing an int and a str on purpose + typer.echo("TRUE=" + typer.style("42", **{flag: True}), color=True) + typer.echo("FALSE=" + typer.style(666, **{flag: False}), color=True) + + result = runner.invoke(app, []) + assert result.exit_code == 0, result.output + lines = [line for line in result.stdout.splitlines() if line] + true_line = next(line for line in lines if line.startswith("TRUE=")) + false_line = next(line for line in lines if line.startswith("FALSE=")) + assert "42" in true_line + assert "666" in false_line + + assert true_code in true_line + assert true_code not in false_line + assert false_code in false_line + assert false_code not in true_line + + +def test_style_color(): + fg_int = typer.style("x", fg=123) + assert "\x1b[38;5;123m" in fg_int + + bg_list = typer.style("x", bg=[1, 2, 3]) + assert "\x1b[48;2;1;2;3m" in bg_list + + with pytest.raises(TypeError, match="Unknown color"): + typer.style("x", fg="not-a-color") + + with pytest.raises(TypeError, match="Unknown color"): + typer.style("x", bg="not-a-color") + + +def test_termui_launch(monkeypatch): + captured = {} + + def fake_open_url(url, wait=False, locate=False): + captured["url"] = url + captured["wait"] = wait + captured["locate"] = locate + return 7 + + monkeypatch.setattr(_termui_impl, "open_url", fake_open_url) + rv = termui.launch("https://example.com", wait=True, locate=True) + assert rv == 7 + assert captured == { + "url": "https://example.com", + "wait": True, + "locate": True, + } diff --git a/tests/test_tutorial/test_parameter_types/test_bool/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_bool/test_tutorial001.py index 6941a35c37..32c07214c0 100644 --- a/tests/test_tutorial/test_parameter_types/test_bool/test_tutorial001.py +++ b/tests/test_tutorial/test_parameter_types/test_bool/test_tutorial001.py @@ -4,6 +4,7 @@ from types import ModuleType import pytest +import typer from typer.testing import CliRunner runner = CliRunner() @@ -22,6 +23,12 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: return mod +def test_type_repr(mod: ModuleType): + command = typer.main.get_command(mod.app) + force_param = next(param for param in command.params if param.name == "force") + assert repr(force_param.type) == "BOOL" + + def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 diff --git a/tests/test_tutorial/test_parameter_types/test_datetime/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_datetime/test_tutorial001.py index 9dfcdf19e3..eea63c5a8b 100644 --- a/tests/test_tutorial/test_parameter_types/test_datetime/test_tutorial001.py +++ b/tests/test_tutorial/test_parameter_types/test_datetime/test_tutorial001.py @@ -1,6 +1,8 @@ import subprocess import sys +from datetime import datetime +import typer from typer.testing import CliRunner from docs_src.parameter_types.datetime import tutorial001_py310 as mod @@ -9,6 +11,12 @@ app = mod.app +def test_type_repr(): + command = typer.main.get_command(app) + birth_param = next(param for param in command.params if param.name == "birth") + assert repr(birth_param.type) == "DateTime" + + def test_help(): result = runner.invoke(app, ["--help"]) assert result.exit_code == 0 @@ -22,6 +30,15 @@ def test_main(): assert "Birth hour: 10" in result.output +def test_main_datetime_object(): + result = runner.invoke( + app, [], default_map={"birth": datetime(1956, 1, 31, 10, 0, 0)} + ) + assert result.exit_code == 0 + assert "Interesting day to be born: 1956-01-31 10:00:00" in result.output + assert "Birth hour: 10" in result.output + + def test_invalid(): result = runner.invoke(app, ["july-19-1989"]) assert result.exit_code != 0 diff --git a/tests/test_tutorial/test_parameter_types/test_index/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_index/test_tutorial001.py index 4b2f422fbb..ee0daa9f06 100644 --- a/tests/test_tutorial/test_parameter_types/test_index/test_tutorial001.py +++ b/tests/test_tutorial/test_parameter_types/test_index/test_tutorial001.py @@ -1,6 +1,7 @@ import subprocess import sys +import typer from typer.testing import CliRunner from docs_src.parameter_types.index import tutorial001_py310 as mod @@ -9,6 +10,16 @@ app = mod.app +def test_type_repr(): + command = typer.main.get_command(app) + age_param = next(param for param in command.params if param.name == "age") + height_meters_param = next( + param for param in command.params if param.name == "height_meters" + ) + assert repr(age_param.type) == "INT" + assert repr(height_meters_param.type) == "FLOAT" + + def test_help(): result = runner.invoke(app, ["--help"]) assert result.exit_code == 0 diff --git a/tests/test_tutorial/test_parameter_types/test_number/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_number/test_tutorial001.py index 6b9445a97f..ffe3ad09a0 100644 --- a/tests/test_tutorial/test_parameter_types/test_number/test_tutorial001.py +++ b/tests/test_tutorial/test_parameter_types/test_number/test_tutorial001.py @@ -24,6 +24,19 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: return mod +def test_type_repr(mod: ModuleType): + command = typer.main.get_command(mod.app) + + id_param = next(param for param in command.params if param.name == "id") + assert repr(id_param.type) == "" + + age_param = next(param for param in command.params if param.name == "age") + assert repr(age_param.type) == "=18>" + + score_param = next(param for param in command.params if param.name == "score") + assert repr(score_param.type) == "" + + def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 diff --git a/tests/test_tutorial/test_parameter_types/test_uuid/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_uuid/test_tutorial001.py index cad8c69cc4..7b79e81405 100644 --- a/tests/test_tutorial/test_parameter_types/test_uuid/test_tutorial001.py +++ b/tests/test_tutorial/test_parameter_types/test_uuid/test_tutorial001.py @@ -1,6 +1,8 @@ import subprocess import sys +import uuid +import typer from typer.testing import CliRunner from docs_src.parameter_types.uuid import tutorial001_py310 as mod @@ -9,6 +11,12 @@ app = mod.app +def test_type_repr(): + command = typer.main.get_command(app) + user_id_param = next(param for param in command.params if param.name == "user_id") + assert repr(user_id_param.type) == "UUID" + + def test_main(): result = runner.invoke(app, ["d48edaa6-871a-4082-a196-4daab372d4a1"]) assert result.exit_code == 0 @@ -16,6 +24,14 @@ def test_main(): assert "UUID version is: 4" in result.output +def test_main_with_uuid_object(): + user_id = uuid.UUID("d48edaa6-871a-4082-a196-4daab372d4a1") + result = runner.invoke(app, [], default_map={"user_id": user_id}) + assert result.exit_code == 0 + assert "USER_ID is d48edaa6-871a-4082-a196-4daab372d4a1" in result.output + assert "UUID version is: 4" in result.output + + def test_invalid_uuid(): result = runner.invoke(app, ["7479706572-72756c6573"]) assert result.exit_code != 0 diff --git a/tests/test_type_conversion.py b/tests/test_type_conversion.py index 83edd1ecb5..26709db0e2 100644 --- a/tests/test_type_conversion.py +++ b/tests/test_type_conversion.py @@ -1,12 +1,15 @@ +import os from enum import Enum from pathlib import Path from typing import Any -import click import pytest import typer +from typer import _click, models from typer.testing import CliRunner +from tests.utils import needs_linux, needs_windows + runner = CliRunner() @@ -133,6 +136,18 @@ def tuple_recursive_conversion(container: type_annotation): assert result.exit_code == 0 +def test_tuple_wrong_arity(): + app = typer.Typer() + + @app.command() + def tuple_arity(value: tuple[str, str] = typer.Option(...)): + print(value) # pragma: no cover + + result = runner.invoke(app, [], default_map={"value": ("only-one",)}) + assert result.exit_code == 2 + assert "2 values are required, but 1 given." in result.output + + def test_custom_parse(): app = typer.Typer() @@ -146,15 +161,29 @@ def custom_parser( assert result.exit_code == 0 +def test_custom_parse_value_error(): + app = typer.Typer() + + @app.command() + def custom_parser( + hex_value: int = typer.Argument(None, parser=lambda x: int(x, 0)), + ): + print(hex_value) # pragma: no cover + + result = runner.invoke(app, ["not-a-hex"]) + assert result.exit_code == 2 + assert "Invalid value" in result.output + + def test_custom_click_type(): - class BaseNumberParamType(click.ParamType): + class BaseNumberParamType(_click.types.ParamType): name = "base_integer" def convert( self, value: Any, - param: click.Parameter | None, - ctx: click.Context | None, + param: _click.Parameter | None, + ctx: _click.Context | None, ) -> Any: return int(value, 0) @@ -168,3 +197,204 @@ def custom_click_type( result = runner.invoke(app, ["0x56"]) assert result.exit_code == 0 + + +def test_int_range_open_bound_clamp(): + app = typer.Typer() + + @app.command() + def custom_click_type( + value: int = typer.Argument( + ..., + click_type=_click.types.IntRange(min=1, min_open=True, clamp=True), + ), + ): + print(value) + + result = runner.invoke(app, ["1"]) + assert result.exit_code == 0 + assert "2" in result.output + + +def test_bool_convert_invalid(): + app = typer.Typer() + + @app.command() + def main(value: bool): + print(value) # pragma: no cover + + result = runner.invoke(app, ["maybe"]) + assert result.exit_code == 2 + assert "is not a valid boolean" in result.output + assert "yes" in result.output + assert "false" in result.output + + +@pytest.mark.parametrize( + ("arg_enc", "system_enc", "raw_value", "expected_output"), + [ + pytest.param("latin-1", "utf-8", b"\xff", "ÿ"), + pytest.param("ascii", "latin-1", b"\xff", "ÿ"), + pytest.param("ascii", "utf-16", b"\xff", "�"), + pytest.param("ascii", "ascii", b"\xff", "�"), + ], +) +def test_string_param_type_converts_bytes( + monkeypatch: pytest.MonkeyPatch, + arg_enc: str, + system_enc: str, + raw_value: bytes, + expected_output: str, +): + app = typer.Typer() + + @app.command() + def show(name: str = typer.Option(...)): + print(name) + + command = typer.main.get_command(app) + name_param = next(param for param in command.params if param.name == "name") + assert repr(name_param.type) == "STRING" + + monkeypatch.setattr(_click.types, "_get_argv_encoding", lambda: arg_enc) + monkeypatch.setattr(_click.types.sys, "getfilesystemencoding", lambda: system_enc) + + result = runner.invoke(app, [], default_map={"name": raw_value}) + assert result.exit_code == 0 + assert expected_output in result.output + + +@pytest.mark.parametrize("path_type", [str, bytes, Path]) +def test_path_coerced(path_type) -> None: + # Ensure coerce_path_result works correctly + app = typer.Typer() + + @app.command() + def show(path: Any = typer.Option(..., path_type=path_type)): + print(path) + + result = runner.invoke(app, ["--path", "dir/my_awesome_file.txt"]) + assert result.exit_code == 0 + assert "my_awesome_file" in result.output + + +@pytest.mark.parametrize( + ("create_file", "option_kwargs", "deny_mode", "expected_error"), + [ + (True, {"file_okay": False, "dir_okay": True}, None, "is a file"), + (False, {"file_okay": True, "dir_okay": False}, None, "is a directory"), + (True, {"readable": True}, os.R_OK, "is not readable"), + (True, {"readable": False, "writable": True}, os.W_OK, "is not writable"), + ], +) +def test_path_convert_failures( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + create_file: bool, + option_kwargs: dict[str, bool], + deny_mode: int | None, + expected_error: str, +) -> None: + app = typer.Typer() + + @app.command() + def show(path: Path = typer.Option(..., **option_kwargs)): + print(path) # pragma: no cover + + if deny_mode is not None: + original_access = os.access + + def fake_access(path: str, mode: int) -> bool: + if mode == deny_mode: + return False + return original_access(path, mode) # pragma: no cover + + monkeypatch.setattr(models.os, "access", fake_access) + + path = tmp_path / "some_path" + if create_file: + path.write_text("hello") + else: + path.mkdir() + result = runner.invoke(app, ["--path", str(path)]) + + assert result.exit_code != 0 + assert expected_error in result.output + + +def test_convert_type(): + from typer._click.types import convert_type + + # str + assert convert_type(str) is _click.types.STRING + assert convert_type(None) is _click.types.STRING + assert convert_type(None, default=["a"]) is _click.types.STRING + + # tuples + tuple_type = convert_type((str, int)) + assert isinstance(tuple_type, _click.types.Tuple) + assert [type(item) for item in tuple_type.types] == [ + type(_click.types.STRING), + type(_click.types.INT), + ] + + guessed_tuple = convert_type(None, default=[(1, "x")]) + assert isinstance(guessed_tuple, _click.types.Tuple) + assert [type(item) for item in guessed_tuple.types] == [ + type(_click.types.INT), + type(_click.types.STRING), + ] + + # numbers + assert convert_type(int) is _click.types.INT + assert convert_type(float) is _click.types.FLOAT + assert convert_type(bool) is _click.types.BOOL + + param_type = _click.types.IntRange(min=0, max=10) + assert convert_type(param_type) is param_type + + guessed_int = convert_type(None, default=42) + assert guessed_int is _click.types.INT + + # custom type + class CustomType: + pass + + guessed_unknown = convert_type(None, default=CustomType()) + assert guessed_unknown is _click.types.STRING + + func_type = convert_type(CustomType) + assert isinstance(func_type, _click.types.FuncParamType) + assert func_type.name == "CustomType" + + +@pytest.mark.parametrize( + ("platform_case", "stdin_encoding", "filesystem_encoding"), + [ + pytest.param("windows", None, "utf-8", marks=needs_windows), + pytest.param("linux", "latin-1", "utf-8", marks=needs_linux), + pytest.param("linux", None, "latin-1", marks=needs_linux), + ], +) +def test_argv_encoding( + monkeypatch: pytest.MonkeyPatch, + platform_case: str, + stdin_encoding: str | None, + filesystem_encoding: str, +) -> None: + sys = _click._compat.sys + if platform_case == "windows": + import locale + + monkeypatch.setattr(locale, "getpreferredencoding", lambda: "latin-1") + else: + + class FakeStdin: + def __init__(self, encoding: str | None) -> None: + self.encoding = encoding + + monkeypatch.setattr(sys, "stdin", FakeStdin(stdin_encoding)) + monkeypatch.setattr(sys, "getfilesystemencoding", lambda: filesystem_encoding) + + converted = _click.types.STRING.convert(b"\xff", None, None) + assert converted == "ÿ" diff --git a/tests/test_types.py b/tests/test_types.py index adb100eb82..caeef451aa 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,6 +1,8 @@ from enum import Enum +import pytest import typer +from typer import _click from typer.testing import CliRunner app = typer.Typer(context_settings={"token_normalize_func": str.lower}) @@ -12,23 +14,108 @@ class User(str, Enum): @app.command() -def hello(name: User = User.rick) -> None: +def hello_option(name: User = User.rick) -> None: print(f"Hello {name.value}!") +@app.command() +def hello_argument(name: User) -> None: + print(f"Hello {name.value}!") + + +@app.command() +def hello_no_choices( + name: User = typer.Option(..., "--name", show_choices=False), +): + print(f"Hello {name.value}!") + + +@app.command() +def hello_all(names: list[str] = typer.Argument(["World"], envvar="NAMES")) -> None: + for name in names: + print(f"Hello {name}!") + + +@app.command() +def split_variadic_and_pair(items: list[str], pair: tuple[str, str]) -> None: + print(f"items={items}") + print(f"pair={pair}") + + runner = CliRunner() def test_enum_choice() -> None: - # This test is only for coverage of the new custom TyperChoice class - result = runner.invoke(app, ["--name", "morty"], catch_exceptions=False) + result = runner.invoke( + app, ["hello-option", "--name", "morty"], catch_exceptions=False + ) assert result.exit_code == 0 assert "Hello Morty!" in result.output - result = runner.invoke(app, ["--name", "Rick"]) + result = runner.invoke(app, ["hello-option", "--name", "Rick"]) + assert result.exit_code == 0 + assert "Hello Rick!" in result.output + + result = runner.invoke(app, ["hello-option", "--name", "RICK"]) assert result.exit_code == 0 assert "Hello Rick!" in result.output - result = runner.invoke(app, ["--name", "RICK"]) + result = runner.invoke(app, ["hello-no-choices", "--name", "RICK"]) assert result.exit_code == 0 assert "Hello Rick!" in result.output + + result = runner.invoke(app, ["hello-argument", "RICK"]) + assert result.exit_code == 0 + assert "Hello Rick!" in result.output + + +def test_enum_choice_repr() -> None: + root_command = typer.main.get_command(app) + command = root_command.commands["hello-option"] + name_param = next(param for param in command.params if param.name == "name") + assert repr(name_param.type).startswith("Choice([") + + +def test_enum_choice_help() -> None: + result = runner.invoke(app, ["hello-argument", "--help"]) + assert result.exit_code == 0 + assert "{rick|morty}" in result.output + + result = runner.invoke(app, ["hello-option", "--help"]) + assert result.exit_code == 0 + assert "[rick|morty]" in result.output + + result = runner.invoke(app, ["hello-no-choices", "--help"]) + assert result.exit_code == 0 + assert "--name" in result.output + assert "rick|morty" not in result.output + + +def test_enum_choice_missing_message() -> None: + result = runner.invoke(app, ["hello-argument"]) + assert result.exit_code != 0 + assert "Missing argument" in result.output + assert "Choose from:" in result.output + assert "rick" in result.output + assert "morty" in result.output + + +def test_split_envvar_value(monkeypatch) -> None: + # This will use split_envvar_value to produce two strings from the envvar + monkeypatch.setenv("NAMES", "Rick Morty") + result = runner.invoke(app, ["hello-all"]) + assert result.exit_code == 0 + assert "Hello Rick!" in result.output + assert "Hello Morty!" in result.output + + +def test_list_pair() -> None: + result = runner.invoke(app, ["split-variadic-and-pair", "a", "b", "c", "x", "y"]) + assert result.exit_code == 0 + assert "items=['a', 'b', 'c']" in result.output + assert "pair=('x', 'y')" in result.output + + +def test_float_range_open_bounds_with_clamp_not_allowed(): + with pytest.raises(TypeError, match="Clamping is not supported for open bounds."): + _click.types.FloatRange(min=0.0, min_open=True, clamp=True) diff --git a/tests/test_types_file.py b/tests/test_types_file.py new file mode 100644 index 0000000000..61fd71c300 --- /dev/null +++ b/tests/test_types_file.py @@ -0,0 +1,395 @@ +import subprocess +import sys +from io import BytesIO, StringIO, TextIOWrapper +from pathlib import Path + +import pytest +import typer +from typer._click._compat import get_best_encoding, should_strip_ansi +from typer._click.testing import make_input_stream +from typer._click.utils import PacifyFlushWrapper +from typer.testing import CliRunner + +from tests.utils import needs_linux, needs_windows + +app = typer.Typer() + + +@app.command() +def read_text(file_in: typer.FileText = typer.Option(..., lazy=True)): + data = file_in.read() + typer.echo(f"text-len={len(data)}") + + +@app.command() +def write_text(file_out: typer.FileTextWrite = typer.Option(..., lazy=None)): + file_out.write("This is a single line\n") + typer.echo("1 line written") + + +@app.command() +def write_lazy(file_out: typer.FileTextWrite = typer.Option(..., lazy=True)): + file_out.write("This is a single lazy line\n") + typer.echo("1 line written") + + +@app.command() +def probe_lazy_file_behaviors( + file_in: typer.FileText = typer.Option(..., lazy=True), + file_out: typer.FileTextWrite = typer.Option(..., lazy=True), +): + typer.echo(f"repr-before={repr(file_out)}") + file_out.write("repr-opened\n") + typer.echo(f"repr-after={repr(file_out)}") + with file_in as stream: + typer.echo(f"context-len={len(stream.read())}") + stream.seek(0) + first_line = next(iter(stream), "") + typer.echo(f"first-line={first_line.rstrip()}") + + +@app.command() +def write_binary(file_out: typer.FileBinaryWrite = typer.Option(...)): + file_out.write(b"binary-written\n") + + +@app.command() +def write_binary_stderr(): + stream = typer.get_binary_stream("stderr") + stream.write(b"binary-stderr\n") + stream.flush() + + +@app.command() +def read_binary(file_in: typer.FileBinaryRead = typer.Option(...)): + data = file_in.read() + typer.echo(f"binary-len={len(data)}") + + +runner = CliRunner() + + +def test_text_stdin_dash() -> None: + result = runner.invoke(app, ["read-text", "--file-in=-"], input="hello\n") + assert result.exit_code == 0 + assert "text-len=6" in result.output + + +def test_lazy_file(tmp_path: Path) -> None: + # dash: written to stdout + result = runner.invoke(app, ["write-text", "--file-out=-"]) + assert result.exit_code == 0 + assert "This is a single line" in result.output + assert "1 line written" in result.output + + # lazy + file + file_path = tmp_path / "example.txt" + result = runner.invoke(app, ["write-lazy", f"--file-out={file_path}"]) + assert result.exit_code == 0 + assert "This is a single lazy line" not in result.output + assert "1 line written" in result.output + assert file_path.exists() + assert file_path.read_text() == "This is a single lazy line\n" + + # lazy probe: unopened/opened repr, context manager, and iteration. + result = runner.invoke( + app, + [ + "probe-lazy-file-behaviors", + f"--file-in={file_path}", + f"--file-out={tmp_path / 'repr-opened.txt'}", + ], + ) + assert result.exit_code == 0 + assert "repr-before= None: + stream = StringIO() + result = runner.invoke( + app, ["write-text"], default_map={"write-text": {"file_out": stream}} + ) + assert result.exit_code == 0 + assert "1 line written" in result.output + assert stream.getvalue() == "This is a single line\n" + + +def test_input_stream() -> None: + binary_stream = BytesIO(b"hello") + converted = make_input_stream(binary_stream, charset="utf-8") + assert converted is binary_stream + + text_stream = TextIOWrapper(BytesIO(b"hello"), encoding="utf-8") + converted = make_input_stream(text_stream, charset="utf-8") + assert converted is text_stream.buffer + + +def test_binary_dash() -> None: + result = runner.invoke(app, ["write-binary", "--file-out=-"]) + assert result.exit_code == 0 + assert result.stdout_bytes == b"binary-written\n" + + result = runner.invoke( + app, ["read-binary", "--file-in=-"], input=b"\x00\x01\x02abc" + ) + assert result.exit_code == 0 + assert "binary-len=6" in result.output + + +def test_binary_stderr() -> None: + result = subprocess.run( + [ + sys.executable, + "-c", + "from tests.test_types_file import app; app()", + "write-binary-stderr", + ], + capture_output=True, + ) + assert result.returncode == 0 + assert result.stderr == b"binary-stderr\n" + + +@pytest.mark.parametrize( + ("errors_arg", "expected_errors"), + [ + (None, "replace"), + ("strict", "strict"), + ], +) +def test_get_text_stream_errors( + monkeypatch, + errors_arg: str | None, + expected_errors: str, +) -> None: + class BinaryStdout(BytesIO): + pass + + binary_stdout = BinaryStdout() + monkeypatch.setattr(sys, "stdout", binary_stdout) + + text_stream = typer.get_text_stream("stdout", encoding=None, errors=errors_arg) + text_stream.write("stream-text") + text_stream.flush() + + assert text_stream.errors == expected_errors + assert text_stream.writable() is True + assert binary_stdout.getvalue() == b"stream-text" + + +def test_get_best_encoding() -> None: + """Test that ASCII is being transformed into UTF-8""" + + class AsciiStream: + encoding = "ascii" + + class Utf8Stream: + encoding = "utf-8" + + class UnknownStream: + encoding = "unknown" + + assert get_best_encoding(AsciiStream()) == "utf-8" + assert get_best_encoding(Utf8Stream()) == "utf-8" + assert get_best_encoding(UnknownStream()) == "unknown" + + +def test_pacify_flush_wrapper() -> None: + class Wrapped: + def __init__(self) -> None: + self.name = "wrapped-stream" + + def flush(self) -> None: + return None # pragma: no cover + + wrapped = PacifyFlushWrapper(Wrapped()) + assert wrapped.name == "wrapped-stream" + + +def test_text_stream_isatty(monkeypatch) -> None: + class BinaryStdout(BytesIO): + def isatty(self) -> bool: + return True + + binary_stdout = BinaryStdout() + monkeypatch.setattr(sys, "stdout", binary_stdout) + text_stream = typer.get_text_stream("stdout", encoding="utf-8", errors=None) + assert text_stream.isatty() is True + + +def test_text_stream_buffer_read1(monkeypatch) -> None: + class BinaryStdinNoRead1: + def __init__(self, data: bytes) -> None: + self._data = data + self._pos = 0 + + def read(self, size: int = -1) -> bytes: + if size < 0: + size = len(self._data) - self._pos # pragma: no cover + chunk = self._data[self._pos : self._pos + size] + self._pos += len(chunk) + return chunk + + binary_stdin = BinaryStdinNoRead1(b"hello") + monkeypatch.setattr(sys, "stdin", binary_stdin) + text_stream = typer.get_text_stream("stdin", encoding="utf-8", errors=None) + assert text_stream._stream.read1(4) == b"hell" + + +def test_binary_stream(monkeypatch) -> None: + binary_stdin = BytesIO(b"hello") + binary_stdout = BytesIO() + monkeypatch.setattr(sys, "stdin", binary_stdin) + monkeypatch.setattr(sys, "stdout", binary_stdout) + + assert typer.get_binary_stream("stdin") is binary_stdin + assert typer.get_binary_stream("stdout") is binary_stdout + + +def test_binary_stream_raises(monkeypatch) -> None: + class TextOnlyStdin: + def read(self, n: int = -1) -> str: + return "hello" + + monkeypatch.setattr(sys, "stdin", TextOnlyStdin()) + with pytest.raises(RuntimeError, match="Was not able to determine binary stream"): + typer.get_binary_stream("stdin") + + +def test_stream_unknown() -> None: + with pytest.raises(TypeError, match="Unknown standard stream 'Plumbus'"): + typer.get_binary_stream("Plumbus") # type: ignore[arg-type] + + with pytest.raises(TypeError, match="Unknown standard stream 'Fleeb'"): + typer.get_text_stream("Fleeb") # type: ignore[arg-type] + + +def test_format_filename() -> None: + filename = b"folder/subdir/demo.txt" + assert typer.format_filename(filename, shorten=True) == "demo.txt" + + +def test_file_error(monkeypatch, tmp_path: Path) -> None: + file_path = tmp_path / "cannot-open.txt" + + def fake_open(path, *args, **kwargs): + if Path(path) == file_path: + raise OSError() + + monkeypatch.setattr("builtins.open", fake_open) + result = runner.invoke(app, ["write-text", f"--file-out={file_path}"]) + assert result.exit_code == 1 + assert "Could not open file" in result.output + assert "cannot-open.txt" in result.output + assert "unknown error" in result.output + + +@needs_windows +def test_app_dir_windows_fallback( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("APPDATA", raising=False) + monkeypatch.setattr("os.path.expanduser", lambda _path: r"C:\Users\Tester") + + assert typer.get_app_dir("My App", roaming=True) == r"C:\Users\Tester\My App" + + +@needs_linux +def test_app_dir_force_posix(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("os.path.expanduser", lambda _path: "/home/tester/.my-app") + + assert typer.get_app_dir("My App", force_posix=True) == "/home/tester/.my-app" + + +def test_text_stream_binary_buffer(monkeypatch) -> None: + class TextStdinWithBinaryBuffer: + def __init__(self, data: bytes) -> None: + self.buffer = BytesIO(data) + self.encoding = "latin-1" + + def read(self, n: int = -1) -> str: + raise OSError("text stream is not readable directly") + + class TextStdoutWithBinaryBuffer: + def __init__(self) -> None: + self.buffer = BytesIO() + self.encoding = "latin-1" + + def write(self, s: str) -> int: + raise OSError("text stream is not writable directly") + + stdin = TextStdinWithBinaryBuffer(b"hello") + stdout = TextStdoutWithBinaryBuffer() + + monkeypatch.setattr(sys, "stdin", stdin) + monkeypatch.setattr(sys, "stdout", stdout) + + text_stdin = typer.get_text_stream("stdin", encoding="utf-8", errors=None) + text_stdout = typer.get_text_stream("stdout", encoding="utf-8", errors=None) + + assert text_stdin.read() == "hello" + text_stdout.write("ok") + text_stdout.flush() + assert stdout.buffer.getvalue() == b"ok" + + +def test_text_stream_binary_stream(monkeypatch) -> None: + binary_stdout = BytesIO() + monkeypatch.setattr(sys, "stdout", binary_stdout) + text_stream = typer.get_text_stream("stdout", encoding="utf-8", errors=None) + text_stream.write("ok") + text_stream.flush() + assert binary_stdout.getvalue() == b"ok" + + +def test_text_stream_stdout_no_binary( + monkeypatch, +) -> None: + class TextStdoutNoBinaryFallback: + encoding = "utf-8" + errors = "strict" + + def write(self, s: str) -> int: + if isinstance(s, bytes): + raise TypeError("bytes not supported") + return len(s) + + stdout = TextStdoutNoBinaryFallback() + monkeypatch.setattr(sys, "stdout", stdout) + text_stream = typer.get_text_stream("stdout", encoding="utf-8", errors="replace") + assert text_stream is stdout + + +def test_jupyter_wrapped_stream(monkeypatch) -> None: + class JupyterLikeStdout(BytesIO): + __module__ = "ipykernel.iostream" + + def isatty(self) -> bool: + return False + + binary_stdout = JupyterLikeStdout() + monkeypatch.setattr(sys, "stdout", binary_stdout) + text_stream = typer.get_text_stream("stdout", encoding="utf-8", errors=None) + assert should_strip_ansi(text_stream, color=None) is False + + +def test_should_strip_ansi(monkeypatch) -> None: + class NonTtyStdin(BytesIO): + def isatty(self) -> bool: + return False + + stdin = NonTtyStdin() + monkeypatch.setattr(sys, "stdin", stdin) + assert should_strip_ansi(stream=None, color=None) is True + assert should_strip_ansi(stream=None, color=True) is False + assert should_strip_ansi(stream=None, color=False) is True diff --git a/tests/test_win_console.py b/tests/test_win_console.py new file mode 100644 index 0000000000..13c390cfe2 --- /dev/null +++ b/tests/test_win_console.py @@ -0,0 +1,326 @@ +""" +Tests for the Windows console functionality. +Created after vendoring Click to ensure test coverage is back up to 100%. +""" + +import ctypes +import io +import sys + +import pytest +import typer +from typer.testing import CliRunner + +from .utils import needs_windows + +pytestmark = needs_windows + + +if sys.platform == "win32": + from typer._click import _compat, _winconsole + + +def _identity_buffer(obj, writable=False): # noqa: ARG001 + return obj + + +def _route_console_stream(target_name, wrapper, state=None): + def patched_windows_stream(stream, encoding, errors): # noqa: ARG001 + current_target = getattr(sys, target_name) + if stream is current_target: + if state is not None and target_name == "stderr": + state["stderr_wrap_calls"] += 1 + buffer = getattr(stream, "buffer", None) + return wrapper(buffer) if buffer else None + return None + + return patched_windows_stream + + +def _capture_write_console(state): + def fake_write_console(handle, buffer, units_to_write, units_written_ptr, reserved): # noqa: ARG001 + state["write_calls"] += 1 + bytes_to_write = units_to_write * 2 + state["written"].extend(buffer[:bytes_to_write]) + units_written_ptr._obj.value = units_to_write + return 1 + + return fake_write_console + + +def test_winconsole_stdin(monkeypatch): + runner = CliRunner() + app = typer.Typer() + + @app.command() + def read_name(config: typer.FileText = typer.Option(...)) -> None: + name = config.readline().strip() + typer.echo(f"Hello {name}") + + utf16_data = bytearray("Rick\r\n".encode("utf-16-le")) + state = {"pos": 0, "read_calls": 0} + + def fake_read_console(handle, buffer, units_to_read, units_read_ptr, reserved): # noqa: ARG001 + state["read_calls"] += 1 + max_bytes = units_to_read * 2 + chunk = utf16_data[state["pos"] : state["pos"] + max_bytes] + if chunk: + buffer[0 : len(chunk)] = chunk + state["pos"] += len(chunk) + units_read_ptr._obj.value = len(chunk) // 2 + return 1 + + return 1 # pragma: no cover + + monkeypatch.setattr(_winconsole, "get_buffer", _identity_buffer) + monkeypatch.setattr(_winconsole, "ReadConsoleW", fake_read_console) + monkeypatch.setattr(_winconsole, "GetLastError", lambda: 0) + monkeypatch.setattr( + _compat, + "_get_windows_console_stream", + _route_console_stream("stdin", _winconsole._get_text_stdin), + ) + + result = runner.invoke(app, ["--config", "-"]) + assert result.exit_code == 0, result.output + assert "Hello Rick" in result.stdout + assert state["read_calls"] > 0 + + +def test_winconsole_stdout(monkeypatch): + runner = CliRunner() + app = typer.Typer() + state = {"write_calls": 0, "written": bytearray()} + + @app.command() + def write_message(out: typer.FileTextWrite = typer.Option(...)) -> None: + out.write("Hello Summer\n") + + monkeypatch.setattr(_winconsole, "get_buffer", _identity_buffer) + monkeypatch.setattr(_winconsole, "WriteConsoleW", _capture_write_console(state)) + monkeypatch.setattr(_winconsole, "GetLastError", lambda: 0) + monkeypatch.setattr( + _compat, + "_get_windows_console_stream", + _route_console_stream("stdout", _winconsole._get_text_stdout), + ) + + result = runner.invoke(app, ["--out", "-"]) + assert result.exit_code == 0, result.output + assert state["write_calls"] > 0 + assert _winconsole._WindowsConsoleWriter(1).isatty() is True + decoded = state["written"].decode("utf-16-le", errors="ignore") + assert "Hello Summer\r\n" in decoded + + +def test_winconsole_stderr(monkeypatch): + runner = CliRunner() + app = typer.Typer() + state = {"write_calls": 0, "written": bytearray(), "stderr_wrap_calls": 0} + + @app.command() + def main() -> None: + typer.echo("Ran out of adventure time!", err=True) + + monkeypatch.setattr(_winconsole, "get_buffer", _identity_buffer) + monkeypatch.setattr(_winconsole, "WriteConsoleW", _capture_write_console(state)) + monkeypatch.setattr(_winconsole, "GetLastError", lambda: 0) + monkeypatch.setattr( + _compat, + "_get_windows_console_stream", + _route_console_stream("stderr", _winconsole._get_text_stderr, state), + ) + + result = runner.invoke(app) + assert result.exit_code == 0, result.output + assert state["stderr_wrap_calls"] > 0 + assert state["write_calls"] > 0 + decoded = state["written"].decode("utf-16-le", errors="ignore") + assert "Ran out of adventure time!\r\n" in decoded + + +@pytest.mark.parametrize( + ("writable", "source", "expected_flags"), + [ + (True, bytearray(b"python"), 1), # PyBUF_WRITABLE + (False, b"python", 0), # PyBUF_SIMPLE + ], +) +def test_get_buffer(monkeypatch, writable, source, expected_flags): + state = {"flags": None, "released": 0} + if writable: + backing = (_winconsole.c_char * len(source)).from_buffer(source) + else: + backing = (_winconsole.c_char * len(source)).from_buffer_copy(source) + backing_ptr = _winconsole.c_void_p(ctypes.addressof(backing)) + + def fake_object_get_buffer(obj, buf_ref, flags): # noqa: ARG001 + state["flags"] = flags + buf = buf_ref._obj + buf.buf = backing_ptr + buf.len = len(source) + + def fake_buffer_release(buf_ref): # noqa: ARG001 + state["released"] += 1 + + monkeypatch.setattr(_winconsole, "PyObject_GetBuffer", fake_object_get_buffer) + monkeypatch.setattr(_winconsole, "PyBuffer_Release", fake_buffer_release) + + probe = source if writable else b"x" + out = _winconsole.get_buffer(probe, writable=writable) + if writable: + # mutate the first byte of "python" to obtain another beloved programming language + out[0] = b"c" + assert source == bytearray(b"cython") + else: + assert bytes(out[: len(source)]) == source + assert state["flags"] == expected_flags + assert state["released"] == 1 + + +def test_isatty(): + assert _winconsole._WindowsConsoleRawIOBase(None).isatty() is True + assert _winconsole._WindowsConsoleReader(0).isatty() is True + assert _winconsole._WindowsConsoleReader(1).isatty() is True + + +def test_console_stream(): + class NamedBytesIO(io.BytesIO): + name = "fake-buffer" + + def isatty(self): + return False + + stream = _winconsole.ConsoleStream( + io.TextIOWrapper(io.BytesIO(), encoding="utf-8"), NamedBytesIO() + ) + assert stream.isatty() is False + assert stream.name == "fake-buffer" + assert "fake-buffer" in repr(stream) + assert "utf-8" in repr(stream) + + # test writelines + stream.writelines(["hello", " ", "world"]) + stream._text_stream.flush() + assert stream._text_stream.buffer.getvalue().decode("utf-8") == "hello world" + + # Cover bytes write path. + assert stream.write(b"!") == 1 + assert stream.buffer.getvalue().endswith(b"!") + + +@pytest.mark.parametrize( + ("error", "msg"), + [ + (0, "ERROR_SUCCESS"), # ERROR_SUCCESS + (8, "ERROR_NOT_ENOUGH_MEMORY"), # ERROR_NOT_ENOUGH_MEMORY + (342, "Windows error 342"), + ], +) +def test_error_message(error, msg): + writer = _winconsole._WindowsConsoleWriter + assert writer._get_error_message(error) == msg + + +def test_is_console(): + assert _winconsole._is_console(object()) is False + + +def test_get_windows_console_stream_factory_and_buffer_paths(monkeypatch): + monkeypatch.setattr(_winconsole, "_is_console", lambda f: True) + monkeypatch.setattr(_winconsole, "get_buffer", object()) + + class FakeStream: + def __init__(self, fd, buffer=None): + self._fd = fd + self.buffer = buffer + + def fileno(self): + return self._fd + + wrapped = {"called": False, "buffer": None} + + def fake_factory(buffer): + wrapped["called"] = True + wrapped["buffer"] = buffer + return "wrapped", buffer + + monkeypatch.setattr(_winconsole, "_stream_factories", {7: fake_factory}) + + # Known console stream preconditions pass, but no stream factory for this fd. + get_stream = _winconsole._get_windows_console_stream + assert get_stream(FakeStream(99, object()), "utf-16-le", "strict") is None + + # Factory exists, but stream has no usable .buffer. + assert get_stream(FakeStream(7, None), "utf-16-le", "strict") is None + + # Factory exists and buffer is present, so wrapper result is returned. + raw_buffer = object() + out = get_stream(FakeStream(7, raw_buffer), "utf-16-le", "strict") + assert out == ("wrapped", raw_buffer) + assert wrapped["called"] is True + assert wrapped["buffer"] is raw_buffer + + +def test_windows_console_reader(monkeypatch): + reader = _winconsole._WindowsConsoleReader(42) + + # Empty input buffer returns early + assert reader.readinto(bytearray()) == 0 + + # Require an even number of bytes + with pytest.raises(ValueError): + reader.readinto(bytearray(3)) + + def writable_buffer(obj, writable=False): # noqa: ARG001 + return (ctypes.c_char * len(obj)).from_buffer(obj) + + monkeypatch.setattr(_winconsole, "get_buffer", writable_buffer) + + def patch_console(read_console, error): + monkeypatch.setattr(_winconsole, "ReadConsoleW", read_console) + monkeypatch.setattr(_winconsole, "GetLastError", lambda: error) + + def make_read(payload=b"", rv=1, units_read=None): + def read_console(handle, buffer, units_to_read, units_read_ptr, reserved): # noqa: ARG001 + bytes_to_copy = min(len(payload), units_to_read * 2) + if bytes_to_copy: + buffer[0:bytes_to_copy] = payload[:bytes_to_copy] + read_units = units_read if units_read is not None else bytes_to_copy // 2 + units_read_ptr._obj.value = read_units + return rv + + return read_console + + # Normal successful read returns the number of bytes read + patch_console(make_read(payload=b"A\x00B\x00"), _winconsole.ERROR_SUCCESS) + assert reader.readinto(bytearray(4)) == 4 + + # CTRL+Z (EOF) should be translated into an empty read + patch_console( + make_read(payload=_winconsole.EOF + b"\x00", units_read=1), + _winconsole.ERROR_SUCCESS, + ) + assert reader.readinto(bytearray(2)) == 0 + + # An aborted read should sleep briefly while waiting for KeyboardInterrupt + sleep_state = {"calls": 0} + + def fake_sleep(seconds): + sleep_state["calls"] += 1 + assert seconds == 0.1 + + monkeypatch.setattr( + _winconsole, "time", type("FakeTime", (), {"sleep": fake_sleep}) + ) + patch_console( + make_read(payload=b"Z\x00", units_read=1), + _winconsole.ERROR_OPERATION_ABORTED, + ) + assert reader.readinto(bytearray(2)) == 2 + assert sleep_state["calls"] == 1 + + # Failed reads propagate a Windows error + patch_console(make_read(rv=0), _winconsole.ERROR_NOT_ENOUGH_MEMORY) + with pytest.raises(OSError, match="Windows error"): + reader.readinto(bytearray(2)) diff --git a/tests/utils.py b/tests/utils.py index 35c441d365..7cb285bc53 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,9 +9,16 @@ needs_linux = pytest.mark.skipif( not sys.platform.startswith("linux"), reason="Test requires Linux" ) +needs_macos = pytest.mark.skipif( + not sys.platform.startswith("darwin"), reason="Test requires macOS" +) needs_windows = pytest.mark.skipif( not sys.platform.startswith("win"), reason="Test requires Windows" ) +skip_if_windows = pytest.mark.skipif( + sys.platform == "win32", + reason="Test should not be run on Windows", +) needs_rich = pytest.mark.skipif(not HAS_RICH, reason="Test requires Rich") diff --git a/typer/.agents/skills/typer/SKILL.md b/typer/.agents/skills/typer/SKILL.md index 20e6cdd589..45bf622407 100644 --- a/typer/.agents/skills/typer/SKILL.md +++ b/typer/.agents/skills/typer/SKILL.md @@ -259,7 +259,7 @@ if __name__ == "__main__": ## Click -Originally, Typer was built on Click. However, going forward Typer will vendor Click. As such, Click extensions should not be used anymore. +Originally, Typer was built on Click. However, since version 0.26.0, Typer has vendored Click. As such, Click extensions should not be used anymore. Other settings of `Option` and `Argument` that came from Click but shouldn't be used in Typer anymore, include: `expose_value`, `shell_complete`, `show_choices`, `errors`, `prompt_required`, `is_flag`, `flag_value` and `allow_from_autoenv`. diff --git a/typer/__init__.py b/typer/__init__.py index 548408fe98..392bd006ff 100644 --- a/typer/__init__.py +++ b/typer/__init__.py @@ -4,28 +4,21 @@ from shutil import get_terminal_size as get_terminal_size -from click.exceptions import Abort as Abort -from click.exceptions import BadParameter as BadParameter -from click.exceptions import Exit as Exit -from click.termui import clear as clear -from click.termui import confirm as confirm -from click.termui import echo_via_pager as echo_via_pager -from click.termui import edit as edit -from click.termui import getchar as getchar -from click.termui import pause as pause -from click.termui import progressbar as progressbar -from click.termui import prompt as prompt -from click.termui import secho as secho -from click.termui import style as style -from click.termui import unstyle as unstyle -from click.utils import echo as echo -from click.utils import format_filename as format_filename -from click.utils import get_app_dir as get_app_dir -from click.utils import get_binary_stream as get_binary_stream -from click.utils import get_text_stream as get_text_stream -from click.utils import open_file as open_file - from . import colors as colors +from ._click.exceptions import Abort as Abort +from ._click.exceptions import BadParameter as BadParameter +from ._click.exceptions import Exit as Exit +from ._click.termui import confirm as confirm +from ._click.termui import getchar as getchar +from ._click.termui import progressbar as progressbar +from ._click.termui import prompt as prompt +from ._click.termui import secho as secho +from ._click.termui import style as style +from ._click.utils import echo as echo +from ._click.utils import format_filename as format_filename +from ._click.utils import get_app_dir as get_app_dir +from ._click.utils import get_binary_stream as get_binary_stream +from ._click.utils import get_text_stream as get_text_stream from .main import Typer as Typer from .main import launch as launch from .main import run as run diff --git a/typer/_click/LICENSE.txt b/typer/_click/LICENSE.txt new file mode 100644 index 0000000000..d12a849186 --- /dev/null +++ b/typer/_click/LICENSE.txt @@ -0,0 +1,28 @@ +Copyright 2014 Pallets + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/typer/_click/__init__.py b/typer/_click/__init__.py new file mode 100644 index 0000000000..48c9edcddd --- /dev/null +++ b/typer/_click/__init__.py @@ -0,0 +1,11 @@ +""" +Code taken and adapted from Click: https://github.com/pallets/click/releases/tag/8.3.1 +""" + +from .core import Command as Command +from .core import Context as Context +from .core import Parameter as Parameter +from .exceptions import ClickException as ClickException +from .formatting import HelpFormatter as HelpFormatter +from .termui import launch as launch +from .utils import echo as echo diff --git a/typer/_click/_compat.py b/typer/_click/_compat.py new file mode 100644 index 0000000000..6ed0ceb8ac --- /dev/null +++ b/typer/_click/_compat.py @@ -0,0 +1,569 @@ +import codecs +import io +import os +import re +import sys +from collections.abc import Callable, Mapping, MutableMapping +from types import TracebackType +from typing import ( + IO, + Any, + BinaryIO, + TextIO, + cast, +) +from weakref import WeakKeyDictionary + +CYGWIN = sys.platform.startswith("cygwin") +WIN = sys.platform.startswith("win") +auto_wrap_for_ansi: Callable[[TextIO], TextIO] | None = None +_ansi_re = re.compile(r"\033\[[;?0-9]*[a-zA-Z]") + + +def _make_text_stream( + stream: BinaryIO, + encoding: str | None, + errors: str, +) -> TextIO: + if encoding is None: + encoding = get_best_encoding(stream) + return _NonClosingTextIOWrapper( + stream, + encoding, + errors, + line_buffering=True, + ) + + +def is_ascii_encoding(encoding: str) -> bool: + """Checks if a given encoding is ascii.""" + try: + return codecs.lookup(encoding).name == "ascii" + except LookupError: + return False + + +def get_best_encoding(stream: IO[Any]) -> str: + """Returns the default stream encoding if not found.""" + rv = getattr(stream, "encoding", None) or sys.getdefaultencoding() + if is_ascii_encoding(rv): + return "utf-8" + return rv + + +class _NonClosingTextIOWrapper(io.TextIOWrapper): + def __init__( + self, + stream: BinaryIO, + encoding: str | None, + errors: str | None, + **extra: Any, + ) -> None: + self._stream = stream = cast(BinaryIO, _FixupStream(stream)) + super().__init__(stream, encoding, errors, **extra) + + def __del__(self) -> None: + try: + self.detach() + except Exception: # pragma: no cover + pass + + def isatty(self) -> bool: + # https://bitbucket.org/pypy/pypy/issue/1803 + return self._stream.isatty() + + +class _FixupStream: + """The new io interface needs more from streams than streams + traditionally implement. As such, this fix-up code is necessary in + some circumstances. + """ + + def __init__( + self, + stream: BinaryIO, + ): + self._stream = stream + + def __getattr__(self, name: str) -> Any: + return getattr(self._stream, name) + + def read1(self, size: int) -> bytes: + f = getattr(self._stream, "read1", None) + + if f is not None: + return cast(bytes, f(size)) + + return self._stream.read(size) + + def readable(self) -> bool: + return True + + def writable(self) -> bool: + return True + + def seekable(self) -> bool: + x = getattr(self._stream, "seekable", None) + if x is not None: + return cast(bool, x()) + return False + + +def _is_binary_reader(stream: IO[Any], default: bool = False) -> bool: + try: + return isinstance(stream.read(0), bytes) + except Exception: # pragma: no cover + return default + # This happens in some cases where the stream was already + # closed. In this case, we assume the default. + + +def _is_binary_writer(stream: IO[Any], default: bool = False) -> bool: + try: + stream.write(b"") + except Exception: # pragma: no cover + try: + stream.write("") + return False + except Exception: + pass + return default + return True + + +def _find_binary_reader(stream: IO[Any]) -> BinaryIO | None: + # We need to figure out if the given stream is already binary. + # This can happen because the official docs recommend detaching + # the streams to get binary streams. Some code might do this, so + # we need to deal with this case explicitly. + if _is_binary_reader(stream, False): + return cast(BinaryIO, stream) + + buf = getattr(stream, "buffer", None) + + # Same situation here; this time we assume that the buffer is + # actually binary in case it's closed. + if buf is not None and _is_binary_reader(buf, True): + return cast(BinaryIO, buf) + + return None + + +def _find_binary_writer(stream: IO[Any]) -> BinaryIO | None: + # We need to figure out if the given stream is already binary. + # This can happen because the official docs recommend detaching + # the streams to get binary streams. Some code might do this, so + # we need to deal with this case explicitly. + if _is_binary_writer(stream, False): + return cast(BinaryIO, stream) + + buf = getattr(stream, "buffer", None) + + # Same situation here; this time we assume that the buffer is + # actually binary in case it's closed. + if buf is not None and _is_binary_writer(buf, True): + return cast(BinaryIO, buf) + + return None + + +def _stream_is_misconfigured(stream: TextIO) -> bool: + """A stream is misconfigured if its encoding is ASCII.""" + # If the stream does not have an encoding set, we assume it's set + # to ASCII. This appears to happen in certain unittest + # environments. It's not quite clear what the correct behavior is + # but this at least will force Click to recover somehow. + return is_ascii_encoding(getattr(stream, "encoding", None) or "ascii") + + +def _is_compat_stream_attr(stream: TextIO, attr: str, value: str | None) -> bool: + """A stream attribute is compatible if it is equal to the + desired value or the desired value is unset and the attribute + has a value. + """ + stream_value = getattr(stream, attr, None) + return stream_value == value or (value is None and stream_value is not None) + + +def _is_compatible_text_stream( + stream: TextIO, encoding: str | None, errors: str | None +) -> bool: + """Check if a stream's encoding and errors attributes are + compatible with the desired values. + """ + return _is_compat_stream_attr( + stream, "encoding", encoding + ) and _is_compat_stream_attr(stream, "errors", errors) + + +def _force_correct_text_stream( + text_stream: IO[Any], + encoding: str | None, + errors: str | None, + is_binary: Callable[[IO[Any], bool], bool], + find_binary: Callable[[IO[Any]], BinaryIO | None], +) -> TextIO: + if is_binary(text_stream, False): + binary_reader = cast(BinaryIO, text_stream) + else: + text_stream = cast(TextIO, text_stream) + # If the stream looks compatible, and won't default to a + # misconfigured ascii encoding, return it as-is. + if _is_compatible_text_stream(text_stream, encoding, errors) and not ( + encoding is None and _stream_is_misconfigured(text_stream) + ): + return text_stream + + # Otherwise, get the underlying binary reader. + possible_binary_reader = find_binary(text_stream) + + # If that's not possible, silently use the original reader + # and get mojibake instead of exceptions. + if possible_binary_reader is None: + return text_stream + + binary_reader = possible_binary_reader + + # Default errors to replace instead of strict in order to get + # something that works. + if errors is None: + errors = "replace" + + # Wrap the binary stream in a text stream with the correct + # encoding parameters. + return _make_text_stream( + binary_reader, + encoding, + errors, + ) + + +def _force_correct_text_reader( + text_reader: IO[Any], + encoding: str | None, + errors: str | None, +) -> TextIO: + return _force_correct_text_stream( + text_reader, + encoding, + errors, + _is_binary_reader, + _find_binary_reader, + ) + + +def _force_correct_text_writer( + text_writer: IO[Any], + encoding: str | None, + errors: str | None, +) -> TextIO: + return _force_correct_text_stream( + text_writer, + encoding, + errors, + _is_binary_writer, + _find_binary_writer, + ) + + +def get_binary_stdin() -> BinaryIO: + reader = _find_binary_reader(sys.stdin) + if reader is None: # pragma: no cover + raise RuntimeError("Was not able to determine binary stream for sys.stdin.") + return reader + + +def get_binary_stdout() -> BinaryIO: + writer = _find_binary_writer(sys.stdout) + if writer is None: # pragma: no cover + raise RuntimeError("Was not able to determine binary stream for sys.stdout.") + return writer + + +def get_binary_stderr() -> BinaryIO: + writer = _find_binary_writer(sys.stderr) + if writer is None: # pragma: no cover + raise RuntimeError("Was not able to determine binary stream for sys.stderr.") + return writer + + +def get_text_stdin(encoding: str | None = None, errors: str | None = None) -> TextIO: + rv = _get_windows_console_stream(sys.stdin, encoding, errors) + if rv is not None: + return rv + return _force_correct_text_reader(sys.stdin, encoding, errors) + + +def get_text_stdout(encoding: str | None = None, errors: str | None = None) -> TextIO: + rv = _get_windows_console_stream(sys.stdout, encoding, errors) + if rv is not None: + return rv + return _force_correct_text_writer(sys.stdout, encoding, errors) + + +def get_text_stderr(encoding: str | None = None, errors: str | None = None) -> TextIO: + rv = _get_windows_console_stream(sys.stderr, encoding, errors) + if rv is not None: + return rv + return _force_correct_text_writer(sys.stderr, encoding, errors) + + +def _wrap_io_open( + file: str | os.PathLike[str] | int, + mode: str, + encoding: str | None, + errors: str | None, +) -> IO[Any]: + """Handles not passing ``encoding`` and ``errors`` in binary mode.""" + if "b" in mode: + return open(file, mode) + + return open(file, mode, encoding=encoding, errors=errors) + + +def open_stream( + filename: str | os.PathLike[str], + mode: str = "r", + encoding: str | None = None, + errors: str | None = "strict", + atomic: bool = False, +) -> tuple[IO[Any], bool]: + binary = "b" in mode + filename = os.fspath(filename) + + # Standard streams first, ignoring the atomic flag. + if os.fsdecode(filename) == "-": + if any(m in mode for m in ["w", "a", "x"]): + if binary: + return get_binary_stdout(), False + return get_text_stdout(encoding=encoding, errors=errors), False + if binary: + return get_binary_stdin(), False + return get_text_stdin(encoding=encoding, errors=errors), False + + # Non-atomic writes directly go out through the regular open functions. + if not atomic: + return _wrap_io_open(filename, mode, encoding, errors), True + + # Some usability stuff for atomic writes + if "a" in mode: + raise ValueError( + "Appending to an existing file is not supported, because that" + " would involve an expensive `copy`-operation to a temporary" + " file. Open the file in normal `w`-mode and copy explicitly" + " if that's what you're after." + ) + if "x" in mode: + raise ValueError("Use the `overwrite`-parameter instead.") + if "w" not in mode: + raise ValueError("Atomic writes only make sense with `w`-mode.") + + # Atomic writes are more complicated. They work by opening a file + # as a proxy in the same folder and then using the fdopen + # functionality to wrap it in a Python file. Then we wrap it in an + # atomic file that moves the file over on close. + import errno + import random + + try: + perm: int | None = os.stat(filename).st_mode + except OSError: # pragma: no cover + perm = None + + flags = os.O_RDWR | os.O_CREAT | os.O_EXCL + + if binary: + flags |= getattr(os, "O_BINARY", 0) + + while True: + tmp_filename = os.path.join( + os.path.dirname(filename), + f".__atomic-write{random.randrange(1 << 32):08x}", + ) + try: + fd = os.open(tmp_filename, flags, 0o666 if perm is None else perm) + break + except OSError as e: # pragma: no cover + if e.errno == errno.EEXIST or ( + os.name == "nt" + and e.errno == errno.EACCES + and os.path.isdir(e.filename) + and os.access(e.filename, os.W_OK) + ): + continue + raise + + if perm is not None: + os.chmod(tmp_filename, perm) # in case perm includes bits in umask + + f = _wrap_io_open(fd, mode, encoding, errors) + af = _AtomicFile(f, tmp_filename, os.path.realpath(filename)) + return cast(IO[Any], af), True + + +class _AtomicFile: + def __init__(self, f: IO[Any], tmp_filename: str, real_filename: str) -> None: + self._f = f + self._tmp_filename = tmp_filename + self._real_filename = real_filename + self.closed = False + + @property + def name(self) -> str: + return self._real_filename + + def close(self, delete: bool = False) -> None: + if self.closed: + return # pragma: no cover + self._f.close() + os.replace(self._tmp_filename, self._real_filename) + self.closed = True + + def __getattr__(self, name: str) -> Any: + return getattr(self._f, name) + + def __enter__(self) -> "_AtomicFile": + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.close(delete=exc_type is not None) + + def __repr__(self) -> str: + return repr(self._f) + + +def strip_ansi(value: str) -> str: + return _ansi_re.sub("", value) + + +def _is_jupyter_kernel_output(stream: IO[Any]) -> bool: + while isinstance(stream, (_FixupStream, _NonClosingTextIOWrapper)): + stream = stream._stream + + return stream.__class__.__module__.startswith("ipykernel.") + + +def should_strip_ansi(stream: IO[Any] | None = None, color: bool | None = None) -> bool: + if color is None: + if stream is None: + stream = sys.stdin + return not isatty(stream) and not _is_jupyter_kernel_output(stream) + return not color + + +# On Windows, wrap the output streams with colorama to support ANSI +# color codes. +# NOTE: double check is needed so mypy does not analyze this on Linux +if sys.platform.startswith("win") and WIN: + from ._winconsole import _get_windows_console_stream + + def _get_argv_encoding() -> str: + import locale + + return locale.getpreferredencoding() + + _ansi_stream_wrappers: MutableMapping[TextIO, TextIO] = WeakKeyDictionary() + + def auto_wrap_for_ansi(stream: TextIO, color: bool | None = None) -> TextIO: + """Support ANSI color and style codes on Windows by wrapping a + stream with colorama. + """ + try: + cached = _ansi_stream_wrappers.get(stream) + except Exception: # pragma: no cover + cached = None + + if cached is not None: + return cached + + import colorama + + strip = should_strip_ansi(stream, color) + ansi_wrapper = colorama.AnsiToWin32(stream, strip=strip) + rv = cast(TextIO, ansi_wrapper.stream) + _write = rv.write + + def _safe_write(s: str) -> int: + try: + return _write(s) + except BaseException: # pragma: no cover + ansi_wrapper.reset_all() + raise + + rv.write = _safe_write # type: ignore[method-assign] # ty: ignore[invalid-assignment] + + try: + _ansi_stream_wrappers[stream] = rv + except Exception: # pragma: no cover + pass + + return rv + +else: + + def _get_argv_encoding() -> str: + return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding() + + def _get_windows_console_stream( + f: TextIO, encoding: str | None, errors: str | None + ) -> TextIO | None: + return None + + +def term_len(x: str) -> int: + return len(strip_ansi(x)) + + +def isatty(stream: IO[Any]) -> bool: + try: + return stream.isatty() + except Exception: # pragma: no cover + return False + + +def _make_cached_stream_func( + src_func: Callable[[], TextIO], + wrapper_func: Callable[[], TextIO], +) -> Callable[[], TextIO]: + cache: MutableMapping[TextIO, TextIO] = WeakKeyDictionary() + + def func() -> TextIO: + stream = src_func() + + try: + rv = cache.get(stream) + except Exception: # pragma: no cover + rv = None + if rv is not None: + return rv + rv = wrapper_func() + try: + cache[stream] = rv + except Exception: # pragma: no cover + pass + return rv + + return func + + +_default_text_stdin = _make_cached_stream_func(lambda: sys.stdin, get_text_stdin) +_default_text_stdout = _make_cached_stream_func(lambda: sys.stdout, get_text_stdout) +_default_text_stderr = _make_cached_stream_func(lambda: sys.stderr, get_text_stderr) + + +binary_streams: Mapping[str, Callable[[], BinaryIO]] = { + "stdin": get_binary_stdin, + "stdout": get_binary_stdout, + "stderr": get_binary_stderr, +} + +text_streams: Mapping[str, Callable[[str | None, str | None], TextIO]] = { + "stdin": get_text_stdin, + "stdout": get_text_stdout, + "stderr": get_text_stderr, +} diff --git a/typer/_click/_termui_impl.py b/typer/_click/_termui_impl.py new file mode 100644 index 0000000000..c621810cbf --- /dev/null +++ b/typer/_click/_termui_impl.py @@ -0,0 +1,522 @@ +""" +To keep the import times down, some infrequently used termui functionality +is placed here and only imported as needed. +""" + +import contextlib +import math +import os +import sys +import time +from collections.abc import Callable, Iterable, Iterator +from io import StringIO +from types import TracebackType +from typing import Generic, TextIO, TypeVar, cast + +from ._compat import ( + CYGWIN, + WIN, + _default_text_stdout, + get_best_encoding, + isatty, + term_len, +) +from .utils import echo + +V = TypeVar("V") + +if os.name == "nt": + BEFORE_BAR = "\r" + AFTER_BAR = "\n" +else: + BEFORE_BAR = "\r\033[?25l" + AFTER_BAR = "\033[?25h\n" + + +class ProgressBar(Generic[V]): + def __init__( + self, + iterable: Iterable[V] | None, + length: int | None = None, + fill_char: str = "#", + empty_char: str = " ", + bar_template: str = "%(bar)s", + info_sep: str = " ", + hidden: bool = False, + show_eta: bool = True, + show_percent: bool | None = None, + show_pos: bool = False, + item_show_func: Callable[[V | None], str | None] | None = None, + label: str | None = None, + file: TextIO | None = None, + color: bool | None = None, + update_min_steps: int = 1, + width: int = 30, + ) -> None: + self.fill_char = fill_char + self.empty_char = empty_char + self.bar_template = bar_template + self.info_sep = info_sep + self.hidden = hidden + self.show_eta = show_eta + self.show_percent = show_percent + self.show_pos = show_pos + self.item_show_func = item_show_func + self.label: str = label or "" + + if file is None: + file = _default_text_stdout() + + # There are no standard streams attached to write to. For example, + # pythonw on Windows. + if file is None: # pragma: no cover + file = StringIO() + + self.file = file + self.color = color + self.update_min_steps = update_min_steps + self._completed_intervals = 0 + self.width: int = width + self.autowidth: bool = width == 0 + + if length is None: + from operator import length_hint + + length = length_hint(iterable, -1) + + if length == -1: # pragma: no cover + length = None + if iterable is None: + if length is None: # pragma: no cover + raise TypeError("iterable or length is required") + iterable = cast("Iterable[V]", range(length)) + self.iter: Iterable[V] = iter(iterable) + self.length = length + self.pos: int = 0 + self.avg: list[float] = [] + self.last_eta: float + self.start: float + self.start = self.last_eta = time.time() + self.eta_known: bool = False + self.finished: bool = False + self.max_width: int | None = None + self.entered: bool = False + self.current_item: V | None = None + self._is_atty = isatty(self.file) + self._last_line: str | None = None + + def __enter__(self) -> "ProgressBar[V]": + self.entered = True + self.render_progress() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.render_finish() + + def __iter__(self) -> Iterator[V]: + if not self.entered: + raise RuntimeError("You need to use progress bars in a with block.") + self.render_progress() + return self.generator() + + def __next__(self) -> V: + # Iteration is defined in terms of a generator function, + # returned by iter(self); use that to define next(). This works + # because `self.iter` is an iterable consumed by that generator, + # so it is re-entry safe. Calling `next(self.generator())` + # twice works and does "what you want". + return next(iter(self)) + + def render_finish(self) -> None: + if self.hidden or not self._is_atty: + return + self.file.write(AFTER_BAR) + self.file.flush() + + @property + def pct(self) -> float: + if self.finished: + return 1.0 + return min(self.pos / (float(self.length or 1) or 1), 1.0) + + @property + def time_per_iteration(self) -> float: + if not self.avg: + return 0.0 + return sum(self.avg) / float(len(self.avg)) + + @property + def eta(self) -> float: + if self.length is not None and not self.finished: + return self.time_per_iteration * (self.length - self.pos) + return 0.0 + + def format_eta(self) -> str: + if self.eta_known: + t = int(self.eta) + seconds = t % 60 + t //= 60 + minutes = t % 60 + t //= 60 + hours = t % 24 + t //= 24 + if t > 0: + return f"{t}d {hours:02}:{minutes:02}:{seconds:02}" + else: + return f"{hours:02}:{minutes:02}:{seconds:02}" + return "" + + def format_pos(self) -> str: + pos = str(self.pos) + if self.length is not None: + pos += f"/{self.length}" + return pos + + def format_pct(self) -> str: + return f"{int(self.pct * 100): 4}%"[1:] + + def format_bar(self) -> str: + if self.length is not None: + bar_length = int(self.pct * self.width) + bar = self.fill_char * bar_length + bar += self.empty_char * (self.width - bar_length) + elif self.finished: + bar = self.fill_char * self.width + else: + chars = list(self.empty_char * (self.width or 1)) + if self.time_per_iteration != 0: + chars[ + int( + (math.cos(self.pos * self.time_per_iteration) / 2.0 + 0.5) + * self.width + ) + ] = self.fill_char + bar = "".join(chars) + return bar + + def format_progress_line(self) -> str: + show_percent = self.show_percent + + info_bits = [] + if self.length is not None and show_percent is None: + show_percent = not self.show_pos + + if self.show_pos: + info_bits.append(self.format_pos()) + if show_percent: + info_bits.append(self.format_pct()) + if self.show_eta and self.eta_known and not self.finished: + info_bits.append(self.format_eta()) + if self.item_show_func is not None: + item_info = self.item_show_func(self.current_item) + if item_info is not None: + info_bits.append(item_info) + + return ( + self.bar_template + % { + "label": self.label, + "bar": self.format_bar(), + "info": self.info_sep.join(info_bits), + } + ).rstrip() + + def render_progress(self) -> None: + if self.hidden: + return + + if not self._is_atty: + # Only output the label once if the output is not a TTY. + if self._last_line != self.label: + self._last_line = self.label + echo(self.label, file=self.file, color=self.color) + return + + buf = [] + # Update width in case the terminal has been resized + if self.autowidth: + import shutil + + old_width = self.width + self.width = 0 + clutter_length = term_len(self.format_progress_line()) + new_width = max(0, shutil.get_terminal_size().columns - clutter_length) + if new_width < old_width and self.max_width is not None: + buf.append(BEFORE_BAR) + buf.append(" " * self.max_width) + self.max_width = new_width + self.width = new_width + + clear_width = self.width + if self.max_width is not None: + clear_width = self.max_width + + buf.append(BEFORE_BAR) + line = self.format_progress_line() + line_len = term_len(line) + if self.max_width is None or self.max_width < line_len: + self.max_width = line_len + + buf.append(line) + buf.append(" " * (clear_width - line_len)) + line = "".join(buf) + # Render the line only if it changed. + + if line != self._last_line: + self._last_line = line + echo(line, file=self.file, color=self.color, nl=False) + self.file.flush() + + def make_step(self, n_steps: int) -> None: + self.pos += n_steps + if self.length is not None and self.pos >= self.length: + self.finished = True + + if (time.time() - self.last_eta) < 1.0: + return + + self.last_eta = time.time() + + # self.avg is a rolling list of length <= 7 of steps where steps are + # defined as time elapsed divided by the total progress through + # self.length. + if self.pos: + step = (time.time() - self.start) / self.pos + else: + step = time.time() - self.start + + self.avg = self.avg[-6:] + [step] + + self.eta_known = self.length is not None + + def update(self, n_steps: int) -> None: + """Update the progress bar by advancing a specified number of steps.""" + self._completed_intervals += n_steps + + if self._completed_intervals >= self.update_min_steps: + self.make_step(self._completed_intervals) + self.render_progress() + self._completed_intervals = 0 + + def finish(self) -> None: + self.eta_known = False + self.current_item = None + self.finished = True + + def generator(self) -> Iterator[V]: + """Return a generator which yields the items added to the bar + during construction, and updates the progress bar *after* the + yielded block returns. + """ + # WARNING: the iterator interface for `ProgressBar` relies on + # this and only works because this is a simple generator which + # doesn't create or manage additional state. If this function + # changes, the impact should be evaluated both against + # `iter(bar)` and `next(bar)`. `next()` in particular may call + # `self.generator()` repeatedly, and this must remain safe in + # order for that interface to work. + if not self.entered: # pragma: no cover + raise RuntimeError("You need to use progress bars in a with block.") + + if not self._is_atty: + yield from self.iter + else: + for rv in self.iter: + self.current_item = rv + + # This allows show_item_func to be updated before the + # item is processed. Only trigger at the beginning of + # the update interval. + if self._completed_intervals == 0: + self.render_progress() + + yield rv + self.update(1) + + self.finish() + self.render_progress() + + +def open_url(url: str, wait: bool = False, locate: bool = False) -> int: + import subprocess + + def _unquote_file(url: str) -> str: + from urllib.parse import unquote + + if url.startswith("file://"): + url = unquote(url[7:]) + + return url + + if sys.platform == "darwin": + args = ["open"] + if wait: + args.append("-W") + if locate: + args.append("-R") + args.append(_unquote_file(url)) + null = open("/dev/null", "w") + try: + return subprocess.Popen(args, stderr=null).wait() + finally: + null.close() + elif WIN: + if locate: + url = _unquote_file(url) + args = ["explorer", f"/select,{url}"] + else: + args = ["start"] + if wait: + args.append("/WAIT") + args.append("") + args.append(url) + try: + return subprocess.call(args) + except OSError: + # Command not found + return 127 + elif CYGWIN: # pragma: no cover + if locate: + url = _unquote_file(url) + args = ["cygstart", os.path.dirname(url)] + else: + args = ["cygstart"] + if wait: + args.append("-w") + args.append(url) + try: + return subprocess.call(args) + except OSError: + # Command not found + return 127 + + try: + if locate: + url = os.path.dirname(_unquote_file(url)) or "." + else: + url = _unquote_file(url) + c = subprocess.Popen(["xdg-open", url]) + if wait: + return c.wait() + return 0 + except OSError: # pragma: no cover + # TODO: remove this part, doesn't get hit by Typer code paths? + if url.startswith(("http://", "https://")) and not locate and not wait: + import webbrowser + + webbrowser.open(url) + return 0 + return 1 + + +def _translate_ch_to_exc(ch: str) -> None: + if ch == "\x03": + raise KeyboardInterrupt() + + if ch == "\x04" and not WIN: # Unix-like, Ctrl+D + raise EOFError() + + if ch == "\x1a" and WIN: # Windows, Ctrl+Z + raise EOFError() + + return None + + +if sys.platform == "win32": + import msvcrt + + @contextlib.contextmanager + def raw_terminal() -> Iterator[int]: + yield -1 + + def getchar(echo: bool) -> str: + # The function `getch` will return a bytes object corresponding to + # the pressed character. Since Windows 10 build 1803, it will also + # return \x00 when called a second time after pressing a regular key. + # + # `getwch` does not share this probably-bugged behavior. Moreover, it + # returns a Unicode object by default, which is what we want. + # + # Either of these functions will return \x00 or \xe0 to indicate + # a special key, and you need to call the same function again to get + # the "rest" of the code. The fun part is that \u00e0 is + # "latin small letter a with grave", so if you type that on a French + # keyboard, you _also_ get a \xe0. + # E.g., consider the Up arrow. This returns \xe0 and then \x48. The + # resulting Unicode string reads as "a with grave" + "capital H". + # This is indistinguishable from when the user actually types + # "a with grave" and then "capital H". + # + # When \xe0 is returned, we assume it's part of a special-key sequence + # and call `getwch` again, but that means that when the user types + # the \u00e0 character, `getchar` doesn't return until a second + # character is typed. + # The alternative is returning immediately, but that would mess up + # cross-platform handling of arrow keys and others that start with + # \xe0. Another option is using `getch`, but then we can't reliably + # read non-ASCII characters, because return values of `getch` are + # limited to the current 8-bit codepage. + # + # Anyway, Click doesn't claim to do this Right(tm), and using `getwch` + # is doing the right thing in more situations than with `getch`. + + if echo: + func = cast(Callable[[], str], msvcrt.getwche) + else: + func = cast(Callable[[], str], msvcrt.getwch) + + rv = func() + + if rv in ("\x00", "\xe0"): + # \x00 and \xe0 are control characters that indicate special key, + # see above. + rv += func() + + _translate_ch_to_exc(rv) + return rv + +else: + import termios + import tty + + @contextlib.contextmanager + def raw_terminal() -> Iterator[int]: + f: TextIO | None + fd: int + + if not isatty(sys.stdin): + f = open("/dev/tty") + fd = f.fileno() + else: + fd = sys.stdin.fileno() + f = None + + try: + old_settings = termios.tcgetattr(fd) + + try: + tty.setraw(fd) + yield fd + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + sys.stdout.flush() + + if f is not None: + f.close() + except termios.error: # pragma: no cover + pass + + def getchar(echo: bool) -> str: + with raw_terminal() as fd: + ch = os.read(fd, 32).decode(get_best_encoding(sys.stdin), "replace") + + if echo and isatty(sys.stdout): # pragma: no cover + sys.stdout.write(ch) + + _translate_ch_to_exc(ch) + return ch diff --git a/typer/_click/_textwrap.py b/typer/_click/_textwrap.py new file mode 100644 index 0000000000..9f9636b31f --- /dev/null +++ b/typer/_click/_textwrap.py @@ -0,0 +1,46 @@ +import textwrap +from collections.abc import Iterator +from contextlib import contextmanager + + +class TextWrapper(textwrap.TextWrapper): + def _handle_long_word( + self, + reversed_chunks: list[str], + cur_line: list[str], + cur_len: int, + width: int, + ) -> None: + space_left = max(width - cur_len, 1) + + last = reversed_chunks[-1] + cut = last[:space_left] + res = last[space_left:] + cur_line.append(cut) + reversed_chunks[-1] = res + + @contextmanager + def extra_indent(self, indent: str) -> Iterator[None]: + old_initial_indent = self.initial_indent + old_subsequent_indent = self.subsequent_indent + self.initial_indent += indent + self.subsequent_indent += indent + + try: + yield + finally: + self.initial_indent = old_initial_indent + self.subsequent_indent = old_subsequent_indent + + def indent_only(self, text: str) -> str: + rv = [] + + for idx, line in enumerate(text.splitlines()): + indent = self.initial_indent + + if idx > 0: + indent = self.subsequent_indent + + rv.append(f"{indent}{line}") + + return "\n".join(rv) diff --git a/typer/_click/_winconsole.py b/typer/_click/_winconsole.py new file mode 100644 index 0000000000..f6dedbfd6e --- /dev/null +++ b/typer/_click/_winconsole.py @@ -0,0 +1,300 @@ +# This module is based on the excellent work by Adam Bartoš who +# provided a lot of what went into the implementation here in +# the discussion to issue1602 in the Python bug tracker. +# +# There are some general differences in regards to how this works +# compared to the original patches as we do not need to patch +# the entire interpreter but just work in our little world of +# echo and prompt. +import io +import sys +import time +from collections.abc import Callable, Iterable, Mapping +from ctypes import ( + POINTER, + Array, + Structure, + byref, + c_char, + c_char_p, + c_int, + c_ssize_t, + c_ulong, + c_void_p, + py_object, +) +from ctypes.wintypes import DWORD, HANDLE, LPCWSTR, LPWSTR +from typing import ( + TYPE_CHECKING, + Any, + AnyStr, + BinaryIO, + Literal, + TextIO, + cast, +) + +from ._compat import _NonClosingTextIOWrapper + +assert sys.platform == "win32" +import msvcrt # noqa: E402 +from ctypes import WINFUNCTYPE, windll # noqa: E402 + +c_ssize_p = POINTER(c_ssize_t) + +kernel32 = windll.kernel32 +GetStdHandle = kernel32.GetStdHandle +ReadConsoleW = kernel32.ReadConsoleW +WriteConsoleW = kernel32.WriteConsoleW +GetConsoleMode = kernel32.GetConsoleMode +GetLastError = kernel32.GetLastError +GetCommandLineW = WINFUNCTYPE(LPWSTR)(("GetCommandLineW", windll.kernel32)) +CommandLineToArgvW = WINFUNCTYPE(POINTER(LPWSTR), LPCWSTR, POINTER(c_int))( + ("CommandLineToArgvW", windll.shell32) +) +LocalFree = WINFUNCTYPE(c_void_p, c_void_p)(("LocalFree", windll.kernel32)) + +STDIN_HANDLE = GetStdHandle(-10) +STDOUT_HANDLE = GetStdHandle(-11) +STDERR_HANDLE = GetStdHandle(-12) + +PyBUF_SIMPLE = 0 +PyBUF_WRITABLE = 1 + +ERROR_SUCCESS = 0 +ERROR_NOT_ENOUGH_MEMORY = 8 +ERROR_OPERATION_ABORTED = 995 + +STDIN_FILENO = 0 +STDOUT_FILENO = 1 +STDERR_FILENO = 2 + +EOF = b"\x1a" +MAX_BYTES_WRITTEN = 32767 + +if TYPE_CHECKING: + try: + # Using `typing_extensions.Buffer` instead of `collections.abc` + # on Windows for some reason does not have `Sized` implemented. + from collections.abc import Buffer # type: ignore + except ImportError: + from typing_extensions import Buffer + +try: + from ctypes import pythonapi +except ImportError: # pragma: no cover + # On PyPy we cannot get buffers so our ability to operate here is + # severely limited. + get_buffer = None +else: + + class Py_buffer(Structure): + _fields_ = [ # noqa: RUF012 + ("buf", c_void_p), + ("obj", py_object), + ("len", c_ssize_t), + ("itemsize", c_ssize_t), + ("readonly", c_int), + ("ndim", c_int), + ("format", c_char_p), + ("shape", c_ssize_p), + ("strides", c_ssize_p), + ("suboffsets", c_ssize_p), + ("internal", c_void_p), + ] + + PyObject_GetBuffer = pythonapi.PyObject_GetBuffer + PyBuffer_Release = pythonapi.PyBuffer_Release + + def get_buffer(obj: "Buffer", writable: bool = False) -> Array[c_char]: + buf = Py_buffer() + flags: int = PyBUF_WRITABLE if writable else PyBUF_SIMPLE + PyObject_GetBuffer(py_object(obj), byref(buf), flags) + + try: + buffer_type = c_char * buf.len + out: Array[c_char] = buffer_type.from_address(buf.buf) + return out + finally: + PyBuffer_Release(byref(buf)) + + +class _WindowsConsoleRawIOBase(io.RawIOBase): + def __init__(self, handle: int | None) -> None: + self.handle = handle + + def isatty(self) -> Literal[True]: + super().isatty() + return True + + +class _WindowsConsoleReader(_WindowsConsoleRawIOBase): + def readable(self) -> Literal[True]: + return True + + def readinto(self, b: "Buffer") -> int: + bytes_to_be_read = len(b) + if not bytes_to_be_read: + return 0 + elif bytes_to_be_read % 2: + raise ValueError( + "cannot read odd number of bytes from UTF-16-LE encoded console" + ) + + buffer = get_buffer(b, writable=True) + code_units_to_be_read = bytes_to_be_read // 2 + code_units_read = c_ulong() + + rv = ReadConsoleW( + HANDLE(self.handle), + buffer, + code_units_to_be_read, + byref(code_units_read), + None, + ) + if GetLastError() == ERROR_OPERATION_ABORTED: + # wait for KeyboardInterrupt + time.sleep(0.1) + if not rv: + raise OSError(f"Windows error: {GetLastError()}") + + if buffer[0] == EOF: + return 0 + return 2 * code_units_read.value + + +class _WindowsConsoleWriter(_WindowsConsoleRawIOBase): + def writable(self) -> Literal[True]: + return True + + @staticmethod + def _get_error_message(errno: int) -> str: + if errno == ERROR_SUCCESS: + return "ERROR_SUCCESS" + elif errno == ERROR_NOT_ENOUGH_MEMORY: + return "ERROR_NOT_ENOUGH_MEMORY" + return f"Windows error {errno}" + + def write(self, b: "Buffer") -> int: + bytes_to_be_written = len(b) + buf = get_buffer(b) + code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2 + code_units_written = c_ulong() + + WriteConsoleW( + HANDLE(self.handle), + buf, + code_units_to_be_written, + byref(code_units_written), + None, + ) + bytes_written = 2 * code_units_written.value + + if bytes_written == 0 and bytes_to_be_written > 0: + raise OSError(self._get_error_message(GetLastError())) # pragma: no cover + return bytes_written + + +class ConsoleStream: + def __init__(self, text_stream: TextIO, byte_stream: BinaryIO) -> None: + self._text_stream = text_stream + self.buffer = byte_stream + + @property + def name(self) -> str: + return self.buffer.name + + def write(self, x: AnyStr) -> int: + if isinstance(x, str): + return self._text_stream.write(x) + try: + self.flush() + except Exception: # pragma: no cover + pass + return self.buffer.write(x) + + def writelines(self, lines: Iterable[AnyStr]) -> None: + for line in lines: + self.write(line) + + def __getattr__(self, name: str) -> Any: + return getattr(self._text_stream, name) + + def isatty(self) -> bool: + return self.buffer.isatty() + + def __repr__(self) -> str: + return f"" + + +def _get_text_stdin(buffer_stream: BinaryIO) -> TextIO: + text_stream = _NonClosingTextIOWrapper( + io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)), + "utf-16-le", + "strict", + line_buffering=True, + ) + return cast(TextIO, ConsoleStream(text_stream, buffer_stream)) + + +def _get_text_stdout(buffer_stream: BinaryIO) -> TextIO: + text_stream = _NonClosingTextIOWrapper( + io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)), + "utf-16-le", + "strict", + line_buffering=True, + ) + return cast(TextIO, ConsoleStream(text_stream, buffer_stream)) + + +def _get_text_stderr(buffer_stream: BinaryIO) -> TextIO: + text_stream = _NonClosingTextIOWrapper( + io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)), + "utf-16-le", + "strict", + line_buffering=True, + ) + return cast(TextIO, ConsoleStream(text_stream, buffer_stream)) + + +_stream_factories: Mapping[int, Callable[[BinaryIO], TextIO]] = { + 0: _get_text_stdin, + 1: _get_text_stdout, + 2: _get_text_stderr, +} + + +def _is_console(f: TextIO) -> bool: + if not hasattr(f, "fileno"): + return False + + try: + fileno = f.fileno() + except (OSError, io.UnsupportedOperation): + return False + + handle = msvcrt.get_osfhandle(fileno) + return bool(GetConsoleMode(handle, byref(DWORD()))) + + +def _get_windows_console_stream( + f: TextIO, encoding: str | None, errors: str | None +) -> TextIO | None: + if ( + get_buffer is None + or encoding not in {"utf-16-le", None} + or errors not in {"strict", None} + or not _is_console(f) + ): + return None + + func = _stream_factories.get(f.fileno()) + if func is None: + return None + + b = getattr(f, "buffer", None) + + if b is None: + return None + + return func(b) diff --git a/typer/_click/core.py b/typer/_click/core.py new file mode 100644 index 0000000000..580b558b9f --- /dev/null +++ b/typer/_click/core.py @@ -0,0 +1,1111 @@ +import enum +import inspect +import os +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterator, Mapping, MutableMapping, Sequence +from contextlib import AbstractContextManager, ExitStack, contextmanager +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Literal, + NoReturn, + TypeVar, + Union, + cast, + overload, +) + +from . import types +from .exceptions import ( + Abort, + BadParameter, + Exit, + MissingParameter, + NoArgsIsHelpError, + UsageError, +) +from .formatting import HelpFormatter +from .globals import pop_context, push_context +from .parser import _OptionParser +from .termui import style +from .utils import echo, make_default_short_help + +if TYPE_CHECKING: + from ..core import TyperOption + from .shell_completion import CompletionItem + +F = TypeVar("F", bound="Callable[..., Any]") +V = TypeVar("V") + + +def _complete_visible_commands( + ctx: "Context", incomplete: str +) -> Iterator[tuple[str, "Command"]]: + """List all the subcommands of a group that start with the + incomplete value and aren't hidden. + """ + # avoid circular imports + from ..core import TyperGroup + + multi = cast(TyperGroup, ctx.command) + + for name in multi.list_commands(ctx): + if name.startswith(incomplete): + command = multi.get_command(ctx, name) + + if command is not None and not command.hidden: + yield name, command + + +@contextmanager +def augment_usage_errors( + ctx: "Context", param: Union["Parameter", None] = None +) -> Iterator[None]: + """Context manager that attaches extra information to exceptions.""" + try: + yield + except BadParameter as e: + if e.ctx is None: + e.ctx = ctx + if param is not None and e.param is None: + e.param = param + raise + except UsageError as e: # pragma: no cover + if e.ctx is None: + e.ctx = ctx + raise + + +def iter_params_for_processing( + invocation_order: Sequence["Parameter"], + declaration_order: Sequence["Parameter"], +) -> list["Parameter"]: + """Returns all declared parameters in the order they should be processed. + + The declared parameters are re-shuffled depending on the order in which + they were invoked, as well as the eagerness of each parameters. + + The invocation order takes precedence over the declaration order. I.e. the + order in which the user provided them to the CLI is respected. + + This behavior and its effect on callback evaluation is detailed at: + https://click.palletsprojects.com/en/stable/advanced/#callback-evaluation-order + """ + + def sort_key(item: Parameter) -> tuple[bool, float]: + try: + idx: float = invocation_order.index(item) + except ValueError: + idx = float("inf") + + return not item.is_eager, idx + + return sorted(declaration_order, key=sort_key) + + +class ParameterSource(enum.Enum): + """This is an `Enum` that indicates the source of a + parameter's value. + """ + + COMMANDLINE = enum.auto() + """The value was provided by the command line args.""" + ENVIRONMENT = enum.auto() + """The value was provided with an environment variable.""" + DEFAULT = enum.auto() + """Used the default specified by the parameter.""" + DEFAULT_MAP = enum.auto() + """Used a default provided by `Context.default_map`.""" + PROMPT = enum.auto() + """Used a prompt to confirm a default or provide a value.""" + + +class Context: + """The context is a special internal object that holds state relevant + for the script execution at every single level. It's normally invisible + to commands unless they opt-in to getting access to it. + + The context is useful as it can pass internal objects around and can + control special execution features such as reading data from + environment variables. + + A context can be used as context manager in which case it will call + `close` on teardown. + """ + + formatter_class: type[HelpFormatter] = HelpFormatter + + def __init__( + self, + command: "Command", + parent: Union["Context", None] = None, + info_name: str | None = None, + obj: Any | None = None, + auto_envvar_prefix: str | None = None, + default_map: MutableMapping[str, Any] | None = None, + terminal_width: int | None = None, + max_content_width: int | None = None, + resilient_parsing: bool = False, + allow_extra_args: bool | None = None, + allow_interspersed_args: bool | None = None, + ignore_unknown_options: bool | None = None, + help_option_names: list[str] | None = None, + token_normalize_func: Callable[[str], str] | None = None, + color: bool | None = None, + show_default: bool | None = None, + ) -> None: + self.parent = parent + self.command = command + self.info_name = info_name + # Map of parameter names to their parsed values. + self.params: dict[str, Any] = {} + # the leftover arguments. + self.args: list[str] = [] + # protected arguments. used to implement nested parsing. + self._protected_args: list[str] = [] + # the collected prefixes of the command's options. + self._opt_prefixes: set[str] = set(parent._opt_prefixes) if parent else set() + + if obj is None and parent is not None: + obj = parent.obj + + self.obj: Any = obj + self._meta: dict[str, Any] = getattr(parent, "meta", {}) + + # A dictionary (-like object) with defaults for parameters. + if ( + default_map is None + and info_name is not None + and parent is not None + and parent.default_map is not None + ): + default_map = parent.default_map.get(info_name) + + self.default_map: MutableMapping[str, Any] | None = default_map + + # This flag indicates if a subcommand is going to be executed. + self.invoked_subcommand: str | None = None + + if terminal_width is None and parent is not None: + terminal_width = parent.terminal_width + + # The width of the terminal (None is autodetection). + self.terminal_width: int | None = terminal_width + + if max_content_width is None and parent is not None: + max_content_width = parent.max_content_width + + self.max_content_width: int | None = max_content_width + + if allow_extra_args is None: + allow_extra_args = command.allow_extra_args + + self.allow_extra_args = allow_extra_args + + if allow_interspersed_args is None: + allow_interspersed_args = command.allow_interspersed_args + + self.allow_interspersed_args: bool = allow_interspersed_args + + if ignore_unknown_options is None: + ignore_unknown_options = command.ignore_unknown_options + + self.ignore_unknown_options: bool = ignore_unknown_options + + if help_option_names is None: + if parent is not None: + help_option_names = parent.help_option_names + else: + help_option_names = ["--help"] + + self.help_option_names: list[str] = help_option_names + + if token_normalize_func is None and parent is not None: + token_normalize_func = parent.token_normalize_func + + # An optional normalization function for tokens. (options, choices, commands etc.) + self.token_normalize_func: Callable[[str], str] | None = token_normalize_func + + # Indicates if resilient parsing is enabled. + self.resilient_parsing: bool = resilient_parsing + + # If there is no envvar prefix yet, but the parent has one and + # the command on this level has a name, we can expand the envvar + # prefix automatically. + if auto_envvar_prefix is None: + if ( + parent is not None + and parent.auto_envvar_prefix is not None + and self.info_name is not None + ): + auto_envvar_prefix = ( + f"{parent.auto_envvar_prefix}_{self.info_name.upper()}" + ) + else: + auto_envvar_prefix = auto_envvar_prefix.upper() + + if auto_envvar_prefix is not None: + auto_envvar_prefix = auto_envvar_prefix.replace("-", "_") + + self.auto_envvar_prefix: str | None = auto_envvar_prefix + + if color is None and parent is not None: + color = parent.color + + # Controls if styling output is wanted or not. + self.color: bool | None = color + + if show_default is None and parent is not None: + show_default = parent.show_default + + # Show option default values when formatting help text. + self.show_default: bool | None = show_default + + self._close_callbacks: list[Callable[[], Any]] = [] + self._depth = 0 + self._parameter_source: dict[str, ParameterSource] = {} + self._exit_stack = ExitStack() + + def __enter__(self) -> "Context": + self._depth += 1 + push_context(self) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + tb: TracebackType | None, + ) -> bool | None: + self._depth -= 1 + exit_result: bool | None = None + if self._depth == 0: + exit_result = self._close_with_exception_info(exc_type, exc_value, tb) + pop_context() + + return exit_result + + @contextmanager + def scope(self, cleanup: bool = True) -> Iterator["Context"]: + """This helper method can be used with the context object to promote + it to the current thread local (see `get_current_context`). + The default behavior of this is to invoke the cleanup functions which + can be disabled by setting `cleanup` to `False`. The cleanup + functions are typically used for things such as closing file handles. + + If the cleanup is intended the context object can also be directly + used as a context manager. + """ + if not cleanup: + self._depth += 1 + try: + with self as rv: + yield rv + finally: + if not cleanup: + self._depth -= 1 + + @property + def meta(self) -> dict[str, Any]: + """This is a dictionary which is shared with all the contexts + that are nested. It exists so that click utilities can store some + state here if they need to. It is however the responsibility of + that code to manage this dictionary well. + + The keys are supposed to be unique dotted strings. For instance + module paths are a good choice for it. What is stored in there is + irrelevant for the operation of click. However what is important is + that code that places data here adheres to the general semantics of + the system. + """ + return self._meta + + def make_formatter(self) -> HelpFormatter: + """Creates the HelpFormatter for the help and + usage output. + """ + return self.formatter_class( + width=self.terminal_width, max_width=self.max_content_width + ) + + def with_resource(self, context_manager: AbstractContextManager[V]) -> V: + """Register a resource as if it were used in a ``with`` + statement. The resource will be cleaned up when the context is + popped. + + Uses `contextlib.ExitStack.enter_context`. It calls the + resource's ``__enter__()`` method and returns the result. When + the context is popped, it closes the stack, which calls the + resource's ``__exit__()`` method. + + To register a cleanup function for something that isn't a + context manager, use `call_on_close`. Or use something + from `contextlib` to turn it into a context manager first. + """ + return self._exit_stack.enter_context(context_manager) + + def call_on_close(self, f: Callable[..., Any]) -> Callable[..., Any]: + """Register a function to be called when the context tears down. + + This can be used to close resources opened during the script + execution. Resources that support Python's context manager + protocol which would be used in a ``with`` statement should be + registered with `with_resource` instead. + """ + return self._exit_stack.callback(f) + + def close(self) -> None: + """Invoke all close callbacks registered with `call_on_close`, + and exit all context managers entered with `with_resource`. + """ + self._close_with_exception_info(None, None, None) + + def _close_with_exception_info( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + tb: TracebackType | None, + ) -> bool | None: + """Unwind the exit stack by calling its `__exit__` providing the exception + information to allow for exception handling by the various resources registered + using `with_resource` + """ + exit_result = self._exit_stack.__exit__(exc_type, exc_value, tb) + # In case the context is reused, create a new exit stack. + self._exit_stack = ExitStack() + + return exit_result + + @property + def command_path(self) -> str: + """The computed command path. This is used for the ``usage`` + information on the help page. It's automatically created by + combining the info names of the chain of contexts to the root. + """ + rv = "" + if self.info_name is not None: + rv = self.info_name + if self.parent is not None: + parent_command_path = [self.parent.command_path] + + if isinstance(self.parent.command, Command): + for param in self.parent.command.get_params(self): + parent_command_path.extend(param.get_usage_pieces(self)) + + rv = f"{' '.join(parent_command_path)} {rv}" + return rv.lstrip() + + def find_root(self) -> "Context": + """Finds the outermost context.""" + node = self + while node.parent is not None: + node = node.parent + return node + + def find_object(self, object_type: type[V]) -> V | None: + """Finds the closest object of a given type.""" + node: Context | None = self + + while node is not None: + if isinstance(node.obj, object_type): + return node.obj + + node = node.parent + + return None + + def ensure_object(self, object_type: type[V]) -> V: + """Like `find_object` but sets the innermost object to a + new instance of `object_type` if it does not exist. + """ + rv = self.find_object(object_type) + if rv is None: + self.obj = rv = object_type() + return rv + + @overload + def lookup_default(self, name: str, call: Literal[True] = True) -> Any | None: ... + + @overload + def lookup_default( + self, name: str, call: Literal[False] = ... + ) -> Any | Callable[[], Any] | None: ... + + def lookup_default(self, name: str, call: bool = True) -> Any | None: + """Get the default for a parameter from `default_map`.""" + if self.default_map is not None: + value = self.default_map.get(name) + + if call and callable(value): + return value() + + return value + + return None + + def fail(self, message: str) -> NoReturn: + """Aborts the execution of the program with a specific error + message. + """ + raise UsageError(message, self) + + def abort(self) -> NoReturn: + """Aborts the script.""" + raise Abort() + + def exit(self, code: int = 0) -> NoReturn: + """Exits the application with a given exit code.""" + self.close() + raise Exit(code) + + def get_usage(self) -> str: + """Helper method to get formatted usage string for the current + context and command. + """ + return self.command.get_usage(self) + + def get_help(self) -> str: + """Helper method to get formatted help page for the current + context and command. + """ + return self.command.get_help(self) + + def invoke(self, callback: Callable[..., V], /, *args: Any, **kwargs: Any) -> V: + """Invokes a command callback in exactly the way it expects. There + are two ways to invoke this method: + + 1. the first argument can be a callback and all other arguments and + keyword arguments are forwarded directly to the function. + 2. the first argument is a click command object. In that case all + arguments are forwarded as well but proper click parameters + (options and click arguments) must be keyword arguments and Click + will fill in defaults. + """ + ctx = self + + with augment_usage_errors(self): + with ctx: + return callback(*args, **kwargs) + + def set_parameter_source(self, name: str, source: ParameterSource) -> None: + """Set the source of a parameter. This indicates the location + from which the value of the parameter was obtained. + """ + self._parameter_source[name] = source + + def get_parameter_source(self, name: str) -> ParameterSource | None: + """Get the source of a parameter. This indicates the location + from which the value of the parameter was obtained. + + This can be useful for determining when a user specified a value + on the command line that is the same as the default value. It + will be `ParameterSource.DEFAULT` only if the + value was actually taken from the default. + """ + return self._parameter_source.get(name) + + +class Command(ABC): + """Commands are the basic building block of command line interfaces in + Click. A basic command handles command line parsing and might dispatch + more parsing to commands nested below it. + """ + + context_class: type[Context] = Context + allow_extra_args = False + allow_interspersed_args = True + ignore_unknown_options = False + + def __init__( + self, + name: str | None, + context_settings: MutableMapping[str, Any] | None = None, + callback: Callable[..., Any] | None = None, + params: list["Parameter"] | None = None, + help: str | None = None, + epilog: str | None = None, + short_help: str | None = None, + options_metavar: str | None = "[OPTIONS]", + add_help_option: bool = True, + no_args_is_help: bool = False, + hidden: bool = False, + deprecated: bool | str = False, + ) -> None: + self.name = name + + if context_settings is None: + context_settings = {} + + self.context_settings: MutableMapping[str, Any] = context_settings + + self.callback = callback + self.params: list[Parameter] = params or [] + self.help = help + self.epilog = epilog + self.options_metavar = options_metavar + self.short_help = short_help + self.add_help_option = add_help_option + self._help_option: TyperOption | None = None + self.no_args_is_help = no_args_is_help + self.hidden = hidden + self.deprecated = deprecated + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.name}>" + + def get_usage(self, ctx: Context) -> str: + """Formats the usage line into a string and returns it.""" + formatter = ctx.make_formatter() + self.format_usage(ctx, formatter) + return formatter.getvalue().rstrip("\n") + + def get_params(self, ctx: Context) -> list["Parameter"]: + params = self.params + help_option = self.get_help_option(ctx) + + if help_option is not None: + params = [*params, help_option] + + return params + + def format_usage(self, ctx: Context, formatter: HelpFormatter) -> None: + """Writes the usage line into the formatter.""" + pieces = self.collect_usage_pieces(ctx) + formatter.write_usage(ctx.command_path, " ".join(pieces)) + + def collect_usage_pieces(self, ctx: Context) -> list[str]: + """Returns all the pieces that go into the usage line and returns + it as a list of strings. + """ + rv = [self.options_metavar] if self.options_metavar else [] + + for param in self.get_params(ctx): + rv.extend(param.get_usage_pieces(ctx)) + + return rv + + def get_help_option_names(self, ctx: Context) -> list[str]: + """Returns the names for the help option.""" + all_names = set(ctx.help_option_names) + for param in self.params: + all_names.difference_update(param.opts) + all_names.difference_update(param.secondary_opts) + return list(all_names) + + def get_help_option(self, ctx: Context) -> Union["TyperOption", None]: + """Returns the help option object.""" + help_option_names = self.get_help_option_names(ctx) + + if not help_option_names or not self.add_help_option: + return None + + # Cache the help option object in private _help_option attribute to + # avoid creating it multiple times. Not doing this will break the + # callback ordering by iter_params_for_processing(), which relies on + # object comparison. + if self._help_option is None: + # Avoid circular import. + from .decorators import help_option + + # Apply help_option decorator and pop resulting option + help_option(help_option_names)(self) + self._help_option = cast("TyperOption", self.params.pop()) + + return self._help_option + + def make_parser(self, ctx: Context) -> _OptionParser: + """Creates the underlying option parser for this command.""" + parser = _OptionParser(ctx) + for param in self.get_params(ctx): + param.add_to_parser(parser, ctx) + return parser + + def get_help(self, ctx: Context) -> str: + """Formats the help into a string and returns it.""" + formatter = ctx.make_formatter() + self.format_help(ctx, formatter) + return formatter.getvalue().rstrip("\n") + + def get_short_help_str(self, limit: int = 45) -> str: + """Gets short help for the command or makes it by shortening the + long help string. + """ + if self.short_help: + text = inspect.cleandoc(self.short_help) + elif self.help: + text = make_default_short_help(self.help, limit) + else: + text = "" + + if self.deprecated: + deprecated_message = ( + f"(DEPRECATED: {self.deprecated})" + if isinstance(self.deprecated, str) + else "(DEPRECATED)" + ) + text = f"{text} {deprecated_message}" + + return text.strip() + + def format_help(self, ctx: Context, formatter: HelpFormatter) -> None: + """Writes the help into the formatter if it exists.""" + self.format_usage(ctx, formatter) + self.format_help_text(ctx, formatter) + self.format_options(ctx, formatter) + self.format_epilog(ctx, formatter) + + def format_help_text(self, ctx: Context, formatter: HelpFormatter) -> None: + """Writes the help text to the formatter if it exists.""" + if self.help is not None: + # truncate the help text to the first form feed + text = inspect.cleandoc(self.help).partition("\f")[0] + else: + text = "" + + if self.deprecated: + deprecated_message = ( + f"(DEPRECATED: {self.deprecated})" + if isinstance(self.deprecated, str) + else "(DEPRECATED)" + ) + text = f"{text} {deprecated_message}" + + if text: + formatter.write_paragraph() + + with formatter.indentation(): + formatter.write_text(text) + + @abstractmethod + def format_options(self, ctx: Context, formatter: HelpFormatter) -> None: + pass # pragma: no cover + + def format_epilog(self, ctx: Context, formatter: HelpFormatter) -> None: + """Writes the epilog into the formatter if it exists.""" + if self.epilog: + epilog = inspect.cleandoc(self.epilog) + formatter.write_paragraph() + + with formatter.indentation(): + formatter.write_text(epilog) + + def make_context( + self, + info_name: str | None, + args: list[str], + parent: Context | None = None, + **extra: Any, + ) -> Context: + """This function when given an info name and arguments will kick + off the parsing and create a new `Context`. It does not + invoke the actual command callback though. + + To quickly customize the context class used without overriding + this method, set the `context_class` attribute. + """ + for key, value in self.context_settings.items(): + if key not in extra: + extra[key] = value + + ctx = self.context_class(self, info_name=info_name, parent=parent, **extra) + + with ctx.scope(cleanup=False): + self.parse_args(ctx, args) + return ctx + + def parse_args(self, ctx: Context, args: list[str]) -> list[str]: + if not args and self.no_args_is_help and not ctx.resilient_parsing: + raise NoArgsIsHelpError(ctx) # pragma: no cover + + parser = self.make_parser(ctx) + opts, args, param_order = parser.parse_args(args=args) + + for param in iter_params_for_processing(param_order, self.get_params(ctx)): + _, args = param.handle_parse_result(ctx, opts, args) + + if args and not ctx.allow_extra_args and not ctx.resilient_parsing: + ctx.fail(f"Got unexpected extra argument(s) ({' '.join(map(str, args))})") + + ctx.args = args + ctx._opt_prefixes.update(parser._opt_prefixes) + return args + + def invoke(self, ctx: Context) -> Any: + """Given a context, this invokes the attached callback (if it exists) + in the right way. + """ + if self.deprecated: + extra_message = ( + f" {self.deprecated}" if isinstance(self.deprecated, str) else "" + ) + message = f"DeprecationWarning: The command {self.name!r} is deprecated.{extra_message}" + echo(style(message, fg="red"), err=True) + + if self.callback is not None: + return ctx.invoke(self.callback, **ctx.params) + + def shell_complete(self, ctx: Context, incomplete: str) -> list["CompletionItem"]: + """Return a list of completions for the incomplete value. Looks + at the names of options and chained multi-commands. + + Any command could be part of a chained multi-command, so sibling + commands are valid at any point during command completion. + """ + # avoid circular imports + from .shell_completion import CompletionItem + + results: list[CompletionItem] = [] + + if incomplete and not incomplete[0].isalnum(): + # avoid circular imports + from ..core import TyperOption + + for param in self.get_params(ctx): + if ( + not isinstance(param, TyperOption) + or param.hidden + or ( + not param.multiple + and ctx.get_parameter_source(param.name) # type: ignore + is ParameterSource.COMMANDLINE + ) + ): + continue + + results.extend( + CompletionItem(name, help=param.help) + for name in [*param.opts, *param.secondary_opts] + if name.startswith(incomplete) + ) + + return results + + @abstractmethod + def main( + self, + args: Sequence[str] | None = None, + prog_name: str | None = None, + complete_var: str | None = None, + standalone_mode: bool = True, + windows_expand_args: bool = True, + **extra: Any, + ) -> Any: + pass # pragma: no cover + + @abstractmethod + def _main_shell_completion( + self, + ctx_args: MutableMapping[str, Any], + prog_name: str, + complete_var: str | None = None, + ) -> None: + pass # pragma: no cover + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Alias for self.main""" + return self.main(*args, **kwargs) + + +class Parameter(ABC): + r"""A parameter to a command comes in two versions: they are either + `Option`\s or `Argument`\s. + + Some settings are supported by both options and arguments. + """ + + param_type_name = "parameter" + + def __init__( + self, + param_decls: Sequence[str] | None = None, + type: types.ParamType | Any | None = None, + required: bool = False, + default: Any | Callable[[], Any] | None = None, + callback: Callable[[Context, "Parameter", Any], Any] | None = None, + nargs: int | None = None, + multiple: bool = False, + metavar: str | None = None, + expose_value: bool = True, + is_eager: bool = False, + envvar: str | Sequence[str] | None = None, + shell_complete: Callable[ + [Context, "Parameter", str], list["CompletionItem"] | list[str] + ] + | None = None, + ) -> None: + self.name: str | None + self.opts: list[str] + self.secondary_opts: list[str] + self.name, self.opts, self.secondary_opts = self._parse_decls( + param_decls or (), expose_value + ) + self.type: types.ParamType = types.convert_type(type, default) + + # Default nargs to what the type tells us if we have that + # information available. + if nargs is None: + if self.type.is_composite: + nargs = self.type.arity + else: + nargs = 1 + + self.required = required + self.callback = callback + self.nargs = nargs + self.multiple = multiple + self.expose_value = expose_value + self.default: Any | Callable[[], Any] | None = default + self.is_eager = is_eager + self.metavar = metavar + self.envvar = envvar + self._custom_shell_complete = shell_complete + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.name}>" + + @abstractmethod + def _parse_decls( + self, decls: Sequence[str], expose_value: bool + ) -> tuple[str | None, list[str], list[str]]: + pass # pragma: no cover + + @property + def human_readable_name(self) -> str: + """Returns the human readable name of this parameter. This is the + same as the name for options, but the metavar for arguments. + """ + assert self.name is not None, "self.name should be set" + return self.name + + def make_metavar(self, ctx: Context) -> str: + if self.metavar is not None: + return self.metavar + + metavar = self.type.get_metavar(param=self, ctx=ctx) + + if metavar is None: + metavar = self.type.name.upper() + + if self.nargs != 1: + metavar += "..." + + return metavar + + @overload + def get_default(self, ctx: Context, call: Literal[True] = True) -> Any | None: ... + + @overload + def get_default( + self, ctx: Context, call: bool = ... + ) -> Any | Callable[[], Any] | None: ... + + def get_default( + self, ctx: Context, call: bool = True + ) -> Any | Callable[[], Any] | None: + """Get the default for the parameter""" + value = ctx.lookup_default(self.name, call=False) # type: ignore + + if value is None: + value = self.default + + if call and callable(value): + value = value() + + return value + + @abstractmethod + def add_to_parser(self, parser: _OptionParser, ctx: Context) -> None: + pass # pragma: no cover + + def consume_value( + self, ctx: Context, opts: Mapping[str, Any] + ) -> tuple[Any, ParameterSource]: + value = opts.get(self.name) # type: ignore + source = ParameterSource.COMMANDLINE + + if value is None: + value = self.value_from_envvar(ctx) + source = ParameterSource.ENVIRONMENT + + if value is None: + value = ctx.lookup_default(self.name) # type: ignore + source = ParameterSource.DEFAULT_MAP + + if value is None: + value = self.get_default(ctx) + source = ParameterSource.DEFAULT + + return value, source + + def type_cast_value(self, ctx: Context, value: Any) -> Any: + """Convert and validate a value against the parameter's + `type`, `multiple`, and `nargs`. + """ + if value is None: + return () if self.multiple or self.nargs == -1 else None + + def check_iter(value: Any) -> Iterator[Any]: + if isinstance(value, str): + raise BadParameter("Value must be an iterable.", ctx=ctx, param=self) + else: + return iter(value) + + # Define the conversion function based on nargs and type. + if self.nargs == 1 or self.type.is_composite: + + def convert(value: Any) -> Any: + return self.type(value, param=self, ctx=ctx) + + elif self.nargs == -1: + + def convert(value: Any) -> Any: # tuple[t.Any, ...] + return tuple(self.type(x, self, ctx) for x in check_iter(value)) + + # TODO: evaluate whether we need to keep this in Typer + else: # nargs > 1 + + def convert(value: Any) -> Any: # tuple[t.Any, ...] + value = tuple(check_iter(value)) + + if len(value) != self.nargs: + raise BadParameter( + f"Takes {self.nargs} values but {len(value)} given.", + ctx=ctx, + param=self, + ) + + return tuple(self.type(x, self, ctx) for x in value) + + if self.multiple: + return tuple(convert(x) for x in check_iter(value)) + + return convert(value) + + @abstractmethod + def value_is_missing(self, value: Any) -> bool: + pass # pragma: no cover + + def process_value(self, ctx: Context, value: Any) -> Any: + """Process the value of this parameter""" + value = self.type_cast_value(ctx, value) + + if self.required and self.value_is_missing(value): + raise MissingParameter(ctx=ctx, param=self) + + if self.callback is not None: + value = self.callback(ctx, self, value) + + return value + + def resolve_envvar_value(self, ctx: Context) -> str | None: + """Returns the value found in the environment variable(s) attached to this + parameter. + + Environment variables values are `always returned as strings + `_. + + This method returns ``None`` if: + + - the `envvar` property is not set on `Parameter`, + - the environment variable is not found in the environment, + - the variable is found in the environment but its value is empty (i.e. the + environment variable is present but has an empty string). + + If `envvar` is setup with multiple environment variables, + then only the first non-empty value is returned. + """ + if self.envvar is None: + return None + + if isinstance(self.envvar, str): + rv = os.environ.get(self.envvar) + + if rv: + return rv + else: + for envvar in self.envvar: + rv = os.environ.get(envvar) + + # Return the first non-empty value of the list of environment variables. + if rv: + return rv + # Else, absence of value is interpreted as an environment variable that + # is not set, so proceed to the next one. + + return None + + def value_from_envvar(self, ctx: Context) -> str | Sequence[str] | None: + """Process the raw environment variable string for this parameter. + + Returns the string as-is or splits it into a sequence of strings if the + parameter is expecting multiple values (i.e. its `nargs` property is set + to a value other than ``1``). + """ + rv: Any | None = self.resolve_envvar_value(ctx) + + if rv is not None and self.nargs != 1: + rv = self.type.split_envvar_value(rv) + + return rv + + def handle_parse_result( + self, ctx: Context, opts: Mapping[str, Any], args: list[str] + ) -> tuple[Any, list[str]]: + """Process the value produced by the parser from user input. + + Always process the value through the Parameter's `type`, wherever it + comes from. + + If the parameter is deprecated, this method warn the user about it. But only if + the value has been explicitly set by the user (and as such, is not coming from + a default). + """ + with augment_usage_errors(ctx, param=self): + value, source = self.consume_value(ctx, opts) + + ctx.set_parameter_source(self.name, source) # type: ignore + + # Process the value through the parameter's type. + try: + value = self.process_value(ctx, value) + except Exception: + if not ctx.resilient_parsing: + raise + value = None + + if self.expose_value: + ctx.params[self.name] = value # type: ignore + + return value, args + + @abstractmethod + def get_help_record(self, ctx: Context) -> tuple[str, str] | None: + pass # pragma: no cover + + def get_usage_pieces(self, ctx: Context) -> list[str]: + return [] + + def get_error_hint(self, ctx: Context) -> str: + """Get a stringified version of the param for use in error messages to + indicate which param caused the error. + """ + hint_list = self.opts or [self.human_readable_name] + return " / ".join(f"'{x}'" for x in hint_list) + + def shell_complete(self, ctx: Context, incomplete: str) -> list["CompletionItem"]: + """Return a list of completions for the incomplete value. If a + ``shell_complete`` function was given during init, it is used. + Otherwise, the `type` `ParamType.shell_complete` function is used. + """ + if self._custom_shell_complete is not None: + results = self._custom_shell_complete(ctx, self, incomplete) + + if results and isinstance(results[0], str): + from .shell_completion import CompletionItem + + results = [CompletionItem(c) for c in results] + + return cast("list[CompletionItem]", results) + + return self.type.shell_complete(ctx, self, incomplete) diff --git a/typer/_click/decorators.py b/typer/_click/decorators.py new file mode 100644 index 0000000000..28ad656a8c --- /dev/null +++ b/typer/_click/decorators.py @@ -0,0 +1,60 @@ +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar + +from .core import Command, Context, Parameter +from .utils import echo + +if TYPE_CHECKING: + from ..core import TyperGroup, TyperOption + + GrpType = TypeVar("GrpType", bound=TyperGroup) + + +P = ParamSpec("P") + +R = TypeVar("R") +T = TypeVar("T") +_AnyCallable = Callable[..., Any] + + +CmdType = TypeVar("CmdType", bound=Command) + + +def option( + param_decls: list[str], cls: type["TyperOption"] | None = None, **attrs: Any +) -> Callable[[Command], Command]: + """Attaches an option to the command.""" + if cls is None: + # avoid circular imports + from ..core import TyperOption + + cls = TyperOption + + def decorator(f: Command) -> Command: + param = cls(param_decls=param_decls, **attrs) + f.params.append(param) + return f + + return decorator + + +def help_option(param_decls: list[str]) -> Callable[[Command], Command]: + """Help option which prints the help page and exits the program.""" + + def show_help(ctx: Context, param: Parameter, value: bool) -> None: + """Callback that print the help page on ```` and exits.""" + if value and not ctx.resilient_parsing: + echo(ctx.get_help(), color=ctx.color) + ctx.exit() + + assert len(param_decls) > 0, "At least one help option should be provided" + + return option( + param_decls, + is_flag=True, + expose_value=False, + is_eager=True, + help="Show this message and exit.", + callback=show_help, + required=False, + ) diff --git a/typer/_click/exceptions.py b/typer/_click/exceptions.py new file mode 100644 index 0000000000..af2af260a6 --- /dev/null +++ b/typer/_click/exceptions.py @@ -0,0 +1,260 @@ +from collections.abc import Sequence +from typing import IO, TYPE_CHECKING, Any, Union + +from ._compat import get_text_stderr +from .globals import resolve_color_default +from .utils import echo, format_filename + +if TYPE_CHECKING: + from .core import Command, Context, Parameter + + +def _join_param_hints(param_hint: Sequence[str] | str | None) -> str | None: + if param_hint is not None and not isinstance(param_hint, str): + return " / ".join(repr(x) for x in param_hint) + + return param_hint + + +class ClickException(Exception): + """An exception that Click can handle and show to the user.""" + + exit_code = 1 + + def __init__(self, message: str) -> None: + super().__init__(message) + # The context will be removed by the time we print the message, so cache + # the color settings here to be used later on (in `show`) + self.show_color: bool | None = resolve_color_default() + self.message = message + + def format_message(self) -> str: + return self.message + + def __str__(self) -> str: + return self.message + + def show(self, file: IO[Any] | None = None) -> None: + if file is None: + file = get_text_stderr() + + echo( + f"Error: {self.format_message()}", + file=file, + color=self.show_color, + ) + + +class UsageError(ClickException): + """An internal exception that signals a usage error. This typically + aborts any further handling. + """ + + exit_code = 2 + + def __init__(self, message: str, ctx: Union["Context", None] = None) -> None: + super().__init__(message) + self.ctx = ctx + self.cmd: Command | None = self.ctx.command if self.ctx else None + + def show(self, file: IO[Any] | None = None) -> None: + if file is None: + file = get_text_stderr() + color = None + hint = "" + if ( + self.ctx is not None + and self.ctx.command.get_help_option(self.ctx) is not None + ): + command = self.ctx.command_path + option = self.ctx.help_option_names[0] + hint = f"Try '{command} {option}' for help.\n" + if self.ctx is not None: + color = self.ctx.color + echo(f"{self.ctx.get_usage()}\n{hint}", file=file, color=color) + echo( + f"Error: {self.format_message()}", + file=file, + color=color, + ) + + +class BadParameter(UsageError): + """An exception that formats out a standardized error message for a + bad parameter. This is useful when thrown from a callback or type as + Click will attach contextual information to it (for instance, which + parameter it is). + """ + + def __init__( + self, + message: str, + ctx: Union["Context", None] = None, + param: Union["Parameter", None] = None, + param_hint: Sequence[str] | str | None = None, + ) -> None: + super().__init__(message, ctx) + self.param = param + self.param_hint = param_hint + + def format_message(self) -> str: + if self.param_hint is not None: + param_hint = self.param_hint + elif self.param is not None: + param_hint = self.param.get_error_hint(self.ctx) # type: ignore + else: + return f"Invalid value: {self.message}" + + hint = _join_param_hints(param_hint) + return f"Invalid value for {hint}: {self.message}" + + +class MissingParameter(BadParameter): + """Raised if click required an option or argument but it was not + provided when invoking the script. + """ + + def __init__( + self, + message: str | None = None, + ctx: Union["Context", None] = None, + param: Union["Parameter", None] = None, + param_hint: Sequence[str] | str | None = None, + param_type: str | None = None, + ) -> None: + super().__init__(message or "", ctx, param, param_hint) + self.param_type = param_type + + def format_message(self) -> str: + if self.param_hint is not None: + param_hint: Sequence[str] | str | None = self.param_hint + elif self.param is not None: + param_hint = self.param.get_error_hint(self.ctx) # type: ignore + else: + param_hint = None + + param_hint = _join_param_hints(param_hint) + param_hint = f" {param_hint}" if param_hint else "" + + param_type = self.param_type + if param_type is None and self.param is not None: + param_type = self.param.param_type_name + + msg = self.message + if self.param is not None: + msg_extra = self.param.type.get_missing_message( + param=self.param, ctx=self.ctx + ) + if msg_extra: + if msg: + msg += f". {msg_extra}" + else: + msg = msg_extra + + msg = f" {msg}" if msg else "" + + # Translate param_type for known types. + if param_type == "argument": + missing = "Missing argument" + elif param_type == "option": + missing = "Missing option" + elif param_type == "parameter": + missing = "Missing parameter" + else: + missing = f"Missing {param_type}" + + return f"{missing}{param_hint}.{msg}" + + def __str__(self) -> str: + if not self.message: + param_name = self.param.name if self.param else None + return f"Missing parameter: {param_name}" + else: + return self.message + + +class NoSuchOption(UsageError): + """Raised if click attempted to handle an option that does not + exist. + """ + + def __init__( + self, + option_name: str, + message: str | None = None, + possibilities: Sequence[str] | None = None, + ctx: Union["Context", None] = None, + ) -> None: + if message is None: + message = f"No such option: {option_name}" + + super().__init__(message, ctx) + self.option_name = option_name + self.possibilities = possibilities + + def format_message(self) -> str: + if not self.possibilities: + return self.message + + possibility_str = ", ".join(sorted(self.possibilities)) + suggest = (f"(Possible options: {possibility_str})",) + return f"{self.message} {suggest}" + + +class BadOptionUsage(UsageError): + """Raised if an option is generally supplied but the use of the option + was incorrect. This is for instance raised if the number of arguments + for an option is not correct. + """ + + def __init__( + self, option_name: str, message: str, ctx: Union["Context", None] = None + ) -> None: + super().__init__(message, ctx) + self.option_name = option_name + + +class BadArgumentUsage(UsageError): + """Raised if an argument is generally supplied but the use of the argument + was incorrect. This is for instance raised if the number of values + for an argument is not correct. + """ + + +class NoArgsIsHelpError(UsageError): + def __init__(self, ctx: "Context") -> None: + self.ctx: Context + super().__init__(ctx.get_help(), ctx=ctx) + + def show(self, file: IO[Any] | None = None) -> None: + echo(self.format_message(), file=file, err=True, color=self.ctx.color) + + +class FileError(ClickException): + """Raised if a file cannot be opened.""" + + def __init__(self, filename: str, hint: str | None = None) -> None: + if hint is None: + hint = "unknown error" + + super().__init__(hint) + self.ui_filename: str = format_filename(filename) + self.filename = filename + + def format_message(self) -> str: + return f"Could not open file {self.ui_filename!r}: {self.message}" + + +class Abort(RuntimeError): + """An internal signalling exception that signals Click to abort.""" + + +class Exit(RuntimeError): + """An exception that indicates that the application should exit with some + status code. + """ + + __slots__ = ("exit_code",) + + def __init__(self, code: int = 0) -> None: + self.exit_code: int = code diff --git a/typer/_click/formatting.py b/typer/_click/formatting.py new file mode 100644 index 0000000000..b5eaab3bd0 --- /dev/null +++ b/typer/_click/formatting.py @@ -0,0 +1,272 @@ +from collections.abc import Iterable, Iterator, Sequence +from contextlib import contextmanager + +from ._compat import term_len +from .parser import _split_opt + +# Can force a width. This is used by the test system +FORCED_WIDTH: int | None = None + + +def measure_table(rows: Iterable[tuple[str, str]]) -> tuple[int, ...]: + widths: dict[int, int] = {} + + for row in rows: + for idx, col in enumerate(row): + widths[idx] = max(widths.get(idx, 0), term_len(col)) + + return tuple(y for x, y in sorted(widths.items())) + + +def iter_rows( + rows: Iterable[tuple[str, str]], col_count: int +) -> Iterator[tuple[str, ...]]: + for row in rows: + yield row + ("",) * (col_count - len(row)) + + +def wrap_text( + text: str, + width: int = 78, + initial_indent: str = "", + subsequent_indent: str = "", + preserve_paragraphs: bool = False, +) -> str: + """A helper function that intelligently wraps text. By default, it + assumes that it operates on a single paragraph of text but if the + `preserve_paragraphs` parameter is provided it will intelligently + handle paragraphs (defined by two empty lines). + + If paragraphs are handled, a paragraph can be prefixed with an empty + line containing the ``\\b`` character (``\\x08``) to indicate that + no rewrapping should happen in that block. + """ + from ._textwrap import TextWrapper + + text = text.expandtabs() + wrapper = TextWrapper( + width, + initial_indent=initial_indent, + subsequent_indent=subsequent_indent, + replace_whitespace=False, + ) + if not preserve_paragraphs: + return wrapper.fill(text) + + p: list[tuple[int, bool, str]] = [] + buf: list[str] = [] + indent = None + + def _flush_par() -> None: + if not buf: + return + if buf[0].strip() == "\b": + p.append((indent or 0, True, "\n".join(buf[1:]))) + else: + p.append((indent or 0, False, " ".join(buf))) + del buf[:] + + for line in text.splitlines(): + if not line: + _flush_par() + indent = None + else: + if indent is None: + orig_len = term_len(line) + line = line.lstrip() + indent = orig_len - term_len(line) + buf.append(line) + _flush_par() + + rv = [] + for indent, raw, text in p: + with wrapper.extra_indent(" " * indent): + if raw: + rv.append(wrapper.indent_only(text)) + else: + rv.append(wrapper.fill(text)) + + return "\n\n".join(rv) + + +class HelpFormatter: + """This class helps with formatting text-based help pages. It's + usually just needed for very special internal cases, but it's also + exposed so that developers can write their own fancy outputs. + + At present, it always writes into memory. + """ + + def __init__( + self, + indent_increment: int = 2, + width: int | None = None, + max_width: int | None = None, + ) -> None: + self.indent_increment = indent_increment + if max_width is None: + max_width = 80 + if width is None: + import shutil + + width = FORCED_WIDTH + if width is None: + width = max(min(shutil.get_terminal_size().columns, max_width) - 2, 50) + self.width = width + self.current_indent: int = 0 + self.buffer: list[str] = [] + + def write(self, string: str) -> None: + """Writes a unicode string into the internal buffer.""" + self.buffer.append(string) + + def indent(self) -> None: + """Increases the indentation.""" + self.current_indent += self.indent_increment + + def dedent(self) -> None: + """Decreases the indentation.""" + self.current_indent -= self.indent_increment + + def write_usage(self, prog: str, args: str = "", prefix: str | None = None) -> None: + """Writes a usage line into the buffer.""" + if prefix is None: + prefix = "Usage: " + + usage_prefix = f"{prefix:>{self.current_indent}}{prog} " + text_width = self.width - self.current_indent + + if text_width >= (term_len(usage_prefix) + 20): + # The arguments will fit to the right of the prefix. + indent = " " * term_len(usage_prefix) + self.write( + wrap_text( + args, + text_width, + initial_indent=usage_prefix, + subsequent_indent=indent, + ) + ) + else: + # The prefix is too long, put the arguments on the next line. + self.write(usage_prefix) + self.write("\n") + indent = " " * (max(self.current_indent, term_len(prefix)) + 4) + self.write( + wrap_text( + args, text_width, initial_indent=indent, subsequent_indent=indent + ) + ) + + self.write("\n") + + def write_heading(self, heading: str) -> None: + """Writes a heading into the buffer.""" + self.write(f"{'':>{self.current_indent}}{heading}:\n") + + def write_paragraph(self) -> None: + """Writes a paragraph into the buffer.""" + if self.buffer: + self.write("\n") + + def write_text(self, text: str) -> None: + """Writes re-indented text into the buffer. This rewraps and + preserves paragraphs. + """ + indent = " " * self.current_indent + self.write( + wrap_text( + text, + self.width, + initial_indent=indent, + subsequent_indent=indent, + preserve_paragraphs=True, + ) + ) + self.write("\n") + + def write_dl( + self, + rows: Sequence[tuple[str, str]], + col_max: int = 30, + col_spacing: int = 2, + ) -> None: + """Writes a definition list into the buffer. This is how options + and commands are usually formatted. + """ + rows = list(rows) + widths = measure_table(rows) + if len(widths) != 2: # pragma: no cover + raise TypeError("Expected two columns for definition list") + + first_col = min(widths[0], col_max) + col_spacing + + for first, second in iter_rows(rows, len(widths)): + self.write(f"{'':>{self.current_indent}}{first}") + if not second: + self.write("\n") + continue + if term_len(first) <= first_col - col_spacing: + self.write(" " * (first_col - term_len(first))) + else: + self.write("\n") + self.write(" " * (first_col + self.current_indent)) + + text_width = max(self.width - first_col - 2, 10) + wrapped_text = wrap_text(second, text_width, preserve_paragraphs=True) + lines = wrapped_text.splitlines() + + if lines: + self.write(f"{lines[0]}\n") + + for line in lines[1:]: + self.write(f"{'':>{first_col + self.current_indent}}{line}\n") + else: # pragma: no cover + self.write("\n") + + @contextmanager + def section(self, name: str) -> Iterator[None]: + """Helpful context manager that writes a paragraph, a heading, + and the indents. + """ + self.write_paragraph() + self.write_heading(name) + self.indent() + try: + yield + finally: + self.dedent() + + @contextmanager + def indentation(self) -> Iterator[None]: + """A context manager that increases the indentation.""" + self.indent() + try: + yield + finally: + self.dedent() + + def getvalue(self) -> str: + """Returns the buffer contents.""" + return "".join(self.buffer) + + +def join_options(options: Sequence[str]) -> tuple[str, bool]: + """Given a list of option strings this joins them in the most appropriate + way and returns them in the form ``(formatted_string, + any_prefix_is_slash)`` where the second item in the tuple is a flag that + indicates if any of the option prefixes was a slash. + """ + rv = [] + any_prefix_is_slash = False + + for opt in options: + prefix = _split_opt(opt)[0] + + if prefix == "/": + any_prefix_is_slash = True + + rv.append((len(prefix), opt)) + + rv.sort(key=lambda x: x[0]) + return ", ".join(x[1] for x in rv), any_prefix_is_slash diff --git a/typer/_click/globals.py b/typer/_click/globals.py new file mode 100644 index 0000000000..372dc40749 --- /dev/null +++ b/typer/_click/globals.py @@ -0,0 +1,61 @@ +from threading import local +from typing import TYPE_CHECKING, Literal, Union, cast, overload + +if TYPE_CHECKING: + from .core import Context + +_local = local() + + +@overload +def get_current_context(silent: Literal[False] = False) -> "Context": ... + + +@overload +def get_current_context(silent: bool = ...) -> Union["Context", None]: ... + + +def get_current_context(silent: bool = False) -> Union["Context", None]: + """Returns the current click context. This can be used as a way to + access the current context object from anywhere. This is a more implicit + alternative to the `pass_context` decorator. This function is + primarily useful for helpers such as `echo` which might be + interested in changing its behavior based on the current context. + + To push the current context, `Context.scope` can be used. + """ + try: + return cast("Context", _local.stack[-1]) + except (AttributeError, IndexError) as e: + if not silent: + raise RuntimeError( + "There is no active click context." + ) from e # pragma: no cover + + return None + + +def push_context(ctx: "Context") -> None: + """Pushes a new context to the current stack.""" + _local.__dict__.setdefault("stack", []).append(ctx) + + +def pop_context() -> None: + """Removes the top level from the stack.""" + _local.stack.pop() + + +def resolve_color_default(color: bool | None = None) -> bool | None: + """Internal helper to get the default value of the color flag. If a + value is passed it's returned unchanged, otherwise it's looked up from + the current context. + """ + if color is not None: + return color + + ctx = get_current_context(silent=True) + + if ctx is not None: + return ctx.color + + return None diff --git a/typer/_click/parser.py b/typer/_click/parser.py new file mode 100644 index 0000000000..71eb3003cc --- /dev/null +++ b/typer/_click/parser.py @@ -0,0 +1,459 @@ +""" +This module started out as largely a copy paste from the stdlib's +optparse module with the features removed that we do not need from +optparse because we implement them in Click on a higher level (for +instance type handling, help formatting and a lot more). + +The plan is to remove more and more from here over time. + +The reason this is a different module and not optparse from the stdlib +is that there are differences in 2.x and 3.x about the error messages +generated and optparse in the stdlib uses gettext for no good reason +and might cause us issues. + +Click uses parts of optparse written by Gregory P. Ward and maintained +by the Python Software Foundation. This is limited to code in parser.py. + +Copyright 2001-2006 Gregory P. Ward. All rights reserved. +Copyright 2002-2006 Python Software Foundation. All rights reserved. +""" + +# This code uses parts of optparse written by Gregory P. Ward and +# maintained by the Python Software Foundation. +# Copyright 2001-2006 Gregory P. Ward +# Copyright 2002-2006 Python Software Foundation +from collections import deque +from collections.abc import Sequence +from typing import ( + TYPE_CHECKING, + Any, + TypeVar, + Union, +) + +from .exceptions import BadArgumentUsage, BadOptionUsage, NoSuchOption, UsageError + +if TYPE_CHECKING: + from typer.core import TyperArgument as CoreArgument + from typer.core import TyperOption as CoreOption + + from .core import Context + from .core import Parameter as CoreParameter + +V = TypeVar("V") + + +def _unpack_args( + args: Sequence[str], nargs_spec: Sequence[int] +) -> tuple[Sequence[str | Sequence[str | None] | None], list[str]]: + """Given an iterable of arguments and an iterable of nargs specifications, + it returns a tuple with all the unpacked arguments at the first index + and all remaining arguments as the second. + + The nargs specification is the number of arguments that should be consumed + or `-1` to indicate that this position should eat up all the remainders. + """ + args = deque(args) + nargs_spec = deque(nargs_spec) + rv: list[str | tuple[str | None, ...] | None] = [] + spos: int | None = None + + def _fetch(c: deque[V]) -> V | None: + try: + if spos is None: + return c.popleft() + else: + return c.pop() + except IndexError: + return None + + while nargs_spec: + nargs = _fetch(nargs_spec) + assert nargs is not None + + if nargs == 1: + rv.append(_fetch(args)) + elif nargs > 1: + x = [_fetch(args) for _ in range(nargs)] + + # If we're reversed, we're pulling in the arguments in reverse, + # so we need to turn them around. + if spos is not None: + x.reverse() + + rv.append(tuple(x)) + elif nargs < 0: + if spos is not None: # pragma: no cover + raise TypeError("Cannot have two nargs < 0") + + spos = len(rv) + rv.append(None) + + # spos is the position of the wildcard (star). If it's not `None`, + # we fill it with the remainder. + if spos is not None: + rv[spos] = tuple(args) + args = [] + rv[spos + 1 :] = reversed(rv[spos + 1 :]) + + return tuple(rv), list(args) + + +def _split_opt(opt: str) -> tuple[str, str]: + first = opt[:1] + if first.isalnum(): + return "", opt + if opt[1:2] == first: + return opt[:2], opt[2:] + return first, opt[1:] + + +def _normalize_opt(opt: str, ctx: Union["Context", None]) -> str: + if ctx is None or ctx.token_normalize_func is None: + return opt + prefix, opt = _split_opt(opt) + return f"{prefix}{ctx.token_normalize_func(opt)}" + + +class _Option: + def __init__( + self, + obj: "CoreOption", + opts: Sequence[str], + dest: str | None, + action: str = "store", + nargs: int = 1, + const: Any | None = None, + ): + self._short_opts = [] + self._long_opts = [] + self.prefixes: set[str] = set() + + for opt in opts: + prefix, value = _split_opt(opt) + if not prefix: # pragma: no cover + raise ValueError(f"Invalid start character for option ({opt})") + self.prefixes.add(prefix[0]) + if len(prefix) == 1 and len(value) == 1: + self._short_opts.append(opt) + else: + self._long_opts.append(opt) + self.prefixes.add(prefix) + + self.dest = dest + self.action = action + self.nargs = nargs + self.const = const + self.obj = obj + + @property + def takes_value(self) -> bool: + return self.action in ("store", "append") + + def process(self, value: Any, state: "_ParsingState") -> None: + if self.action == "store": + state.opts[self.dest] = value # type: ignore + elif self.action == "store_const": + state.opts[self.dest] = self.const # type: ignore + elif self.action == "append": + state.opts.setdefault(self.dest, []).append(value) # type: ignore + elif self.action == "append_const": + state.opts.setdefault(self.dest, []).append(self.const) # type: ignore + elif self.action == "count": + state.opts[self.dest] = state.opts.get(self.dest, 0) + 1 # type: ignore + else: # pragma: no cover + raise ValueError(f"unknown action '{self.action}'") + state.order.append(self.obj) + + +class _Argument: + def __init__(self, obj: "CoreArgument", dest: str | None, nargs: int = 1): + self.dest = dest + self.nargs = nargs + self.obj = obj + + def process( + self, + value: str | Sequence[str | None] | None, + state: "_ParsingState", + ) -> None: + if self.nargs > 1: + assert value is not None + holes = sum(1 for x in value if x is None) + if holes == len(value): + value = None + elif holes != 0: + raise BadArgumentUsage( + f"Argument {self.dest!r} takes {self.nargs} values." + ) + + if self.nargs == -1 and self.obj.envvar is not None and value == (): + # Replace empty tuple with None so that a value from the + # environment may be tried. + value = None + + state.opts[self.dest] = value # type: ignore + state.order.append(self.obj) + + +class _ParsingState: + def __init__(self, rargs: list[str]) -> None: + self.opts: dict[str, Any] = {} + self.largs: list[str] = [] + self.rargs = rargs + self.order: list[CoreParameter] = [] + + +class _OptionParser: + """The option parser is an internal class that is ultimately used to + parse options and arguments. It's modelled after optparse and brings + a similar but vastly simplified API. It should generally not be used + directly as the high level Click classes wrap it for you. + + It's not nearly as extensible as optparse or argparse as it does not + implement features that are implemented on a higher level (such as + types or defaults). + """ + + def __init__(self, ctx: Union["Context", None] = None) -> None: + self.ctx = ctx + # This controls how the parser deals with interspersed arguments. + # If this is set to `False`, the parser will stop on the first + # non-option. Click uses this to implement nested subcommands + # safely. + self.allow_interspersed_args: bool = True + # This tells the parser how to deal with unknown options. By + # default it will error out (which is sensible), but there is a + # second mode where it will ignore it and continue processing + # after shifting all the unknown options into the resulting args. + self.ignore_unknown_options: bool = False + + if ctx is not None: + self.allow_interspersed_args = ctx.allow_interspersed_args + self.ignore_unknown_options = ctx.ignore_unknown_options + + self._short_opt: dict[str, _Option] = {} + self._long_opt: dict[str, _Option] = {} + self._opt_prefixes = {"-", "--"} + self._args: list[_Argument] = [] + + def add_option( + self, + obj: "CoreOption", + opts: Sequence[str], + dest: str | None, + action: str = "store", + nargs: int = 1, + const: Any | None = None, + ) -> None: + """Adds a new option named `dest` to the parser. The destination + is not inferred (unlike with optparse) and needs to be explicitly + provided. Action can be any of ``store``, ``store_const``, + ``append``, ``append_const`` or ``count``. + + The `obj` can be used to identify the option in the order list + that is returned from the parser. + """ + opts = [_normalize_opt(opt, self.ctx) for opt in opts] + option = _Option(obj, opts, dest, action=action, nargs=nargs, const=const) + self._opt_prefixes.update(option.prefixes) + for opt in option._short_opts: + self._short_opt[opt] = option + for opt in option._long_opts: + self._long_opt[opt] = option + + def add_argument( + self, obj: "CoreArgument", dest: str | None, nargs: int = 1 + ) -> None: + """Adds a positional argument named `dest` to the parser. + + The `obj` can be used to identify the option in the order list + that is returned from the parser. + """ + self._args.append(_Argument(obj, dest=dest, nargs=nargs)) + + def parse_args( + self, args: list[str] + ) -> tuple[dict[str, Any], list[str], list["CoreParameter"]]: + """Parses positional arguments and returns ``(values, args, order)`` + for the parsed options and arguments as well as the leftover + arguments if there are any. The order is a list of objects as they + appear on the command line. If arguments appear multiple times they + will be memorized multiple times as well. + """ + state = _ParsingState(args) + try: + self._process_args_for_options(state) + self._process_args_for_args(state) + except UsageError: + if self.ctx is None or not self.ctx.resilient_parsing: + raise + return state.opts, state.largs, state.order + + def _process_args_for_args(self, state: _ParsingState) -> None: + pargs, args = _unpack_args( + state.largs + state.rargs, [x.nargs for x in self._args] + ) + + for idx, arg in enumerate(self._args): + arg.process(pargs[idx], state) + + state.largs = args + state.rargs = [] + + def _process_args_for_options(self, state: _ParsingState) -> None: + while state.rargs: + arg = state.rargs.pop(0) + arglen = len(arg) + # Double dashes always handled explicitly regardless of what + # prefixes are valid. + if arg == "--": + return + elif arg[:1] in self._opt_prefixes and arglen > 1: + self._process_opts(arg, state) + elif self.allow_interspersed_args: + state.largs.append(arg) + else: + state.rargs.insert(0, arg) + return + + # Say this is the original argument list: + # [arg0, arg1, ..., arg(i-1), arg(i), arg(i+1), ..., arg(N-1)] + # ^ + # (we are about to process arg(i)). + # + # Then rargs is [arg(i), ..., arg(N-1)] and largs is a *subset* of + # [arg0, ..., arg(i-1)] (any options and their arguments will have + # been removed from largs). + # + # The while loop will usually consume 1 or more arguments per pass. + # If it consumes 1 (eg. arg is an option that takes no arguments), + # then after _process_arg() is done the situation is: + # + # largs = subset of [arg0, ..., arg(i)] + # rargs = [arg(i+1), ..., arg(N-1)] + # + # If allow_interspersed_args is false, largs will always be + # *empty* -- still a subset of [arg0, ..., arg(i-1)], but + # not a very interesting subset! + + def _match_long_opt( + self, opt: str, explicit_value: str | None, state: _ParsingState + ) -> None: + if opt not in self._long_opt: + from difflib import get_close_matches + + possibilities = get_close_matches(opt, self._long_opt) + raise NoSuchOption(opt, possibilities=possibilities, ctx=self.ctx) + + option = self._long_opt[opt] + if option.takes_value: + # At this point it's safe to modify rargs by injecting the + # explicit value, because no exception is raised in this + # branch. This means that the inserted value will be fully + # consumed. + if explicit_value is not None: + state.rargs.insert(0, explicit_value) + + value = self._get_value_from_state(opt, option, state) + + elif explicit_value is not None: # pragma: no cover + raise BadOptionUsage(opt, f"Option {opt!r} does not take a value.") + + else: + value = None + + option.process(value, state) + + def _match_short_opt(self, arg: str, state: _ParsingState) -> None: + stop = False + i = 1 + prefix = arg[0] + unknown_options = [] + + for ch in arg[1:]: + opt = _normalize_opt(f"{prefix}{ch}", self.ctx) + option = self._short_opt.get(opt) + i += 1 + + if not option: + if self.ignore_unknown_options: + unknown_options.append(ch) + continue + raise NoSuchOption(opt, ctx=self.ctx) + if option.takes_value: + # Any characters left in arg? Pretend they're the + # next arg, and stop consuming characters of arg. + if i < len(arg): + state.rargs.insert(0, arg[i:]) + stop = True + + value = self._get_value_from_state(opt, option, state) + + else: + value = None + + option.process(value, state) + + if stop: + break + + # If we got any unknown options we recombine the string of the + # remaining options and re-attach the prefix, then report that + # to the state as new 'largs'. This way there is basic combinatorics + # that can be achieved while still ignoring unknown arguments. + if self.ignore_unknown_options and unknown_options: + state.largs.append(f"{prefix}{''.join(unknown_options)}") + + def _get_value_from_state( + self, option_name: str, option: _Option, state: _ParsingState + ) -> str | Sequence[str]: + nargs = option.nargs + + value: str | Sequence[str] + + if len(state.rargs) < nargs: + msg = "an argument." if nargs == 1 else f"{nargs} arguments." + raise BadOptionUsage( + option_name, + f"Option {option_name!r} requires {msg}", + ) + elif nargs == 1: + value = state.rargs.pop(0) + else: + value = tuple(state.rargs[:nargs]) + del state.rargs[:nargs] + + return value + + def _process_opts(self, arg: str, state: _ParsingState) -> None: + explicit_value = None + # Long option handling happens in two parts. The first part is + # supporting explicitly attached values. In any case, we will try + # to long match the option first. + if "=" in arg: + long_opt, explicit_value = arg.split("=", 1) + else: + long_opt = arg + norm_long_opt = _normalize_opt(long_opt, self.ctx) + + # At this point we will match the (assumed) long option through + # the long option matching code. Note that this allows options + # like "-foo" to be matched as long options. + try: + self._match_long_opt(norm_long_opt, explicit_value, state) + except NoSuchOption: + # At this point the long option matching failed, and we need + # to try with short options. However there is a special rule + # which says, that if we have a two character options prefix + # (applies to "--foo" for instance), we do not dispatch to the + # short option code and will instead raise the no option + # error. + if arg[:2] not in self._opt_prefixes: + self._match_short_opt(arg, state) + return + + if not self.ignore_unknown_options: + raise + + state.largs.append(arg) diff --git a/typer/_click/py.typed b/typer/_click/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/typer/_click/shell_completion.py b/typer/_click/shell_completion.py new file mode 100644 index 0000000000..4490f18810 --- /dev/null +++ b/typer/_click/shell_completion.py @@ -0,0 +1,306 @@ +import re +from abc import ABC, abstractmethod +from collections.abc import MutableMapping +from typing import Any, ClassVar, TypeVar + +from .core import Command, Context, Parameter, ParameterSource + + +class CompletionItem: + """Represents a completion value and metadata about the value. The + default metadata is ``type`` to indicate special shell handling, + and ``help`` if a shell supports showing a help string next to the + value. + + Arbitrary parameters can be passed when creating the object, and + accessed using ``item.attr``. If an attribute wasn't passed, + accessing it returns ``None``. + """ + + __slots__ = ("value", "type", "help", "_info") + + def __init__( + self, + value: Any, + type: str = "plain", + help: str | None = None, + **kwargs: Any, + ) -> None: + self.value: Any = value + self.type: str = type + self.help: str | None = help + self._info = kwargs + + def __getattr__(self, name: str) -> Any: + return self._info.get(name) + + +class ShellComplete(ABC): + """Base class for providing shell completion support. A subclass for + a given shell will override attributes and methods to implement the + completion instructions (``source`` and ``complete``). + """ + + name: ClassVar[str] + """Name to register the shell as with `add_completion_class`. + This is used in completion instructions (``{name}_source`` and + ``{name}_complete``). + """ + + source_template: ClassVar[str] + """Completion script template formatted by `source`. This must + be provided by subclasses. + """ + + def __init__( + self, + cli: Command, + ctx_args: MutableMapping[str, Any], + prog_name: str, + complete_var: str, + ) -> None: + self.cli = cli + self.ctx_args = ctx_args + self.prog_name = prog_name + self.complete_var = complete_var + + @property + def func_name(self) -> str: + """The name of the shell function defined by the completion + script. + """ + safe_name = re.sub(r"\W*", "", self.prog_name.replace("-", "_"), flags=re.ASCII) + return f"_{safe_name}_completion" + + @abstractmethod + def source_vars(self) -> dict[str, Any]: + """Vars for formatting `source_template`.""" + pass # pragma: no cover + + def source(self) -> str: + """Produce the shell script that defines the completion + function. By default this ``%``-style formats + `source_template` with the dict returned by `source_vars`. + """ + return self.source_template % self.source_vars() + + @abstractmethod + def get_completion_args(self) -> tuple[list[str], str]: + """Use the env vars defined by the shell script to return a + tuple of ``args, incomplete``. This must be implemented by + subclasses. + """ + pass # pragma: no cover + + def get_completions(self, args: list[str], incomplete: str) -> list[CompletionItem]: + """Determine the context and last complete command or parameter + from the complete args. Call that object's ``shell_complete`` + method to get the completions for the incomplete value. + """ + ctx = _resolve_context(self.cli, self.ctx_args, self.prog_name, args) + obj, incomplete = _resolve_incomplete(ctx, args, incomplete) + return obj.shell_complete(ctx, incomplete) + + @abstractmethod + def format_completion(self, item: CompletionItem) -> str: + """Format a completion item into the form recognized by the + shell script. This must be implemented by subclasses. + """ + pass # pragma: no cover + + def complete(self) -> str: + """Produce the completion data to send back to the shell. + + By default this calls `get_completion_args`, gets the + completions, then calls `format_completion` for each + completion. + """ + args, incomplete = self.get_completion_args() + completions = self.get_completions(args, incomplete) + out = [self.format_completion(item) for item in completions] + return "\n".join(out) + + +ShellCompleteType = TypeVar("ShellCompleteType", bound="type[ShellComplete]") + + +_available_shells: dict[str, type[ShellComplete]] = {} + + +def add_completion_class(cls: ShellCompleteType, name: str) -> ShellCompleteType: + """Register a `ShellComplete` subclass under the given name. + The name will be provided by the completion instruction environment + variable during completion. + """ + _available_shells[name] = cls + + return cls + + +def get_completion_class(shell: str) -> type[ShellComplete] | None: + """Look up a registered `ShellComplete` subclass by the name + provided by the completion instruction environment variable. If the + name isn't registered, returns ``None``. + """ + return _available_shells.get(shell) + + +def split_arg_string(string: str) -> list[str]: + """Split an argument string as with `shlex.split`, but don't + fail if the string is incomplete. Ignores a missing closing quote or + incomplete escape sequence and uses the partial token as-is. + """ + import shlex + + lex = shlex.shlex(string, posix=True) + lex.whitespace_split = True + lex.commenters = "" + out = [] + + try: + for token in lex: + out.append(token) + except ValueError: # pragma: no cover + # Raised when end-of-string is reached in an invalid state. Use + # the partial token as-is. The quote or escape character is in + # lex.state, not lex.token. + out.append(lex.token) + + return out + + +def _is_incomplete_argument(ctx: Context, param: Parameter) -> bool: + """Determine if the given parameter is an argument that can still + accept values. + """ + # avoid circular imports + from ..core import TyperArgument + + if not isinstance(param, TyperArgument): + return False + + assert param.name is not None + # Will be None if expose_value is False. + value = ctx.params.get(param.name) + return ( + param.nargs == -1 + or ctx.get_parameter_source(param.name) is not ParameterSource.COMMANDLINE + or ( + param.nargs > 1 + and isinstance(value, (tuple, list)) + and len(value) < param.nargs + ) + ) + + +def _start_of_option(ctx: Context, value: str) -> bool: + """Check if the value looks like the start of an option.""" + if not value: + return False + + c = value[0] + return c in ctx._opt_prefixes + + +def _is_incomplete_option(ctx: Context, args: list[str], param: Parameter) -> bool: + """Determine if the given parameter is an option that needs a value.""" + # avoid circular imports + from ..core import TyperOption + + if not isinstance(param, TyperOption): + return False + + if param.is_flag or param.count: + return False + + last_option = None + + for index, arg in enumerate(reversed(args)): + if index + 1 > param.nargs: + break + + if _start_of_option(ctx, arg): + last_option = arg + break + + return last_option is not None and last_option in param.opts + + +def _resolve_context( + cli: Command, + ctx_args: MutableMapping[str, Any], + prog_name: str, + args: list[str], +) -> Context: + """Produce the context hierarchy starting with the command and + traversing the complete arguments. This only follows the commands, + it doesn't trigger input prompts or callbacks. + """ + # avoid circular imports + from ..core import TyperGroup + + ctx_args["resilient_parsing"] = True + with cli.make_context(prog_name, args.copy(), **ctx_args) as ctx: + args = ctx._protected_args + ctx.args + + while args: + command = ctx.command + + if isinstance(command, TyperGroup): + # if not command.chain: + name, cmd, args = command.resolve_command(ctx, args) + + if cmd is None: + return ctx + + with cmd.make_context( + name, args, parent=ctx, resilient_parsing=True + ) as sub_ctx: + ctx = sub_ctx + args = ctx._protected_args + ctx.args + else: # pragma: no cover + break + + return ctx + + +def _resolve_incomplete( + ctx: Context, args: list[str], incomplete: str +) -> tuple[Command | Parameter, str]: + """Find the Click object that will handle the completion of the + incomplete value. Return the object and the incomplete value. + """ + # Different shells treat an "=" between a long option name and + # value differently. Might keep the value joined, return the "=" + # as a separate item, or return the split name and value. Always + # split and discard the "=" to make completion easier. + if incomplete == "=": + incomplete = "" + elif "=" in incomplete and _start_of_option(ctx, incomplete): + name, _, incomplete = incomplete.partition("=") + args.append(name) + + # The "--" marker tells Click to stop treating values as options + # even if they start with the option character. If it hasn't been + # given and the incomplete arg looks like an option, the current + # command will provide option name completions. + if "--" not in args and _start_of_option(ctx, incomplete): + return ctx.command, incomplete + + params = ctx.command.get_params(ctx) + + # If the last complete arg is an option name with an incomplete + # value, the option will provide value completions. + for param in params: + if _is_incomplete_option(ctx, args, param): + return param, incomplete + + # It's not an option name or value. The first argument without a + # parsed value will provide value completions. + for param in params: + if _is_incomplete_argument(ctx, param): + return param, incomplete + + # There were no unparsed arguments, the command may be a group that + # will provide command name completions. + return ctx.command, incomplete diff --git a/typer/_click/termui.py b/typer/_click/termui.py new file mode 100644 index 0000000000..0a8c82574d --- /dev/null +++ b/typer/_click/termui.py @@ -0,0 +1,430 @@ +import io +from collections.abc import Callable, Iterable +from contextlib import AbstractContextManager +from typing import IO, TYPE_CHECKING, Any, AnyStr, TextIO, TypeVar, overload + +from .exceptions import Abort, UsageError +from .globals import resolve_color_default +from .types import ParamType, convert_type +from .utils import LazyFile, echo + +if TYPE_CHECKING: + from ._termui_impl import ProgressBar + +V = TypeVar("V") + +# The prompt functions to use. The doc tools currently override these +# functions to customize how they work. +visible_prompt_func: Callable[[str], str] = input + +_ansi_colors = { + "black": 30, + "red": 31, + "green": 32, + "yellow": 33, + "blue": 34, + "magenta": 35, + "cyan": 36, + "white": 37, + "reset": 39, + "bright_black": 90, + "bright_red": 91, + "bright_green": 92, + "bright_yellow": 93, + "bright_blue": 94, + "bright_magenta": 95, + "bright_cyan": 96, + "bright_white": 97, +} +_ansi_reset_all = "\033[0m" + + +def hidden_prompt_func(prompt: str) -> str: + import getpass + + return getpass.getpass(prompt) + + +def _build_prompt( + text: str, + suffix: str, + show_default: bool = False, + default: Any | None = None, + show_choices: bool = True, + type: ParamType | None = None, +) -> str: + # prevent circular imports + from .._types import TyperChoice + + prompt = text + if type is not None and show_choices and isinstance(type, TyperChoice): + prompt += f" ({', '.join(map(str, type.choices))})" + if default is not None and show_default: + prompt = f"{prompt} [{_format_default(default)}]" + return f"{prompt}{suffix}" + + +def _format_default(default: Any) -> Any: + if isinstance(default, (io.IOBase, LazyFile)) and hasattr(default, "name"): + return default.name + + return default + + +def prompt( + text: str, + default: Any | None = None, + hide_input: bool = False, + confirmation_prompt: bool | str = False, + type: ParamType | Any | None = None, + value_proc: Callable[[str], Any] | None = None, + prompt_suffix: str = ": ", + show_default: bool = True, + err: bool = False, + show_choices: bool = True, +) -> Any: + """Prompts a user for input. This is a convenience function that can + be used to prompt a user for input later. + + If the user aborts the input by sending an interrupt signal, this + function will catch it and raise an `Abort` exception. + """ + + def prompt_func(text: str) -> str: + f = hidden_prompt_func if hide_input else visible_prompt_func + try: + # Write the prompt separately so that we get nice + # coloring through colorama on Windows + echo(text[:-1], nl=False, err=err) + # Echo the last character to stdout to work around an issue where + # readline causes backspace to clear the whole line. + return f(text[-1:]) + except (KeyboardInterrupt, EOFError): # pragma: no cover + # getpass doesn't print a newline if the user aborts input with ^C. + # Allegedly this behavior is inherited from getpass(3). + # A doc bug has been filed at https://bugs.python.org/issue24711 + if hide_input: + echo(None, err=err) + raise Abort() from None + + if value_proc is None: + value_proc = convert_type(type, default) + + prompt = _build_prompt( + text, prompt_suffix, show_default, default, show_choices, type + ) + + if confirmation_prompt: + if confirmation_prompt is True: + confirmation_prompt = "Repeat for confirmation" + + confirmation_prompt = _build_prompt(confirmation_prompt, prompt_suffix) + + while True: + while True: + value = prompt_func(prompt) + if value: + break + elif default is not None: + value = default + break + try: + result = value_proc(value) + except UsageError as e: # pragma: no cover + if hide_input: + echo("Error: The value you entered was invalid.", err=err) + else: + echo(f"Error: {e.message}", err=err) + continue + if not confirmation_prompt: + return result + while True: + value2 = prompt_func(confirmation_prompt) + is_empty = not value and not value2 + if value2 or is_empty: + break + if value == value2: + return result + echo("Error: The two entered values do not match.", err=err) + + +def confirm( + text: str, + default: bool | None = False, + abort: bool = False, + prompt_suffix: str = ": ", + show_default: bool = True, + err: bool = False, +) -> bool: + """Prompts for confirmation (yes/no question). + + If the user aborts the input by sending a interrupt signal this + function will catch it and raise an `Abort` exception. + """ + prompt = _build_prompt( + text, + prompt_suffix, + show_default, + "y/n" if default is None else ("Y/n" if default else "y/N"), + ) + + while True: + try: + # Write the prompt separately so that we get nice + # coloring through colorama on Windows + echo(prompt[:-1], nl=False, err=err) + # Echo the last character to stdout to work around an issue where + # readline causes backspace to clear the whole line. + value = visible_prompt_func(prompt[-1:]).lower().strip() + except (KeyboardInterrupt, EOFError): # pragma: no cover + raise Abort() from None + if value in ("y", "yes"): + rv = True + elif value in ("n", "no"): + rv = False + elif default is not None and value == "": + rv = default + else: # pragma: no cover + echo("Error: invalid input", err=err) + continue + break + if abort and not rv: + raise Abort() + return rv + + +@overload +def progressbar( + *, + length: int, + label: str | None = None, + hidden: bool = False, + show_eta: bool = True, + show_percent: bool | None = None, + show_pos: bool = False, + fill_char: str = "#", + empty_char: str = "-", + bar_template: str = "%(label)s [%(bar)s] %(info)s", + info_sep: str = " ", + width: int = 36, + file: TextIO | None = None, + color: bool | None = None, + update_min_steps: int = 1, +) -> "ProgressBar[int]": ... + + +@overload +def progressbar( + iterable: Iterable[V] | None = None, + length: int | None = None, + label: str | None = None, + hidden: bool = False, + show_eta: bool = True, + show_percent: bool | None = None, + show_pos: bool = False, + item_show_func: Callable[[V | None], str | None] | None = None, + fill_char: str = "#", + empty_char: str = "-", + bar_template: str = "%(label)s [%(bar)s] %(info)s", + info_sep: str = " ", + width: int = 36, + file: TextIO | None = None, + color: bool | None = None, + update_min_steps: int = 1, +) -> "ProgressBar[V]": ... + + +def progressbar( + iterable: Iterable[V] | None = None, + length: int | None = None, + label: str | None = None, + hidden: bool = False, + show_eta: bool = True, + show_percent: bool | None = None, + show_pos: bool = False, + item_show_func: Callable[[V | None], str | None] | None = None, + fill_char: str = "#", + empty_char: str = "-", + bar_template: str = "%(label)s [%(bar)s] %(info)s", + info_sep: str = " ", + width: int = 36, + file: TextIO | None = None, + color: bool | None = None, + update_min_steps: int = 1, +) -> "ProgressBar[V]": + """This function creates an iterable context manager that can be used + to iterate over something while showing a progress bar. It will + either iterate over the `iterable` or `length` items (that are counted + up). While iteration happens, this function will print a rendered + progress bar to the given `file` (defaults to stdout) and will attempt + to calculate remaining time and more. By default, this progress bar + will not be rendered if the file is not a terminal. + + The context manager creates the progress bar. When the context + manager is entered the progress bar is already created. With every + iteration over the progress bar, the iterable passed to the bar is + advanced and the bar is updated. When the context manager exits, + a newline is printed and the progress bar is finalized on screen. + + Note: The progress bar is currently designed for use cases where the + total progress can be expected to take at least several seconds. + Because of this, the ProgressBar class object won't display + progress that is considered too fast, and progress where the time + between steps is less than a second. + + No printing must happen or the progress bar will be unintentionally + destroyed. + """ + from ._termui_impl import ProgressBar + + color = resolve_color_default(color) + return ProgressBar( + iterable=iterable, + length=length, + hidden=hidden, + show_eta=show_eta, + show_percent=show_percent, + show_pos=show_pos, + item_show_func=item_show_func, + fill_char=fill_char, + empty_char=empty_char, + bar_template=bar_template, + info_sep=info_sep, + file=file, + label=label, + width=width, + color=color, + update_min_steps=update_min_steps, + ) + + +def _interpret_color(color: int | tuple[int, int, int] | str, offset: int = 0) -> str: + if isinstance(color, int): + return f"{38 + offset};5;{color:d}" + + if isinstance(color, (tuple, list)): + r, g, b = color + return f"{38 + offset};2;{r:d};{g:d};{b:d}" + + return str(_ansi_colors[color] + offset) + + +def style( + text: Any, + fg: int | tuple[int, int, int] | str | None = None, + bg: int | tuple[int, int, int] | str | None = None, + bold: bool | None = None, + dim: bool | None = None, + underline: bool | None = None, + overline: bool | None = None, + italic: bool | None = None, + blink: bool | None = None, + reverse: bool | None = None, + strikethrough: bool | None = None, + reset: bool = True, +) -> str: + """Styles a text with ANSI styles and returns the new string. By + default the styling is self contained which means that at the end + of the string a reset code is issued. This can be prevented by + passing ``reset=False``. + """ + if not isinstance(text, str): + text = str(text) + + bits = [] + + if fg: + try: + bits.append(f"\033[{_interpret_color(fg)}m") + except KeyError: + raise TypeError(f"Unknown color {fg!r}") from None + + if bg: + try: + bits.append(f"\033[{_interpret_color(bg, 10)}m") + except KeyError: + raise TypeError(f"Unknown color {bg!r}") from None + + if bold is not None: + bits.append(f"\033[{1 if bold else 22}m") + if dim is not None: + bits.append(f"\033[{2 if dim else 22}m") + if underline is not None: + bits.append(f"\033[{4 if underline else 24}m") + if overline is not None: + bits.append(f"\033[{53 if overline else 55}m") + if italic is not None: + bits.append(f"\033[{3 if italic else 23}m") + if blink is not None: + bits.append(f"\033[{5 if blink else 25}m") + if reverse is not None: + bits.append(f"\033[{7 if reverse else 27}m") + if strikethrough is not None: + bits.append(f"\033[{9 if strikethrough else 29}m") + bits.append(text) + if reset: + bits.append(_ansi_reset_all) + return "".join(bits) + + +def secho( + message: Any | None = None, + file: IO[AnyStr] | None = None, + nl: bool = True, + err: bool = False, + color: bool | None = None, + **styles: Any, +) -> None: + """This function combines `echo` and `style` into one call.""" + if message is not None and not isinstance(message, (bytes, bytearray)): + message = style(message, **styles) + + return echo(message, file=file, nl=nl, err=err, color=color) + + +def launch(url: str, wait: bool = False, locate: bool = False) -> int: + """This function launches the given URL (or filename) in the default + viewer application for this file type. If this is an executable, it + might launch the executable in a new session. The return value is + the exit code of the launched application. Usually, ``0`` indicates + success. + """ + from ._termui_impl import open_url + + return open_url(url, wait=wait, locate=locate) + + +# If this is provided, getchar() calls into this instead. This is used +# for unittesting purposes. +_getchar: Callable[[bool], str] | None = None + + +def getchar(echo: bool = False) -> str: + """Fetches a single character from the terminal and returns it. This + will always return a unicode character and under certain rare + circumstances this might return more than one character. The + situations which more than one character is returned is when for + whatever reason multiple characters end up in the terminal buffer or + standard input was not actually a terminal. + + Note that this will always read from the terminal, even if something + is piped into the standard input. + + Note for Windows: in rare cases when typing non-ASCII characters, this + function might wait for a second character and then return both at once. + This is because certain Unicode characters look like special-key markers. + """ + global _getchar + + if _getchar is None: + from ._termui_impl import getchar as f + + _getchar = f + + return _getchar(echo) + + +def raw_terminal() -> AbstractContextManager[int]: + from ._termui_impl import raw_terminal as f + + return f() diff --git a/typer/_click/testing.py b/typer/_click/testing.py new file mode 100644 index 0000000000..0d8c03a790 --- /dev/null +++ b/typer/_click/testing.py @@ -0,0 +1,366 @@ +import contextlib +import io +import os +import shlex +import sys +from collections.abc import Iterator, Mapping, Sequence +from types import TracebackType +from typing import IO, TYPE_CHECKING, Any, BinaryIO, cast + +from . import _compat, formatting, termui, utils +from ._compat import _find_binary_reader +from .core import Command + +if TYPE_CHECKING: + from _typeshed import ReadableBuffer + + +class BytesIOCopy(io.BytesIO): + """Patch ``io.BytesIO`` to let the written stream be copied to another.""" + + def __init__(self, copy_to: io.BytesIO) -> None: + super().__init__() + self.copy_to = copy_to + + def flush(self) -> None: + super().flush() + self.copy_to.flush() + + def write(self, b: "ReadableBuffer") -> int: + self.copy_to.write(b) + return super().write(b) + + +class StreamMixer: + """Mixes `` and `` streams. + + The result is available in the ``output`` attribute. + """ + + def __init__(self) -> None: + self.output: io.BytesIO = io.BytesIO() + self.stdout: io.BytesIO = BytesIOCopy(copy_to=self.output) + self.stderr: io.BytesIO = BytesIOCopy(copy_to=self.output) + + def __del__(self) -> None: + """ + Guarantee that embedded file-like objects are closed in a + predictable order, protecting against races between + self.output being closed and other streams being flushed on close + """ + self.stderr.close() + self.stdout.close() + self.output.close() + + +class _NamedTextIOWrapper(io.TextIOWrapper): + def __init__(self, buffer: BinaryIO, name: str, mode: str, **kwargs: Any) -> None: + super().__init__(buffer, **kwargs) + self._name = name + self._mode = mode + + @property + def name(self) -> str: + return self._name # pragma: no cover + + @property + def mode(self) -> str: + return self._mode # pragma: no cover + + +def make_input_stream(input: str | bytes | IO[Any] | None, charset: str) -> BinaryIO: + # Is already an input stream. + if hasattr(input, "read"): + rv = _find_binary_reader(cast("IO[Any]", input)) + + if rv is not None: + return rv + + raise TypeError( + "Could not find binary reader for input stream." + ) # pragma: no cover + + if input is None: + input = b"" + elif isinstance(input, str): + input = input.encode(charset) + + return io.BytesIO(input) + + +class Result: + """Holds the captured result of an invoked CLI script.""" + + def __init__( + self, + runner: "CliRunner", + stdout_bytes: bytes, + stderr_bytes: bytes, + output_bytes: bytes, + return_value: Any, + exit_code: int, + exception: BaseException | None, + exc_info: tuple[type[BaseException], BaseException, TracebackType] + | None = None, + ): + self.runner = runner + self.stdout_bytes = stdout_bytes + self.stderr_bytes = stderr_bytes + self.output_bytes = output_bytes + self.return_value = return_value + self.exit_code = exit_code + self.exception = exception + self.exc_info = exc_info + + @property + def output(self) -> str: + """The terminal output as unicode string, as the user would see it.""" + return self.output_bytes.decode(self.runner.charset, "replace").replace( + "\r\n", "\n" + ) + + @property + def stdout(self) -> str: + """The standard output as unicode string.""" + return self.stdout_bytes.decode(self.runner.charset, "replace").replace( + "\r\n", "\n" + ) + + @property + def stderr(self) -> str: + """The standard error as unicode string.""" + return self.stderr_bytes.decode(self.runner.charset, "replace").replace( + "\r\n", "\n" + ) + + def __repr__(self) -> str: + exc_str = repr(self.exception) if self.exception else "okay" + return f"<{type(self).__name__} {exc_str}>" + + +class CliRunner: + """The CLI runner provides functionality to invoke a Click command line + script for unittesting purposes in a isolated environment. This only + works in single-threaded systems without any concurrency as it changes the + global interpreter state. + """ + + def __init__( + self, + charset: str = "utf-8", + env: Mapping[str, str | None] | None = None, + catch_exceptions: bool = True, + ) -> None: + self.charset = charset + self.env: Mapping[str, str | None] = env or {} + self.catch_exceptions = catch_exceptions + + def get_default_prog_name(self, cli: Command) -> str: + """Given a command object it will return the default program name + for it. The default is the `name` attribute or ``"root"`` if not + set. + """ + return cli.name or "root" + + def make_env( + self, overrides: Mapping[str, str | None] | None = None + ) -> Mapping[str, str | None]: + """Returns the environment overrides for invoking a script.""" + rv = dict(self.env) + if overrides: + rv.update(overrides) + return rv + + @contextlib.contextmanager + def isolation( + self, + input: str | bytes | IO[Any] | None = None, + env: Mapping[str, str | None] | None = None, + color: bool = False, + ) -> Iterator[tuple[io.BytesIO, io.BytesIO, io.BytesIO]]: + """A context manager that sets up the isolation for invoking of a + command line tool. This sets up `` with the given input data + and `os.environ` with the overrides from the given dictionary. + This also rebinds some internals in Click to be mocked (like the + prompt functionality). + """ + bytes_input = make_input_stream(input, self.charset) + + old_stdin = sys.stdin + old_stdout = sys.stdout + old_stderr = sys.stderr + old_forced_width = formatting.FORCED_WIDTH + formatting.FORCED_WIDTH = 80 + + env = self.make_env(env) + + stream_mixer = StreamMixer() + + sys.stdin = text_input = _NamedTextIOWrapper( + bytes_input, encoding=self.charset, name="", mode="r" + ) + + sys.stdout = _NamedTextIOWrapper( + stream_mixer.stdout, encoding=self.charset, name="", mode="w" + ) + + sys.stderr = _NamedTextIOWrapper( + stream_mixer.stderr, + encoding=self.charset, + name="", + mode="w", + errors="backslashreplace", + ) + + def visible_input(prompt: str | None = None) -> str: + sys.stdout.write(prompt or "") + try: + val = next(text_input).rstrip("\r\n") + except StopIteration as e: # pragma: no cover + raise EOFError() from e + sys.stdout.write(f"{val}\n") + sys.stdout.flush() + return val + + def hidden_input(prompt: str | None = None) -> str: + sys.stdout.write(f"{prompt or ''}\n") + sys.stdout.flush() + try: + return next(text_input).rstrip("\r\n") + except StopIteration as e: # pragma: no cover + raise EOFError() from e + + def _getchar(echo: bool) -> str: + char = sys.stdin.read(1) + + if echo: + sys.stdout.write(char) + + sys.stdout.flush() + return char + + default_color = color + + def should_strip_ansi( + stream: IO[Any] | None = None, color: bool | None = None + ) -> bool: + if color is None: + return not default_color + return not color + + old_visible_prompt_func = termui.visible_prompt_func + old_hidden_prompt_func = termui.hidden_prompt_func + old__getchar_func = termui._getchar + old_should_strip_ansi = utils.should_strip_ansi # type: ignore[attr-defined] + old__compat_should_strip_ansi = _compat.should_strip_ansi + termui.visible_prompt_func = visible_input + termui.hidden_prompt_func = hidden_input # ty: ignore[invalid-assignment] + termui._getchar = _getchar + utils.should_strip_ansi = should_strip_ansi # type: ignore + _compat.should_strip_ansi = should_strip_ansi # ty: ignore[invalid-assignment] + + old_env = {} + try: + for key, value in env.items(): + old_env[key] = os.environ.get(key) + if value is None: + try: + del os.environ[key] + except Exception: # pragma: no cover + pass + else: + os.environ[key] = value + yield (stream_mixer.stdout, stream_mixer.stderr, stream_mixer.output) + finally: + for key, value in old_env.items(): + if value is None: + try: + del os.environ[key] + except Exception: # pragma: no cover + pass + else: + os.environ[key] = value + sys.stdout = old_stdout + sys.stderr = old_stderr + sys.stdin = old_stdin + termui.visible_prompt_func = old_visible_prompt_func + termui.hidden_prompt_func = old_hidden_prompt_func + termui._getchar = old__getchar_func + utils.should_strip_ansi = old_should_strip_ansi # type: ignore[attr-defined] + _compat.should_strip_ansi = old__compat_should_strip_ansi + formatting.FORCED_WIDTH = old_forced_width + + def invoke( + self, + cli: Command, + args: str | Sequence[str] | None = None, + input: str | bytes | IO[Any] | None = None, + env: Mapping[str, str | None] | None = None, + catch_exceptions: bool | None = None, + color: bool = False, + **extra: Any, + ) -> Result: + """Invokes a command in an isolated environment. The arguments are + forwarded directly to the command line script, the `extra` keyword + arguments are passed to the `Command.main` function of + the command. + """ + exc_info = None + if catch_exceptions is None: + catch_exceptions = self.catch_exceptions + + with self.isolation(input=input, env=env, color=color) as outstreams: + return_value = None + exception: BaseException | None = None + exit_code = 0 + + if isinstance(args, str): + args = shlex.split(args) + + try: + prog_name = extra.pop("prog_name") + except KeyError: + prog_name = self.get_default_prog_name(cli) + + try: + return_value = cli.main(args=args or (), prog_name=prog_name, **extra) + except SystemExit as e: + exc_info = sys.exc_info() + e_code = cast("int | Any | None", e.code) + + if e_code is None: + e_code = 0 + + if e_code != 0: + exception = e + + if not isinstance(e_code, int): + sys.stdout.write(str(e_code)) + sys.stdout.write("\n") + e_code = 1 + + exit_code = e_code + + except Exception as e: + if not catch_exceptions: + raise + exception = e + exit_code = 1 + exc_info = sys.exc_info() + finally: + sys.stdout.flush() + sys.stderr.flush() + stdout = outstreams[0].getvalue() + stderr = outstreams[1].getvalue() + output = outstreams[2].getvalue() + + return Result( + runner=self, + stdout_bytes=stdout, + stderr_bytes=stderr, + output_bytes=output, + return_value=return_value, + exit_code=exit_code, + exception=exception, + exc_info=exc_info, # type: ignore + ) diff --git a/typer/_click/types.py b/typer/_click/types.py new file mode 100644 index 0000000000..5ccf15fe1b --- /dev/null +++ b/typer/_click/types.py @@ -0,0 +1,695 @@ +import os +import sys +from collections.abc import Callable, Sequence +from datetime import datetime +from typing import ( + IO, + TYPE_CHECKING, + Any, + ClassVar, + Literal, + NoReturn, + TypedDict, + TypeGuard, + TypeVar, + Union, + cast, +) + +from ._compat import _get_argv_encoding, open_stream +from .exceptions import BadParameter +from .utils import LazyFile, format_filename, safecall + +if TYPE_CHECKING: + from .core import Context, Parameter + from .shell_completion import CompletionItem + +ParamTypeValue = TypeVar("ParamTypeValue") + + +class ParamType: + """Represents the type of a parameter. Validates and converts values + from the command line or Python into the correct type. + + To implement a custom type, subclass and implement at least the + following: + + - The `name` class attribute must be set. + - Calling an instance of the type with ``None`` must return + ``None``. This is already implemented by default. + - `convert` must convert string values to the correct type. + - `convert` must accept values that are already the correct + type. + - It must be able to convert a value if the ``ctx`` and ``param`` + arguments are ``None``. This can occur when converting prompt + input. + """ + + is_composite: ClassVar[bool] = False + arity: ClassVar[int] = 1 + name: str + + # if a list of this type is expected and the value is pulled from a + # string environment variable, this is what splits it up. `None` + # means any whitespace. For all parameters the general rule is that + # whitespace splits them up. The exception are paths and files which + # are split by ``os.path.pathsep`` by default (":" on Unix and ";" on + # Windows). + envvar_list_splitter: ClassVar[str | None] = None + + def __call__( + self, + value: Any, + param: Union["Parameter", None] = None, + ctx: Union["Context", None] = None, + ) -> Any: + if value is not None: + return self.convert(value, param, ctx) + + def get_metavar(self, param: "Parameter", ctx: "Context") -> str | None: + """Returns the metavar default for this param if it provides one.""" + pass # pragma: no cover + + def get_missing_message( + self, param: "Parameter", ctx: Union["Context", None] + ) -> str | None: + """Optionally might return extra information about a missing + parameter. + """ + pass # pragma: no cover + + def convert( + self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] + ) -> Any: + pass # pragma: no cover + + def split_envvar_value(self, rv: str) -> Sequence[str]: + """Given a value from an environment variable this splits it up + into small chunks depending on the defined envvar list splitter. + + If the splitter is set to `None`, which means that whitespace splits, + then leading and trailing whitespace is ignored. Otherwise, leading + and trailing splitters usually lead to empty items being included. + """ + return (rv or "").split(self.envvar_list_splitter) + + def fail( + self, + message: str, + param: Union["Parameter", None] = None, + ctx: Union["Context", None] = None, + ) -> NoReturn: + """Helper method to fail with an invalid value message.""" + raise BadParameter(message, ctx=ctx, param=param) + + def shell_complete( + self, ctx: "Context", param: "Parameter", incomplete: str + ) -> list["CompletionItem"]: + """Return a list of `CompletionItem` objects for the + incomplete value. Most types do not provide completions, but + some do, and this allows custom types to provide custom + completions as well. + """ + return [] + + +class CompositeParamType(ParamType): + is_composite = True + + @property + def arity(self) -> int: # type: ignore + raise NotImplementedError() # pragma: no cover + + +class FuncParamType(ParamType): + def __init__(self, func: Callable[[Any], Any]) -> None: + self.name: str = getattr(func, "__name__", "function") + self.func = func + + def convert( + self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] + ) -> Any: + try: + return self.func(value) + except ValueError: + try: + value = str(value) + except UnicodeError: # pragma: no cover + assert isinstance(value, bytes) + value = value.decode("utf-8", "replace") + + self.fail(value, param, ctx) + + +class StringParamType(ParamType): + name = "text" + + def convert( + self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] + ) -> Any: + if isinstance(value, bytes): + enc = _get_argv_encoding() + try: + value = value.decode(enc) + except UnicodeError: + fs_enc = sys.getfilesystemencoding() + if fs_enc != enc: + try: + value = value.decode(fs_enc) + except UnicodeError: + value = value.decode("utf-8", "replace") + else: + value = value.decode("utf-8", "replace") + return value + return str(value) + + def __repr__(self) -> str: + return "STRING" + + +class DateTime(ParamType): + """The DateTime type converts date strings into `datetime` objects. + + The format strings which are checked are configurable, but default to some + common (non-timezone aware) ISO 8601 formats. + + When specifying *DateTime* formats, you should only pass a list or a tuple. + Other iterables, like generators, may lead to surprising results. + + The format strings are processed using ``datetime.strptime``, and this + consequently defines the format strings which are allowed. + + Parsing is tried using each format, in order, and the first format which + parses successfully is used. + """ + + name = "datetime" + + def __init__(self, formats: Sequence[str] | None = None): + self.formats: Sequence[str] = formats or [ + "%Y-%m-%d", + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%d %H:%M:%S", + ] + + def get_metavar(self, param: "Parameter", ctx: "Context") -> str | None: + return f"[{'|'.join(self.formats)}]" + + def _try_to_convert_date(self, value: Any, format: str) -> datetime | None: + try: + return datetime.strptime(value, format) + except ValueError: + return None + + def convert( + self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] + ) -> Any: + if isinstance(value, datetime): + return value + + for format in self.formats: + converted = self._try_to_convert_date(value, format) + + if converted is not None: + return converted + + formats_str = ", ".join(map(repr, self.formats)) + self.fail( + f"{value!r} does not match the formats {formats_str}.", + param, + ctx, + ) + + def __repr__(self) -> str: + return "DateTime" + + +class _NumberParamTypeBase(ParamType): + _number_class: ClassVar[type[Any]] + + def convert( + self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] + ) -> Any: + try: + return self._number_class(value) + except ValueError: + self.fail( + f"{value!r} is not a valid {self.name}.", + param, + ctx, + ) + + +class _NumberRangeBase(_NumberParamTypeBase): + def __init__( + self, + min: float | None = None, + max: float | None = None, + min_open: bool = False, + max_open: bool = False, + clamp: bool = False, + ) -> None: + self.min = min + self.max = max + self.min_open = min_open + self.max_open = max_open + self.clamp = clamp + + def convert( + self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] + ) -> Any: + import operator + + rv = super().convert(value, param, ctx) + lt_min: bool = self.min is not None and ( + operator.le if self.min_open else operator.lt + )(rv, self.min) + gt_max: bool = self.max is not None and ( + operator.ge if self.max_open else operator.gt + )(rv, self.max) + + if self.clamp: + if lt_min: + return self._clamp(self.min, 1, self.min_open) # type: ignore[arg-type] + + if gt_max: + return self._clamp(self.max, -1, self.max_open) # type: ignore[arg-type] + + if lt_min or gt_max: + self.fail( + f"{rv} is not in the range {self._describe_range()}.", + param, + ctx, + ) + + return rv + + def _clamp(self, bound: float, dir: Literal[1, -1], open: bool) -> float: + """Find the valid value to clamp to bound in the given + direction. + """ + raise NotImplementedError # pragma: no cover + + def _describe_range(self) -> str: + """Describe the range for use in help text.""" + if self.min is None: + op = "<" if self.max_open else "<=" + return f"x{op}{self.max}" + + if self.max is None: + op = ">" if self.min_open else ">=" + return f"x{op}{self.min}" + + lop = "<" if self.min_open else "<=" + rop = "<" if self.max_open else "<=" + return f"{self.min}{lop}x{rop}{self.max}" + + def __repr__(self) -> str: + clamp = " clamped" if self.clamp else "" + return f"<{type(self).__name__} {self._describe_range()}{clamp}>" + + +class IntParamType(_NumberParamTypeBase): + name = "integer" + _number_class = int + + def __repr__(self) -> str: + return "INT" + + +class IntRange(_NumberRangeBase, IntParamType): + """Restrict an `INT` value to a range of accepted values. See + + If ``min`` or ``max`` are not passed, any value is accepted in that + direction. If ``min_open`` or ``max_open`` are enabled, the + corresponding boundary is not included in the range. + + If ``clamp`` is enabled, a value outside the range is clamped to the + boundary instead of failing. + """ + + name = "integer range" + + def _clamp( # type: ignore + self, bound: int, dir: Literal[1, -1], open: bool + ) -> int: + if not open: + return bound + + return bound + dir + + +class FloatParamType(_NumberParamTypeBase): + name = "float" + _number_class = float + + def __repr__(self) -> str: + return "FLOAT" + + +class FloatRange(_NumberRangeBase, FloatParamType): + """Restrict a `FLOAT` value to a range of accepted + values. See `ranges`. + + If ``min`` or ``max`` are not passed, any value is accepted in that + direction. If ``min_open`` or ``max_open`` are enabled, the + corresponding boundary is not included in the range. + + If ``clamp`` is enabled, a value outside the range is clamped to the + boundary instead of failing. This is not supported if either + boundary is marked ``open``. + """ + + name = "float range" + + def __init__( + self, + min: float | None = None, + max: float | None = None, + min_open: bool = False, + max_open: bool = False, + clamp: bool = False, + ) -> None: + super().__init__( + min=min, max=max, min_open=min_open, max_open=max_open, clamp=clamp + ) + + if (min_open or max_open) and clamp: + raise TypeError("Clamping is not supported for open bounds.") + + def _clamp(self, bound: float, dir: Literal[1, -1], open: bool) -> float: + if not open: + return bound + + # Could use math.nextafter here, but clamping an + # open float range doesn't seem to be particularly useful. It's + # left up to the user to write a callback to do it if needed. + raise RuntimeError( + "Clamping is not supported for open bounds." + ) # pragma: no cover + + +class BoolParamType(ParamType): + name = "boolean" + + bool_states: dict[str, bool] = { + "1": True, + "0": False, + "yes": True, + "no": False, + "true": True, + "false": False, + "on": True, + "off": False, + "t": True, + "f": False, + "y": True, + "n": False, + # Absence of value is considered False. + "": False, + } + """A mapping of string values to boolean states. + + Mapping is inspired by `configparser.ConfigParser.BOOLEAN_STATES` + and extends it. + """ + + @staticmethod + def str_to_bool(value: str | bool) -> bool | None: + """Convert a string to a boolean value. + + If the value is already a boolean, it is returned as-is. If the value is a + string, it is stripped of whitespaces and lower-cased, then checked against + the known boolean states pre-defined in the `BoolParamType.bool_states` mapping + above. + + Returns `None` if the value does not match any known boolean state. + """ + if isinstance(value, bool): + return value + return BoolParamType.bool_states.get(value.strip().lower()) + + def convert( + self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] + ) -> bool: + normalized = self.str_to_bool(value) + if normalized is None: + states = ", ".join(sorted(self.bool_states)) + self.fail( + f"{value!r} is not a valid boolean. Recognized values: {states}", + param, + ctx, + ) + return normalized + + def __repr__(self) -> str: + return "BOOL" + + +class UUIDParameterType(ParamType): + name = "uuid" + + def convert( + self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] + ) -> Any: + import uuid + + if isinstance(value, uuid.UUID): + return value + + value = value.strip() + + try: + return uuid.UUID(value) + except ValueError: + self.fail(f"{value!r} is not a valid UUID.", param, ctx) + + def __repr__(self) -> str: + return "UUID" + + +class File(ParamType): + """Declares a parameter to be a file for reading or writing. The file + is automatically closed once the context tears down (after the command + finished working). + + Files can be opened for reading or writing. The special value ``-`` + indicates stdin or stdout depending on the mode. + + By default, the file is opened for reading text data, but it can also be + opened in binary mode or for writing. The encoding parameter can be used + to force a specific encoding. + + The `lazy` flag controls if the file should be opened immediately or upon + first IO. The default is to be non-lazy for standard input and output + streams as well as files opened for reading, `lazy` otherwise. When opening a + file lazily for reading, it is still opened temporarily for validation, but + will not be held open until first IO. lazy is mainly useful when opening + for writing to avoid creating the file until it is needed. + + Files can also be opened atomically in which case all writes go into a + separate file in the same folder and upon completion the file will + be moved over to the original location. This is useful if a file + regularly read by other users is modified. + """ + + name = "filename" + envvar_list_splitter: ClassVar[str] = os.path.pathsep + + def __init__( + self, + mode: str = "r", + encoding: str | None = None, + errors: str | None = "strict", + lazy: bool | None = None, + atomic: bool = False, + ) -> None: + self.mode = mode + self.encoding = encoding + self.errors = errors + self.lazy = lazy + self.atomic = atomic + + def resolve_lazy_flag(self, value: str | os.PathLike[str]) -> bool: + if self.lazy is not None: + return self.lazy + if os.fspath(value) == "-": + return False + elif "w" in self.mode: + return True + return False + + def convert( + self, + value: str | os.PathLike[str] | IO[Any], + param: Union["Parameter", None], + ctx: Union["Context", None], + ) -> IO[Any]: + if _is_file_like(value): + return value + + value = cast("str | os.PathLike[str]", value) + + try: + lazy = self.resolve_lazy_flag(value) + + if lazy: + lf = LazyFile( + value, self.mode, self.encoding, self.errors, atomic=self.atomic + ) + + if ctx is not None: + ctx.call_on_close(lf.close_intelligently) + + return cast("IO[Any]", lf) + + f, should_close = open_stream( + value, self.mode, self.encoding, self.errors, atomic=self.atomic + ) + + # If a context is provided, we automatically close the file + # at the end of the context execution (or flush out). If a + # context does not exist, it's the caller's responsibility to + # properly close the file. This for instance happens when the + # type is used with prompts. + if ctx is not None: + if should_close: + ctx.call_on_close(safecall(f.close)) + else: + ctx.call_on_close(safecall(f.flush)) + + return f + except OSError as e: # pragma: no cover + self.fail(f"'{format_filename(value)}': {e.strerror}", param, ctx) + + def shell_complete( + self, ctx: "Context", param: "Parameter", incomplete: str + ) -> list["CompletionItem"]: + """Return a special completion marker that tells the completion + system to use the shell to provide file path completions. + """ + from .shell_completion import CompletionItem + + return [CompletionItem(incomplete, type="file")] + + +def _is_file_like(value: Any) -> TypeGuard[IO[Any]]: + return hasattr(value, "read") or hasattr(value, "write") + + +class Tuple(CompositeParamType): + """The default behavior of Click is to apply a type on a value directly. + This works well in most cases, except for when `nargs` is set to a fixed + count and different types should be used for different items. In this + case the `Tuple` type can be used. This type can only be used + if `nargs` is set to a fixed number. + + For more information see `tuple-type`. + + This can be selected by using a Python tuple literal as a type. + """ + + def __init__(self, types: Sequence[type[Any] | ParamType]) -> None: + self.types: Sequence[ParamType] = [convert_type(ty) for ty in types] + + @property + def name(self) -> str: # type: ignore[override] + return f"<{' '.join(ty.name for ty in self.types)}>" + + @property + def arity(self) -> int: # type: ignore + return len(self.types) + + def convert( + self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] + ) -> Any: + len_type = len(self.types) + len_value = len(value) + + if len_value != len_type: + self.fail( + f"{len_type} values are required, but {len_value} given.", + param=param, + ctx=ctx, + ) + + return tuple( + ty(x, param, ctx) for ty, x in zip(self.types, value, strict=False) + ) + + +def convert_type(ty: Any | None, default: Any | None = None) -> ParamType: + """Find the most appropriate `ParamType` for the given Python + type. If the type isn't provided, it can be inferred from a default + value. + """ + guessed_type = False + + if ty is None and default is not None: + if isinstance(default, (tuple, list)): + # If the default is empty, ty will remain None and will + # return STRING. + if default: + item = default[0] + + # A tuple of tuples needs to detect the inner types. + # Can't call convert recursively because that would + # incorrectly unwind the tuple to a single type. + if isinstance(item, (tuple, list)): + ty = tuple(map(type, item)) + else: + ty = type(item) + else: + ty = type(default) + + guessed_type = True + + if isinstance(ty, tuple): + return Tuple(ty) + + if isinstance(ty, ParamType): + return ty + + if ty is str or ty is None: + return STRING + + if ty is int: + return INT + + if ty is float: + return FLOAT + + if ty is bool: + return BOOL + + if guessed_type: + return STRING + + return FuncParamType(ty) + + +# A unicode string parameter type which is the implicit default. This +# can also be selected by using ``str`` as type. +STRING = StringParamType() + +# An integer parameter. This can also be selected by using ``int`` as +# type. +INT = IntParamType() + +# A floating point value parameter. This can also be selected by using +# ``float`` as type. +FLOAT = FloatParamType() + +# A boolean parameter. This is the default for boolean flags. This can +# also be selected by using ``bool`` as a type. +BOOL = BoolParamType() + +# A UUID parameter. +UUID = UUIDParameterType() + + +class OptionHelpExtra(TypedDict, total=False): + envvars: tuple[str, ...] + default: str + range: str + required: str diff --git a/typer/_click/utils.py b/typer/_click/utils.py new file mode 100644 index 0000000000..ac8e5ba3f2 --- /dev/null +++ b/typer/_click/utils.py @@ -0,0 +1,470 @@ +import os +import re +import sys +from collections.abc import Callable, Iterable, Iterator +from functools import update_wrapper +from types import ModuleType, TracebackType +from typing import ( + IO, + Any, + AnyStr, + BinaryIO, + Literal, + ParamSpec, + TextIO, + TypeVar, + cast, +) + +from ._compat import ( + WIN, + _default_text_stderr, + _default_text_stdout, + _find_binary_writer, + auto_wrap_for_ansi, + binary_streams, + open_stream, + should_strip_ansi, + strip_ansi, + text_streams, +) +from .globals import resolve_color_default + +P = ParamSpec("P") +R = TypeVar("R") + + +def _posixify(name: str) -> str: + return "-".join(name.split()).lower() + + +def safecall(func: Callable[P, R]) -> Callable[P, R | None]: + """Wraps a function so that it swallows exceptions.""" + + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | None: + try: + return func(*args, **kwargs) + except Exception: # pragma: no cover + pass + return None # pragma: no cover + + return update_wrapper(wrapper, func) + + +def make_default_short_help(help: str, max_length: int = 45) -> str: + """Returns a condensed version of help string.""" + # Consider only the first paragraph. + paragraph_end = help.find("\n\n") + + if paragraph_end != -1: + help = help[:paragraph_end] + + # Collapse newlines, tabs, and spaces. + words = help.split() + + if not words: + return "" + + # The first paragraph started with a "no rewrap" marker, ignore it. + if words[0] == "\b": + words = words[1:] + + total_length = 0 + last_index = len(words) - 1 + + for i, word in enumerate(words): + total_length += len(word) + (i > 0) + + if total_length > max_length: # too long, truncate + break + + if word[-1] == ".": # sentence end, truncate without "..." + return " ".join(words[: i + 1]) + + if total_length == max_length and i != last_index: + break # not at sentence end, truncate with "..." + else: + return " ".join(words) # no truncation needed + + # Account for the length of the suffix. + total_length += len("...") + + # remove words until the length is short enough + while i > 0: + total_length -= len(words[i]) + (i > 0) + + if total_length <= max_length: + break + + i -= 1 + + return " ".join(words[:i]) + "..." + + +class LazyFile: + """A lazy file works like a regular file but it does not fully open + the file but it does perform some basic checks early to see if the + filename parameter does make sense. This is useful for safely opening + files for writing. + """ + + def __init__( + self, + filename: str | os.PathLike[str], + mode: str = "r", + encoding: str | None = None, + errors: str | None = "strict", + atomic: bool = False, + ): + self.name: str = os.fspath(filename) + self.mode = mode + self.encoding = encoding + self.errors = errors + self.atomic = atomic + self._f: IO[Any] | None + self.should_close: bool + + if self.name == "-": + self._f, self.should_close = open_stream(filename, mode, encoding, errors) + else: + if "r" in mode: + # Open and close the file in case we're opening it for + # reading so that we can catch at least some errors in + # some cases early. + open(filename, mode).close() + self._f = None + self.should_close = True + + def __getattr__(self, name: str) -> Any: + return getattr(self.open(), name) + + def __repr__(self) -> str: + if self._f is not None: + return repr(self._f) + return f"" + + def open(self) -> IO[Any]: + """Opens the file if it's not yet open. This call might fail with + a `FileError`. Not handling this error will produce an error + that Click shows. + """ + if self._f is not None: + return self._f + try: + rv, self.should_close = open_stream( + self.name, self.mode, self.encoding, self.errors, atomic=self.atomic + ) + except OSError as e: + from .exceptions import FileError + + raise FileError(self.name, hint=e.strerror) from e + self._f = rv + return rv + + def close(self) -> None: + """Closes the underlying file, no matter what.""" + if self._f is not None: + self._f.close() + + def close_intelligently(self) -> None: + """This function only closes the file if it was opened by the lazy + file wrapper. For instance this will never close stdin. + """ + if self.should_close: + self.close() + + def __enter__(self) -> "LazyFile": + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.close_intelligently() + + def __iter__(self) -> Iterator[AnyStr]: + self.open() + return iter(self._f) # type: ignore + + +def echo( + message: Any | None = None, + file: IO[Any] | None = None, + nl: bool = True, + err: bool = False, + color: bool | None = None, +) -> None: + """Print a message and newline to stdout or a file. This should be + used instead of `print` because it provides better support + for different data, files, and environments. + + Compared to `print`, this does the following: + + - Ensures that the output encoding is not misconfigured on Linux. + - Supports Unicode in the Windows console. + - Supports writing to binary outputs, and supports writing bytes + to text outputs. + - Supports colors and styles on Windows. + - Removes ANSI color and style codes if the output does not look + like an interactive terminal. + - Always flushes the output. + """ + if file is None: + if err: + file = _default_text_stderr() + else: + file = _default_text_stdout() + + # There are no standard streams attached to write to. For example, + # pythonw on Windows. + if file is None: + return + + # Convert non bytes/text into the native string type. + if message is not None and not isinstance(message, (str, bytes, bytearray)): + out: str | bytes | bytearray | None = str(message) + else: + out = message + + if nl: + out = out or "" + if isinstance(out, str): + out += "\n" + else: + out += b"\n" + + if not out: + file.flush() + return + + # If there is a message and the value looks like bytes, we manually + # need to find the binary stream and write the message in there. + # This is done separately so that most stream types will work as you + # would expect. Eg: you can write to StringIO for other cases. + if isinstance(out, (bytes, bytearray)): + binary_file = _find_binary_writer(file) + + if binary_file is not None: + file.flush() + binary_file.write(out) + binary_file.flush() + return + + # ANSI style code support. For no message or bytes, nothing happens. + # When outputting to a file instead of a terminal, strip codes. + else: + color = resolve_color_default(color) + + if should_strip_ansi(file, color): + out = strip_ansi(out) + elif WIN: + if auto_wrap_for_ansi is not None: + file = auto_wrap_for_ansi(file, color) # type: ignore[arg-type,call-arg] + elif not color: + out = strip_ansi(out) + + file.write(out) # type: ignore[arg-type] + file.flush() + + +def get_binary_stream(name: Literal["stdin", "stdout", "stderr"]) -> BinaryIO: + """Returns a system stream for byte processing.""" + opener = binary_streams.get(name) + if opener is None: + raise TypeError(f"Unknown standard stream '{name}'") + return opener() + + +def get_text_stream( + name: Literal["stdin", "stdout", "stderr"], + encoding: str | None = None, + errors: str | None = "strict", +) -> TextIO: + """Returns a system stream for text processing. This usually returns + a wrapped stream around a binary stream returned from + `get_binary_stream` but it also can take shortcuts for already + correctly configured streams. + """ + opener = text_streams.get(name) + if opener is None: + raise TypeError(f"Unknown standard stream '{name}'") + return opener(encoding, errors) + + +def format_filename( + filename: str | bytes | os.PathLike[str] | os.PathLike[bytes], + shorten: bool = False, +) -> str: + """Format a filename as a string for display. Ensures the filename can be + displayed by replacing any invalid bytes or surrogate escapes in the name + with the replacement character ``�``. + + Invalid bytes or surrogate escapes will raise an error when written to a + stream with ``errors="strict"``. This will typically happen with ``stdout`` + when the locale is something like ``en_GB.UTF-8``. + + Many scenarios *are* safe to write surrogates though, due to PEP 538 and + PEP 540, including: + + - Writing to ``stderr``, which uses ``errors="backslashreplace"``. + - The system has ``LANG=C.UTF-8``, ``C``, or ``POSIX``. Python opens + stdout and stderr with ``errors="surrogateescape"``. + - None of ``LANG/LC_*`` are set. Python assumes ``LANG=C.UTF-8``. + - Python is started in UTF-8 mode with ``PYTHONUTF8=1`` or ``-X utf8``. + Python opens stdout and stderr with ``errors="surrogateescape"``. + """ + if shorten: + filename = os.path.basename(filename) + else: + filename = os.fspath(filename) + + if isinstance(filename, bytes): + filename = filename.decode(sys.getfilesystemencoding(), "replace") + else: + filename = filename.encode("utf-8", "surrogateescape").decode( + "utf-8", "replace" + ) + + return filename + + +def get_app_dir(app_name: str, roaming: bool = True, force_posix: bool = False) -> str: + r"""Returns the config folder for the application. The default behavior + is to return whatever is most appropriate for the operating system. + + To give you an idea, for an app called ``"Foo Bar"``, something like + the following folders could be returned: + + Mac OS X: + ``~/Library/Application Support/Foo Bar`` + Mac OS X (POSIX): + ``~/.foo-bar`` + Unix: + ``~/.config/foo-bar`` + Unix (POSIX): + ``~/.foo-bar`` + Windows (roaming): + ``C:\Users\\AppData\Roaming\Foo Bar`` + Windows (not roaming): + ``C:\Users\\AppData\Local\Foo Bar`` + """ + if WIN: + key = "APPDATA" if roaming else "LOCALAPPDATA" + folder = os.environ.get(key) + if folder is None: + folder = os.path.expanduser("~") + return os.path.join(folder, app_name) + if force_posix: + return os.path.join(os.path.expanduser(f"~/.{_posixify(app_name)}")) + if sys.platform == "darwin": + return os.path.join( + os.path.expanduser("~/Library/Application Support"), app_name + ) + return os.path.join( + os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")), + _posixify(app_name), + ) + + +class PacifyFlushWrapper: + """This wrapper is used to catch and suppress BrokenPipeErrors resulting + from ``.flush()`` being called on broken pipe during the shutdown/final-GC + of the Python interpreter. Notably ``.flush()`` is always called on + ``sys.stdout`` and ``sys.stderr``. So as to have minimal impact on any + other cleanup code, and the case where the underlying file is not a broken + pipe, all calls and attributes are proxied. + """ + + def __init__(self, wrapped: IO[Any]) -> None: + self.wrapped = wrapped + + def flush(self) -> None: + try: + self.wrapped.flush() + except OSError as e: # pragma: no cover + import errno + + if e.errno != errno.EPIPE: + raise + + def __getattr__(self, attr: str) -> Any: + return getattr(self.wrapped, attr) + + +def _detect_program_name( + path: str | None = None, _main: ModuleType | None = None +) -> str: + """Determine the command used to run the program, for use in help + text. If a file or entry point was executed, the file name is + returned. If ``python -m`` was used to execute a module or package, + ``python -m name`` is returned. + + This doesn't try to be too precise, the goal is to give a concise + name for help text. Files are only shown as their name without the + path. ``python`` is only shown for modules, and the full path to + ``sys.executable`` is not shown. + """ + if _main is None: + _main = sys.modules["__main__"] + + if not path: + path = sys.argv[0] + + # The value of __package__ indicates how Python was called. It may + # not exist if a setuptools script is installed as an egg. It may be + # set incorrectly for entry points created with pip on Windows. + # It is set to "" inside a Shiv or PEX zipapp. + if getattr(_main, "__package__", None) in {None, ""} or ( + os.name == "nt" + and _main.__package__ == "" + and not os.path.exists(path) + and os.path.exists(f"{path}.exe") + ): + # Executed a file, like "python app.py". + return os.path.basename(path) + + # Executed a module, like "python -m example". + # Rewritten by Python from "-m script" to "/path/to/script.py". + # Need to look at main module to determine how it was executed. + py_module = cast(str, _main.__package__) + name = os.path.splitext(os.path.basename(path))[0] + + # A submodule like "example.cli". + if name != "__main__": + py_module = f"{py_module}.{name}" + + return f"python -m {py_module.lstrip('.')}" + + +def _expand_args( + args: Iterable[str], + *, + user: bool = True, + env: bool = True, + glob_recursive: bool = True, +) -> list[str]: + """Simulate Unix shell expansion with Python functions.""" + from glob import glob + + out = [] + + for arg in args: + if user: + arg = os.path.expanduser(arg) + + if env: + arg = os.path.expandvars(arg) + + try: + matches = glob(arg, recursive=glob_recursive) + except re.error: # pragma: no cover + matches = [] + + if not matches: + out.append(arg) + else: + out.extend(matches) + + return out diff --git a/typer/_completion_classes.py b/typer/_completion_classes.py index 8548fb4d6a..cfae02c9bf 100644 --- a/typer/_completion_classes.py +++ b/typer/_completion_classes.py @@ -4,11 +4,9 @@ import sys from typing import Any -import click -import click.parser -import click.shell_completion -from click.shell_completion import split_arg_string as click_split_arg_string - +from . import _click +from ._click.shell_completion import CompletionItem, ShellComplete, add_completion_class +from ._click.shell_completion import split_arg_string as click_split_arg_string from ._completion_shared import ( COMPLETION_SCRIPT_BASH, COMPLETION_SCRIPT_FISH, @@ -27,7 +25,7 @@ def _sanitize_help_text(text: str) -> str: return rich_utils.rich_render_text(text) -class BashComplete(click.shell_completion.BashComplete): +class BashComplete(ShellComplete): name = Shells.bash.value source_template = COMPLETION_SCRIPT_BASH @@ -50,7 +48,7 @@ def get_completion_args(self) -> tuple[list[str], str]: return args, incomplete - def format_completion(self, item: click.shell_completion.CompletionItem) -> str: + def format_completion(self, item: CompletionItem) -> str: # TODO: Explore replicating the new behavior from Click, with item types and # triggering completion for files and directories # return f"{item.type},{item.value}" @@ -62,8 +60,42 @@ def complete(self) -> str: out = [self.format_completion(item) for item in completions] return "\n".join(out) + @staticmethod + def _check_version() -> None: + import shutil + import subprocess + + bash_exe = shutil.which("bash") + + if bash_exe is None: + match = None # pragma: no cover + else: + output = subprocess.run( + [bash_exe, "--norc", "-c", 'echo "${BASH_VERSION}"'], + stdout=subprocess.PIPE, + ) + match = re.search(r"^(\d+)\.(\d+)\.\d+", output.stdout.decode()) + + if match is not None: + major, minor = match.groups() + + if major < "4" or major == "4" and minor < "4": + _click.utils.echo( + "Shell completion is not supported for Bash versions older than 4.4.", + err=True, + ) + else: + _click.utils.echo( + "Couldn't detect Bash version, shell completion is not supported.", + err=True, + ) # pragma: no cover + + def source(self) -> str: + self._check_version() + return super().source() + -class ZshComplete(click.shell_completion.ZshComplete): +class ZshComplete(ShellComplete): name = Shells.zsh.value source_template = COMPLETION_SCRIPT_ZSH @@ -85,7 +117,7 @@ def get_completion_args(self) -> tuple[list[str], str]: incomplete = "" return args, incomplete - def format_completion(self, item: click.shell_completion.CompletionItem) -> str: + def format_completion(self, item: CompletionItem) -> str: def escape(s: str) -> str: return ( s.replace('"', '""') @@ -114,7 +146,7 @@ def complete(self) -> str: return "_files" -class FishComplete(click.shell_completion.FishComplete): +class FishComplete(ShellComplete): name = Shells.fish.value source_template = COMPLETION_SCRIPT_FISH @@ -136,7 +168,7 @@ def get_completion_args(self) -> tuple[list[str], str]: incomplete = "" return args, incomplete - def format_completion(self, item: click.shell_completion.CompletionItem) -> str: + def format_completion(self, item: CompletionItem) -> str: # TODO: Explore replicating the new behavior from Click, pay attention to # the difference with and without formatted help # if item.help: @@ -167,7 +199,7 @@ def complete(self) -> str: return "" # pragma: no cover -class PowerShellComplete(click.shell_completion.ShellComplete): +class PowerShellComplete(ShellComplete): name = Shells.powershell.value source_template = COMPLETION_SCRIPT_POWER_SHELL @@ -185,15 +217,13 @@ def get_completion_args(self) -> tuple[list[str], str]: args = cwords[1:-1] if incomplete else cwords[1:] return args, incomplete - def format_completion(self, item: click.shell_completion.CompletionItem) -> str: + def format_completion(self, item: CompletionItem) -> str: return f"{item.value}:::{_sanitize_help_text(item.help) if item.help else ' '}" def completion_init() -> None: - click.shell_completion.add_completion_class(BashComplete, Shells.bash.value) - click.shell_completion.add_completion_class(ZshComplete, Shells.zsh.value) - click.shell_completion.add_completion_class(FishComplete, Shells.fish.value) - click.shell_completion.add_completion_class( - PowerShellComplete, Shells.powershell.value - ) - click.shell_completion.add_completion_class(PowerShellComplete, Shells.pwsh.value) + add_completion_class(BashComplete, Shells.bash.value) + add_completion_class(ZshComplete, Shells.zsh.value) + add_completion_class(FishComplete, Shells.fish.value) + add_completion_class(PowerShellComplete, Shells.powershell.value) + add_completion_class(PowerShellComplete, Shells.pwsh.value) diff --git a/typer/_completion_shared.py b/typer/_completion_shared.py index 5a81dcf68c..8d2c19715c 100644 --- a/typer/_completion_shared.py +++ b/typer/_completion_shared.py @@ -4,9 +4,11 @@ from enum import Enum from pathlib import Path -import click import shellingham +from . import _click +from ._click.globals import get_current_context + class Shells(str, Enum): bash = "bash" @@ -78,8 +80,8 @@ def get_completion_script(*, prog_name: str, complete_var: str, shell: str) -> s cf_name = _invalid_ident_char_re.sub("", prog_name.replace("-", "_")) script = _completion_scripts.get(shell) if script is None: - click.echo(f"Shell {shell} not supported.", err=True) - raise click.exceptions.Exit(1) + _click.echo(f"Shell {shell} not supported.", err=True) + raise _click.exceptions.Exit(1) return ( script % { @@ -172,8 +174,8 @@ def install_powershell(*, prog_name: str, complete_var: str, shell: str) -> Path stdout=subprocess.PIPE, ) if result.returncode != 0: # pragma: no cover - click.echo("Couldn't get PowerShell user profile", err=True) - raise click.exceptions.Exit(result.returncode) + _click.echo("Couldn't get PowerShell user profile", err=True) + raise _click.exceptions.Exit(result.returncode) path_str = "" if isinstance(result.stdout, str): # pragma: no cover path_str = result.stdout @@ -185,8 +187,8 @@ def install_powershell(*, prog_name: str, complete_var: str, shell: str) -> Path except UnicodeDecodeError: # pragma: no cover pass if not path_str: # pragma: no cover - click.echo("Couldn't decode the path automatically", err=True) - raise click.exceptions.Exit(1) + _click.echo("Couldn't decode the path automatically", err=True) + raise _click.exceptions.Exit(1) path_obj = Path(path_str.strip()) parent_dir: Path = path_obj.parent parent_dir.mkdir(parents=True, exist_ok=True) @@ -203,7 +205,7 @@ def install( prog_name: str | None = None, complete_var: str | None = None, ) -> tuple[str, Path]: - prog_name = prog_name or click.get_current_context().find_root().info_name + prog_name = prog_name or get_current_context().find_root().info_name assert prog_name if complete_var is None: complete_var = "_{}_COMPLETE".format(prog_name.replace("-", "_").upper()) @@ -231,8 +233,8 @@ def install( ) return shell, installed_path else: - click.echo(f"Shell {shell} is not supported.") - raise click.exceptions.Exit(1) + _click.echo(f"Shell {shell} is not supported.") + raise _click.exceptions.Exit(1) def _get_shell_name() -> str | None: diff --git a/typer/_types.py b/typer/_types.py index dc9fc63220..09b38afb3f 100644 --- a/typer/_types.py +++ b/typer/_types.py @@ -1,21 +1,42 @@ +from collections.abc import Iterable, Mapping, Sequence from enum import Enum -from typing import TypeVar +from typing import Any, Generic, TypeVar -import click +from . import _click +from ._click import types +from ._click.shell_completion import CompletionItem ParamTypeValue = TypeVar("ParamTypeValue") -class TyperChoice(click.Choice[ParamTypeValue]): +class TyperChoice(types.ParamType, Generic[ParamTypeValue]): + # Code adapted from Click 8.3.1, with Typer using enum values in normalize_choice + name = "choice" + + def __init__( + self, choices: Iterable[ParamTypeValue], case_sensitive: bool = True + ) -> None: + self.choices: Sequence[ParamTypeValue] = tuple(choices) + self.case_sensitive = case_sensitive + + def _normalized_mapping( + self, ctx: _click.Context | None = None + ) -> Mapping[ParamTypeValue, str]: + """ + Returns mapping where keys are the original choices and the values are + the normalized values that are accepted via the command line. + """ + return { + choice: self.normalize_choice( + choice=choice, + ctx=ctx, + ) + for choice in self.choices + } + def normalize_choice( - self, choice: ParamTypeValue, ctx: click.Context | None + self, choice: ParamTypeValue, ctx: _click.Context | None ) -> str: - # Click 8.2.0 added a new method `normalize_choice` to the `Choice` class - # to support enums, but it uses the enum names, while Typer has always used the - # enum values. - # This class overrides that method to maintain the previous behavior. - # In Click: - # normed_value = choice.name if isinstance(choice, Enum) else str(choice) normed_value = str(choice.value) if isinstance(choice, Enum) else str(choice) if ctx is not None and ctx.token_normalize_func is not None: @@ -25,3 +46,75 @@ def normalize_choice( normed_value = normed_value.casefold() return normed_value + + def get_metavar(self, param: _click.Parameter, ctx: _click.Context) -> str | None: + if param.param_type_name == "option" and not param.show_choices: # type: ignore + choice_metavars = [ + types.convert_type(type(choice)).name.upper() for choice in self.choices + ] + choices_str = "|".join([*dict.fromkeys(choice_metavars)]) + else: + choices_str = "|".join( + [str(i) for i in self._normalized_mapping(ctx=ctx).values()] + ) + + # Use curly braces to indicate a required argument. + if param.required and param.param_type_name == "argument": + return f"{{{choices_str}}}" + + # Use square braces to indicate an option or optional argument. + return f"[{choices_str}]" + + def get_missing_message( + self, param: _click.Parameter, ctx: _click.Context | None + ) -> str: + """Message shown when no choice is passed.""" + choices = ",\n\t".join(self._normalized_mapping(ctx=ctx).values()) + return f"Choose from:\n\t{choices}" + + def convert( + self, value: Any, param: _click.Parameter | None, ctx: _click.Context | None + ) -> ParamTypeValue: + """ + For a given value from the parser, normalize it and find its + matching normalized value in the list of choices. Then return the + matched "original" choice. + """ + normed_value = self.normalize_choice(choice=value, ctx=ctx) + normalized_mapping = self._normalized_mapping(ctx=ctx) + + try: + return next( + original + for original, normalized in normalized_mapping.items() + if normalized == normed_value + ) + except StopIteration: + self.fail( + self.get_invalid_choice_message(value=value, ctx=ctx), + param=param, + ctx=ctx, + ) + + def get_invalid_choice_message(self, value: Any, ctx: _click.Context | None) -> str: + """Get the error message when the given choice is invalid.""" + choices_str = ", ".join(map(repr, self._normalized_mapping(ctx=ctx).values())) + return f"{value!r} is not one of {choices_str}." + + def __repr__(self) -> str: + return f"Choice({list(self.choices)})" + + def shell_complete( + self, ctx: _click.Context, param: _click.Parameter, incomplete: str + ) -> list[CompletionItem]: + """Complete choices that start with the incomplete value.""" + + str_choices = map(str, self.choices) + + if self.case_sensitive: + matched = (c for c in str_choices if c.startswith(incomplete)) + else: + incomplete = incomplete.lower() + matched = (c for c in str_choices if c.lower().startswith(incomplete)) + + return [CompletionItem(c) for c in matched] diff --git a/typer/cli.py b/typer/cli.py index 2a7d78c3a4..665bcf5a59 100644 --- a/typer/cli.py +++ b/typer/cli.py @@ -4,13 +4,12 @@ from pathlib import Path from typing import Any -import click import typer import typer.core -from click import Command, Group, Option -from . import __version__ -from .core import HAS_RICH, MARKUP_MODE_KEY +from . import __version__, _click +from ._click import Command +from .core import HAS_RICH, MARKUP_MODE_KEY, TyperGroup, TyperOption default_app_names = ("app", "cli", "main") default_func_names = ("main", "cli", "app") @@ -31,7 +30,7 @@ def __init__(self) -> None: state = State() -def maybe_update_state(ctx: click.Context) -> None: +def maybe_update_state(ctx: _click.Context) -> None: path_or_module = ctx.params.get("path_or_module") if path_or_module: file_path = Path(path_or_module) @@ -53,19 +52,19 @@ def maybe_update_state(ctx: click.Context) -> None: class TyperCLIGroup(typer.core.TyperGroup): - def list_commands(self, ctx: click.Context) -> list[str]: + def list_commands(self, ctx: _click.Context) -> list[str]: self.maybe_add_run(ctx) return super().list_commands(ctx) - def get_command(self, ctx: click.Context, name: str) -> Command | None: # ty: ignore[invalid-method-override] + def get_command(self, ctx: _click.Context, name: str) -> Command | None: # ty: ignore[invalid-method-override] self.maybe_add_run(ctx) return super().get_command(ctx, name) - def invoke(self, ctx: click.Context) -> Any: + def invoke(self, ctx: _click.Context) -> Any: self.maybe_add_run(ctx) return super().invoke(ctx) - def maybe_add_run(self, ctx: click.Context) -> None: + def maybe_add_run(self, ctx: _click.Context) -> None: maybe_update_state(ctx) maybe_add_run_to_cli(self) @@ -138,7 +137,7 @@ def get_typer_from_state() -> typer.Typer | None: return obj -def maybe_add_run_to_cli(cli: click.Group) -> None: +def maybe_add_run_to_cli(cli: TyperGroup) -> None: if "run" not in cli.commands: if state.file or state.module: obj = get_typer_from_state() @@ -151,7 +150,7 @@ def maybe_add_run_to_cli(cli: click.Group) -> None: cli.add_command(click_obj) -def print_version(ctx: click.Context, param: Option, value: bool) -> None: +def print_version(ctx: _click.Context, param: TyperOption, value: bool) -> None: if not value or ctx.resilient_parsing: return typer.echo(f"Typer version: {__version__}") @@ -242,7 +241,7 @@ def get_docs_for_click( docs += "\n" if obj.epilog: docs += f"{obj.epilog}\n\n" - if isinstance(obj, Group): + if isinstance(obj, TyperGroup): group = obj commands = group.list_commands(ctx) if commands: diff --git a/typer/completion.py b/typer/completion.py index 0d621e411d..f63692ddf3 100644 --- a/typer/completion.py +++ b/typer/completion.py @@ -3,8 +3,8 @@ from collections.abc import MutableMapping from typing import Any -import click - +from . import _click +from ._click import shell_completion from ._completion_classes import completion_init from ._completion_shared import Shells, _get_shell_name, get_completion_script, install from .models import ParamMeta @@ -27,19 +27,19 @@ def get_completion_inspect_parameters() -> tuple[ParamMeta, ParamMeta]: return install_param, show_param -def install_callback(ctx: click.Context, param: click.Parameter, value: Any) -> Any: +def install_callback(ctx: _click.Context, param: _click.Parameter, value: Any) -> Any: if not value or ctx.resilient_parsing: return value # pragma: no cover if isinstance(value, str): shell, path = install(shell=value) else: shell, path = install() - click.secho(f"{shell} completion installed in {path}", fg="green") - click.echo("Completion will take effect once you restart the terminal") + _click.termui.secho(f"{shell} completion installed in {path}", fg="green") + _click.echo("Completion will take effect once you restart the terminal") sys.exit(0) -def show_callback(ctx: click.Context, param: click.Parameter, value: Any) -> Any: +def show_callback(ctx: _click.Context, param: _click.Parameter, value: Any) -> Any: if not value or ctx.resilient_parsing: return value # pragma: no cover prog_name = ctx.find_root().info_name @@ -56,7 +56,7 @@ def show_callback(ctx: click.Context, param: click.Parameter, value: Any) -> Any script_content = get_completion_script( prog_name=prog_name, complete_var=complete_var, shell=shell ) - click.echo(script_content) + _click.echo(script_content) sys.exit(0) @@ -103,17 +103,16 @@ def _install_completion_no_auto_placeholder_function( # And to add extra error messages, for compatibility with Typer in previous versions # This is only called in new Command method, only used by Click 8.x+ def shell_complete( - cli: click.Command, + cli: _click.Command, ctx_args: MutableMapping[str, Any], prog_name: str, complete_var: str, instruction: str, ) -> int: - import click - import click.shell_completion + from . import _click if "_" not in instruction: - click.echo("Invalid completion instruction.", err=True) + _click.echo("Invalid completion instruction.", err=True) return 1 # Click 8 changed the order/style of shell instructions from e.g. @@ -124,23 +123,23 @@ def shell_complete( instruction, _, shell = instruction.partition("_") # Typer override end - comp_cls = click.shell_completion.get_completion_class(shell) + comp_cls = shell_completion.get_completion_class(shell) if comp_cls is None: - click.echo(f"Shell {shell} not supported.", err=True) + _click.echo(f"Shell {shell} not supported.", err=True) return 1 comp = comp_cls(cli, ctx_args, prog_name, complete_var) if instruction == "source": - click.echo(comp.source()) + _click.echo(comp.source()) return 0 # Typer override to print the completion help msg with Rich if instruction == "complete": - click.echo(comp.complete()) + _click.echo(comp.complete()) return 0 # Typer override end - click.echo(f'Completion instruction "{instruction}" not supported.', err=True) + _click.echo(f'Completion instruction "{instruction}" not supported.', err=True) return 1 diff --git a/typer/core.py b/typer/core.py index 48fee64e34..edd9370eb2 100644 --- a/typer/core.py +++ b/typer/core.py @@ -2,7 +2,7 @@ import inspect import os import sys -from collections.abc import Callable, MutableMapping, Sequence +from collections.abc import Callable, Mapping, MutableMapping, Sequence from difflib import get_close_matches from enum import Enum from gettext import gettext as _ @@ -13,13 +13,10 @@ cast, ) -import click -import click.core -import click.formatting -import click.shell_completion -import click.types -import click.utils - +from . import _click +from ._click import types +from ._click.parser import _OptionParser +from ._click.shell_completion import CompletionItem from ._typing import Literal from .utils import parse_boolean_env_var @@ -34,7 +31,7 @@ DEFAULT_MARKUP_MODE = None -# Copy from click.parser._split_opt +# Copy from _click.parser._split_opt def _split_opt(opt: str) -> tuple[str, str]: first = opt[:1] if first.isalnum(): @@ -45,10 +42,10 @@ def _split_opt(opt: str) -> tuple[str, str]: def _typer_param_setup_autocompletion_compat( - self: click.Parameter, + self: _click.Parameter, *, autocompletion: Callable[ - [click.Context, list[str], str], list[tuple[str, str] | str] + [_click.Context, list[str], str], list[tuple[str, str] | str] ] | None = None, ) -> None: @@ -65,10 +62,8 @@ def _typer_param_setup_autocompletion_compat( if autocompletion is not None: def compat_autocompletion( - ctx: click.Context, param: click.core.Parameter, incomplete: str - ) -> list["click.shell_completion.CompletionItem"]: - from click.shell_completion import CompletionItem - + ctx: _click.Context, param: _click.core.Parameter, incomplete: str + ) -> list[CompletionItem]: out = [] for c in autocompletion(ctx, [], incomplete): @@ -89,11 +84,11 @@ def compat_autocompletion( def _get_default_string( obj: Union["TyperArgument", "TyperOption"], *, - ctx: click.Context, + ctx: _click.Context, show_default_is_str: bool, default_value: list[Any] | tuple[Any, ...] | str | Callable[..., Any] | Any, ) -> str: - # Extracted from click.core.Option.get_help_record() to be reused by + # Extracted from _click.core.Option.get_help_record() to be reused by # rich_utils avoiding RegEx hacks if show_default_is_str: default_string = f"({obj.show_default})" @@ -112,7 +107,7 @@ def _get_default_string( # For boolean flags that have distinct True/False opts, # use the opt without prefix instead of the value. # Typer override, original commented - # default_string = click.parser.split_opt( + # default_string = _click.parser.split_opt( # (self.opts if self.default else self.secondary_opts)[0] # )[1] if obj.default: @@ -136,9 +131,9 @@ def _get_default_string( def _extract_default_help_str( - obj: Union["TyperArgument", "TyperOption"], *, ctx: click.Context + obj: Union["TyperArgument", "TyperOption"], *, ctx: _click.Context ) -> Any | Callable[[], Any] | None: - # Extracted from click.core.Option.get_help_record() to be reused by + # Extracted from _click.core.Option.get_help_record() to be reused by # rich_utils avoiding RegEx hacks # Temporarily enable resilient parsing to avoid type casting # failing for the default. Might be possible to extend this to @@ -154,7 +149,7 @@ def _extract_default_help_str( def _main( - self: click.Command, + self: _click.Command, *, args: Sequence[str] | None = None, prog_name: str | None = None, @@ -164,7 +159,7 @@ def _main( rich_markup_mode: MarkupMode = DEFAULT_MARKUP_MODE, **extra: Any, ) -> Any: - # Typer override, duplicated from click.main() to handle custom rich exceptions + # Typer override, duplicated from _click.main() to handle custom rich exceptions # Verify that the environment is configured correctly, or reject # further execution to avoid a broken script. if args is None: @@ -172,12 +167,12 @@ def _main( # Covered in Click tests if os.name == "nt" and windows_expand_args: # pragma: no cover - args = click.utils._expand_args(args) + args = _click.utils._expand_args(args) else: args = list(args) if prog_name is None: - prog_name = click.utils._detect_program_name() + prog_name = _click.utils._detect_program_name() # Process shell completion requests and exit early. self._main_shell_completion(extra, prog_name, complete_var) @@ -197,11 +192,11 @@ def _main( # by its truthiness/falsiness ctx.exit() except EOFError as e: - click.echo(file=sys.stderr) - raise click.Abort() from e + _click.echo(file=sys.stderr) + raise _click.exceptions.Abort() from e except KeyboardInterrupt as e: - raise click.exceptions.Exit(130) from e - except click.ClickException as e: + raise _click.exceptions.Exit(130) from e + except _click.exceptions.ClickException as e: if not standalone_mode: raise # Typer override @@ -215,12 +210,12 @@ def _main( sys.exit(e.exit_code) except OSError as e: if e.errno == errno.EPIPE: - sys.stdout = cast(TextIO, click.utils.PacifyFlushWrapper(sys.stdout)) - sys.stderr = cast(TextIO, click.utils.PacifyFlushWrapper(sys.stderr)) + sys.stdout = cast(TextIO, _click.utils.PacifyFlushWrapper(sys.stdout)) + sys.stderr = cast(TextIO, _click.utils.PacifyFlushWrapper(sys.stderr)) sys.exit(1) else: raise - except click.exceptions.Exit as e: + except _click.exceptions.Exit as e: if standalone_mode: sys.exit(e.exit_code) else: @@ -233,7 +228,7 @@ def _main( # `ctx.exit(1)` and to `return 1`, the caller won't be able to # tell the difference between the two return e.exit_code - except click.Abort: + except _click.exceptions.Abort: if not standalone_mode: raise # Typer override @@ -242,19 +237,21 @@ def _main( rich_utils.rich_abort_error() else: - click.echo(_("Aborted!"), file=sys.stderr) + _click.echo(_("Aborted!"), file=sys.stderr) # Typer override end sys.exit(1) -class TyperArgument(click.core.Argument): +class TyperArgument(_click.core.Parameter): + param_type_name = "argument" + def __init__( self, *, # Parameter param_decls: list[str], type: Any | None = None, - required: bool | None = None, + required: bool = False, default: Any | None = None, callback: Callable[..., Any] | None = None, nargs: int | None = None, @@ -265,8 +262,8 @@ def __init__( # Note that shell_complete is not fully supported and will be removed in future versions # TODO: Remove shell_complete in a future version (after 0.16.0) shell_complete: Callable[ - [click.Context, click.Parameter, str], - list["click.shell_completion.CompletionItem"] | list[str], + [_click.Context, _click.Parameter, str], + list[CompletionItem] | list[str], ] | None = None, autocompletion: Callable[..., Any] | None = None, @@ -301,10 +298,17 @@ def __init__( ) _typer_param_setup_autocompletion_compat(self, autocompletion=autocompletion) + @property + def human_readable_name(self) -> str: + if self.metavar is not None: + return self.metavar + assert self.name is not None, "self.name or self.metavar should be set" + return self.name.upper() + def _get_default_string( self, *, - ctx: click.Context, + ctx: _click.Context, show_default_is_str: bool, default_value: list[Any] | tuple[Any, ...] | str | Callable[..., Any] | Any, ) -> str: @@ -316,12 +320,12 @@ def _get_default_string( ) def _extract_default_help_str( - self, *, ctx: click.Context + self, *, ctx: _click.Context ) -> Any | Callable[[], Any] | None: return _extract_default_help_str(self, ctx=ctx) - def get_help_record(self, ctx: click.Context) -> tuple[str, str] | None: - # Modified version of click.core.Option.get_help_record() + def get_help_record(self, ctx: _click.Context) -> tuple[str, str] | None: + # Modified version of _click.core.Option.get_help_record() # to support Arguments if self.hidden: return None @@ -376,8 +380,8 @@ def get_help_record(self, ctx: click.Context) -> tuple[str, str] | None: help = f"{help} {extra_str}" if help else f"{extra_str}" return name, help - def make_metavar(self, ctx: click.Context) -> str: - # Modified version of click.core.Argument.make_metavar() + def make_metavar(self, ctx: _click.Context) -> str: + # Modified version of _click.core.Argument.make_metavar() # to include Argument name if self.metavar is not None: var = self.metavar @@ -397,15 +401,45 @@ def make_metavar(self, ctx: click.Context) -> str: def value_is_missing(self, value: Any) -> bool: return _value_is_missing(self, value) + def _parse_decls( + self, decls: Sequence[str], expose_value: bool + ) -> tuple[str | None, list[str], list[str]]: + if not decls: + if not expose_value: + return None, [], [] + raise TypeError("Argument is marked as exposed, but does not have a name.") + if len(decls) == 1: + name = arg = decls[0] + name = name.replace("-", "_").lower() + else: + raise TypeError( + "Arguments take exactly one parameter declaration, got" + f" {len(decls)}: {decls}." + ) + return name, [arg], [] + + def get_usage_pieces(self, ctx: _click.Context) -> list[str]: + return [self.make_metavar(ctx)] + + def get_error_hint(self, ctx: _click.Context) -> str: + return f"'{self.make_metavar(ctx)}'" + + def add_to_parser(self, parser: _OptionParser, ctx: _click.Context) -> None: + parser.add_argument(dest=self.name, nargs=self.nargs, obj=self) + + +class TyperOption(_click.Parameter): + param_type_name = "option" + + _depr_flag_value: bool | None -class TyperOption(click.core.Option): def __init__( self, *, # Parameter param_decls: list[str], - type: click.types.ParamType | Any | None = None, - required: bool | None = None, + type: types.ParamType | Any | None = None, + required: bool = False, default: Any | None = None, callback: Callable[..., Any] | None = None, nargs: int | None = None, @@ -416,8 +450,8 @@ def __init__( # Note that shell_complete is not fully supported and will be removed in future versions # TODO: Remove shell_complete in a future version (after 0.16.0) shell_complete: Callable[ - [click.Context, click.Parameter, str], - list["click.shell_completion.CompletionItem"] | list[str], + [_click.Context, _click.Parameter, str], + list[CompletionItem] | list[str], ] | None = None, autocompletion: Callable[..., Any] | None = None, @@ -438,9 +472,13 @@ def __init__( # Rich settings rich_help_panel: str | None = None, ): + if help: + help = inspect.cleandoc(help) + super().__init__( - param_decls=param_decls, + param_decls, type=type, + multiple=multiple, required=required, default=default, callback=callback, @@ -449,28 +487,240 @@ def __init__( expose_value=expose_value, is_eager=is_eager, envvar=envvar, - show_default=show_default, - prompt=prompt, - confirmation_prompt=confirmation_prompt, - hide_input=hide_input, - is_flag=is_flag, - multiple=multiple, - count=count, - allow_from_autoenv=allow_from_autoenv, - help=help, - hidden=hidden, - show_choices=show_choices, - show_envvar=show_envvar, - prompt_required=prompt_required, shell_complete=shell_complete, ) + + if prompt is True: + if self.name is None: + raise TypeError("'name' is required with 'prompt=True'.") + + prompt_text: str | None = self.name.replace("_", " ").capitalize() + elif prompt is False: + prompt_text = None + else: + prompt_text = prompt + + self.prompt = prompt_text + self.confirmation_prompt = confirmation_prompt + self.prompt_required = prompt_required + self.hide_input = hide_input + self.hidden = hidden + + # TODO: revisit all of this flag stuff + if is_flag and type is None: + self.type: types.ParamType = types.BoolParamType() + + self.is_flag: bool = bool(is_flag) + self.is_bool_flag: bool = bool( + is_flag and isinstance(self.type, types.BoolParamType) + ) + + if self.is_flag: + self._depr_flag_value = True + else: + self._depr_flag_value = None + + # Counting. TODO: test or remove? Not currently in coverage. + self.count = count + if count and type is None: + self.type = types.IntRange(min=0) + + self.allow_from_autoenv = allow_from_autoenv + self.help = help + self.show_default = show_default + self.show_choices = show_choices + self.show_envvar = show_envvar + _typer_param_setup_autocompletion_compat(self, autocompletion=autocompletion) self.rich_help_panel = rich_help_panel + def get_error_hint(self, ctx: _click.Context) -> str: + result = super().get_error_hint(ctx) + if self.show_envvar and self.envvar is not None: + result += f" (env var: '{self.envvar}')" + return result + + def _parse_decls( + self, decls: Sequence[str], expose_value: bool + ) -> tuple[str | None, list[str], list[str]]: + opts = [] + secondary_opts = [] + name = None + possible_names = [] + + for decl in decls: + if decl.isidentifier(): + if name is not None: + raise TypeError(f"Name '{name}' defined twice") + name = decl + else: + split_char = ";" if decl[:1] == "/" else "/" + if split_char in decl: + first, second = decl.split(split_char, 1) + first = first.rstrip() + if first: + possible_names.append(_split_opt(first)) + opts.append(first) + second = second.lstrip() + if second: + secondary_opts.append(second.lstrip()) + if first == second: + raise ValueError( + f"Boolean option {decl!r} cannot use the" + " same flag for true/false." + ) + else: + possible_names.append(_split_opt(decl)) + opts.append(decl) + + if name is None and possible_names: + possible_names.sort(key=lambda x: -len(x[0])) # group long options first + name = possible_names[0][1].replace("-", "_").lower() + if not name.isidentifier(): + name = None + + return name, opts, secondary_opts + + def add_to_parser(self, parser: _OptionParser, ctx: _click.Context) -> None: + if self.multiple: + action = "append" + elif self.count: + action = "count" + else: + action = "store" + + if self.is_flag: + action = f"{action}_const" + + if self.is_bool_flag and self.secondary_opts: + parser.add_option( + obj=self, opts=self.opts, dest=self.name, action=action, const=True + ) + parser.add_option( + obj=self, + opts=self.secondary_opts, + dest=self.name, + action=action, + const=False, + ) + else: + parser.add_option( + obj=self, + opts=self.opts, + dest=self.name, + action=action, + const=self._depr_flag_value, + ) + else: + parser.add_option( + obj=self, + opts=self.opts, + dest=self.name, + action=action, + nargs=self.nargs, + ) + + def prompt_for_value(self, ctx: _click.Context) -> Any: + """This is an alternative flow that can be activated in the full + value processing if a value does not exist. It will prompt the + user until a valid value exists and then returns the processed + value as result. + """ + assert self.prompt is not None + + # Calculate the default before prompting anything to lock in the value before + # attempting any user interaction. + default = self.get_default(ctx) + + # A boolean flag can use a simplified [y/n] confirmation prompt. + if self.is_bool_flag: + # Nothing prevent you to declare an option that is simultaneously: + # 1) auto-detected as a boolean flag, + # 2) allowed to prompt, and + # 3) still declare a non-boolean default. + # This forced casting into a boolean is necessary to align any non-boolean + # default to the prompt, which is going to be a [y/n]-style confirmation + # because the option is still a boolean flag. That way, instead of [y/n], + # we get [Y/n] or [y/N] depending on the truthy value of the default. + # Refs: https://github.com/pallets/click/pull/3030#discussion_r2289180249 + if default is not None: + default = bool(default) + return _click.termui.confirm(self.prompt, default) + + # If show_default is set to True/False, provide this to `prompt` as well. For + # non-bool values of `show_default`, we use `prompt`'s default behavior + prompt_kwargs: Any = {} + if isinstance(self.show_default, bool): + prompt_kwargs["show_default"] = self.show_default + + return _click.termui.prompt( + self.prompt, + # Use ``None`` to inform the prompt() function to reiterate until a valid + # value is provided by the user if we have no default. + default=default, + type=self.type, + hide_input=self.hide_input, + show_choices=self.show_choices, + confirmation_prompt=self.confirmation_prompt, + value_proc=lambda x: self.process_value(ctx, x), + **prompt_kwargs, + ) + + def value_from_envvar(self, ctx: _click.Context) -> Any: + rv = self.resolve_envvar_value(ctx) + + # Absent environment variable or an empty string is interpreted as unset. + if rv is None: + return None + + def resolve_envvar_value(self, ctx: _click.Context) -> str | None: + rv = super().resolve_envvar_value(ctx) + + if rv is not None: + return rv + + if ( + self.allow_from_autoenv + and ctx.auto_envvar_prefix is not None + and self.name is not None + ): + envvar = f"{ctx.auto_envvar_prefix}_{self.name.upper()}" + rv = os.environ.get(envvar) + + if rv: + return rv + + return None + + def consume_value( + self, ctx: _click.Context, opts: Mapping[str, _click.Parameter] + ) -> tuple[Any, _click.core.ParameterSource]: + """For `Option`, the value can be collected from an interactive prompt + if the option is a flag that needs a value (and the `prompt` property is + set). + + Additionally, this method handles flag option that are activated without a + value, in which case the `flag_value` is returned. + """ + value, source = super().consume_value(ctx, opts) + + # The value wasn't set, or used the param's default, prompt for one to the user + # if prompting is enabled. + if ( + source in {None, _click.core.ParameterSource.DEFAULT} + and self.prompt is not None + and (self.required or self.prompt_required) + and not ctx.resilient_parsing + ): + value = self.prompt_for_value(ctx) + source = _click.core.ParameterSource.PROMPT + + return value, source + def _get_default_string( self, *, - ctx: click.Context, + ctx: _click.Context, show_default_is_str: bool, default_value: list[Any] | tuple[Any, ...] | str | Callable[..., Any] | Any, ) -> str: @@ -482,14 +732,14 @@ def _get_default_string( ) def _extract_default_help_str( - self, *, ctx: click.Context + self, *, ctx: _click.Context ) -> Any | Callable[[], Any] | None: return _extract_default_help_str(self, ctx=ctx) - def make_metavar(self, ctx: click.Context) -> str: + def make_metavar(self, ctx: _click.Context) -> str: return super().make_metavar(ctx=ctx) - def get_help_record(self, ctx: click.Context) -> tuple[str, str] | None: + def get_help_record(self, ctx: _click.Context) -> tuple[str, str] | None: # Duplicate all of Click's logic only to modify a single line, to allow boolean # flags with only names for False values as it's currently supported by Typer # Ref: https://typer.tiangolo.com/tutorial/parameter-types/bool/#only-names-for-false @@ -501,7 +751,7 @@ def get_help_record(self, ctx: click.Context) -> tuple[str, str] | None: def _write_opts(opts: Sequence[str]) -> str: nonlocal any_prefix_is_slash - rv, any_slashes = click.formatting.join_options(opts) + rv, any_slashes = _click.formatting.join_options(opts) if any_slashes: any_prefix_is_slash = True @@ -559,7 +809,7 @@ def _write_opts(opts: Sequence[str]) -> str: if default_string: extra.append(_("default: {default}").format(default=default_string)) - if isinstance(self.type, click.types._NumberRangeBase): + if isinstance(self.type, types._NumberRangeBase): range_str = self.type._describe_range() if range_str: @@ -588,14 +838,10 @@ def value_is_missing(self, value: Any) -> bool: return _value_is_missing(self, value) -def _value_is_missing(param: click.Parameter, value: Any) -> bool: +def _value_is_missing(param: _click.Parameter, value: Any) -> bool: if value is None: return True - # Click 8.3 and beyond - # if value is UNSET: - # return True - if (param.nargs != 1 or param.multiple) and value == (): return True # pragma: no cover @@ -603,7 +849,7 @@ def _value_is_missing(param: click.Parameter, value: Any) -> bool: def _typer_format_options( - self: click.core.Command, *, ctx: click.Context, formatter: click.HelpFormatter + self: _click.core.Command, *, ctx: _click.Context, formatter: _click.HelpFormatter ) -> None: args = [] opts = [] @@ -624,7 +870,7 @@ def _typer_format_options( def _typer_main_shell_completion( - self: click.core.Command, + self: _click.core.Command, *, ctx_args: MutableMapping[str, Any], prog_name: str, @@ -644,14 +890,14 @@ def _typer_main_shell_completion( sys.exit(rv) -class TyperCommand(click.core.Command): +class TyperCommand(_click.core.Command): def __init__( self, name: str | None, *, context_settings: dict[str, Any] | None = None, callback: Callable[..., Any] | None = None, - params: list[click.Parameter] | None = None, + params: list[_click.Parameter] | None = None, help: str | None = None, epilog: str | None = None, short_help: str | None = None, @@ -682,7 +928,7 @@ def __init__( self.rich_help_panel = rich_help_panel def format_options( - self, ctx: click.Context, formatter: click.HelpFormatter + self, ctx: _click.Context, formatter: _click.HelpFormatter ) -> None: _typer_format_options(self, ctx=ctx, formatter=formatter) @@ -716,7 +962,7 @@ def main( **extra, ) - def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + def format_help(self, ctx: _click.Context, formatter: _click.HelpFormatter) -> None: if not HAS_RICH or self.rich_markup_mode is None: if not hasattr(ctx, "obj") or ctx.obj is None: ctx.ensure_object(dict) @@ -732,25 +978,151 @@ def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> Non ) -class TyperGroup(click.core.Group): +class TyperGroup(_click.Command): + allow_extra_args = True + allow_interspersed_args = False + command_class: type[_click.Command] | None = None + group_class: type["TyperGroup"] | type[type] | None = None + def __init__( self, *, name: str | None = None, - commands: dict[str, click.Command] | Sequence[click.Command] | None = None, + commands: dict[str, _click.Command] | Sequence[_click.Command] | None = None, # Rich settings rich_markup_mode: MarkupMode = DEFAULT_MARKUP_MODE, rich_help_panel: str | None = None, suggest_commands: bool = True, + # Click settings + invoke_without_command: bool = False, + no_args_is_help: bool = False, + subcommand_metavar: str | None = None, + result_callback: Callable[..., Any] | None = None, **attrs: Any, ) -> None: - super().__init__(name=name, commands=commands, **attrs) + super().__init__(name=name, **attrs) self.rich_markup_mode: MarkupMode = rich_markup_mode self.rich_help_panel = rich_help_panel self.suggest_commands = suggest_commands + # copied from Click's init + if commands is None: + commands = {} + elif isinstance(commands, Sequence): + commands = { + c.name: c + for c in commands + if isinstance(c, _click.Command) and c.name is not None + } + + self.commands = cast(MutableMapping[str, _click.Command], commands) + self.no_args_is_help = no_args_is_help + self.invoke_without_command = invoke_without_command + + if subcommand_metavar is None: + subcommand_metavar = "COMMAND [ARGS]..." + + self.subcommand_metavar = subcommand_metavar + self._result_callback = result_callback + + def add_command(self, cmd: _click.Command, name: str | None = None) -> None: + name = name or cmd.name + if name is None: + raise TypeError("Command has no name.") + self.commands[name] = cmd + + def get_command(self, ctx: _click.Context, cmd_name: str) -> _click.Command | None: + return self.commands.get(cmd_name) + + def collect_usage_pieces(self, ctx: _click.Context) -> list[str]: + rv = super().collect_usage_pieces(ctx) + rv.append(self.subcommand_metavar) + return rv + + def format_commands( + self, ctx: _click.Context, formatter: _click.HelpFormatter + ) -> None: + commands = [] + for subcommand in self.list_commands(ctx): + cmd = self.get_command(ctx, subcommand) + + commands.append((subcommand, cmd)) + + # allow for 3 times the default spacing + if len(commands): + limit = formatter.width - 6 - max(len(cmd[0]) for cmd in commands) + + rows = [] + for subcommand, cmd in commands: + assert cmd is not None + help = cmd.get_short_help_str(limit) + rows.append((subcommand, help)) + + if rows: + with formatter.section(_("Commands")): + formatter.write_dl(rows) + + def parse_args(self, ctx: _click.Context, args: list[str]) -> list[str]: + if not args and self.no_args_is_help and not ctx.resilient_parsing: + raise _click.exceptions.NoArgsIsHelpError(ctx) + + rest = super().parse_args(ctx, args) + + if rest: + ctx._protected_args, ctx.args = rest[:1], rest[1:] + + return ctx.args + + def invoke(self, ctx: _click.Context) -> Any: + def _process_result(value: Any) -> Any: + if self._result_callback is not None: + value = ctx.invoke(self._result_callback, value, **ctx.params) + return value + + if not ctx._protected_args: + if self.invoke_without_command: + # No subcommand was invoked, so the result callback is + # invoked with the group return value for regular + # groups, or an empty list for chained groups. + with ctx: + rv = super().invoke(ctx) + # return _process_result([] if self.chain else rv) + return _process_result(rv) + ctx.fail(_("Missing command.")) + + # Fetch args back out + args = [*ctx._protected_args, *ctx.args] + ctx.args = [] + ctx._protected_args = [] + + # Make sure the context is entered so we do not clean up + # resources until the result processor has worked. + with ctx: + cmd_name, cmd, args = self.resolve_command(ctx, args) + assert cmd is not None + ctx.invoked_subcommand = cmd_name + super().invoke(ctx) + sub_ctx = cmd.make_context(cmd_name, args, parent=ctx) + with sub_ctx: + return _process_result(sub_ctx.command.invoke(sub_ctx)) + + def shell_complete( + self, ctx: _click.Context, incomplete: str + ) -> list[CompletionItem]: + """Return a list of completions for the incomplete value. Looks + at the names of options, subcommands, and chained + multi-commands. + """ + + results = [ + CompletionItem(name, help=command.get_short_help_str()) + for name, command in _click.core._complete_visible_commands(ctx, incomplete) + ] + results.extend(super().shell_complete(ctx, incomplete)) + return results + def format_options( - self, ctx: click.Context, formatter: click.HelpFormatter + self, ctx: _click.Context, formatter: _click.HelpFormatter ) -> None: _typer_format_options(self, ctx=ctx, formatter=formatter) self.format_commands(ctx, formatter) @@ -765,12 +1137,31 @@ def _main_shell_completion( self, ctx_args=ctx_args, prog_name=prog_name, complete_var=complete_var ) + def _click_resolve_command( + self, ctx: _click.Context, args: list[str] + ) -> tuple[str | None, _click.Command | None, list[str]]: + cmd_name = args[0] + original_cmd_name = cmd_name + + # Get the command + cmd = self.get_command(ctx, cmd_name) + + if cmd is None and ctx.token_normalize_func is not None: + cmd_name = ctx.token_normalize_func(cmd_name) + cmd = self.get_command(ctx, cmd_name) + + if cmd is None and not ctx.resilient_parsing: + if _split_opt(cmd_name)[0]: + self.parse_args(ctx, args) + ctx.fail(_("No such command {name!r}.").format(name=original_cmd_name)) + return cmd_name if cmd else None, cmd, args[1:] + def resolve_command( - self, ctx: click.Context, args: list[str] - ) -> tuple[str | None, click.Command | None, list[str]]: + self, ctx: _click.Context, args: list[str] + ) -> tuple[str | None, _click.Command | None, list[str]]: try: - return super().resolve_command(ctx, args) - except click.UsageError as e: + return self._click_resolve_command(ctx, args) + except _click.exceptions.UsageError as e: if self.suggest_commands: available_commands = list(self.commands.keys()) if available_commands and args: @@ -802,7 +1193,7 @@ def main( **extra, ) - def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + def format_help(self, ctx: _click.Context, formatter: _click.HelpFormatter) -> None: if not HAS_RICH or self.rich_markup_mode is None: return super().format_help(ctx, formatter) from . import rich_utils @@ -813,8 +1204,6 @@ def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> Non markup_mode=self.rich_markup_mode, ) - def list_commands(self, ctx: click.Context) -> list[str]: - """Returns a list of subcommand names. - Note that in Click's Group class, these are sorted. - In Typer, we wish to maintain the original order of creation (cf Issue #933)""" + def list_commands(self, ctx: _click.Context) -> list[str]: + """Returns a list of subcommand names, maintaining the original order of creation (cf Issue #933)""" return [n for n, c in self.commands.items()] diff --git a/typer/main.py b/typer/main.py index ebcf639a2c..43c7a921f7 100644 --- a/typer/main.py +++ b/typer/main.py @@ -15,10 +15,12 @@ from typing import Annotated, Any from uuid import UUID -import click from annotated_doc import Doc from typer._types import TyperChoice +from . import _click +from ._click import types +from ._click.globals import get_current_context from ._typing import get_args, get_origin, is_literal_type, is_union, literal_values from .completion import get_completion_inspect_parameters from .core import ( @@ -73,7 +75,7 @@ def except_hook( _original_except_hook(exc_type, exc_value, tb) return typer_path = os.path.dirname(__file__) - click_path = os.path.dirname(click.__file__) + click_path = os.path.dirname(_click.__file__) internal_dir_names = [typer_path, click_path] exc = exc_value if HAS_RICH: @@ -107,7 +109,7 @@ def except_hook( return -def get_install_completion_arguments() -> tuple[click.Parameter, click.Parameter]: +def get_install_completion_arguments() -> tuple[_click.Parameter, _click.Parameter]: install_param, show_param = get_completion_inspect_parameters() click_install_param, _ = get_click_param(install_param) click_show_param, _ = get_click_param(show_param) @@ -1168,7 +1170,7 @@ def get_group(typer_instance: Typer) -> TyperGroup: return group -def get_command(typer_instance: Typer) -> click.Command: +def get_command(typer_instance: Typer) -> _click.Command: if typer_instance._add_completion: click_install_param, click_show_param = get_install_completion_arguments() if ( @@ -1178,7 +1180,7 @@ def get_command(typer_instance: Typer) -> click.Command: or len(typer_instance.registered_commands) > 1 ): # Create a Group - click_command: click.Command = get_group(typer_instance) + click_command: _click.Command = get_group(typer_instance) if typer_instance._add_completion: click_command.params.append(click_install_param) click_command.params.append(click_show_param) @@ -1288,7 +1290,7 @@ def get_group_from_info( assert group_info.typer_instance, ( "A Typer instance is needed to generate a Click Group" ) - commands: dict[str, click.Command] = {} + commands: dict[str, _click.Command] = {} for command_info in group_info.typer_instance.registered_commands: command = get_command_from_info( command_info=command_info, @@ -1330,7 +1332,6 @@ def get_group_from_info( invoke_without_command=solved_info.invoke_without_command, no_args_is_help=solved_info.no_args_is_help, subcommand_metavar=solved_info.subcommand_metavar, - chain=solved_info.chain, result_callback=solved_info.result_callback, context_settings=solved_info.context_settings, callback=get_callback( @@ -1362,14 +1363,14 @@ def get_command_name(name: str) -> str: def get_params_convertors_ctx_param_name_from_function( callback: Callable[..., Any] | None, -) -> tuple[list[click.Argument | click.Option], dict[str, Any], str | None]: +) -> tuple[list[TyperArgument | TyperOption], dict[str, Any], str | None]: params = [] convertors = {} context_param_name = None if callback: parameters = get_params_from_function(callback) for param_name, param in parameters.items(): - if lenient_issubclass(param.annotation, click.Context): + if lenient_issubclass(param.annotation, _click.Context): context_param_name = param_name continue click_param, convertor = get_click_param(param) @@ -1384,7 +1385,7 @@ def get_command_from_info( *, pretty_exceptions_short: bool, rich_markup_mode: MarkupMode, -) -> click.Command: +) -> _click.Command: assert command_info.callback, "A command must have a callback function" name = command_info.name or get_command_name(command_info.callback.__name__) # ty: ignore use_help = command_info.help @@ -1486,7 +1487,7 @@ def internal_convertor( def get_callback( *, callback: Callable[..., Any] | None = None, - params: Sequence[click.Parameter] = [], + params: Sequence[_click.Parameter] = [], convertors: dict[str, Callable[[str], Any]] | None = None, context_param_name: str | None = None, pretty_exceptions_short: bool, @@ -1510,7 +1511,7 @@ def wrapper(**kwargs: Any) -> Any: else: use_params[k] = v if context_param_name: - use_params[context_param_name] = click.get_current_context() + use_params[context_param_name] = get_current_context() return callback(**use_params) update_wrapper(wrapper, callback) @@ -1519,15 +1520,15 @@ def wrapper(**kwargs: Any) -> Any: def get_click_type( *, annotation: Any, parameter_info: ParameterInfo -) -> click.ParamType: +) -> types.ParamType: if parameter_info.click_type is not None: return parameter_info.click_type elif parameter_info.parser is not None: - return click.types.FuncParamType(parameter_info.parser) + return types.FuncParamType(parameter_info.parser) elif annotation is str: - return click.STRING + return types.STRING elif annotation is int: if parameter_info.min is not None or parameter_info.max is not None: min_ = None @@ -1536,24 +1537,24 @@ def get_click_type( min_ = int(parameter_info.min) if parameter_info.max is not None: max_ = int(parameter_info.max) - return click.IntRange(min=min_, max=max_, clamp=parameter_info.clamp) + return types.IntRange(min=min_, max=max_, clamp=parameter_info.clamp) else: - return click.INT + return types.INT elif annotation is float: if parameter_info.min is not None or parameter_info.max is not None: - return click.FloatRange( + return types.FloatRange( min=parameter_info.min, max=parameter_info.max, clamp=parameter_info.clamp, ) else: - return click.FLOAT + return types.FLOAT elif annotation is bool: - return click.BOOL + return types.BOOL elif annotation == UUID: - return click.UUID + return types.UUID elif annotation == datetime: - return click.DateTime(formats=parameter_info.formats) + return types.DateTime(formats=parameter_info.formats) elif ( annotation == Path or parameter_info.allow_dash @@ -1571,7 +1572,7 @@ def get_click_type( path_type=parameter_info.path_type, ) elif lenient_issubclass(annotation, FileTextWrite): - return click.File( + return types.File( mode=parameter_info.mode or "w", encoding=parameter_info.encoding, errors=parameter_info.errors, @@ -1579,7 +1580,7 @@ def get_click_type( atomic=parameter_info.atomic, ) elif lenient_issubclass(annotation, FileText): - return click.File( + return types.File( mode=parameter_info.mode or "r", encoding=parameter_info.encoding, errors=parameter_info.errors, @@ -1587,7 +1588,7 @@ def get_click_type( atomic=parameter_info.atomic, ) elif lenient_issubclass(annotation, FileBinaryRead): - return click.File( + return types.File( mode=parameter_info.mode or "rb", encoding=parameter_info.encoding, errors=parameter_info.errors, @@ -1595,7 +1596,7 @@ def get_click_type( atomic=parameter_info.atomic, ) elif lenient_issubclass(annotation, FileBinaryWrite): - return click.File( + return types.File( mode=parameter_info.mode or "wb", encoding=parameter_info.encoding, errors=parameter_info.errors, @@ -1603,17 +1604,12 @@ def get_click_type( atomic=parameter_info.atomic, ) elif lenient_issubclass(annotation, Enum): - # The custom TyperChoice is only needed for Click < 8.2.0, to parse the - # command line values matching them to the enum values. Click 8.2.0 added - # support for enum values but reading enum names. - # Passing here the list of enum values (instead of just the enum) accounts for - # Click < 8.2.0. return TyperChoice( [item.value for item in annotation], case_sensitive=parameter_info.case_sensitive, ) elif is_literal_type(annotation): - return click.Choice( + return TyperChoice( literal_values(annotation), case_sensitive=parameter_info.case_sensitive, ) @@ -1626,7 +1622,7 @@ def lenient_issubclass(cls: Any, class_or_tuple: AnyType | tuple[AnyType, ...]) def get_click_param( param: ParamMeta, -) -> tuple[click.Argument | click.Option, Any]: +) -> tuple[TyperArgument | TyperOption, Any]: # First, find out what will be: # * ParamInfo (ArgumentInfo or OptionInfo) # * default_value @@ -1784,7 +1780,7 @@ def get_click_param( ), convertor, ) - raise AssertionError("A click.Parameter should be returned") # pragma: no cover + raise AssertionError("A _click.Parameter should be returned") # pragma: no cover def get_param_callback( @@ -1800,9 +1796,9 @@ def get_param_callback( value_name = None untyped_names: list[str] = [] for param_name, param_sig in parameters.items(): - if lenient_issubclass(param_sig.annotation, click.Context): + if lenient_issubclass(param_sig.annotation, _click.Context): ctx_name = param_name - elif lenient_issubclass(param_sig.annotation, click.Parameter): + elif lenient_issubclass(param_sig.annotation, _click.Parameter): click_param_name = param_name else: untyped_names.append(param_name) @@ -1817,11 +1813,11 @@ def get_param_callback( if untyped_names: click_param_name = untyped_names.pop(0) if untyped_names: - raise click.ClickException( + raise _click.ClickException( "Too many CLI parameter callback function parameters" ) - def wrapper(ctx: click.Context, param: click.Parameter, value: Any) -> Any: + def wrapper(ctx: _click.Context, param: _click.Parameter, value: Any) -> Any: use_params: dict[str, Any] = {} if ctx_name: use_params[ctx_name] = ctx @@ -1851,7 +1847,7 @@ def get_param_completion( unassigned_params = list(parameters.values()) for param_sig in unassigned_params[:]: origin = get_origin(param_sig.annotation) - if lenient_issubclass(param_sig.annotation, click.Context): + if lenient_issubclass(param_sig.annotation, _click.Context): ctx_name = param_sig.name unassigned_params.remove(param_sig) elif lenient_issubclass(origin, list): @@ -1874,11 +1870,11 @@ def get_param_completion( # Extract value param name first if unassigned_params: show_params = " ".join([param.name for param in unassigned_params]) - raise click.ClickException( + raise _click.ClickException( f"Invalid autocompletion callback parameters: {show_params}" ) - def wrapper(ctx: click.Context, args: list[str], incomplete: str | None) -> Any: + def wrapper(ctx: _click.Context, args: list[str], incomplete: str | None) -> Any: use_params: dict[str, Any] = {} if ctx_name: use_params[ctx_name] = ctx @@ -2010,4 +2006,4 @@ def launch( return 0 else: - return click.launch(url, wait=wait, locate=locate) + return _click.launch(url, wait=wait, locate=locate) diff --git a/typer/models.py b/typer/models.py index 3285a96a24..00385c38ce 100644 --- a/typer/models.py +++ b/typer/models.py @@ -1,15 +1,20 @@ import inspect import io +import os +import stat from collections.abc import Callable, Sequence from typing import ( TYPE_CHECKING, Any, + ClassVar, Optional, TypeVar, + cast, ) -import click -import click.shell_completion +from . import _click +from ._click import types +from ._click.shell_completion import CompletionItem if TYPE_CHECKING: # pragma: no cover from .core import TyperCommand, TyperGroup @@ -23,7 +28,7 @@ Required = ... -class Context(click.Context): +class Context(_click.Context): """ The [`Context`](https://click.palletsprojects.com/en/stable/api/#click.Context) has some additional data about the current execution of your program. When declaring it in a [callback](https://typer.tiangolo.com/tutorial/options/callback-and-context/) function, @@ -153,7 +158,7 @@ def main(file: Annotated[typer.FileBinaryWrite, typer.Option()]): pass -class CallbackParam(click.Parameter): +class CallbackParam(_click.Parameter): """ In a callback function, you can declare a function parameter with type `CallbackParam` to access the specific Click [`Parameter`](https://click.palletsprojects.com/en/stable/api/#click.Parameter) object. @@ -286,15 +291,15 @@ def __init__( # Note that shell_complete is not fully supported and will be removed in future versions # TODO: Remove shell_complete in a future version (after 0.16.0) shell_complete: Callable[ - [click.Context, click.Parameter, str], - list["click.shell_completion.CompletionItem"] | list[str], + [_click.Context, _click.Parameter, str], + list["CompletionItem"] | list[str], ] | None = None, autocompletion: Callable[..., Any] | None = None, default_factory: Callable[[], Any] | None = None, # Custom type parser: Callable[[str], Any] | None = None, - click_type: click.ParamType | None = None, + click_type: types.ParamType | None = None, # TyperArgument show_default: bool | str = True, show_choices: bool = True, @@ -395,15 +400,15 @@ def __init__( # Note that shell_complete is not fully supported and will be removed in future versions # TODO: Remove shell_complete in a future version (after 0.16.0) shell_complete: Callable[ - [click.Context, click.Parameter, str], - list["click.shell_completion.CompletionItem"] | list[str], + [_click.Context, _click.Parameter, str], + list["CompletionItem"] | list[str], ] | None = None, autocompletion: Callable[..., Any] | None = None, default_factory: Callable[[], Any] | None = None, # Custom type parser: Callable[[str], Any] | None = None, - click_type: click.ParamType | None = None, + click_type: types.ParamType | None = None, # Option show_default: bool | str = True, prompt: bool | str = False, @@ -523,15 +528,15 @@ def __init__( # Note that shell_complete is not fully supported and will be removed in future versions # TODO: Remove shell_complete in a future version (after 0.16.0) shell_complete: Callable[ - [click.Context, click.Parameter, str], - list["click.shell_completion.CompletionItem"] | list[str], + [_click.Context, _click.Parameter, str], + list["CompletionItem"] | list[str], ] | None = None, autocompletion: Callable[..., Any] | None = None, default_factory: Callable[[], Any] | None = None, # Custom type parser: Callable[[str], Any] | None = None, - click_type: click.ParamType | None = None, + click_type: types.ParamType | None = None, # TyperArgument show_default: bool | str = True, show_choices: bool = True, @@ -640,11 +645,98 @@ def __init__( self.pretty_exceptions_short = pretty_exceptions_short -class TyperPath(click.Path): - # Overwrite Click's behaviour to be compatible with Typer's autocompletion system +class TyperPath(types.ParamType): + # Based originally on code from Click 8.3.1 + # Partly rewritten and added an override for shell_complete + + envvar_list_splitter: ClassVar[str] = os.path.pathsep + + def __init__( + self, + exists: bool = False, + file_okay: bool = True, + dir_okay: bool = True, + writable: bool = False, + readable: bool = True, + resolve_path: bool = False, + allow_dash: bool = False, + path_type: type[Any] | None = None, + ): + self.exists = exists + self.file_okay = file_okay + self.dir_okay = dir_okay + self.readable = readable + self.writable = writable + self.resolve_path = resolve_path + self.allow_dash = allow_dash + self.type = path_type + + if self.file_okay and not self.dir_okay: + self.name = "file" + elif self.dir_okay and not self.file_okay: + self.name = "directory" + else: + self.name = "path" + + def coerce_path_result( + self, value: str | os.PathLike[str] + ) -> str | bytes | os.PathLike[str]: + if self.type is not None and not isinstance(value, self.type): + if ( + self.type is str + ): # pragma: no cover # TODO: perhaps this branch can't be hit and should be removed + return os.fsdecode(value) + elif self.type is bytes: + return os.fsencode(value) + else: + return cast("os.PathLike[str]", self.type(value)) + + return value + + def convert( # ty: ignore[invalid-method-override] + self, + value: str | os.PathLike[str], + param: _click.Parameter | None, + ctx: Context | None, # type: ignore[override] + ) -> str | bytes | os.PathLike[str]: + rv = value + + is_dash = self.file_okay and self.allow_dash and rv in (b"-", "-") + + if not is_dash: + if self.resolve_path: + rv = os.path.realpath(rv) + + try: + st = os.stat(rv) + except OSError: + if not self.exists: + return self.coerce_path_result(rv) + self.fail( + f"{self.name.title()} {_click.utils.format_filename(value)!r} does not exist.", + param, + ctx, + ) + + name = self.name.title() + loc = repr(_click.utils.format_filename(value)) + if not self.file_okay and stat.S_ISREG(st.st_mode): + self.fail(f"{name} {loc} is a file.", param, ctx) + + if not self.dir_okay and stat.S_ISDIR(st.st_mode): + self.fail(f"{name} {loc} is a directory.", param, ctx) + + if self.readable and not os.access(rv, os.R_OK): + self.fail(f"{name} {loc} is not readable.", param, ctx) + + if self.writable and not os.access(rv, os.W_OK): + self.fail(f"{name} {loc} is not writable.", param, ctx) + + return self.coerce_path_result(rv) + def shell_complete( - self, ctx: click.Context, param: click.Parameter, incomplete: str - ) -> list[click.shell_completion.CompletionItem]: + self, ctx: _click.Context, param: _click.Parameter, incomplete: str + ) -> list[CompletionItem]: """Return an empty list so that the autocompletion functionality will work properly from the commandline. """ diff --git a/typer/params.py b/typer/params.py index b325b273c4..833461fa78 100644 --- a/typer/params.py +++ b/typer/params.py @@ -1,13 +1,15 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Annotated, Any, overload -import click from annotated_doc import Doc +from . import _click +from ._click import types +from ._click.shell_completion import CompletionItem from .models import ArgumentInfo, OptionInfo if TYPE_CHECKING: # pragma: no cover - import click.shell_completion + pass # Overload for Option created with custom type 'parser' @@ -24,8 +26,8 @@ def Option( # Note that shell_complete is not fully supported and will be removed in future versions # TODO: Remove shell_complete in a future version (after 0.16.0) shell_complete: Callable[ - [click.Context, click.Parameter, str], - list["click.shell_completion.CompletionItem"] | list[str], + [_click.Context, _click.Parameter, str], + list["CompletionItem"] | list[str], ] | None = None, autocompletion: Callable[..., Any] | None = None, @@ -89,14 +91,14 @@ def Option( # Note that shell_complete is not fully supported and will be removed in future versions # TODO: Remove shell_complete in a future version (after 0.16.0) shell_complete: Callable[ - [click.Context, click.Parameter, str], - list["click.shell_completion.CompletionItem"] | list[str], + [_click.Context, _click.Parameter, str], + list["CompletionItem"] | list[str], ] | None = None, autocompletion: Callable[..., Any] | None = None, default_factory: Callable[[], Any] | None = None, # Custom type - click_type: click.ParamType | None = None, + click_type: types.ParamType | None = None, # Option show_default: bool | str = True, prompt: bool | str = False, @@ -265,8 +267,8 @@ def main(user: Annotated[str, typer.Option(envvar="ME")]): # TODO: Remove shell_complete in a future version (after 0.16.0) shell_complete: Annotated[ Callable[ - [click.Context, click.Parameter, str], - list["click.shell_completion.CompletionItem"] | list[str], + [_click.Context, _click.Parameter, str], + list["CompletionItem"] | list[str], ] | None, Doc( @@ -343,7 +345,7 @@ def main(opt: Annotated[CustomClass, typer.Option(parser=my_parser)] = "Foo"): ), ] = None, click_type: Annotated[ - click.ParamType | None, + types.ParamType | None, Doc( """ Define this parameter to use a [custom Click type](https://click.palletsprojects.com/en/stable/parameters/#implementing-custom-types) in your Typer applications. @@ -1014,8 +1016,8 @@ def Argument( # Note that shell_complete is not fully supported and will be removed in future versions # TODO: Remove shell_complete in a future version (after 0.16.0) shell_complete: Callable[ - [click.Context, click.Parameter, str], - list["click.shell_completion.CompletionItem"] | list[str], + [_click.Context, _click.Parameter, str], + list["CompletionItem"] | list[str], ] | None = None, autocompletion: Callable[..., Any] | None = None, @@ -1070,14 +1072,14 @@ def Argument( # Note that shell_complete is not fully supported and will be removed in future versions # TODO: Remove shell_complete in a future version (after 0.16.0) shell_complete: Callable[ - [click.Context, click.Parameter, str], - list["click.shell_completion.CompletionItem"] | list[str], + [_click.Context, _click.Parameter, str], + list["CompletionItem"] | list[str], ] | None = None, autocompletion: Callable[..., Any] | None = None, default_factory: Callable[[], Any] | None = None, # Custom type - click_type: click.ParamType | None = None, + click_type: types.ParamType | None = None, # TyperArgument show_default: bool | str = True, show_choices: bool = True, @@ -1219,8 +1221,8 @@ def main(name: Annotated[str, typer.Argument(envvar="ME")]): # TODO: Remove shell_complete in a future version (after 0.16.0) shell_complete: Annotated[ Callable[ - [click.Context, click.Parameter, str], - list["click.shell_completion.CompletionItem"] | list[str], + [_click.Context, _click.Parameter, str], + list["CompletionItem"] | list[str], ] | None, Doc( @@ -1297,7 +1299,7 @@ def main(arg: Annotated[CustomClass, typer.Argument(parser=my_parser): ), ] = None, click_type: Annotated[ - click.ParamType | None, + types.ParamType | None, Doc( """ Define this parameter to use a [custom Click type](https://click.palletsprojects.com/en/stable/parameters/#implementing-custom-types) in your Typer applications. diff --git a/typer/rich_utils.py b/typer/rich_utils.py index 69be631207..e5b106126e 100644 --- a/typer/rich_utils.py +++ b/typer/rich_utils.py @@ -8,7 +8,6 @@ from os import getenv from typing import Any, Literal -import click from rich import box from rich.align import Align from rich.columns import Columns @@ -25,6 +24,10 @@ from rich.traceback import Traceback from typer.models import DeveloperExceptionConfig +from . import _click +from ._click import types +from .core import TyperArgument, TyperGroup, TyperOption + # Default styles STYLE_OPTION = "bold cyan" STYLE_SWITCH = "bold green" @@ -184,7 +187,7 @@ def _make_rich_text( @group() def _get_help_text( *, - obj: click.Command | click.Group, + obj: _click.Command | TyperGroup, markup_mode: MarkupModeStrict, ) -> Iterable[Markdown | Text]: """Build primary help text for a click command or group. @@ -231,8 +234,8 @@ def _get_help_text( def _get_parameter_help( *, - param: click.Option | click.Argument | click.Parameter, - ctx: click.Context, + param: TyperOption | TyperArgument | _click.Parameter, + ctx: _click.Context, markup_mode: MarkupModeStrict, ) -> Columns: """Build primary help text for a click option or argument. @@ -348,8 +351,8 @@ def _make_command_help( def _print_options_panel( *, name: str, - params: list[click.Option] | list[click.Argument], - ctx: click.Context, + params: list[TyperOption] | list[TyperArgument], + ctx: _click.Context, markup_mode: MarkupModeStrict, console: Console, ) -> None: @@ -377,7 +380,7 @@ def _print_options_panel( metavar_str = param.make_metavar(ctx=ctx) # Do it ourselves if this is a positional argument if ( - isinstance(param, click.Argument) + isinstance(param, TyperArgument) and param.name and metavar_str == param.name.upper() ): @@ -391,8 +394,8 @@ def _print_options_panel( # https://github.com/pallets/click/blob/c63c70dabd3f86ca68678b4f00951f78f52d0270/src/click/core.py#L2698-L2706 # noqa: E501 # skip count with default range type if ( - isinstance(param.type, click.types._NumberRangeBase) - and isinstance(param, click.Option) + isinstance(param.type, types._NumberRangeBase) + and isinstance(param, TyperOption) and not (param.count and param.type.min == 0 and param.type.max is None) ): range_str = param.type._describe_range() @@ -459,7 +462,7 @@ def _print_options_panel( def _print_commands_panel( *, name: str, - commands: list[click.Command], + commands: list[_click.Command], markup_mode: MarkupModeStrict, console: Console, cmd_len: int, @@ -534,8 +537,8 @@ def _print_commands_panel( def rich_format_help( *, - obj: click.Command | click.Group, - ctx: click.Context, + obj: _click.Command | TyperGroup, + ctx: _click.Context, markup_mode: MarkupModeStrict, ) -> None: """Print nicely formatted help text using rich. @@ -568,18 +571,18 @@ def rich_format_help( (0, 1, 1, 1), ) ) - panel_to_arguments: defaultdict[str, list[click.Argument]] = defaultdict(list) - panel_to_options: defaultdict[str, list[click.Option]] = defaultdict(list) + panel_to_arguments: defaultdict[str, list[TyperArgument]] = defaultdict(list) + panel_to_options: defaultdict[str, list[TyperOption]] = defaultdict(list) for param in obj.get_params(ctx): # Skip if option is hidden if getattr(param, "hidden", False): continue - if isinstance(param, click.Argument): + if isinstance(param, TyperArgument): panel_name = ( getattr(param, _RICH_HELP_PANEL_NAME, None) or ARGUMENTS_PANEL_TITLE ) panel_to_arguments[panel_name].append(param) - elif isinstance(param, click.Option): + elif isinstance(param, TyperOption): panel_name = ( getattr(param, _RICH_HELP_PANEL_NAME, None) or OPTIONS_PANEL_TITLE ) @@ -623,8 +626,8 @@ def rich_format_help( console=console, ) - if isinstance(obj, click.Group): - panel_to_commands: defaultdict[str, list[click.Command]] = defaultdict(list) + if isinstance(obj, TyperGroup): + panel_to_commands: defaultdict[str, list[_click.Command]] = defaultdict(list) for command_name in obj.list_commands(ctx): command = obj.get_command(ctx, command_name) if command and not command.hidden: @@ -674,18 +677,18 @@ def rich_format_help( console.print(Padding(Align(epilogue_text, pad=False), 1)) -def rich_format_error(self: click.ClickException) -> None: +def rich_format_error(self: _click.ClickException) -> None: """Print richly formatted click errors. Called by custom exception handler to print richly formatted click errors. - Mimics original click.ClickException.echo() function but with rich formatting. + Mimics original _click.ClickException.echo() function but with rich formatting. """ # Don't do anything when it's a NoArgsIsHelpError (without importing it, cf. #1278) if self.__class__.__name__ == "NoArgsIsHelpError": return console = _get_rich_console(stderr=True) - ctx: click.Context | None = getattr(self, "ctx", None) + ctx: _click.Context | None = getattr(self, "ctx", None) if ctx is not None: console.print(ctx.get_usage()) diff --git a/typer/testing.py b/typer/testing.py index 09711e66fd..6035867662 100644 --- a/typer/testing.py +++ b/typer/testing.py @@ -1,11 +1,12 @@ from collections.abc import Mapping, Sequence from typing import IO, Any -from click.testing import CliRunner as ClickCliRunner # noqa -from click.testing import Result from typer.main import Typer from typer.main import get_command as _get_command +from ._click.testing import CliRunner as ClickCliRunner # noqa +from ._click.testing import Result + class CliRunner(ClickCliRunner): def invoke( # type: ignore diff --git a/uv.lock b/uv.lock index cf893201e8..207311edeb 100644 --- a/uv.lock +++ b/uv.lock @@ -1787,7 +1787,7 @@ name = "typer" source = { editable = "." } dependencies = [ { name = "annotated-doc" }, - { name = "click" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "rich" }, { name = "shellingham" }, ] @@ -1850,7 +1850,7 @@ tests = [ [package.metadata] requires-dist = [ { name = "annotated-doc", specifier = ">=0.0.2" }, - { name = "click", specifier = ">=8.2.1,<8.4" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "rich", specifier = ">=13.8.0" }, { name = "shellingham", specifier = ">=1.3.0" }, ]