diff --git a/.github/DISCUSSION_TEMPLATE/questions.yml b/.github/DISCUSSION_TEMPLATE/questions.yml
index 92d4d7a143..c4f3256d44 100644
--- a/.github/DISCUSSION_TEMPLATE/questions.yml
+++ b/.github/DISCUSSION_TEMPLATE/questions.yml
@@ -36,8 +36,6 @@ body:
required: true
- label: I already read and followed all the tutorials in the docs and didn't find an answer.
required: true
- - label: I already checked if it is not related to Typer but to [Click](https://github.com/pallets/click).
- required: true
- type: checkboxes
id: help
attributes:
diff --git a/CITATION.cff b/CITATION.cff
index 5fe4aa2d34..b4902c6ab1 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -18,5 +18,6 @@ abstract: >-
Typer, build great CLIs. Easy to code. Based on Python type hints.
keywords:
- typer
- - click
+ - cli
+ - python
license: MIT
diff --git a/README.md b/README.md
index 80b0b20ebd..43a46c5aa0 100644
--- a/README.md
+++ b/README.md
@@ -352,11 +352,20 @@ For a more complete example including more features, see the vendored Click version 8.3.1 and adds a layer on top of it.
-**Typer** mainly adds a layer on top of Click, making the code simpler and easier to use, with autocompletion everywhere, etc, but providing all the powerful features of Click underneath.
+Typer aims to make the code simpler and easier to use, with autocompletion everywhere, etc, while still providing many of the powerful features of Click underneath.
As someone pointed out: ["Nice to see it is built on Click but adds the type stuff. Me gusta!"](https://twitter.com/fishnets88/status/1210126833745838080)
diff --git a/docs/features.md b/docs/features.md
index bd11fa5a9d..25490ec610 100644
--- a/docs/features.md
+++ b/docs/features.md
@@ -63,14 +63,6 @@ Auto completion works when you create a package (installable with `pip`). Or whe
///
-/// tip
-
-**Typer**'s completion is implemented internally, it uses ideas and components from Click and ideas from `click-completion`, but it doesn't use `click-completion` and re-implements some of the relevant parts of Click.
-
-Then it extends those ideas with features and bug fixes. For example, **Typer** programs also support modern versions of PowerShell (e.g. in Windows 10) among all the other shells.
-
-///
-
## Tested
* 100% test coverage.
diff --git a/docs/index.md b/docs/index.md
index 1add7a39dc..fec3eb402e 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -360,11 +360,20 @@ For a more complete example including more features, see the = 8.2.1, < 8.4",
"shellingham >=1.3.0",
"rich >=13.8.0",
"annotated-doc >=0.0.2",
+ "colorama; platform_system == 'Windows'",
]
readme = "README.md"
@@ -189,7 +189,7 @@ ignore = [
"docs_src/*" = ["TID"]
[tool.ruff.lint.isort]
-known-third-party = ["typer", "click"]
+known-third-party = ["typer"]
# For docs_src/subcommands/tutorial003/
known-first-party = ["reigns", "towns", "lands", "items", "users"]
diff --git a/tests/assets/completion_argument.py b/tests/assets/completion_argument.py
index f91e2b7cfb..e2754c4357 100644
--- a/tests/assets/completion_argument.py
+++ b/tests/assets/completion_argument.py
@@ -1,10 +1,10 @@
-import click
import typer
+from typer import _click
app = typer.Typer()
-def shell_complete(ctx: click.Context, param: click.Parameter, incomplete: str):
+def shell_complete(ctx: _click.Context, param: _click.Parameter, incomplete: str):
typer.echo(f"ctx: {ctx.info_name}", err=True)
typer.echo(f"arg is: {param.name}", err=True)
typer.echo(f"incomplete is: {incomplete}", err=True)
diff --git a/tests/atomic_write_example.py b/tests/atomic_write_example.py
new file mode 100644
index 0000000000..a190cc154f
--- /dev/null
+++ b/tests/atomic_write_example.py
@@ -0,0 +1,63 @@
+import time
+
+import typer
+
+app = typer.Typer()
+
+
+@app.command()
+def write_atomic(
+ config: typer.FileText = typer.Option(..., mode="w", atomic=True),
+ pause: float = typer.Option(0.3),
+) -> None:
+ config.write("atomic-content-1\n")
+ config.flush()
+ typer.echo("halfway")
+ time.sleep(pause)
+ config.write("atomic-content-2\n")
+ config.flush()
+ typer.echo("written atomically")
+
+
+@app.command()
+def write_atomic_binary(
+ config: typer.FileBinaryWrite = typer.Option(..., atomic=True, lazy=False),
+) -> None:
+ config.write(b"\x00\x01binary-atomic\n")
+ typer.echo("written binary atomically")
+
+
+@app.command()
+def api_atomic(
+ config: typer.FileText = typer.Option(..., mode="w", atomic=True, lazy=False),
+) -> None:
+ typer.echo(f"name={config.name}")
+ typer.echo(f"repr={repr(config)}")
+ with config as entered:
+ typer.echo(f"entered={entered is config}")
+ entered.write("atomic-api-done\n")
+
+
+@app.command()
+def invalid_atomic_append(
+ config: typer.FileText = typer.Option(..., mode="a", atomic=True, lazy=False),
+) -> None:
+ typer.echo(config.name) # pragma: no cover
+
+
+@app.command()
+def invalid_atomic_exclusive(
+ config: typer.FileText = typer.Option(..., mode="x", atomic=True, lazy=False),
+) -> None:
+ typer.echo(config.name) # pragma: no cover
+
+
+@app.command()
+def invalid_atomic_read(
+ config: typer.FileText = typer.Option(..., mode="r", atomic=True, lazy=False),
+) -> None:
+ typer.echo(config.name) # pragma: no cover
+
+
+if __name__ == "__main__":
+ app()
diff --git a/tests/test_annotated.py b/tests/test_annotated.py
index c487eae9d0..af3cc0680b 100644
--- a/tests/test_annotated.py
+++ b/tests/test_annotated.py
@@ -2,6 +2,7 @@
from pathlib import Path
from typing import Annotated
+import pytest
import typer
from typer.testing import CliRunner
@@ -95,3 +96,14 @@ def custom_parser(
result = runner.invoke(app, "/some/quirky/path/implementation")
assert result.exit_code == 0
+
+
+def test_annotated_option_invalid():
+ app = typer.Typer()
+
+ @app.command()
+ def cmd(value: Annotated[str, typer.Option(..., "foo-bar")]):
+ print(value) # pragma: no cover
+
+ with pytest.raises(ValueError, match="Invalid start character for option"):
+ runner.invoke(app, ["--help"], catch_exceptions=False)
diff --git a/tests/test_atomic_file.py b/tests/test_atomic_file.py
new file mode 100644
index 0000000000..010f899034
--- /dev/null
+++ b/tests/test_atomic_file.py
@@ -0,0 +1,122 @@
+import subprocess
+import sys
+from pathlib import Path
+
+import pytest
+
+from . import atomic_write_example as mod
+
+
+def test_atomic_write(tmp_path: Path) -> None:
+ original_content = "existing-content\n"
+ output_file = tmp_path / "atomic-write-target.txt"
+ output_file.write_text(original_content, encoding="utf-8")
+
+ process = subprocess.Popen(
+ [
+ sys.executable,
+ "-m",
+ "coverage",
+ "run",
+ mod.__file__,
+ "write-atomic",
+ f"--config={output_file}",
+ ],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ )
+ assert process.stdout is not None
+
+ # Halfway of writing the file, check that the original content is still there
+ halfway_line = process.stdout.readline().strip()
+ assert halfway_line == "halfway"
+ assert output_file.read_text(encoding="utf-8") == original_content
+
+ # Only at the end, the full new content is visible
+ stdout, stderr = process.communicate(timeout=5)
+ assert process.returncode == 0, stderr
+ assert "written atomically" in stdout
+ assert (
+ output_file.read_text(encoding="utf-8")
+ == "atomic-content-1\natomic-content-2\n"
+ )
+
+
+def test_atomic_binary_write(tmp_path: Path) -> None:
+ output_file = tmp_path / "atomic-binary.bin"
+
+ result = subprocess.run(
+ [
+ sys.executable,
+ "-m",
+ "coverage",
+ "run",
+ mod.__file__,
+ "write-atomic-binary",
+ f"--config={output_file}",
+ ],
+ capture_output=True,
+ encoding="utf-8",
+ )
+
+ assert result.returncode == 0, result.stderr
+ assert "written binary atomically" in result.stdout
+ assert output_file.read_bytes() == b"\x00\x01binary-atomic\n"
+
+
+def test_atomic_api(tmp_path: Path) -> None:
+ output_file = tmp_path / "atomic-api.txt"
+
+ result = subprocess.run(
+ [
+ sys.executable,
+ "-m",
+ "coverage",
+ "run",
+ mod.__file__,
+ "api-atomic",
+ f"--config={output_file}",
+ ],
+ capture_output=True,
+ encoding="utf-8",
+ )
+
+ assert result.returncode == 0, result.stderr
+ assert f"name={output_file}" in result.stdout
+ assert "repr=<_io.TextIOWrapper" in result.stdout
+ assert "entered=True" in result.stdout
+ assert output_file.read_text(encoding="utf-8") == "atomic-api-done\n"
+
+
+@pytest.mark.parametrize(
+ ("command_name", "expected_message"),
+ [
+ ("invalid-atomic-append", "Appending to an existing file is not supported"),
+ ("invalid-atomic-exclusive", "Use the `overwrite`-parameter instead."),
+ ("invalid-atomic-read", "Atomic writes only make sense with `w`-mode."),
+ ],
+)
+def test_atomic_mode_invalid_options(
+ tmp_path: Path, command_name: str, expected_message: str
+) -> None:
+ output_file = tmp_path / "atomic-invalid-mode.txt"
+ output_file.write_text("existing-content\n", encoding="utf-8")
+
+ result = subprocess.run(
+ [
+ sys.executable,
+ "-m",
+ "coverage",
+ "run",
+ mod.__file__,
+ command_name,
+ f"--config={output_file}",
+ ],
+ capture_output=True,
+ encoding="utf-8",
+ )
+
+ assert result.returncode != 0
+ combined_output = f"{result.stdout}\n{result.stderr}"
+ assert expected_message in combined_output
diff --git a/tests/test_cli/test_help.py b/tests/test_cli/test_help.py
index 64e5495c9a..e829c5801b 100644
--- a/tests/test_cli/test_help.py
+++ b/tests/test_cli/test_help.py
@@ -1,6 +1,11 @@
import subprocess
import sys
+import typer
+from typer.testing import CliRunner
+
+runner = CliRunner()
+
def test_script_help():
result = subprocess.run(
@@ -36,3 +41,119 @@ def test_not_python():
encoding="utf-8",
)
assert "Could not import as Python file" in result.stderr
+
+
+def test_short_help() -> None:
+ app = typer.Typer(
+ rich_markup_mode=None,
+ context_settings={"max_content_width": 50},
+ )
+
+ @app.command(help=" \n\t ")
+ def empty() -> None:
+ pass # pragma: no cover
+
+ @app.command(help="\b first sentence.")
+ def marker() -> None:
+ pass # pragma: no cover
+
+ # Forcing truncation
+ @app.command(help=f"{'x' * 30} {'y' * 5} z trailing")
+ def long() -> None:
+ pass # pragma: no cover
+
+ result = runner.invoke(app, ["--help"], terminal_width=50)
+ assert result.exit_code == 0
+ assert "empty" in result.output
+ assert "marker" in result.output
+ assert "long" in result.output
+ assert "first sentence." in result.output
+ assert f"{'x' * 30}..." in result.output
+
+
+def test_help_wrapping() -> None:
+ app = typer.Typer(
+ rich_markup_mode=None,
+ context_settings={"max_content_width": 50},
+ )
+
+ @app.command(
+ help=(
+ "Wrapped paragraph has enough words to wrap in help output.\n"
+ "\n"
+ "\n"
+ "\b\n"
+ "RAW-LINE-ONE stays on one line even with many many many words.\n"
+ "RAW-LINE-TWO keeps original formatting.\n"
+ "\n"
+ "Final paragraph wraps normally as well."
+ )
+ )
+ def cmd() -> None:
+ pass # pragma: no cover
+
+ result = runner.invoke(app, ["cmd", "--help"], terminal_width=50)
+ assert result.exit_code == 0
+ assert "Wrapped paragraph has enough words to wrap" in result.output
+ assert (
+ "RAW-LINE-ONE stays on one line even with many many many words."
+ in result.output
+ )
+ assert "RAW-LINE-TWO keeps original formatting." in result.output
+ assert "Final paragraph wraps normally as well." in result.output
+
+
+def test_help_wrapping_long_name() -> None:
+ app = typer.Typer(rich_markup_mode=None)
+
+ @app.command()
+ def cmd(value: str) -> None:
+ pass # pragma: no cover
+
+ result = runner.invoke(
+ app,
+ ["cmd", "--help"],
+ terminal_width=40,
+ prog_name="very-long-program-name-that-forces-wrap",
+ )
+ assert result.exit_code == 0
+
+ output_lines = result.output.splitlines()
+ usage_idx = output_lines.index("Usage: very-long-program-name-that-forces-wrap ")
+ args_line = output_lines[usage_idx + 1]
+ assert args_line.lstrip() == "[OPTIONS] VALUE"
+ assert args_line.startswith(" ")
+
+
+def test_format_long_help_option() -> None:
+ app = typer.Typer(rich_markup_mode=None)
+
+ @app.command()
+ def cmd(
+ very_long: str = typer.Option(
+ ...,
+ "--this-is-a-very-very-very-long-option-name",
+ help="Description is rendered in the next line for long option labels.",
+ ),
+ ) -> None:
+ pass # pragma: no cover
+
+ result = runner.invoke(app, ["cmd", "--help"], terminal_width=80)
+ assert result.exit_code == 0
+
+ output_lines = result.output.splitlines()
+ option_idx = next(
+ i
+ for i, line in enumerate(output_lines)
+ if "--this-is-a-very-very-very-long-option-name" in line
+ )
+ assert "Description is rendered" not in output_lines[option_idx]
+ first_desc_line = output_lines[option_idx + 1]
+ assert first_desc_line.lstrip().startswith("Description is rendered")
+ continuation_block = " ".join(
+ line.strip() for line in output_lines[option_idx + 1 :] if line.startswith(" ")
+ )
+ assert (
+ "Description is rendered in the next line for long option labels."
+ in continuation_block
+ )
diff --git a/tests/test_cli/test_parser.py b/tests/test_cli/test_parser.py
new file mode 100644
index 0000000000..9d9b641e4b
--- /dev/null
+++ b/tests/test_cli/test_parser.py
@@ -0,0 +1,67 @@
+import subprocess
+import sys
+
+import typer
+from typer.testing import CliRunner
+
+runner = CliRunner()
+
+
+def test_double_dash() -> None:
+ result = subprocess.run(
+ [
+ sys.executable,
+ "-m",
+ "coverage",
+ "run",
+ "-m",
+ "typer",
+ "tests/assets/cli/sample.py",
+ "run",
+ "hello",
+ "--",
+ "--name",
+ "Camila",
+ ],
+ capture_output=True,
+ encoding="utf-8",
+ )
+ assert "Got unexpected extra argument" in result.stderr
+ assert "--name Camila" in result.stderr
+
+
+def test_unknown_short_option() -> None:
+ result = subprocess.run(
+ [
+ sys.executable,
+ "-m",
+ "coverage",
+ "run",
+ "-m",
+ "typer",
+ "tests/assets/cli/sample.py",
+ "run",
+ "hello",
+ "-x",
+ ],
+ capture_output=True,
+ encoding="utf-8",
+ )
+ assert "No such option: -x" in result.stderr
+
+
+def test_ignore_unknown_short_option() -> None:
+ app = typer.Typer(
+ context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
+ )
+
+ @app.command()
+ def main(
+ ctx: typer.Context, all_: bool = typer.Option(False, "--all", "-a")
+ ) -> None:
+ assert all_
+ print(ctx.args)
+
+ result = runner.invoke(app, ["-azq"])
+ assert result.exit_code == 0
+ assert "['-zq']" in result.output
diff --git a/tests/test_cli/test_program_name.py b/tests/test_cli/test_program_name.py
new file mode 100644
index 0000000000..4c20942596
--- /dev/null
+++ b/tests/test_cli/test_program_name.py
@@ -0,0 +1,13 @@
+from typer import _click
+
+
+def test_detect_program_name_submodule_path() -> None:
+ class MainModule:
+ __package__ = "example"
+
+ program_name = _click.utils._detect_program_name(
+ path="/tmp/cli.py",
+ _main=MainModule(),
+ )
+
+ assert program_name == "python -m example.cli"
diff --git a/tests/test_completion/choice_case_insensitive_example.py b/tests/test_completion/choice_case_insensitive_example.py
new file mode 100644
index 0000000000..1f208dd0b5
--- /dev/null
+++ b/tests/test_completion/choice_case_insensitive_example.py
@@ -0,0 +1,19 @@
+from enum import Enum
+
+import typer
+
+app = typer.Typer()
+
+
+class User(str, Enum):
+ rick = "rick"
+ morty = "morty"
+
+
+@app.command()
+def main(name: User = typer.Option(User.rick, "--name", case_sensitive=False)):
+ print(name.value)
+
+
+if __name__ == "__main__":
+ app()
diff --git a/tests/test_completion/choice_example.py b/tests/test_completion/choice_example.py
new file mode 100644
index 0000000000..2c47d0fdc4
--- /dev/null
+++ b/tests/test_completion/choice_example.py
@@ -0,0 +1,19 @@
+from enum import Enum
+
+import typer
+
+app = typer.Typer()
+
+
+class User(str, Enum):
+ rick = "rick"
+ morty = "morty"
+
+
+@app.command()
+def main(name: User = typer.Option(User.rick, "--name")):
+ print(name.value)
+
+
+if __name__ == "__main__":
+ app()
diff --git a/tests/test_completion/completion_option_then_argument.py b/tests/test_completion/completion_option_then_argument.py
new file mode 100644
index 0000000000..ae5abd8d2e
--- /dev/null
+++ b/tests/test_completion/completion_option_then_argument.py
@@ -0,0 +1,23 @@
+import typer
+
+app = typer.Typer()
+
+
+def complete_name(ctx, args, incomplete):
+ return ["opt-choice"] # pragma: no cover
+
+
+def complete_target(ctx, args, incomplete):
+ return ["arg-choice"]
+
+
+@app.command()
+def main(
+ name: str = typer.Option(..., "--name", autocompletion=complete_name),
+ target: str = typer.Argument(..., autocompletion=complete_target),
+):
+ print(name, target) # pragma: no cover
+
+
+if __name__ == "__main__":
+ app()
diff --git a/tests/test_completion/file_example.py b/tests/test_completion/file_example.py
new file mode 100644
index 0000000000..56556d2006
--- /dev/null
+++ b/tests/test_completion/file_example.py
@@ -0,0 +1,12 @@
+import typer
+
+app = typer.Typer()
+
+
+@app.command()
+def main(config: typer.FileText = typer.Option(...)):
+ print(config.read())
+
+
+if __name__ == "__main__":
+ app()
diff --git a/tests/test_completion/test_completion.py b/tests/test_completion/test_completion.py
index 049ec4f6af..e9be4e25d1 100644
--- a/tests/test_completion/test_completion.py
+++ b/tests/test_completion/test_completion.py
@@ -3,9 +3,12 @@
import sys
from pathlib import Path
+from typer._click.shell_completion import CompletionItem
+
from docs_src.typer_app import tutorial001_py310 as mod
from ..utils import needs_bash, needs_linux, requires_completion_permission
+from . import completion_option_then_argument as mod_option_arg
@needs_bash
@@ -164,3 +167,27 @@ def test_completion_source_pwsh():
"Register-ArgumentCompleter -Native -CommandName tutorial001_py310.py -ScriptBlock $scriptblock"
in result.stdout
)
+
+
+def test_completion_option_argument() -> None:
+ file_name = Path(mod_option_arg.__file__).name
+ result = subprocess.run(
+ [sys.executable, "-m", "coverage", "run", mod_option_arg.__file__, " "],
+ capture_output=True,
+ encoding="utf-8",
+ env={
+ **os.environ,
+ f"_{file_name.upper()}_COMPLETE": "complete_bash",
+ "COMP_WORDS": f"{file_name} --name chosen ",
+ "COMP_CWORD": "3",
+ },
+ )
+ assert "arg-choice" in result.stdout
+ assert "opt-choice" not in result.stdout
+
+
+def test_completion_item_getattr() -> None:
+ item = CompletionItem("demo", source="envvar")
+
+ assert item.source == "envvar"
+ assert item.missing is None
diff --git a/tests/test_completion/test_completion_choice.py b/tests/test_completion/test_completion_choice.py
new file mode 100644
index 0000000000..7cc4d81235
--- /dev/null
+++ b/tests/test_completion/test_completion_choice.py
@@ -0,0 +1,32 @@
+import os
+import subprocess
+import sys
+
+from . import choice_example as mod
+
+
+def test_script() -> None:
+ result = subprocess.run(
+ [sys.executable, "-m", "coverage", "run", mod.__file__, "--name", "rick"],
+ capture_output=True,
+ encoding="utf-8",
+ )
+ assert result.returncode == 0
+ assert "rick" in result.stdout
+
+
+def test_completion_choice_bash() -> None:
+ result = subprocess.run(
+ [sys.executable, "-m", "coverage", "run", mod.__file__, " "],
+ capture_output=True,
+ encoding="utf-8",
+ env={
+ **os.environ,
+ "_CHOICE_EXAMPLE.PY_COMPLETE": "complete_bash",
+ "COMP_WORDS": "choice_example.py --name mo",
+ "COMP_CWORD": "2",
+ },
+ )
+ assert result.returncode == 0
+ assert "morty" in result.stdout
+ assert "rick" not in result.stdout
diff --git a/tests/test_completion/test_completion_choice_no_case.py b/tests/test_completion/test_completion_choice_no_case.py
new file mode 100644
index 0000000000..f14097761d
--- /dev/null
+++ b/tests/test_completion/test_completion_choice_no_case.py
@@ -0,0 +1,32 @@
+import os
+import subprocess
+import sys
+
+from . import choice_case_insensitive_example as mod
+
+
+def test_script() -> None:
+ result = subprocess.run(
+ [sys.executable, "-m", "coverage", "run", mod.__file__, "--name", "rick"],
+ capture_output=True,
+ encoding="utf-8",
+ )
+ assert result.returncode == 0
+ assert "rick" in result.stdout
+
+
+def test_completion_choice_bash_case_insensitive() -> None:
+ result = subprocess.run(
+ [sys.executable, "-m", "coverage", "run", mod.__file__, " "],
+ capture_output=True,
+ encoding="utf-8",
+ env={
+ **os.environ,
+ "_CHOICE_CASE_INSENSITIVE_EXAMPLE.PY_COMPLETE": "complete_bash",
+ "COMP_WORDS": "choice_case_insensitive_example.py --name MO",
+ "COMP_CWORD": "2",
+ },
+ )
+ assert result.returncode == 0
+ assert "morty" in result.stdout
+ assert "rick" not in result.stdout
diff --git a/tests/test_completion/test_completion_file.py b/tests/test_completion/test_completion_file.py
new file mode 100644
index 0000000000..782d3a04bc
--- /dev/null
+++ b/tests/test_completion/test_completion_file.py
@@ -0,0 +1,39 @@
+import os
+import subprocess
+import sys
+
+from . import file_example as mod
+
+
+def test_script() -> None:
+ result = subprocess.run(
+ [
+ sys.executable,
+ "-m",
+ "coverage",
+ "run",
+ mod.__file__,
+ "--config",
+ mod.__file__,
+ ],
+ capture_output=True,
+ encoding="utf-8",
+ )
+ assert result.returncode == 0
+ assert "def main(config: typer.FileText = typer.Option(...)):" in result.stdout
+
+
+def test_completion_file_bash() -> None:
+ result = subprocess.run(
+ [sys.executable, "-m", "coverage", "run", mod.__file__, " "],
+ capture_output=True,
+ encoding="utf-8",
+ env={
+ **os.environ,
+ "_FILE_EXAMPLE.PY_COMPLETE": "complete_bash",
+ "COMP_WORDS": "file_example.py --config file_ex",
+ "COMP_CWORD": "2",
+ },
+ )
+ assert result.returncode == 0
+ assert "file_ex" in result.stdout
diff --git a/tests/test_completion/test_completion_option_colon.py b/tests/test_completion/test_completion_option_colon.py
index 6b65d786e5..8818ee4a1a 100644
--- a/tests/test_completion/test_completion_option_colon.py
+++ b/tests/test_completion/test_completion_option_colon.py
@@ -177,6 +177,39 @@ def test_completion_colon_powershell_single():
assert "nvidia/cuda:10.0-devel-ubuntu18.04" not in result.stdout
+def test_completion_powershell_option_equals_value() -> None:
+ result = subprocess.run(
+ [sys.executable, "-m", "coverage", "run", mod.__file__, " "],
+ capture_output=True,
+ encoding="utf-8",
+ env={
+ **os.environ,
+ "_COLON_EXAMPLE.PY_COMPLETE": "complete_powershell",
+ "_TYPER_COMPLETE_ARGS": "colon_example.py --name=alpine",
+ "_TYPER_COMPLETE_WORD_TO_COMPLETE": "--name=alpine",
+ },
+ )
+ assert "alpine:hello" in result.stdout
+ assert "alpine:latest" in result.stdout
+ assert "nvidia/cuda:10.0-devel-ubuntu18.04" not in result.stdout
+
+
+def test_completion_powershell_option_equals_only() -> None:
+ result = subprocess.run(
+ [sys.executable, "-m", "coverage", "run", mod.__file__, " "],
+ capture_output=True,
+ encoding="utf-8",
+ env={
+ **os.environ,
+ "_COLON_EXAMPLE.PY_COMPLETE": "complete_powershell",
+ "_TYPER_COMPLETE_ARGS": "colon_example.py --name=",
+ "_TYPER_COMPLETE_WORD_TO_COMPLETE": "=",
+ },
+ )
+ assert result.returncode == 0
+ assert result.stdout.strip() == ""
+
+
def test_completion_colon_pwsh_all():
result = subprocess.run(
[sys.executable, "-m", "coverage", "run", mod.__file__, " "],
diff --git a/tests/test_core.py b/tests/test_core.py
new file mode 100644
index 0000000000..2cb759918b
--- /dev/null
+++ b/tests/test_core.py
@@ -0,0 +1,379 @@
+from typing import Annotated
+
+import pytest
+import typer
+import typer._completion_shared
+import typer.completion
+from typer import _click
+from typer.core import TyperArgument, TyperCommand, TyperGroup, TyperOption, _split_opt
+from typer.testing import CliRunner
+
+runner = CliRunner()
+
+
+def test_human_readable_name() -> None:
+ app = typer.Typer()
+
+ @app.command()
+ def main(
+ my_arg_1: Annotated[str, typer.Argument()],
+ my_arg_2: Annotated[str, typer.Argument(metavar="META_ARG")],
+ my_opt: Annotated[str, typer.Option()],
+ ):
+ pass # pragma: no cover
+
+ command = typer.main.get_command(app)
+ params = {param.name: param for param in command.params}
+
+ assert params["my_arg_1"].human_readable_name == "MY_ARG_1"
+ assert params["my_arg_2"].human_readable_name == "META_ARG"
+ assert params["my_opt"].human_readable_name == "my_opt"
+
+
+def test_parameter_metavar() -> None:
+ app = typer.Typer(rich_markup_mode=None)
+
+ @app.command()
+ def cmd(name: Annotated[str, typer.Option(metavar="CUSTOM")]) -> None:
+ pass # pragma: no cover
+
+ result = runner.invoke(app, ["--help"])
+ assert result.exit_code == 0
+ assert "--name CUSTOM" in result.output
+
+
+def test_parameter_nargs_gt_1() -> None:
+ param = TyperArgument(param_decls=["value"], type=str, nargs=2)
+ ctx = _click.Context(TyperCommand(name="cmd"))
+
+ assert param.type_cast_value(ctx, ("one", "two")) == ("one", "two")
+
+ with pytest.raises(
+ _click.exceptions.BadParameter, match="Takes 2 values but 1 given."
+ ):
+ param.type_cast_value(ctx, ("one",))
+
+
+def test_parameter_constructor() -> None:
+ # no param_decl and expose_value is False: sets name to None
+ arg = TyperArgument(param_decls=[], expose_value=False)
+ assert arg.name is None
+ assert arg.opts == []
+ assert arg.secondary_opts == []
+
+ # no param_decl and expose_value is True: raises
+ with pytest.raises(TypeError, match="does not have a name."):
+ TyperArgument(param_decls=[], expose_value=True)
+
+ # len(param_decl) > 1: raises
+ with pytest.raises(TypeError, match="take exactly one parameter declaration"):
+ TyperArgument(param_decls=["first", "second"])
+
+ # duplicated identifier in option declarations: raises
+ with pytest.raises(TypeError, match="Name 'name' defined twice"):
+ TyperOption(param_decls=["name", "name"], required=False)
+
+ # same true/false flag in boolean option declaration: raises
+ with pytest.raises(ValueError, match="cannot use the same flag for true/false"):
+ TyperOption(param_decls=["flag", "--flag/--flag"], required=False, is_flag=True)
+
+ # inferred name is not a valid identifier: sets name to None
+ unnamed_option = TyperOption(param_decls=["--123"], required=False)
+ assert unnamed_option.name is None
+
+ # no param_decl and prompt=True: raises
+ with pytest.raises(TypeError, match="'name' is required with 'prompt=True'."):
+ TyperOption(param_decls=[], expose_value=False, prompt=True, required=False)
+
+ # count works
+ option = TyperOption(
+ param_decls=["verbose", "--verbose", "-v"],
+ type=None,
+ default=0,
+ required=False,
+ count=True,
+ )
+ assert isinstance(option.type, _click.types.IntRange)
+ assert option.type.min == 0
+
+
+def test_option_error_hint() -> None:
+ option = TyperOption(
+ param_decls=["name", "--name"],
+ required=False,
+ show_envvar=True,
+ envvar="APP_NAME",
+ )
+ hint = option.get_error_hint(_click.Context(TyperCommand(name="cmd")))
+ assert "(env var: 'APP_NAME')" in hint
+
+
+def test_group_init() -> None:
+ group_no_commands = TyperGroup(name="root", commands=None)
+ assert group_no_commands.commands == {}
+
+ named = TyperCommand(name="named")
+ unnamed = TyperCommand(name=None)
+ group_command_sequence = TyperGroup(name="root", commands=[named, unnamed])
+ assert group_command_sequence.commands == {"named": named}
+
+
+@pytest.mark.parametrize("with_result_callback", [False, True])
+def test_group_result_callback(with_result_callback: bool) -> None:
+ called = {"child": False, "result_callback": False}
+
+ def child_callback() -> None:
+ called["child"] = True
+ return None
+
+ def result_callback(value, **kwargs): # type: ignore[no-untyped-def]
+ called["result_callback"] = True
+ return value
+
+ child = TyperCommand(name="child", callback=child_callback)
+ group = TyperGroup(
+ name="root",
+ commands={"child": child},
+ result_callback=result_callback if with_result_callback else None,
+ )
+ ctx = group.make_context("root", ["child"])
+
+ result = group.invoke(ctx)
+
+ assert result is None
+ assert called["child"] is True
+ assert called["result_callback"] is with_result_callback
+ assert ctx.invoked_subcommand == "child"
+
+
+def test_group_add_command() -> None:
+ group = TyperGroup(name="root")
+ unnamed_command = TyperCommand(name=None)
+
+ with pytest.raises(TypeError, match="Command has no name."):
+ group.add_command(unnamed_command)
+
+
+def test_group_click_resolve_command() -> None:
+ child = TyperCommand(name="child")
+ group = TyperGroup(name="root", commands={"child": child})
+ ctx = group.make_context("root", ["CHILD"], token_normalize_func=str.lower)
+
+ cmd_name, cmd, remaining = group._click_resolve_command(ctx, ["CHILD"])
+
+ assert cmd_name == "child"
+ assert cmd is child
+ assert remaining == []
+
+
+@pytest.mark.parametrize(
+ ("envvar", "auto_prefix", "set_env", "expected"),
+ [
+ ("APP_NAME", None, True, "my-precious"),
+ (None, "APP", True, "my-precious"),
+ (None, None, False, None),
+ ],
+)
+def test_option_resolve_envvar(
+ monkeypatch: pytest.MonkeyPatch,
+ envvar: str | None,
+ auto_prefix: str | None,
+ set_env: bool,
+ expected: str | None,
+) -> None:
+ option = TyperOption(
+ param_decls=["name", "--name"],
+ required=False,
+ envvar=envvar,
+ )
+ if set_env:
+ monkeypatch.setenv("APP_NAME", "my-precious")
+
+ ctx = _click.Context(TyperCommand(name="cmd"), auto_envvar_prefix=auto_prefix)
+ assert option.resolve_envvar_value(ctx) == expected
+
+
+def test_option_resolve_envvar_list(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ option = TyperOption(
+ param_decls=["name", "--name"],
+ required=False,
+ envvar=["APP_NAME_1", "APP_NAME_2"],
+ )
+ monkeypatch.delenv("APP_NAME_1", raising=False)
+ monkeypatch.delenv("APP_NAME_2", raising=False)
+ ctx = _click.Context(TyperCommand(name="cmd"))
+
+ assert option.resolve_envvar_value(ctx) is None
+
+
+def test_context_auto_envvar() -> None:
+ app = typer.Typer(context_settings={"auto_envvar_prefix": "APP"})
+ sub_app = typer.Typer()
+
+ @sub_app.command()
+ def clone(ctx: typer.Context) -> None:
+ print(ctx.auto_envvar_prefix)
+
+ app.add_typer(sub_app, name="beth")
+
+ result = runner.invoke(app, ["beth", "clone"])
+ assert result.exit_code == 0
+ assert "APP_BETH_CLONE" in result.stdout
+
+
+def test_context_with_resource() -> None:
+ events: list[str] = []
+
+ class DemoResource:
+ def __enter__(self) -> str:
+ events.append("enter")
+ return "pickle-rick"
+
+ def __exit__(self, *args: object) -> None:
+ events.append("exit")
+
+ app = typer.Typer()
+
+ @app.command()
+ def cmd(ctx: typer.Context) -> None:
+ value = ctx.with_resource(DemoResource())
+ assert value == "pickle-rick"
+ assert events == ["enter"]
+ print("I'm a pickle")
+
+ result = runner.invoke(app)
+
+ assert result.exit_code == 0
+ assert "I'm a pickle" in result.stdout
+ assert events == ["enter", "exit"]
+
+
+def test_context_find_root() -> None:
+ app = typer.Typer()
+ sub_app = typer.Typer()
+
+ @sub_app.command()
+ def child(ctx: typer.Context) -> None:
+ root = ctx.find_root()
+ assert root.parent is None
+ assert root is ctx.parent.parent
+ print("ok")
+
+ app.add_typer(sub_app, name="sub")
+
+ result = runner.invoke(app, ["sub", "child"])
+ assert result.exit_code == 0
+ assert "ok" in result.stdout
+
+
+def test_context_find_object() -> None:
+ class Marker:
+ pass
+
+ marker = Marker()
+ app = typer.Typer()
+
+ @app.callback()
+ def callback(ctx: typer.Context) -> None:
+ ctx.obj = marker
+
+ @app.command()
+ def child(ctx: typer.Context) -> None:
+ assert ctx.find_object(Marker) is marker
+ assert ctx.find_object(str) is None
+ print("ok")
+
+ result = runner.invoke(app, ["child"])
+ assert result.exit_code == 0
+ assert "ok" in result.stdout
+
+
+def test_context_lookup_default_callable() -> None:
+ app = typer.Typer()
+
+ @app.command()
+ def child(ctx: typer.Context) -> None:
+ ctx.default_map = {"planet": lambda: "Earth"}
+ assert ctx.lookup_default("planet") == "Earth"
+ value = ctx.lookup_default("planet", call=False)
+ assert callable(value)
+ print("ok")
+
+ result = runner.invoke(app)
+ assert result.exit_code == 0
+ assert "ok" in result.stdout
+
+
+def test_context_abort() -> None:
+ app = typer.Typer()
+
+ @app.command()
+ def cmd(ctx: typer.Context) -> None:
+ ctx.abort()
+
+ result = runner.invoke(app, standalone_mode=False)
+ assert result.exit_code == 1
+ assert isinstance(result.exception, _click.core.Abort)
+
+
+def test_command_help_disabled() -> None:
+ app = typer.Typer()
+
+ @app.command(add_help_option=False)
+ def cmd() -> None:
+ pass # pragma: no cover
+
+ result = runner.invoke(app, ["--help"], standalone_mode=False)
+ assert result.exit_code == 1
+ assert isinstance(result.exception, _click.exceptions.NoSuchOption)
+ assert result.exception.option_name == "--help"
+
+
+def test_command_help_deprecated() -> None:
+ app = typer.Typer(rich_markup_mode=None, epilog="Built with love")
+
+ @app.command(short_help="Shorty", help="Regular help text.", deprecated=True)
+ def one() -> None:
+ pass # pragma: no cover
+
+ @app.command()
+ def two() -> None:
+ pass # pragma: no cover
+
+ result = runner.invoke(app, ["one", "--help"])
+ assert result.exit_code == 0
+ assert "Regular help text. (DEPRECATED)" in result.output
+
+ result = runner.invoke(app, ["--help"])
+ assert result.exit_code == 0
+ assert "Built with love" in result.output
+ assert "oneShorty(DEPRECATED)" in result.output.replace(" ", "")
+
+
+@pytest.mark.parametrize(
+ ("value", "expected_prefix", "expected_opt"),
+ [
+ ("--verbose", "--", "verbose"),
+ ("//verbose", "//", "verbose"),
+ ("-verbose", "-", "verbose"),
+ ("verbose", "", "verbose"),
+ ],
+)
+def test_split_opt(value: str, expected_prefix: str, expected_opt: str) -> None:
+ prefix, opt = _split_opt(value)
+ assert prefix == expected_prefix
+ assert opt == expected_opt
+
+
+def test_nargs_default_map():
+ app = typer.Typer()
+
+ @app.command()
+ def main(names: list[str] = typer.Option(None)):
+ print(names) # pragma: no cover
+
+ result = runner.invoke(app, [], default_map={"names": "not-a-list"})
+ assert result.exit_code == 2
+ assert "Invalid value" in result.output
diff --git a/tests/test_launch.py b/tests/test_launch.py
index c15d6c57da..1a97c197aa 100644
--- a/tests/test_launch.py
+++ b/tests/test_launch.py
@@ -1,10 +1,11 @@
+import io
import subprocess
from unittest.mock import patch
import pytest
import typer
-from tests.utils import needs_windows
+from tests.utils import needs_linux, needs_macos, needs_windows
url = "http://example.com"
@@ -51,19 +52,94 @@ def test_launch_url_no_xdg_open():
mock_webbrowser_open.assert_called_once_with(url)
-def test_calls_original_launch_when_not_passing_urls():
- with patch("typer.main.click.launch", return_value=0) as launch_mock:
- typer.launch("not a url")
+@pytest.fixture
+def allow_dev_null(monkeypatch):
+ real_open = open
- launch_mock.assert_called_once_with("not a url", wait=False, locate=False)
+ def fake_open(path, *args, **kwargs):
+ if path == "/dev/null":
+ return io.StringIO()
+ return real_open(path, *args, **kwargs) # pragma: no cover
+
+ monkeypatch.setattr("builtins.open", fake_open)
+
+
+@needs_macos
+def test_open_url_macos(monkeypatch, allow_dev_null):
+ recorded: list[list[str]] = []
+
+ class Proc:
+ def wait(self) -> int:
+ return 42
+
+ def fake_popen(args, **kwargs):
+ recorded.append(list(args))
+ return Proc()
+
+ monkeypatch.setattr(subprocess, "Popen", fake_popen)
+
+ assert typer.launch("/path/to/file", wait=True, locate=True) == 42
+ assert recorded[0][:3] == ["open", "-W", "-R"]
+ assert recorded[0][-1] == "/path/to/file"
+
+
+@needs_windows
+def test_launch_files_windows(monkeypatch):
+ calls: list[list[str]] = []
+
+ def fake_call(args):
+ calls.append(list(args))
+ return 0
+
+ monkeypatch.setattr(subprocess, "call", fake_call)
+
+ assert typer.launch("C:/Tools/readme.txt", wait=True, locate=False) == 0
+ assert typer.launch("file:///C:/tmp/a.txt", wait=False, locate=True) == 0
+ assert calls.pop(0) == ["start", "/WAIT", "", "C:/Tools/readme.txt"]
+ assert calls.pop(0) == ["explorer", "/select,/C:/tmp/a.txt"]
+
+ monkeypatch.setattr(subprocess, "call", lambda a: (_ for _ in ()).throw(OSError()))
+ assert typer.launch("D:/no/such/file.txt", wait=False, locate=False) == 127
+
+
+@needs_linux
+def test_open_url_linux_wait(monkeypatch):
+ class Proc:
+ def __init__(self, code: int = 0) -> None:
+ self._code = code
+
+ def wait(self) -> int:
+ return self._code
+
+ monkeypatch.setattr(subprocess, "Popen", lambda *a, **k: Proc(7))
+
+ assert typer.launch("/file", wait=True, locate=False) == 7
+
+
+@needs_linux
+def test_open_url_linux_locate(monkeypatch):
+ recorded: list[list[str]] = []
+
+ class Proc:
+ def wait(self) -> int:
+ return 0 # pragma: no cover
+
+ def fake_popen(args, **kwargs):
+ recorded.append(list(args))
+ return Proc()
+
+ monkeypatch.setattr(subprocess, "Popen", fake_popen)
+
+ assert typer.launch("/tmp/sub/file.txt", wait=False, locate=True) == 0
+ assert recorded[-1] == ["xdg-open", "/tmp/sub"]
@needs_windows
def test_launch_file():
with (
- patch("click._termui_impl.sys.platform", "win32"),
- patch("click._termui_impl.WIN", True),
- patch("click._termui_impl.CYGWIN", False),
+ patch("typer._click._termui_impl.sys.platform", "win32"),
+ patch("typer._click._termui_impl.WIN", True),
+ patch("typer._click._termui_impl.CYGWIN", False),
patch("subprocess.call", return_value=0) as call_mock,
):
result = typer.launch("C:/tmp/file.txt", locate=True)
diff --git a/tests/test_others.py b/tests/test_others.py
index b389ed353f..bd83e60a07 100644
--- a/tests/test_others.py
+++ b/tests/test_others.py
@@ -6,12 +6,11 @@
from typing import Annotated
from unittest import mock
-import click
import pytest
import typer
import typer._completion_shared
import typer.completion
-from typer.core import _split_opt
+from typer import _click
from typer.main import solve_typer_info_defaults, solve_typer_info_help
from typer.models import ParameterInfo, TyperInfo
from typer.testing import CliRunner
@@ -37,14 +36,14 @@ def test_too_many_parsers():
def custom_parser(value: str) -> int:
return int(value) # pragma: no cover
- class CustomClickParser(click.ParamType):
+ class CustomClickParser(_click.types.ParamType):
name = "custom_parser"
def convert(
self,
value: str,
- param: click.Parameter | None,
- ctx: click.Context | None,
+ param: _click.Parameter | None,
+ ctx: _click.Context | None,
) -> typing.Any:
return int(value) # pragma: no cover
@@ -61,14 +60,14 @@ def test_valid_parser_permutations():
def custom_parser(value: str) -> int:
return int(value) # pragma: no cover
- class CustomClickParser(click.ParamType):
+ class CustomClickParser(_click.types.ParamType):
name = "custom_parser"
def convert(
self,
value: str,
- param: click.Parameter | None,
- ctx: click.Context | None,
+ param: _click.Parameter | None,
+ ctx: _click.Context | None,
) -> typing.Any:
return int(value) # pragma: no cover
@@ -104,7 +103,7 @@ def name_callback(ctx, param, val1, val2):
def main(name: str = typer.Option(..., callback=name_callback)):
pass # pragma: no cover
- with pytest.raises(click.ClickException) as exc_info:
+ with pytest.raises(_click.ClickException) as exc_info:
runner.invoke(app, ["--name", "Camila"])
assert (
exc_info.value.message == "Too many CLI parameter callback function parameters"
@@ -127,6 +126,135 @@ def main(name: str = typer.Option(..., callback=name_callback)):
assert "value is: Camila" in result.stdout
+@pytest.mark.parametrize(
+ ("param_hint", "option_decls", "expected_message"),
+ [
+ ("--name", (), "Invalid value for --name"),
+ (None, ("--name", "-n"), "Invalid value for '--name' / '-n'"),
+ ],
+)
+def test_bad_parameter_callback(
+ param_hint: str | None, option_decls: tuple[str, ...], expected_message: str
+) -> None:
+ app = typer.Typer()
+
+ def my_bad(value: str) -> str:
+ kwargs = {"param_hint": param_hint} if param_hint is not None else {}
+ raise typer.BadParameter("custom validation failed", **kwargs)
+
+ @app.command()
+ def main(name: str = typer.Option(..., *option_decls, callback=my_bad)) -> None:
+ typer.echo(name) # pragma: no cover
+
+ result = runner.invoke(app, ["--name", "Camila"])
+ assert result.exit_code == 2
+ assert expected_message in result.stderr
+ assert "custom validation failed" in result.stderr
+
+
+def test_bad_parameter_main() -> None:
+ app = typer.Typer()
+
+ @app.command()
+ def main() -> None:
+ raise typer.BadParameter("custom validation failed")
+
+ result = runner.invoke(app, [])
+ assert result.exit_code == 2
+ assert "Invalid value: custom validation failed" in result.stderr
+
+
+@pytest.mark.parametrize(
+ ("kw", "msg"),
+ [
+ (
+ {"param_hint": ["--name", "-n"], "param_type": "parameter"},
+ "Missing parameter '--name' / '-n'.",
+ ),
+ ({"param_type": "value"}, "Missing value."),
+ ],
+)
+def test_missing_parameter_msg(kw: dict[str, object], msg: str) -> None:
+ app = typer.Typer(rich_markup_mode=None)
+
+ @app.command()
+ def main() -> None:
+ raise typer._click.exceptions.MissingParameter(**kw)
+
+ result = runner.invoke(app, [])
+ assert result.exit_code == 2
+ assert msg in result.stderr
+
+
+def test_missing_parameter_callback_msg() -> None:
+ def my_cb(ctx: typer.Context, param: typer.CallbackParam, value: str) -> str:
+ raise typer._click.exceptions.MissingParameter(
+ message="My bad", ctx=ctx, param=param, param_type="parameter"
+ )
+
+ app = typer.Typer(rich_markup_mode=None)
+
+ @app.command()
+ def main(
+ mode: Annotated[
+ typing.Literal["alpha", "beta"],
+ typer.Option(..., "--mode", callback=my_cb),
+ ],
+ ) -> None:
+ typer.echo(mode) # pragma: no cover
+
+ result = runner.invoke(app, ["--mode", "alpha"])
+ assert result.exit_code == 2
+ assert "Missing parameter '--mode'." in result.stderr
+ assert "My bad. Choose from:" in result.stderr
+ assert "alpha" in result.stderr
+ assert "beta" in result.stderr
+ result_msg = runner.invoke(app, ["--mode", "alpha"], standalone_mode=False)
+ assert isinstance(result_msg.exception, typer._click.exceptions.MissingParameter)
+ assert str(result_msg.exception) == "My bad"
+
+
+def test_missing_parameter_str() -> None:
+ def my_cb(ctx: typer.Context, param: typer.CallbackParam, value: str) -> str:
+ raise typer._click.exceptions.MissingParameter(ctx=ctx, param=param)
+
+ app = typer.Typer()
+
+ @app.command()
+ def main(mode: str = typer.Option(..., "--mode", callback=my_cb)) -> None:
+ typer.echo(mode) # pragma: no cover
+
+ result2 = runner.invoke(app, ["--mode", "alpha"], standalone_mode=False)
+ assert isinstance(result2.exception, typer._click.exceptions.MissingParameter)
+ assert str(result2.exception) == "Missing parameter: mode"
+
+
+def test_click_exception_show_default_file() -> None:
+ app = typer.Typer(rich_markup_mode=None)
+
+ @app.command()
+ def main() -> None:
+ raise typer._click.ClickException("custom click failure")
+
+ result = runner.invoke(app, [])
+ assert result.exit_code == 1
+ assert "custom click" in result.stderr
+ assert "failure" in result.stderr
+
+
+def test_no_args_is_help_show() -> None:
+ app = typer.Typer(rich_markup_mode=None)
+
+ @app.callback(invoke_without_command=True, no_args_is_help=True)
+ def main() -> None:
+ return None # pragma: no cover
+
+ result = runner.invoke(app, [])
+ assert result.exit_code == 2
+ assert "Usage:" in result.stderr
+ assert "Show this message and exit." in result.stderr
+
+
def test_callback_3_untyped_parameters():
app = typer.Typer()
@@ -169,6 +297,18 @@ def main(
assert "Hello World" in result.stdout
+def test_multiple_bool_flags() -> None:
+ app = typer.Typer()
+
+ @app.command()
+ def main(choices: list[bool] = typer.Option([], "--accept/--reject")) -> None:
+ print(choices)
+
+ result = runner.invoke(app, ["--accept", "--reject", "--accept"])
+ assert result.exit_code == 0
+ assert "[True, False, True]" in result.stdout
+
+
def test_empty_list_default_generator():
def empty_list() -> list[str]:
return []
@@ -266,7 +406,7 @@ def name_callback(ctx, args, incomplete, val2):
def main(name: str = typer.Option(..., autocompletion=name_callback)):
pass # pragma: no cover
- with pytest.raises(click.ClickException) as exc_info:
+ with pytest.raises(_click.ClickException) as exc_info:
runner.invoke(app, ["--name", "Camila"])
assert exc_info.value.message == "Invalid autocompletion callback parameters: val2"
@@ -303,24 +443,6 @@ def main(name: str):
assert "Show this message and exit." in result.stdout
-def test_split_opt():
- prefix, opt = _split_opt("--verbose")
- assert prefix == "--"
- assert opt == "verbose"
-
- prefix, opt = _split_opt("//verbose")
- assert prefix == "//"
- assert opt == "verbose"
-
- prefix, opt = _split_opt("-verbose")
- assert prefix == "-"
- assert opt == "verbose"
-
- prefix, opt = _split_opt("verbose")
- assert prefix == ""
- assert opt == "verbose"
-
-
def test_options_metadata_typer_default():
app = typer.Typer(options_metavar="[options]")
diff --git a/tests/test_progress_bar.py b/tests/test_progress_bar.py
new file mode 100644
index 0000000000..c684476dec
--- /dev/null
+++ b/tests/test_progress_bar.py
@@ -0,0 +1,397 @@
+"""
+Tests for the Progress bar functionality.
+Created after vendoring Click to ensure test coverage is back up to 100%.
+"""
+
+import io
+import shutil
+
+import pytest
+import typer
+from typer import progressbar
+from typer._click import _termui_impl
+from typer.testing import CliRunner
+
+runner = CliRunner()
+
+
+def _fake_clock(monkeypatch: pytest.MonkeyPatch) -> list[float]:
+ clock = [0.0]
+ monkeypatch.setattr(_termui_impl.time, "time", lambda: clock[0])
+ return clock
+
+
+def _pbar(**kw):
+ return progressbar(file=kw.pop("file", io.StringIO()), **kw)
+
+
+@pytest.mark.parametrize(
+ ("iterable", "length", "hidden", "label", "expected_count"),
+ [
+ (["a", "b"], None, False, "Processing", 2),
+ (None, 3, False, "Counting", 3),
+ (["x", "y"], None, True, "Hidden", 2),
+ ],
+)
+def test_progressbar(iterable, length, hidden, label, expected_count):
+ app = typer.Typer()
+
+ @app.command()
+ def main():
+ bar_out = io.StringIO()
+ count = 0
+ with progressbar(
+ iterable, length=length, hidden=hidden, label=label, file=bar_out
+ ) as bar:
+ for _ in bar:
+ count += 1
+ typer.echo(f"count={count}")
+ typer.echo(f"bar={bar_out.getvalue()!r}")
+
+ result = runner.invoke(app, [])
+ assert result.exit_code == 0, result.output
+ assert f"count={expected_count}" in result.stdout
+ assert (label in result.stdout) == (not hidden)
+
+
+@pytest.mark.parametrize(
+ ("label", "pbar_kw", "must_contain", "must_not_contain"),
+ [
+ pytest.param(
+ "TTY",
+ {
+ "show_pos": True,
+ "show_percent": True,
+ "item_show_func": lambda item: f"item={item}",
+ },
+ ("TTY", "1/1", "100%", "item=x"),
+ (),
+ ),
+ pytest.param(
+ "HeurPct",
+ {},
+ ("HeurPct", "100%"),
+ ("1/1",),
+ ),
+ pytest.param(
+ "HeurPos",
+ {"show_pos": True},
+ ("HeurPos", "1/1"),
+ ("100%",),
+ ),
+ ],
+)
+def test_progressbar_tty(
+ monkeypatch, label: str, pbar_kw: dict, must_contain, must_not_contain
+):
+ monkeypatch.setattr(_termui_impl, "isatty", lambda f: True)
+ _fake_clock(monkeypatch)
+
+ app = typer.Typer()
+
+ @app.command()
+ def main():
+ bar_out = io.StringIO()
+ with progressbar(
+ ["x"],
+ label=label,
+ file=bar_out,
+ bar_template="%(label)s %(info)s",
+ width=1,
+ **pbar_kw,
+ ) as bar:
+ for _ in bar:
+ pass
+ typer.echo(bar_out.getvalue())
+
+ result = runner.invoke(app, [])
+ assert result.exit_code == 0, result.output
+ for part in must_contain:
+ assert part in result.stdout
+ for part in must_not_contain:
+ assert part not in result.stdout
+
+
+def test_progressbar_tty_show_eta(monkeypatch):
+ monkeypatch.setattr(_termui_impl, "isatty", lambda f: True)
+ clock = _fake_clock(monkeypatch)
+ clock[0] = 1_000.0
+
+ app = typer.Typer()
+
+ @app.command()
+ def main():
+ bar_out = io.StringIO()
+ with progressbar(
+ ["a", "b"],
+ label="ETA",
+ file=bar_out,
+ show_pos=True,
+ show_percent=False,
+ show_eta=True,
+ bar_template="%(label)s %(info)s",
+ width=1,
+ ) as bar:
+ for i, _ in enumerate(bar):
+ if i == 0:
+ clock[0] = 1_001.0
+ typer.echo(bar_out.getvalue())
+
+ result = runner.invoke(app, [])
+ assert result.exit_code == 0, result.output
+ for part in ("ETA", "1/2", "00:00:01"):
+ assert part in result.stdout
+
+
+def test_progressbar_autowidth(monkeypatch):
+ monkeypatch.setattr(_termui_impl, "isatty", lambda f: True)
+ call = [0]
+ real_get_terminal_size = shutil.get_terminal_size
+
+ def fake_get_terminal_size(*args, **kwargs):
+ # Pytest (and others) call get_terminal_size(fallback=...); only stub no-arg calls
+ if args or kwargs:
+ return real_get_terminal_size(*args, **kwargs)
+ col = 120 if call[0] == 0 else 40
+ call[0] += 1
+ return type("TS", (), {"columns": col, "lines": 24})()
+
+ monkeypatch.setattr(shutil, "get_terminal_size", fake_get_terminal_size)
+
+ state: dict[str, object] = {}
+
+ app = typer.Typer()
+
+ @app.command()
+ def main():
+ out = io.StringIO()
+ with progressbar(["a", "b"], width=0, label="AW", file=out) as bar:
+ state["autowidth"] = bar.autowidth
+ for _ in bar:
+ pass
+ state["call_count"] = call[0]
+ state["out"] = out.getvalue()
+ typer.echo("done")
+
+ result = runner.invoke(app, [])
+ assert result.exit_code == 0, result.output
+ assert state["autowidth"] is True
+ assert state["call_count"] >= 2
+ out = str(state["out"])
+ assert "\r" in out and "AW" in out
+ assert "0%" in out and "50%" in out and "100%" in out
+
+
+def test_progress_bar_iter():
+ not_entered = _pbar(iterable=[1, 2], length=2)
+ with pytest.raises(RuntimeError, match="with block"):
+ iter(not_entered)
+
+ entered = _pbar(iterable=[10, 20], length=2)
+ with entered:
+ iterator = iter(entered)
+ assert next(iterator) == 10
+ assert next(entered) == 20
+ with pytest.raises(StopIteration):
+ next(iterator)
+
+
+def test_progress_bar_time(monkeypatch):
+ clock = _fake_clock(monkeypatch)
+ state: dict[str, object] = {}
+ clock[0] = 1_000.0
+
+ app = typer.Typer()
+
+ @app.command()
+ def main():
+ bar = _pbar(iterable=None, length=10)
+ state["tpi0"] = bar.time_per_iteration
+ clock[0] = 1_000.5
+ bar.make_step(1)
+ state["avg_after_one"] = list(bar.avg)
+ state["tpi_after_one"] = bar.time_per_iteration
+ clock[0] = 1_001.0
+ bar.make_step(1)
+ state["pos2"] = bar.pos
+ state["avg2"] = list(bar.avg)
+ state["tpi2"] = bar.time_per_iteration
+ clock[0] = 1_002.0
+ bar.make_step(1)
+ state["avg3"] = list(bar.avg)
+ state["tpi3"] = bar.time_per_iteration
+ typer.echo("ok")
+
+ result = runner.invoke(app, [])
+ assert result.exit_code == 0, result.output
+ assert state["tpi0"] == 0.0
+ assert state["avg_after_one"] == [] and state["tpi_after_one"] == 0.0
+ assert state["pos2"] == 2
+ assert state["avg2"] == [(1_001.0 - 1_000.0) / 2.0]
+ assert state["tpi2"] == pytest.approx(0.5)
+ assert state["avg3"] == [0.5, (1_002.0 - 1_000.0) / 3.0]
+ assert state["tpi3"] == pytest.approx(sum(state["avg3"]) / 2.0) # type: ignore[arg-type]
+
+
+def test_progress_bar_time_zero_steps(monkeypatch):
+ clock = _fake_clock(monkeypatch)
+ state: dict[str, object] = {}
+ clock[0] = 2_000.0
+
+ app = typer.Typer()
+
+ @app.command()
+ def main():
+ bar = _pbar(iterable=None, length=3)
+ clock[0] = 2_001.0
+ bar.make_step(0)
+ state["pos"] = bar.pos
+ state["avg"] = list(bar.avg)
+ state["tpi"] = bar.time_per_iteration
+ typer.echo("ok")
+
+ result = runner.invoke(app, [])
+ assert result.exit_code == 0, result.output
+ assert state["pos"] == 0
+ assert state["avg"] == [1.0]
+ assert state["tpi"] == pytest.approx(1.0)
+
+
+def test_progress_bar_eta(monkeypatch):
+ state: dict[str, object] = {}
+
+ app = typer.Typer()
+
+ @app.command()
+ def main():
+ state["eta0"] = _pbar(iterable=[1, 2], length=None).eta
+
+ done = _pbar(iterable=None, length=5)
+ done.pos, done.finished = 2, True
+ state["eta_done"] = done.eta
+
+ fresh = _pbar(iterable=None, length=5)
+ state["eta_known_fresh"] = fresh.eta_known
+ state["fmt_eta_fresh"] = fresh.format_eta()
+
+ clock = _fake_clock(monkeypatch)
+ clock[0] = 5_000.0
+ bar = _pbar(iterable=None, length=10)
+ clock[0] = 5_001.0
+ bar.make_step(3)
+ state["pos3"] = bar.pos
+ state["eta_after"] = bar.eta
+ state["tpi"] = bar.time_per_iteration
+
+ cases_out = []
+ for t0, t1, length, n_steps, _expected_fmt, expected_eta_int in (
+ (9_000.0, 9_001.0, 10, 1, "00:00:09", None),
+ (1_000.0, 100_000.0, 2, 1, "1d 03:30:00", 99_000),
+ ):
+ clock2 = _fake_clock(monkeypatch)
+ clock2[0] = t0
+ b = _pbar(iterable=None, length=length)
+ clock2[0] = t1
+ b.make_step(n_steps)
+ cases_out.append(
+ (
+ b.eta_known,
+ b.format_eta(),
+ int(b.eta) if expected_eta_int is not None else None,
+ )
+ )
+
+ state["cases"] = cases_out
+
+ clock3 = _fake_clock(monkeypatch)
+ clock3[0] = 3_000.0
+ b2 = _pbar(iterable=None, length=2)
+ clock3[0] = 3_001.0
+ b2.make_step(1)
+ state["fmt_before_finish"] = b2.format_eta()
+ b2.finish()
+ state["fmt_after_finish"] = b2.format_eta()
+ typer.echo("ok")
+
+ result = runner.invoke(app, [])
+ assert result.exit_code == 0, result.output
+ assert state["eta0"] == 0.0
+ assert state["eta_done"] == 0.0
+ assert not state["eta_known_fresh"] and state["fmt_eta_fresh"] == ""
+ assert state["pos3"] == 3
+ assert state["eta_after"] == pytest.approx(state["tpi"] * (10 - 3)) # type: ignore[operator]
+
+ (ek1, fmt1, ei1), (ek2, fmt2, ei2) = state["cases"] # type: ignore[misc]
+ assert ek1 and fmt1 == "00:00:09" and ei1 is None
+ assert ek2 and fmt2 == "1d 03:30:00" and ei2 == 99_000
+
+ assert state["fmt_before_finish"] != ""
+ assert state["fmt_after_finish"] == ""
+
+
+@pytest.mark.parametrize(
+ ("width", "fill_char", "empty_char", "expected_bar", "finished", "sample_timing"),
+ [
+ pytest.param(4, "X", "-", "XXXX", True, False, id="finished"),
+ pytest.param(4, "#", "-", "----", False, False, id="no_timing_yet"),
+ pytest.param(5, "*", ".", None, False, True, id="indeterminate"),
+ ],
+)
+def test_progress_bar_unknown_length(
+ monkeypatch,
+ width: int,
+ fill_char: str,
+ empty_char: str,
+ expected_bar: str | None,
+ finished: bool,
+ sample_timing: bool,
+):
+ clock: list[float] | None = _fake_clock(monkeypatch) if sample_timing else None
+ if clock is not None:
+ clock[0] = 100.0
+
+ state: dict[str, object] = {}
+
+ class _IterableWithoutLength:
+ def __iter__(self):
+ return iter((1, 2, 3))
+
+ app = typer.Typer()
+
+ @app.command()
+ def main():
+ bar = _pbar(
+ iterable=_IterableWithoutLength(),
+ length=None,
+ width=width,
+ fill_char=fill_char,
+ empty_char=empty_char,
+ )
+ assert bar.length is None
+
+ if sample_timing:
+ assert clock is not None
+ clock[0] = 101.0
+ bar.make_step(1)
+ assert bar.time_per_iteration > 0
+ rendered = bar.format_bar()
+ assert len(rendered) == width
+ assert rendered.count(fill_char) == 1
+ assert rendered.count(empty_char) == width - 1
+ state["branch"] = "sample_timing"
+ elif finished:
+ bar.finished = True
+ state["bar"] = bar.format_bar()
+ state["branch"] = "finished"
+ else:
+ assert bar.time_per_iteration == 0.0
+ state["bar"] = bar.format_bar()
+ state["branch"] = "no_timing"
+ typer.echo("ok")
+
+ result = runner.invoke(app, [])
+ assert result.exit_code == 0, result.output
+ if sample_timing:
+ assert state["branch"] == "sample_timing"
+ else:
+ assert state["bar"] == expected_bar
diff --git a/tests/test_termui.py b/tests/test_termui.py
new file mode 100644
index 0000000000..f6e04199a7
--- /dev/null
+++ b/tests/test_termui.py
@@ -0,0 +1,458 @@
+"""
+Tests for the termui, echo, and CliRunner isolation functionality.
+Created after vendoring Click to ensure test coverage is back up to 100%.
+"""
+
+import io
+import os
+from contextlib import contextmanager
+from typing import Literal
+
+import pytest
+import typer
+from typer._click import _termui_impl, termui
+from typer.testing import CliRunner
+
+from tests.utils import needs_windows, skip_if_windows
+
+
+def test_raw_terminal(monkeypatch):
+ runner = CliRunner()
+ app = typer.Typer()
+ state = {"entered": 0, "exited": 0}
+
+ @contextmanager
+ def fake_raw_terminal():
+ state["entered"] += 1
+ try:
+ yield 42
+ finally:
+ state["exited"] += 1
+
+ monkeypatch.setattr(_termui_impl, "raw_terminal", fake_raw_terminal)
+
+ @app.command()
+ def main():
+ with termui.raw_terminal() as fd:
+ typer.echo(f"fd={fd}")
+
+ result = runner.invoke(app, [])
+ assert result.exit_code == 0, result.output
+ assert "fd=42" in result.stdout
+ assert state["entered"] == 1
+ assert state["exited"] == 1
+
+
+def test_getchar(monkeypatch):
+ # Cached path: call the existing _getchar directly.
+ cached_state = {"echo": None}
+
+ def cached_getchar(echo: bool) -> str:
+ cached_state["echo"] = echo
+ return "x"
+
+ monkeypatch.setattr(termui, "_getchar", cached_getchar)
+ assert termui.getchar(echo=True) == "x"
+ assert cached_state["echo"] is True
+
+ # Lazy-load path: _getchar is None, so import/cache _termui_impl.getchar.
+ lazy_state = {"calls": 0}
+
+ def lazy_getchar(echo: bool) -> str:
+ lazy_state["calls"] += 1
+ return "y" if not echo else "z"
+
+ monkeypatch.setattr(termui, "_getchar", None)
+ monkeypatch.setattr(_termui_impl, "getchar", lazy_getchar)
+
+ assert termui.getchar(echo=False) == "y"
+ assert termui._getchar is lazy_getchar
+ assert termui.getchar(echo=True) == "z"
+ assert lazy_state["calls"] == 2
+
+
+def test_clirunner_getchar(monkeypatch) -> None:
+ runner = CliRunner()
+ app = typer.Typer()
+
+ @app.command()
+ def main() -> None:
+ first = termui.getchar(echo=False)
+ second = termui.getchar(echo=True)
+ typer.echo(f"\nfirst={first};second={second}")
+
+ monkeypatch.setattr(termui, "_getchar", None)
+ result = runner.invoke(app, [], input="ab")
+ assert result.exit_code == 0, result.output
+ assert result.stdout.splitlines() == ["b", "first=a;second=b"]
+
+
+def test_clirunner_env_none(monkeypatch) -> None:
+ runner = CliRunner()
+ app = typer.Typer()
+ env_key = "TYPER_TEST_ENV_REMOVE"
+ monkeypatch.setenv(env_key, "present")
+
+ @app.command()
+ def main() -> None:
+ typer.echo(f"inside={os.environ.get(env_key)}")
+
+ result = runner.invoke(app, [], env={env_key: None})
+ assert result.exit_code == 0, result.output
+ assert "inside=None" in result.stdout
+ assert os.environ.get(env_key) == "present"
+
+
+@pytest.mark.parametrize(
+ ("runner_exc", "invoke_exc"),
+ [
+ (False, None),
+ (True, False),
+ ],
+)
+def test_clirunner_invoke_catch_exceptions(
+ runner_exc: bool, invoke_exc: bool | None
+) -> None:
+ runner = CliRunner(catch_exceptions=runner_exc)
+ app = typer.Typer()
+
+ @app.command()
+ def main() -> None:
+ raise RuntimeError("boom")
+
+ with pytest.raises(RuntimeError, match="boom"):
+ runner.invoke(app, [], catch_exceptions=invoke_exc)
+
+
+@pytest.mark.parametrize(
+ ("exit_value", "expected_exit_code", "expected_stdout"),
+ [
+ (None, 0, ""),
+ ("bad-exit", 1, "bad-exit\n"),
+ ],
+)
+def test_clirunner_invoke_system_exit_branches(
+ exit_value: object,
+ expected_exit_code: int,
+ expected_stdout: str,
+) -> None:
+ runner = CliRunner()
+ app = typer.Typer()
+
+ @app.command()
+ def main() -> None:
+ raise SystemExit(exit_value)
+
+ result = runner.invoke(app, [])
+ assert result.exit_code == expected_exit_code
+ assert result.stdout == expected_stdout
+ if expected_exit_code:
+ assert isinstance(result.exception, SystemExit)
+ else:
+ assert result.exception is None
+
+
+@needs_windows
+def test_termui_impl_windows_raw_terminal():
+ with _termui_impl.raw_terminal() as fd:
+ assert fd == -1
+ with termui.raw_terminal() as fd:
+ assert fd == -1
+
+
+@needs_windows
+def test_termui_impl_windows_getchar(monkeypatch):
+ monkeypatch.setattr(_termui_impl.msvcrt, "getwch", lambda: "a")
+ monkeypatch.setattr(_termui_impl.msvcrt, "getwche", lambda: "b")
+ assert _termui_impl.getchar(echo=False) == "a"
+ assert _termui_impl.getchar(echo=True) == "b"
+
+ seq_null = iter(["\x00", "K"])
+ monkeypatch.setattr(_termui_impl.msvcrt, "getwch", lambda: next(seq_null))
+ assert _termui_impl.getchar(echo=False) == "\x00K"
+
+ seq_e0 = iter(["\xe0", "H"])
+ monkeypatch.setattr(_termui_impl.msvcrt, "getwch", lambda: next(seq_e0))
+ assert _termui_impl.getchar(echo=False) == "\xe0H"
+
+ seq_echo = iter(["\x00", "M"])
+ monkeypatch.setattr(_termui_impl.msvcrt, "getwche", lambda: next(seq_echo))
+ assert _termui_impl.getchar(echo=True) == "\x00M"
+
+ seq_e0_echo = iter(["\xe0", "Z"])
+ monkeypatch.setattr(_termui_impl.msvcrt, "getwche", lambda: next(seq_e0_echo))
+ assert _termui_impl.getchar(echo=True) == "\xe0Z"
+
+ monkeypatch.setattr(_termui_impl.msvcrt, "getwch", lambda: "\x03")
+ with pytest.raises(KeyboardInterrupt):
+ _termui_impl.getchar(echo=False)
+
+ monkeypatch.setattr(_termui_impl.msvcrt, "getwch", lambda: "\x1a")
+ with pytest.raises(EOFError):
+ _termui_impl.getchar(echo=False)
+
+
+@skip_if_windows
+@pytest.mark.parametrize("use_stdin_tty", [True, False])
+def test_termui_impl_posix_raw_terminal(monkeypatch, use_stdin_tty: bool):
+ state: dict[str, object] = {}
+ flushed: list[None] = []
+ fake_tty = None
+
+ if use_stdin_tty:
+ expected_fd = 14
+ old_termios = "old_settings"
+ monkeypatch.setattr(
+ _termui_impl, "isatty", lambda s: s is _termui_impl.sys.stdin
+ )
+ monkeypatch.setattr(_termui_impl.sys.stdin, "fileno", lambda: expected_fd)
+ else:
+ expected_fd = 27
+ old_termios = "old"
+ monkeypatch.setattr(
+ _termui_impl,
+ "isatty",
+ lambda s: s is not _termui_impl.sys.stdin,
+ )
+
+ class FakeTTY:
+ def __init__(self) -> None:
+ self.closed = False
+
+ def fileno(self) -> int:
+ return expected_fd
+
+ def close(self) -> None:
+ self.closed = True
+
+ fake_tty = FakeTTY()
+ real_open = open
+
+ def fake_open(path, *args, **kwargs):
+ if path == "/dev/tty":
+ return fake_tty
+ return real_open(path, *args, **kwargs) # pragma: no cover
+
+ monkeypatch.setattr("builtins.open", fake_open)
+
+ def tcgetattr(fd: int) -> str:
+ state["tcgetattr_fd"] = fd
+ return old_termios
+
+ def setraw(fd: int) -> None:
+ state["setraw_fd"] = fd
+
+ def tcsetattr(fd: int, when: int, old: str) -> None:
+ state["tcsetattr"] = (fd, when, old)
+
+ monkeypatch.setattr(_termui_impl.termios, "tcgetattr", tcgetattr)
+ monkeypatch.setattr(_termui_impl.tty, "setraw", setraw)
+ monkeypatch.setattr(_termui_impl.termios, "tcsetattr", tcsetattr)
+ monkeypatch.setattr(
+ _termui_impl.sys.stdout, "flush", lambda *a, **k: flushed.append(None)
+ )
+
+ with _termui_impl.raw_terminal() as fd:
+ assert fd == expected_fd
+
+ assert state["tcgetattr_fd"] == expected_fd
+ assert state["setraw_fd"] == expected_fd
+ assert state["tcsetattr"] == (
+ expected_fd,
+ _termui_impl.termios.TCSADRAIN,
+ old_termios,
+ )
+ assert flushed == [None]
+ if fake_tty is not None:
+ assert fake_tty.closed is True
+
+
+@skip_if_windows
+def test_termui_impl_posix_getchar(monkeypatch):
+ @contextmanager
+ def fake_raw():
+ yield 7
+
+ monkeypatch.setattr(_termui_impl, "raw_terminal", fake_raw)
+ monkeypatch.setattr(_termui_impl.os, "read", lambda fd, n: b"q")
+ monkeypatch.setattr(_termui_impl, "get_best_encoding", lambda stdin: "utf-8")
+ monkeypatch.setattr(_termui_impl, "isatty", lambda f: f is _termui_impl.sys.stdout)
+ written: list[str] = []
+ monkeypatch.setattr(_termui_impl.sys.stdout, "write", lambda s: written.append(s))
+
+ assert _termui_impl.getchar(echo=True) == "q"
+ assert written == ["q"]
+
+
+@skip_if_windows
+def test_termui_impl_posix_getchar_eof(monkeypatch):
+ @contextmanager
+ def fake_raw():
+ yield 5
+
+ monkeypatch.setattr(_termui_impl, "raw_terminal", fake_raw)
+ monkeypatch.setattr(_termui_impl.os, "read", lambda fd, n: b"\x04")
+ monkeypatch.setattr(_termui_impl, "get_best_encoding", lambda stdin: "utf-8")
+ monkeypatch.setattr(_termui_impl, "isatty", lambda f: False)
+
+ with pytest.raises(EOFError):
+ _termui_impl.getchar(echo=False)
+
+
+def test_prompt():
+ runner = CliRunner()
+ app = typer.Typer()
+ fake_file = io.StringIO("data")
+ fake_file.name = "demo.txt"
+
+ @app.command()
+ def main(
+ accept: bool = typer.Option(True, prompt=True),
+ name: str = typer.Option(..., prompt=True),
+ flavor: Literal["a", "b"] = typer.Option(..., prompt=True),
+ city: str = typer.Option("London", prompt=True),
+ config: str = typer.Option(fake_file, prompt=True),
+ password: str = typer.Option(
+ ...,
+ prompt=True,
+ hide_input=True,
+ confirmation_prompt=True,
+ ),
+ ):
+ typer.echo(
+ f"accept={accept};name={name};flavor={flavor};city={city};config={config};pass_len={len(password)}"
+ )
+
+ result = runner.invoke(app, [], input="\nAda\na\n\ncustom.ini\nsecret\nsecret\n")
+ assert result.exit_code == 0, result.output
+ assert (
+ "accept=True;name=Ada;flavor=a;city=London;config=custom.ini;pass_len=6"
+ in result.stdout
+ )
+ assert "(a, b): " in result.stdout
+ assert "[demo.txt]: " in result.stdout
+
+
+def test_hidden_prompt_func(monkeypatch):
+ monkeypatch.setattr("getpass.getpass", lambda prompt: "secret")
+ assert termui.hidden_prompt_func("Password: ") == "secret"
+
+
+def test_echo_stdout_missing(monkeypatch: pytest.MonkeyPatch) -> None:
+ monkeypatch.setattr("sys.stdout", None)
+ typer.echo("ignored")
+
+
+def test_echo_stringifies() -> None:
+ stream = io.StringIO()
+ typer.echo(123, file=stream, nl=False)
+ assert stream.getvalue() == "123"
+
+
+def test_echo_bytes() -> None:
+ buffer = io.BytesIO()
+ stream = io.TextIOWrapper(buffer, encoding="utf-8")
+ typer.echo(b"abc", file=stream, nl=True)
+ assert buffer.getvalue() == b"abc\n"
+
+
+def test_echo_empty_output() -> None:
+ class FlushTrackingTextStream(io.StringIO):
+ def __init__(self) -> None:
+ super().__init__()
+ self.flush_count = 0
+
+ def flush(self) -> None:
+ self.flush_count += 1
+ super().flush()
+
+ def write(self, s: str) -> int:
+ raise AssertionError("Empty output") # pragma: no cover
+
+ stream = FlushTrackingTextStream()
+ typer.echo("", file=stream, nl=False)
+ assert stream.flush_count == 1
+
+
+@needs_windows
+def test_echo_windows_color_none(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ class TtyStream(io.StringIO):
+ def isatty(self) -> bool:
+ return True
+
+ stream = TtyStream()
+ monkeypatch.setattr("typer._click.utils.auto_wrap_for_ansi", None)
+ typer.echo("\x1b[31mred\x1b[0m", file=stream, nl=False, color=None)
+ assert stream.getvalue() == "red"
+
+
+@pytest.mark.parametrize(
+ ("flag", "true_code", "false_code"),
+ [
+ ("bold", "\x1b[1m", "\x1b[22m"),
+ ("dim", "\x1b[2m", "\x1b[22m"),
+ ("underline", "\x1b[4m", "\x1b[24m"),
+ ("overline", "\x1b[53m", "\x1b[55m"),
+ ("italic", "\x1b[3m", "\x1b[23m"),
+ ("blink", "\x1b[5m", "\x1b[25m"),
+ ("reverse", "\x1b[7m", "\x1b[27m"),
+ ("strikethrough", "\x1b[9m", "\x1b[29m"),
+ ],
+)
+def test_style(flag, true_code, false_code):
+ runner = CliRunner()
+ app = typer.Typer()
+
+ @app.command()
+ def main():
+ # testing an int and a str on purpose
+ typer.echo("TRUE=" + typer.style("42", **{flag: True}), color=True)
+ typer.echo("FALSE=" + typer.style(666, **{flag: False}), color=True)
+
+ result = runner.invoke(app, [])
+ assert result.exit_code == 0, result.output
+ lines = [line for line in result.stdout.splitlines() if line]
+ true_line = next(line for line in lines if line.startswith("TRUE="))
+ false_line = next(line for line in lines if line.startswith("FALSE="))
+ assert "42" in true_line
+ assert "666" in false_line
+
+ assert true_code in true_line
+ assert true_code not in false_line
+ assert false_code in false_line
+ assert false_code not in true_line
+
+
+def test_style_color():
+ fg_int = typer.style("x", fg=123)
+ assert "\x1b[38;5;123m" in fg_int
+
+ bg_list = typer.style("x", bg=[1, 2, 3])
+ assert "\x1b[48;2;1;2;3m" in bg_list
+
+ with pytest.raises(TypeError, match="Unknown color"):
+ typer.style("x", fg="not-a-color")
+
+ with pytest.raises(TypeError, match="Unknown color"):
+ typer.style("x", bg="not-a-color")
+
+
+def test_termui_launch(monkeypatch):
+ captured = {}
+
+ def fake_open_url(url, wait=False, locate=False):
+ captured["url"] = url
+ captured["wait"] = wait
+ captured["locate"] = locate
+ return 7
+
+ monkeypatch.setattr(_termui_impl, "open_url", fake_open_url)
+ rv = termui.launch("https://example.com", wait=True, locate=True)
+ assert rv == 7
+ assert captured == {
+ "url": "https://example.com",
+ "wait": True,
+ "locate": True,
+ }
diff --git a/tests/test_tutorial/test_parameter_types/test_bool/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_bool/test_tutorial001.py
index 6941a35c37..32c07214c0 100644
--- a/tests/test_tutorial/test_parameter_types/test_bool/test_tutorial001.py
+++ b/tests/test_tutorial/test_parameter_types/test_bool/test_tutorial001.py
@@ -4,6 +4,7 @@
from types import ModuleType
import pytest
+import typer
from typer.testing import CliRunner
runner = CliRunner()
@@ -22,6 +23,12 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType:
return mod
+def test_type_repr(mod: ModuleType):
+ command = typer.main.get_command(mod.app)
+ force_param = next(param for param in command.params if param.name == "force")
+ assert repr(force_param.type) == "BOOL"
+
+
def test_help(mod: ModuleType):
result = runner.invoke(mod.app, ["--help"])
assert result.exit_code == 0
diff --git a/tests/test_tutorial/test_parameter_types/test_datetime/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_datetime/test_tutorial001.py
index 9dfcdf19e3..eea63c5a8b 100644
--- a/tests/test_tutorial/test_parameter_types/test_datetime/test_tutorial001.py
+++ b/tests/test_tutorial/test_parameter_types/test_datetime/test_tutorial001.py
@@ -1,6 +1,8 @@
import subprocess
import sys
+from datetime import datetime
+import typer
from typer.testing import CliRunner
from docs_src.parameter_types.datetime import tutorial001_py310 as mod
@@ -9,6 +11,12 @@
app = mod.app
+def test_type_repr():
+ command = typer.main.get_command(app)
+ birth_param = next(param for param in command.params if param.name == "birth")
+ assert repr(birth_param.type) == "DateTime"
+
+
def test_help():
result = runner.invoke(app, ["--help"])
assert result.exit_code == 0
@@ -22,6 +30,15 @@ def test_main():
assert "Birth hour: 10" in result.output
+def test_main_datetime_object():
+ result = runner.invoke(
+ app, [], default_map={"birth": datetime(1956, 1, 31, 10, 0, 0)}
+ )
+ assert result.exit_code == 0
+ assert "Interesting day to be born: 1956-01-31 10:00:00" in result.output
+ assert "Birth hour: 10" in result.output
+
+
def test_invalid():
result = runner.invoke(app, ["july-19-1989"])
assert result.exit_code != 0
diff --git a/tests/test_tutorial/test_parameter_types/test_index/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_index/test_tutorial001.py
index 4b2f422fbb..ee0daa9f06 100644
--- a/tests/test_tutorial/test_parameter_types/test_index/test_tutorial001.py
+++ b/tests/test_tutorial/test_parameter_types/test_index/test_tutorial001.py
@@ -1,6 +1,7 @@
import subprocess
import sys
+import typer
from typer.testing import CliRunner
from docs_src.parameter_types.index import tutorial001_py310 as mod
@@ -9,6 +10,16 @@
app = mod.app
+def test_type_repr():
+ command = typer.main.get_command(app)
+ age_param = next(param for param in command.params if param.name == "age")
+ height_meters_param = next(
+ param for param in command.params if param.name == "height_meters"
+ )
+ assert repr(age_param.type) == "INT"
+ assert repr(height_meters_param.type) == "FLOAT"
+
+
def test_help():
result = runner.invoke(app, ["--help"])
assert result.exit_code == 0
diff --git a/tests/test_tutorial/test_parameter_types/test_number/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_number/test_tutorial001.py
index 6b9445a97f..ffe3ad09a0 100644
--- a/tests/test_tutorial/test_parameter_types/test_number/test_tutorial001.py
+++ b/tests/test_tutorial/test_parameter_types/test_number/test_tutorial001.py
@@ -24,6 +24,19 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType:
return mod
+def test_type_repr(mod: ModuleType):
+ command = typer.main.get_command(mod.app)
+
+ id_param = next(param for param in command.params if param.name == "id")
+ assert repr(id_param.type) == ""
+
+ age_param = next(param for param in command.params if param.name == "age")
+ assert repr(age_param.type) == "=18>"
+
+ score_param = next(param for param in command.params if param.name == "score")
+ assert repr(score_param.type) == ""
+
+
def test_help(mod: ModuleType):
result = runner.invoke(mod.app, ["--help"])
assert result.exit_code == 0
diff --git a/tests/test_tutorial/test_parameter_types/test_uuid/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_uuid/test_tutorial001.py
index cad8c69cc4..7b79e81405 100644
--- a/tests/test_tutorial/test_parameter_types/test_uuid/test_tutorial001.py
+++ b/tests/test_tutorial/test_parameter_types/test_uuid/test_tutorial001.py
@@ -1,6 +1,8 @@
import subprocess
import sys
+import uuid
+import typer
from typer.testing import CliRunner
from docs_src.parameter_types.uuid import tutorial001_py310 as mod
@@ -9,6 +11,12 @@
app = mod.app
+def test_type_repr():
+ command = typer.main.get_command(app)
+ user_id_param = next(param for param in command.params if param.name == "user_id")
+ assert repr(user_id_param.type) == "UUID"
+
+
def test_main():
result = runner.invoke(app, ["d48edaa6-871a-4082-a196-4daab372d4a1"])
assert result.exit_code == 0
@@ -16,6 +24,14 @@ def test_main():
assert "UUID version is: 4" in result.output
+def test_main_with_uuid_object():
+ user_id = uuid.UUID("d48edaa6-871a-4082-a196-4daab372d4a1")
+ result = runner.invoke(app, [], default_map={"user_id": user_id})
+ assert result.exit_code == 0
+ assert "USER_ID is d48edaa6-871a-4082-a196-4daab372d4a1" in result.output
+ assert "UUID version is: 4" in result.output
+
+
def test_invalid_uuid():
result = runner.invoke(app, ["7479706572-72756c6573"])
assert result.exit_code != 0
diff --git a/tests/test_type_conversion.py b/tests/test_type_conversion.py
index 83edd1ecb5..26709db0e2 100644
--- a/tests/test_type_conversion.py
+++ b/tests/test_type_conversion.py
@@ -1,12 +1,15 @@
+import os
from enum import Enum
from pathlib import Path
from typing import Any
-import click
import pytest
import typer
+from typer import _click, models
from typer.testing import CliRunner
+from tests.utils import needs_linux, needs_windows
+
runner = CliRunner()
@@ -133,6 +136,18 @@ def tuple_recursive_conversion(container: type_annotation):
assert result.exit_code == 0
+def test_tuple_wrong_arity():
+ app = typer.Typer()
+
+ @app.command()
+ def tuple_arity(value: tuple[str, str] = typer.Option(...)):
+ print(value) # pragma: no cover
+
+ result = runner.invoke(app, [], default_map={"value": ("only-one",)})
+ assert result.exit_code == 2
+ assert "2 values are required, but 1 given." in result.output
+
+
def test_custom_parse():
app = typer.Typer()
@@ -146,15 +161,29 @@ def custom_parser(
assert result.exit_code == 0
+def test_custom_parse_value_error():
+ app = typer.Typer()
+
+ @app.command()
+ def custom_parser(
+ hex_value: int = typer.Argument(None, parser=lambda x: int(x, 0)),
+ ):
+ print(hex_value) # pragma: no cover
+
+ result = runner.invoke(app, ["not-a-hex"])
+ assert result.exit_code == 2
+ assert "Invalid value" in result.output
+
+
def test_custom_click_type():
- class BaseNumberParamType(click.ParamType):
+ class BaseNumberParamType(_click.types.ParamType):
name = "base_integer"
def convert(
self,
value: Any,
- param: click.Parameter | None,
- ctx: click.Context | None,
+ param: _click.Parameter | None,
+ ctx: _click.Context | None,
) -> Any:
return int(value, 0)
@@ -168,3 +197,204 @@ def custom_click_type(
result = runner.invoke(app, ["0x56"])
assert result.exit_code == 0
+
+
+def test_int_range_open_bound_clamp():
+ app = typer.Typer()
+
+ @app.command()
+ def custom_click_type(
+ value: int = typer.Argument(
+ ...,
+ click_type=_click.types.IntRange(min=1, min_open=True, clamp=True),
+ ),
+ ):
+ print(value)
+
+ result = runner.invoke(app, ["1"])
+ assert result.exit_code == 0
+ assert "2" in result.output
+
+
+def test_bool_convert_invalid():
+ app = typer.Typer()
+
+ @app.command()
+ def main(value: bool):
+ print(value) # pragma: no cover
+
+ result = runner.invoke(app, ["maybe"])
+ assert result.exit_code == 2
+ assert "is not a valid boolean" in result.output
+ assert "yes" in result.output
+ assert "false" in result.output
+
+
+@pytest.mark.parametrize(
+ ("arg_enc", "system_enc", "raw_value", "expected_output"),
+ [
+ pytest.param("latin-1", "utf-8", b"\xff", "ÿ"),
+ pytest.param("ascii", "latin-1", b"\xff", "ÿ"),
+ pytest.param("ascii", "utf-16", b"\xff", "�"),
+ pytest.param("ascii", "ascii", b"\xff", "�"),
+ ],
+)
+def test_string_param_type_converts_bytes(
+ monkeypatch: pytest.MonkeyPatch,
+ arg_enc: str,
+ system_enc: str,
+ raw_value: bytes,
+ expected_output: str,
+):
+ app = typer.Typer()
+
+ @app.command()
+ def show(name: str = typer.Option(...)):
+ print(name)
+
+ command = typer.main.get_command(app)
+ name_param = next(param for param in command.params if param.name == "name")
+ assert repr(name_param.type) == "STRING"
+
+ monkeypatch.setattr(_click.types, "_get_argv_encoding", lambda: arg_enc)
+ monkeypatch.setattr(_click.types.sys, "getfilesystemencoding", lambda: system_enc)
+
+ result = runner.invoke(app, [], default_map={"name": raw_value})
+ assert result.exit_code == 0
+ assert expected_output in result.output
+
+
+@pytest.mark.parametrize("path_type", [str, bytes, Path])
+def test_path_coerced(path_type) -> None:
+ # Ensure coerce_path_result works correctly
+ app = typer.Typer()
+
+ @app.command()
+ def show(path: Any = typer.Option(..., path_type=path_type)):
+ print(path)
+
+ result = runner.invoke(app, ["--path", "dir/my_awesome_file.txt"])
+ assert result.exit_code == 0
+ assert "my_awesome_file" in result.output
+
+
+@pytest.mark.parametrize(
+ ("create_file", "option_kwargs", "deny_mode", "expected_error"),
+ [
+ (True, {"file_okay": False, "dir_okay": True}, None, "is a file"),
+ (False, {"file_okay": True, "dir_okay": False}, None, "is a directory"),
+ (True, {"readable": True}, os.R_OK, "is not readable"),
+ (True, {"readable": False, "writable": True}, os.W_OK, "is not writable"),
+ ],
+)
+def test_path_convert_failures(
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path: Path,
+ create_file: bool,
+ option_kwargs: dict[str, bool],
+ deny_mode: int | None,
+ expected_error: str,
+) -> None:
+ app = typer.Typer()
+
+ @app.command()
+ def show(path: Path = typer.Option(..., **option_kwargs)):
+ print(path) # pragma: no cover
+
+ if deny_mode is not None:
+ original_access = os.access
+
+ def fake_access(path: str, mode: int) -> bool:
+ if mode == deny_mode:
+ return False
+ return original_access(path, mode) # pragma: no cover
+
+ monkeypatch.setattr(models.os, "access", fake_access)
+
+ path = tmp_path / "some_path"
+ if create_file:
+ path.write_text("hello")
+ else:
+ path.mkdir()
+ result = runner.invoke(app, ["--path", str(path)])
+
+ assert result.exit_code != 0
+ assert expected_error in result.output
+
+
+def test_convert_type():
+ from typer._click.types import convert_type
+
+ # str
+ assert convert_type(str) is _click.types.STRING
+ assert convert_type(None) is _click.types.STRING
+ assert convert_type(None, default=["a"]) is _click.types.STRING
+
+ # tuples
+ tuple_type = convert_type((str, int))
+ assert isinstance(tuple_type, _click.types.Tuple)
+ assert [type(item) for item in tuple_type.types] == [
+ type(_click.types.STRING),
+ type(_click.types.INT),
+ ]
+
+ guessed_tuple = convert_type(None, default=[(1, "x")])
+ assert isinstance(guessed_tuple, _click.types.Tuple)
+ assert [type(item) for item in guessed_tuple.types] == [
+ type(_click.types.INT),
+ type(_click.types.STRING),
+ ]
+
+ # numbers
+ assert convert_type(int) is _click.types.INT
+ assert convert_type(float) is _click.types.FLOAT
+ assert convert_type(bool) is _click.types.BOOL
+
+ param_type = _click.types.IntRange(min=0, max=10)
+ assert convert_type(param_type) is param_type
+
+ guessed_int = convert_type(None, default=42)
+ assert guessed_int is _click.types.INT
+
+ # custom type
+ class CustomType:
+ pass
+
+ guessed_unknown = convert_type(None, default=CustomType())
+ assert guessed_unknown is _click.types.STRING
+
+ func_type = convert_type(CustomType)
+ assert isinstance(func_type, _click.types.FuncParamType)
+ assert func_type.name == "CustomType"
+
+
+@pytest.mark.parametrize(
+ ("platform_case", "stdin_encoding", "filesystem_encoding"),
+ [
+ pytest.param("windows", None, "utf-8", marks=needs_windows),
+ pytest.param("linux", "latin-1", "utf-8", marks=needs_linux),
+ pytest.param("linux", None, "latin-1", marks=needs_linux),
+ ],
+)
+def test_argv_encoding(
+ monkeypatch: pytest.MonkeyPatch,
+ platform_case: str,
+ stdin_encoding: str | None,
+ filesystem_encoding: str,
+) -> None:
+ sys = _click._compat.sys
+ if platform_case == "windows":
+ import locale
+
+ monkeypatch.setattr(locale, "getpreferredencoding", lambda: "latin-1")
+ else:
+
+ class FakeStdin:
+ def __init__(self, encoding: str | None) -> None:
+ self.encoding = encoding
+
+ monkeypatch.setattr(sys, "stdin", FakeStdin(stdin_encoding))
+ monkeypatch.setattr(sys, "getfilesystemencoding", lambda: filesystem_encoding)
+
+ converted = _click.types.STRING.convert(b"\xff", None, None)
+ assert converted == "ÿ"
diff --git a/tests/test_types.py b/tests/test_types.py
index adb100eb82..caeef451aa 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -1,6 +1,8 @@
from enum import Enum
+import pytest
import typer
+from typer import _click
from typer.testing import CliRunner
app = typer.Typer(context_settings={"token_normalize_func": str.lower})
@@ -12,23 +14,108 @@ class User(str, Enum):
@app.command()
-def hello(name: User = User.rick) -> None:
+def hello_option(name: User = User.rick) -> None:
print(f"Hello {name.value}!")
+@app.command()
+def hello_argument(name: User) -> None:
+ print(f"Hello {name.value}!")
+
+
+@app.command()
+def hello_no_choices(
+ name: User = typer.Option(..., "--name", show_choices=False),
+):
+ print(f"Hello {name.value}!")
+
+
+@app.command()
+def hello_all(names: list[str] = typer.Argument(["World"], envvar="NAMES")) -> None:
+ for name in names:
+ print(f"Hello {name}!")
+
+
+@app.command()
+def split_variadic_and_pair(items: list[str], pair: tuple[str, str]) -> None:
+ print(f"items={items}")
+ print(f"pair={pair}")
+
+
runner = CliRunner()
def test_enum_choice() -> None:
- # This test is only for coverage of the new custom TyperChoice class
- result = runner.invoke(app, ["--name", "morty"], catch_exceptions=False)
+ result = runner.invoke(
+ app, ["hello-option", "--name", "morty"], catch_exceptions=False
+ )
assert result.exit_code == 0
assert "Hello Morty!" in result.output
- result = runner.invoke(app, ["--name", "Rick"])
+ result = runner.invoke(app, ["hello-option", "--name", "Rick"])
+ assert result.exit_code == 0
+ assert "Hello Rick!" in result.output
+
+ result = runner.invoke(app, ["hello-option", "--name", "RICK"])
assert result.exit_code == 0
assert "Hello Rick!" in result.output
- result = runner.invoke(app, ["--name", "RICK"])
+ result = runner.invoke(app, ["hello-no-choices", "--name", "RICK"])
assert result.exit_code == 0
assert "Hello Rick!" in result.output
+
+ result = runner.invoke(app, ["hello-argument", "RICK"])
+ assert result.exit_code == 0
+ assert "Hello Rick!" in result.output
+
+
+def test_enum_choice_repr() -> None:
+ root_command = typer.main.get_command(app)
+ command = root_command.commands["hello-option"]
+ name_param = next(param for param in command.params if param.name == "name")
+ assert repr(name_param.type).startswith("Choice([")
+
+
+def test_enum_choice_help() -> None:
+ result = runner.invoke(app, ["hello-argument", "--help"])
+ assert result.exit_code == 0
+ assert "{rick|morty}" in result.output
+
+ result = runner.invoke(app, ["hello-option", "--help"])
+ assert result.exit_code == 0
+ assert "[rick|morty]" in result.output
+
+ result = runner.invoke(app, ["hello-no-choices", "--help"])
+ assert result.exit_code == 0
+ assert "--name" in result.output
+ assert "rick|morty" not in result.output
+
+
+def test_enum_choice_missing_message() -> None:
+ result = runner.invoke(app, ["hello-argument"])
+ assert result.exit_code != 0
+ assert "Missing argument" in result.output
+ assert "Choose from:" in result.output
+ assert "rick" in result.output
+ assert "morty" in result.output
+
+
+def test_split_envvar_value(monkeypatch) -> None:
+ # This will use split_envvar_value to produce two strings from the envvar
+ monkeypatch.setenv("NAMES", "Rick Morty")
+ result = runner.invoke(app, ["hello-all"])
+ assert result.exit_code == 0
+ assert "Hello Rick!" in result.output
+ assert "Hello Morty!" in result.output
+
+
+def test_list_pair() -> None:
+ result = runner.invoke(app, ["split-variadic-and-pair", "a", "b", "c", "x", "y"])
+ assert result.exit_code == 0
+ assert "items=['a', 'b', 'c']" in result.output
+ assert "pair=('x', 'y')" in result.output
+
+
+def test_float_range_open_bounds_with_clamp_not_allowed():
+ with pytest.raises(TypeError, match="Clamping is not supported for open bounds."):
+ _click.types.FloatRange(min=0.0, min_open=True, clamp=True)
diff --git a/tests/test_types_file.py b/tests/test_types_file.py
new file mode 100644
index 0000000000..61fd71c300
--- /dev/null
+++ b/tests/test_types_file.py
@@ -0,0 +1,395 @@
+import subprocess
+import sys
+from io import BytesIO, StringIO, TextIOWrapper
+from pathlib import Path
+
+import pytest
+import typer
+from typer._click._compat import get_best_encoding, should_strip_ansi
+from typer._click.testing import make_input_stream
+from typer._click.utils import PacifyFlushWrapper
+from typer.testing import CliRunner
+
+from tests.utils import needs_linux, needs_windows
+
+app = typer.Typer()
+
+
+@app.command()
+def read_text(file_in: typer.FileText = typer.Option(..., lazy=True)):
+ data = file_in.read()
+ typer.echo(f"text-len={len(data)}")
+
+
+@app.command()
+def write_text(file_out: typer.FileTextWrite = typer.Option(..., lazy=None)):
+ file_out.write("This is a single line\n")
+ typer.echo("1 line written")
+
+
+@app.command()
+def write_lazy(file_out: typer.FileTextWrite = typer.Option(..., lazy=True)):
+ file_out.write("This is a single lazy line\n")
+ typer.echo("1 line written")
+
+
+@app.command()
+def probe_lazy_file_behaviors(
+ file_in: typer.FileText = typer.Option(..., lazy=True),
+ file_out: typer.FileTextWrite = typer.Option(..., lazy=True),
+):
+ typer.echo(f"repr-before={repr(file_out)}")
+ file_out.write("repr-opened\n")
+ typer.echo(f"repr-after={repr(file_out)}")
+ with file_in as stream:
+ typer.echo(f"context-len={len(stream.read())}")
+ stream.seek(0)
+ first_line = next(iter(stream), "")
+ typer.echo(f"first-line={first_line.rstrip()}")
+
+
+@app.command()
+def write_binary(file_out: typer.FileBinaryWrite = typer.Option(...)):
+ file_out.write(b"binary-written\n")
+
+
+@app.command()
+def write_binary_stderr():
+ stream = typer.get_binary_stream("stderr")
+ stream.write(b"binary-stderr\n")
+ stream.flush()
+
+
+@app.command()
+def read_binary(file_in: typer.FileBinaryRead = typer.Option(...)):
+ data = file_in.read()
+ typer.echo(f"binary-len={len(data)}")
+
+
+runner = CliRunner()
+
+
+def test_text_stdin_dash() -> None:
+ result = runner.invoke(app, ["read-text", "--file-in=-"], input="hello\n")
+ assert result.exit_code == 0
+ assert "text-len=6" in result.output
+
+
+def test_lazy_file(tmp_path: Path) -> None:
+ # dash: written to stdout
+ result = runner.invoke(app, ["write-text", "--file-out=-"])
+ assert result.exit_code == 0
+ assert "This is a single line" in result.output
+ assert "1 line written" in result.output
+
+ # lazy + file
+ file_path = tmp_path / "example.txt"
+ result = runner.invoke(app, ["write-lazy", f"--file-out={file_path}"])
+ assert result.exit_code == 0
+ assert "This is a single lazy line" not in result.output
+ assert "1 line written" in result.output
+ assert file_path.exists()
+ assert file_path.read_text() == "This is a single lazy line\n"
+
+ # lazy probe: unopened/opened repr, context manager, and iteration.
+ result = runner.invoke(
+ app,
+ [
+ "probe-lazy-file-behaviors",
+ f"--file-in={file_path}",
+ f"--file-out={tmp_path / 'repr-opened.txt'}",
+ ],
+ )
+ assert result.exit_code == 0
+ assert "repr-before= None:
+ stream = StringIO()
+ result = runner.invoke(
+ app, ["write-text"], default_map={"write-text": {"file_out": stream}}
+ )
+ assert result.exit_code == 0
+ assert "1 line written" in result.output
+ assert stream.getvalue() == "This is a single line\n"
+
+
+def test_input_stream() -> None:
+ binary_stream = BytesIO(b"hello")
+ converted = make_input_stream(binary_stream, charset="utf-8")
+ assert converted is binary_stream
+
+ text_stream = TextIOWrapper(BytesIO(b"hello"), encoding="utf-8")
+ converted = make_input_stream(text_stream, charset="utf-8")
+ assert converted is text_stream.buffer
+
+
+def test_binary_dash() -> None:
+ result = runner.invoke(app, ["write-binary", "--file-out=-"])
+ assert result.exit_code == 0
+ assert result.stdout_bytes == b"binary-written\n"
+
+ result = runner.invoke(
+ app, ["read-binary", "--file-in=-"], input=b"\x00\x01\x02abc"
+ )
+ assert result.exit_code == 0
+ assert "binary-len=6" in result.output
+
+
+def test_binary_stderr() -> None:
+ result = subprocess.run(
+ [
+ sys.executable,
+ "-c",
+ "from tests.test_types_file import app; app()",
+ "write-binary-stderr",
+ ],
+ capture_output=True,
+ )
+ assert result.returncode == 0
+ assert result.stderr == b"binary-stderr\n"
+
+
+@pytest.mark.parametrize(
+ ("errors_arg", "expected_errors"),
+ [
+ (None, "replace"),
+ ("strict", "strict"),
+ ],
+)
+def test_get_text_stream_errors(
+ monkeypatch,
+ errors_arg: str | None,
+ expected_errors: str,
+) -> None:
+ class BinaryStdout(BytesIO):
+ pass
+
+ binary_stdout = BinaryStdout()
+ monkeypatch.setattr(sys, "stdout", binary_stdout)
+
+ text_stream = typer.get_text_stream("stdout", encoding=None, errors=errors_arg)
+ text_stream.write("stream-text")
+ text_stream.flush()
+
+ assert text_stream.errors == expected_errors
+ assert text_stream.writable() is True
+ assert binary_stdout.getvalue() == b"stream-text"
+
+
+def test_get_best_encoding() -> None:
+ """Test that ASCII is being transformed into UTF-8"""
+
+ class AsciiStream:
+ encoding = "ascii"
+
+ class Utf8Stream:
+ encoding = "utf-8"
+
+ class UnknownStream:
+ encoding = "unknown"
+
+ assert get_best_encoding(AsciiStream()) == "utf-8"
+ assert get_best_encoding(Utf8Stream()) == "utf-8"
+ assert get_best_encoding(UnknownStream()) == "unknown"
+
+
+def test_pacify_flush_wrapper() -> None:
+ class Wrapped:
+ def __init__(self) -> None:
+ self.name = "wrapped-stream"
+
+ def flush(self) -> None:
+ return None # pragma: no cover
+
+ wrapped = PacifyFlushWrapper(Wrapped())
+ assert wrapped.name == "wrapped-stream"
+
+
+def test_text_stream_isatty(monkeypatch) -> None:
+ class BinaryStdout(BytesIO):
+ def isatty(self) -> bool:
+ return True
+
+ binary_stdout = BinaryStdout()
+ monkeypatch.setattr(sys, "stdout", binary_stdout)
+ text_stream = typer.get_text_stream("stdout", encoding="utf-8", errors=None)
+ assert text_stream.isatty() is True
+
+
+def test_text_stream_buffer_read1(monkeypatch) -> None:
+ class BinaryStdinNoRead1:
+ def __init__(self, data: bytes) -> None:
+ self._data = data
+ self._pos = 0
+
+ def read(self, size: int = -1) -> bytes:
+ if size < 0:
+ size = len(self._data) - self._pos # pragma: no cover
+ chunk = self._data[self._pos : self._pos + size]
+ self._pos += len(chunk)
+ return chunk
+
+ binary_stdin = BinaryStdinNoRead1(b"hello")
+ monkeypatch.setattr(sys, "stdin", binary_stdin)
+ text_stream = typer.get_text_stream("stdin", encoding="utf-8", errors=None)
+ assert text_stream._stream.read1(4) == b"hell"
+
+
+def test_binary_stream(monkeypatch) -> None:
+ binary_stdin = BytesIO(b"hello")
+ binary_stdout = BytesIO()
+ monkeypatch.setattr(sys, "stdin", binary_stdin)
+ monkeypatch.setattr(sys, "stdout", binary_stdout)
+
+ assert typer.get_binary_stream("stdin") is binary_stdin
+ assert typer.get_binary_stream("stdout") is binary_stdout
+
+
+def test_binary_stream_raises(monkeypatch) -> None:
+ class TextOnlyStdin:
+ def read(self, n: int = -1) -> str:
+ return "hello"
+
+ monkeypatch.setattr(sys, "stdin", TextOnlyStdin())
+ with pytest.raises(RuntimeError, match="Was not able to determine binary stream"):
+ typer.get_binary_stream("stdin")
+
+
+def test_stream_unknown() -> None:
+ with pytest.raises(TypeError, match="Unknown standard stream 'Plumbus'"):
+ typer.get_binary_stream("Plumbus") # type: ignore[arg-type]
+
+ with pytest.raises(TypeError, match="Unknown standard stream 'Fleeb'"):
+ typer.get_text_stream("Fleeb") # type: ignore[arg-type]
+
+
+def test_format_filename() -> None:
+ filename = b"folder/subdir/demo.txt"
+ assert typer.format_filename(filename, shorten=True) == "demo.txt"
+
+
+def test_file_error(monkeypatch, tmp_path: Path) -> None:
+ file_path = tmp_path / "cannot-open.txt"
+
+ def fake_open(path, *args, **kwargs):
+ if Path(path) == file_path:
+ raise OSError()
+
+ monkeypatch.setattr("builtins.open", fake_open)
+ result = runner.invoke(app, ["write-text", f"--file-out={file_path}"])
+ assert result.exit_code == 1
+ assert "Could not open file" in result.output
+ assert "cannot-open.txt" in result.output
+ assert "unknown error" in result.output
+
+
+@needs_windows
+def test_app_dir_windows_fallback(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ monkeypatch.delenv("APPDATA", raising=False)
+ monkeypatch.setattr("os.path.expanduser", lambda _path: r"C:\Users\Tester")
+
+ assert typer.get_app_dir("My App", roaming=True) == r"C:\Users\Tester\My App"
+
+
+@needs_linux
+def test_app_dir_force_posix(monkeypatch: pytest.MonkeyPatch) -> None:
+ monkeypatch.setattr("os.path.expanduser", lambda _path: "/home/tester/.my-app")
+
+ assert typer.get_app_dir("My App", force_posix=True) == "/home/tester/.my-app"
+
+
+def test_text_stream_binary_buffer(monkeypatch) -> None:
+ class TextStdinWithBinaryBuffer:
+ def __init__(self, data: bytes) -> None:
+ self.buffer = BytesIO(data)
+ self.encoding = "latin-1"
+
+ def read(self, n: int = -1) -> str:
+ raise OSError("text stream is not readable directly")
+
+ class TextStdoutWithBinaryBuffer:
+ def __init__(self) -> None:
+ self.buffer = BytesIO()
+ self.encoding = "latin-1"
+
+ def write(self, s: str) -> int:
+ raise OSError("text stream is not writable directly")
+
+ stdin = TextStdinWithBinaryBuffer(b"hello")
+ stdout = TextStdoutWithBinaryBuffer()
+
+ monkeypatch.setattr(sys, "stdin", stdin)
+ monkeypatch.setattr(sys, "stdout", stdout)
+
+ text_stdin = typer.get_text_stream("stdin", encoding="utf-8", errors=None)
+ text_stdout = typer.get_text_stream("stdout", encoding="utf-8", errors=None)
+
+ assert text_stdin.read() == "hello"
+ text_stdout.write("ok")
+ text_stdout.flush()
+ assert stdout.buffer.getvalue() == b"ok"
+
+
+def test_text_stream_binary_stream(monkeypatch) -> None:
+ binary_stdout = BytesIO()
+ monkeypatch.setattr(sys, "stdout", binary_stdout)
+ text_stream = typer.get_text_stream("stdout", encoding="utf-8", errors=None)
+ text_stream.write("ok")
+ text_stream.flush()
+ assert binary_stdout.getvalue() == b"ok"
+
+
+def test_text_stream_stdout_no_binary(
+ monkeypatch,
+) -> None:
+ class TextStdoutNoBinaryFallback:
+ encoding = "utf-8"
+ errors = "strict"
+
+ def write(self, s: str) -> int:
+ if isinstance(s, bytes):
+ raise TypeError("bytes not supported")
+ return len(s)
+
+ stdout = TextStdoutNoBinaryFallback()
+ monkeypatch.setattr(sys, "stdout", stdout)
+ text_stream = typer.get_text_stream("stdout", encoding="utf-8", errors="replace")
+ assert text_stream is stdout
+
+
+def test_jupyter_wrapped_stream(monkeypatch) -> None:
+ class JupyterLikeStdout(BytesIO):
+ __module__ = "ipykernel.iostream"
+
+ def isatty(self) -> bool:
+ return False
+
+ binary_stdout = JupyterLikeStdout()
+ monkeypatch.setattr(sys, "stdout", binary_stdout)
+ text_stream = typer.get_text_stream("stdout", encoding="utf-8", errors=None)
+ assert should_strip_ansi(text_stream, color=None) is False
+
+
+def test_should_strip_ansi(monkeypatch) -> None:
+ class NonTtyStdin(BytesIO):
+ def isatty(self) -> bool:
+ return False
+
+ stdin = NonTtyStdin()
+ monkeypatch.setattr(sys, "stdin", stdin)
+ assert should_strip_ansi(stream=None, color=None) is True
+ assert should_strip_ansi(stream=None, color=True) is False
+ assert should_strip_ansi(stream=None, color=False) is True
diff --git a/tests/test_win_console.py b/tests/test_win_console.py
new file mode 100644
index 0000000000..13c390cfe2
--- /dev/null
+++ b/tests/test_win_console.py
@@ -0,0 +1,326 @@
+"""
+Tests for the Windows console functionality.
+Created after vendoring Click to ensure test coverage is back up to 100%.
+"""
+
+import ctypes
+import io
+import sys
+
+import pytest
+import typer
+from typer.testing import CliRunner
+
+from .utils import needs_windows
+
+pytestmark = needs_windows
+
+
+if sys.platform == "win32":
+ from typer._click import _compat, _winconsole
+
+
+def _identity_buffer(obj, writable=False): # noqa: ARG001
+ return obj
+
+
+def _route_console_stream(target_name, wrapper, state=None):
+ def patched_windows_stream(stream, encoding, errors): # noqa: ARG001
+ current_target = getattr(sys, target_name)
+ if stream is current_target:
+ if state is not None and target_name == "stderr":
+ state["stderr_wrap_calls"] += 1
+ buffer = getattr(stream, "buffer", None)
+ return wrapper(buffer) if buffer else None
+ return None
+
+ return patched_windows_stream
+
+
+def _capture_write_console(state):
+ def fake_write_console(handle, buffer, units_to_write, units_written_ptr, reserved): # noqa: ARG001
+ state["write_calls"] += 1
+ bytes_to_write = units_to_write * 2
+ state["written"].extend(buffer[:bytes_to_write])
+ units_written_ptr._obj.value = units_to_write
+ return 1
+
+ return fake_write_console
+
+
+def test_winconsole_stdin(monkeypatch):
+ runner = CliRunner()
+ app = typer.Typer()
+
+ @app.command()
+ def read_name(config: typer.FileText = typer.Option(...)) -> None:
+ name = config.readline().strip()
+ typer.echo(f"Hello {name}")
+
+ utf16_data = bytearray("Rick\r\n".encode("utf-16-le"))
+ state = {"pos": 0, "read_calls": 0}
+
+ def fake_read_console(handle, buffer, units_to_read, units_read_ptr, reserved): # noqa: ARG001
+ state["read_calls"] += 1
+ max_bytes = units_to_read * 2
+ chunk = utf16_data[state["pos"] : state["pos"] + max_bytes]
+ if chunk:
+ buffer[0 : len(chunk)] = chunk
+ state["pos"] += len(chunk)
+ units_read_ptr._obj.value = len(chunk) // 2
+ return 1
+
+ return 1 # pragma: no cover
+
+ monkeypatch.setattr(_winconsole, "get_buffer", _identity_buffer)
+ monkeypatch.setattr(_winconsole, "ReadConsoleW", fake_read_console)
+ monkeypatch.setattr(_winconsole, "GetLastError", lambda: 0)
+ monkeypatch.setattr(
+ _compat,
+ "_get_windows_console_stream",
+ _route_console_stream("stdin", _winconsole._get_text_stdin),
+ )
+
+ result = runner.invoke(app, ["--config", "-"])
+ assert result.exit_code == 0, result.output
+ assert "Hello Rick" in result.stdout
+ assert state["read_calls"] > 0
+
+
+def test_winconsole_stdout(monkeypatch):
+ runner = CliRunner()
+ app = typer.Typer()
+ state = {"write_calls": 0, "written": bytearray()}
+
+ @app.command()
+ def write_message(out: typer.FileTextWrite = typer.Option(...)) -> None:
+ out.write("Hello Summer\n")
+
+ monkeypatch.setattr(_winconsole, "get_buffer", _identity_buffer)
+ monkeypatch.setattr(_winconsole, "WriteConsoleW", _capture_write_console(state))
+ monkeypatch.setattr(_winconsole, "GetLastError", lambda: 0)
+ monkeypatch.setattr(
+ _compat,
+ "_get_windows_console_stream",
+ _route_console_stream("stdout", _winconsole._get_text_stdout),
+ )
+
+ result = runner.invoke(app, ["--out", "-"])
+ assert result.exit_code == 0, result.output
+ assert state["write_calls"] > 0
+ assert _winconsole._WindowsConsoleWriter(1).isatty() is True
+ decoded = state["written"].decode("utf-16-le", errors="ignore")
+ assert "Hello Summer\r\n" in decoded
+
+
+def test_winconsole_stderr(monkeypatch):
+ runner = CliRunner()
+ app = typer.Typer()
+ state = {"write_calls": 0, "written": bytearray(), "stderr_wrap_calls": 0}
+
+ @app.command()
+ def main() -> None:
+ typer.echo("Ran out of adventure time!", err=True)
+
+ monkeypatch.setattr(_winconsole, "get_buffer", _identity_buffer)
+ monkeypatch.setattr(_winconsole, "WriteConsoleW", _capture_write_console(state))
+ monkeypatch.setattr(_winconsole, "GetLastError", lambda: 0)
+ monkeypatch.setattr(
+ _compat,
+ "_get_windows_console_stream",
+ _route_console_stream("stderr", _winconsole._get_text_stderr, state),
+ )
+
+ result = runner.invoke(app)
+ assert result.exit_code == 0, result.output
+ assert state["stderr_wrap_calls"] > 0
+ assert state["write_calls"] > 0
+ decoded = state["written"].decode("utf-16-le", errors="ignore")
+ assert "Ran out of adventure time!\r\n" in decoded
+
+
+@pytest.mark.parametrize(
+ ("writable", "source", "expected_flags"),
+ [
+ (True, bytearray(b"python"), 1), # PyBUF_WRITABLE
+ (False, b"python", 0), # PyBUF_SIMPLE
+ ],
+)
+def test_get_buffer(monkeypatch, writable, source, expected_flags):
+ state = {"flags": None, "released": 0}
+ if writable:
+ backing = (_winconsole.c_char * len(source)).from_buffer(source)
+ else:
+ backing = (_winconsole.c_char * len(source)).from_buffer_copy(source)
+ backing_ptr = _winconsole.c_void_p(ctypes.addressof(backing))
+
+ def fake_object_get_buffer(obj, buf_ref, flags): # noqa: ARG001
+ state["flags"] = flags
+ buf = buf_ref._obj
+ buf.buf = backing_ptr
+ buf.len = len(source)
+
+ def fake_buffer_release(buf_ref): # noqa: ARG001
+ state["released"] += 1
+
+ monkeypatch.setattr(_winconsole, "PyObject_GetBuffer", fake_object_get_buffer)
+ monkeypatch.setattr(_winconsole, "PyBuffer_Release", fake_buffer_release)
+
+ probe = source if writable else b"x"
+ out = _winconsole.get_buffer(probe, writable=writable)
+ if writable:
+ # mutate the first byte of "python" to obtain another beloved programming language
+ out[0] = b"c"
+ assert source == bytearray(b"cython")
+ else:
+ assert bytes(out[: len(source)]) == source
+ assert state["flags"] == expected_flags
+ assert state["released"] == 1
+
+
+def test_isatty():
+ assert _winconsole._WindowsConsoleRawIOBase(None).isatty() is True
+ assert _winconsole._WindowsConsoleReader(0).isatty() is True
+ assert _winconsole._WindowsConsoleReader(1).isatty() is True
+
+
+def test_console_stream():
+ class NamedBytesIO(io.BytesIO):
+ name = "fake-buffer"
+
+ def isatty(self):
+ return False
+
+ stream = _winconsole.ConsoleStream(
+ io.TextIOWrapper(io.BytesIO(), encoding="utf-8"), NamedBytesIO()
+ )
+ assert stream.isatty() is False
+ assert stream.name == "fake-buffer"
+ assert "fake-buffer" in repr(stream)
+ assert "utf-8" in repr(stream)
+
+ # test writelines
+ stream.writelines(["hello", " ", "world"])
+ stream._text_stream.flush()
+ assert stream._text_stream.buffer.getvalue().decode("utf-8") == "hello world"
+
+ # Cover bytes write path.
+ assert stream.write(b"!") == 1
+ assert stream.buffer.getvalue().endswith(b"!")
+
+
+@pytest.mark.parametrize(
+ ("error", "msg"),
+ [
+ (0, "ERROR_SUCCESS"), # ERROR_SUCCESS
+ (8, "ERROR_NOT_ENOUGH_MEMORY"), # ERROR_NOT_ENOUGH_MEMORY
+ (342, "Windows error 342"),
+ ],
+)
+def test_error_message(error, msg):
+ writer = _winconsole._WindowsConsoleWriter
+ assert writer._get_error_message(error) == msg
+
+
+def test_is_console():
+ assert _winconsole._is_console(object()) is False
+
+
+def test_get_windows_console_stream_factory_and_buffer_paths(monkeypatch):
+ monkeypatch.setattr(_winconsole, "_is_console", lambda f: True)
+ monkeypatch.setattr(_winconsole, "get_buffer", object())
+
+ class FakeStream:
+ def __init__(self, fd, buffer=None):
+ self._fd = fd
+ self.buffer = buffer
+
+ def fileno(self):
+ return self._fd
+
+ wrapped = {"called": False, "buffer": None}
+
+ def fake_factory(buffer):
+ wrapped["called"] = True
+ wrapped["buffer"] = buffer
+ return "wrapped", buffer
+
+ monkeypatch.setattr(_winconsole, "_stream_factories", {7: fake_factory})
+
+ # Known console stream preconditions pass, but no stream factory for this fd.
+ get_stream = _winconsole._get_windows_console_stream
+ assert get_stream(FakeStream(99, object()), "utf-16-le", "strict") is None
+
+ # Factory exists, but stream has no usable .buffer.
+ assert get_stream(FakeStream(7, None), "utf-16-le", "strict") is None
+
+ # Factory exists and buffer is present, so wrapper result is returned.
+ raw_buffer = object()
+ out = get_stream(FakeStream(7, raw_buffer), "utf-16-le", "strict")
+ assert out == ("wrapped", raw_buffer)
+ assert wrapped["called"] is True
+ assert wrapped["buffer"] is raw_buffer
+
+
+def test_windows_console_reader(monkeypatch):
+ reader = _winconsole._WindowsConsoleReader(42)
+
+ # Empty input buffer returns early
+ assert reader.readinto(bytearray()) == 0
+
+ # Require an even number of bytes
+ with pytest.raises(ValueError):
+ reader.readinto(bytearray(3))
+
+ def writable_buffer(obj, writable=False): # noqa: ARG001
+ return (ctypes.c_char * len(obj)).from_buffer(obj)
+
+ monkeypatch.setattr(_winconsole, "get_buffer", writable_buffer)
+
+ def patch_console(read_console, error):
+ monkeypatch.setattr(_winconsole, "ReadConsoleW", read_console)
+ monkeypatch.setattr(_winconsole, "GetLastError", lambda: error)
+
+ def make_read(payload=b"", rv=1, units_read=None):
+ def read_console(handle, buffer, units_to_read, units_read_ptr, reserved): # noqa: ARG001
+ bytes_to_copy = min(len(payload), units_to_read * 2)
+ if bytes_to_copy:
+ buffer[0:bytes_to_copy] = payload[:bytes_to_copy]
+ read_units = units_read if units_read is not None else bytes_to_copy // 2
+ units_read_ptr._obj.value = read_units
+ return rv
+
+ return read_console
+
+ # Normal successful read returns the number of bytes read
+ patch_console(make_read(payload=b"A\x00B\x00"), _winconsole.ERROR_SUCCESS)
+ assert reader.readinto(bytearray(4)) == 4
+
+ # CTRL+Z (EOF) should be translated into an empty read
+ patch_console(
+ make_read(payload=_winconsole.EOF + b"\x00", units_read=1),
+ _winconsole.ERROR_SUCCESS,
+ )
+ assert reader.readinto(bytearray(2)) == 0
+
+ # An aborted read should sleep briefly while waiting for KeyboardInterrupt
+ sleep_state = {"calls": 0}
+
+ def fake_sleep(seconds):
+ sleep_state["calls"] += 1
+ assert seconds == 0.1
+
+ monkeypatch.setattr(
+ _winconsole, "time", type("FakeTime", (), {"sleep": fake_sleep})
+ )
+ patch_console(
+ make_read(payload=b"Z\x00", units_read=1),
+ _winconsole.ERROR_OPERATION_ABORTED,
+ )
+ assert reader.readinto(bytearray(2)) == 2
+ assert sleep_state["calls"] == 1
+
+ # Failed reads propagate a Windows error
+ patch_console(make_read(rv=0), _winconsole.ERROR_NOT_ENOUGH_MEMORY)
+ with pytest.raises(OSError, match="Windows error"):
+ reader.readinto(bytearray(2))
diff --git a/tests/utils.py b/tests/utils.py
index 35c441d365..7cb285bc53 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -9,9 +9,16 @@
needs_linux = pytest.mark.skipif(
not sys.platform.startswith("linux"), reason="Test requires Linux"
)
+needs_macos = pytest.mark.skipif(
+ not sys.platform.startswith("darwin"), reason="Test requires macOS"
+)
needs_windows = pytest.mark.skipif(
not sys.platform.startswith("win"), reason="Test requires Windows"
)
+skip_if_windows = pytest.mark.skipif(
+ sys.platform == "win32",
+ reason="Test should not be run on Windows",
+)
needs_rich = pytest.mark.skipif(not HAS_RICH, reason="Test requires Rich")
diff --git a/typer/.agents/skills/typer/SKILL.md b/typer/.agents/skills/typer/SKILL.md
index 20e6cdd589..45bf622407 100644
--- a/typer/.agents/skills/typer/SKILL.md
+++ b/typer/.agents/skills/typer/SKILL.md
@@ -259,7 +259,7 @@ if __name__ == "__main__":
## Click
-Originally, Typer was built on Click. However, going forward Typer will vendor Click. As such, Click extensions should not be used anymore.
+Originally, Typer was built on Click. However, since version 0.26.0, Typer has vendored Click. As such, Click extensions should not be used anymore.
Other settings of `Option` and `Argument` that came from Click but shouldn't be used in Typer anymore, include: `expose_value`, `shell_complete`, `show_choices`, `errors`, `prompt_required`, `is_flag`, `flag_value` and `allow_from_autoenv`.
diff --git a/typer/__init__.py b/typer/__init__.py
index 548408fe98..392bd006ff 100644
--- a/typer/__init__.py
+++ b/typer/__init__.py
@@ -4,28 +4,21 @@
from shutil import get_terminal_size as get_terminal_size
-from click.exceptions import Abort as Abort
-from click.exceptions import BadParameter as BadParameter
-from click.exceptions import Exit as Exit
-from click.termui import clear as clear
-from click.termui import confirm as confirm
-from click.termui import echo_via_pager as echo_via_pager
-from click.termui import edit as edit
-from click.termui import getchar as getchar
-from click.termui import pause as pause
-from click.termui import progressbar as progressbar
-from click.termui import prompt as prompt
-from click.termui import secho as secho
-from click.termui import style as style
-from click.termui import unstyle as unstyle
-from click.utils import echo as echo
-from click.utils import format_filename as format_filename
-from click.utils import get_app_dir as get_app_dir
-from click.utils import get_binary_stream as get_binary_stream
-from click.utils import get_text_stream as get_text_stream
-from click.utils import open_file as open_file
-
from . import colors as colors
+from ._click.exceptions import Abort as Abort
+from ._click.exceptions import BadParameter as BadParameter
+from ._click.exceptions import Exit as Exit
+from ._click.termui import confirm as confirm
+from ._click.termui import getchar as getchar
+from ._click.termui import progressbar as progressbar
+from ._click.termui import prompt as prompt
+from ._click.termui import secho as secho
+from ._click.termui import style as style
+from ._click.utils import echo as echo
+from ._click.utils import format_filename as format_filename
+from ._click.utils import get_app_dir as get_app_dir
+from ._click.utils import get_binary_stream as get_binary_stream
+from ._click.utils import get_text_stream as get_text_stream
from .main import Typer as Typer
from .main import launch as launch
from .main import run as run
diff --git a/typer/_click/LICENSE.txt b/typer/_click/LICENSE.txt
new file mode 100644
index 0000000000..d12a849186
--- /dev/null
+++ b/typer/_click/LICENSE.txt
@@ -0,0 +1,28 @@
+Copyright 2014 Pallets
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+1. Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
+PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
+TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/typer/_click/__init__.py b/typer/_click/__init__.py
new file mode 100644
index 0000000000..48c9edcddd
--- /dev/null
+++ b/typer/_click/__init__.py
@@ -0,0 +1,11 @@
+"""
+Code taken and adapted from Click: https://github.com/pallets/click/releases/tag/8.3.1
+"""
+
+from .core import Command as Command
+from .core import Context as Context
+from .core import Parameter as Parameter
+from .exceptions import ClickException as ClickException
+from .formatting import HelpFormatter as HelpFormatter
+from .termui import launch as launch
+from .utils import echo as echo
diff --git a/typer/_click/_compat.py b/typer/_click/_compat.py
new file mode 100644
index 0000000000..6ed0ceb8ac
--- /dev/null
+++ b/typer/_click/_compat.py
@@ -0,0 +1,569 @@
+import codecs
+import io
+import os
+import re
+import sys
+from collections.abc import Callable, Mapping, MutableMapping
+from types import TracebackType
+from typing import (
+ IO,
+ Any,
+ BinaryIO,
+ TextIO,
+ cast,
+)
+from weakref import WeakKeyDictionary
+
+CYGWIN = sys.platform.startswith("cygwin")
+WIN = sys.platform.startswith("win")
+auto_wrap_for_ansi: Callable[[TextIO], TextIO] | None = None
+_ansi_re = re.compile(r"\033\[[;?0-9]*[a-zA-Z]")
+
+
+def _make_text_stream(
+ stream: BinaryIO,
+ encoding: str | None,
+ errors: str,
+) -> TextIO:
+ if encoding is None:
+ encoding = get_best_encoding(stream)
+ return _NonClosingTextIOWrapper(
+ stream,
+ encoding,
+ errors,
+ line_buffering=True,
+ )
+
+
+def is_ascii_encoding(encoding: str) -> bool:
+ """Checks if a given encoding is ascii."""
+ try:
+ return codecs.lookup(encoding).name == "ascii"
+ except LookupError:
+ return False
+
+
+def get_best_encoding(stream: IO[Any]) -> str:
+ """Returns the default stream encoding if not found."""
+ rv = getattr(stream, "encoding", None) or sys.getdefaultencoding()
+ if is_ascii_encoding(rv):
+ return "utf-8"
+ return rv
+
+
+class _NonClosingTextIOWrapper(io.TextIOWrapper):
+ def __init__(
+ self,
+ stream: BinaryIO,
+ encoding: str | None,
+ errors: str | None,
+ **extra: Any,
+ ) -> None:
+ self._stream = stream = cast(BinaryIO, _FixupStream(stream))
+ super().__init__(stream, encoding, errors, **extra)
+
+ def __del__(self) -> None:
+ try:
+ self.detach()
+ except Exception: # pragma: no cover
+ pass
+
+ def isatty(self) -> bool:
+ # https://bitbucket.org/pypy/pypy/issue/1803
+ return self._stream.isatty()
+
+
+class _FixupStream:
+ """The new io interface needs more from streams than streams
+ traditionally implement. As such, this fix-up code is necessary in
+ some circumstances.
+ """
+
+ def __init__(
+ self,
+ stream: BinaryIO,
+ ):
+ self._stream = stream
+
+ def __getattr__(self, name: str) -> Any:
+ return getattr(self._stream, name)
+
+ def read1(self, size: int) -> bytes:
+ f = getattr(self._stream, "read1", None)
+
+ if f is not None:
+ return cast(bytes, f(size))
+
+ return self._stream.read(size)
+
+ def readable(self) -> bool:
+ return True
+
+ def writable(self) -> bool:
+ return True
+
+ def seekable(self) -> bool:
+ x = getattr(self._stream, "seekable", None)
+ if x is not None:
+ return cast(bool, x())
+ return False
+
+
+def _is_binary_reader(stream: IO[Any], default: bool = False) -> bool:
+ try:
+ return isinstance(stream.read(0), bytes)
+ except Exception: # pragma: no cover
+ return default
+ # This happens in some cases where the stream was already
+ # closed. In this case, we assume the default.
+
+
+def _is_binary_writer(stream: IO[Any], default: bool = False) -> bool:
+ try:
+ stream.write(b"")
+ except Exception: # pragma: no cover
+ try:
+ stream.write("")
+ return False
+ except Exception:
+ pass
+ return default
+ return True
+
+
+def _find_binary_reader(stream: IO[Any]) -> BinaryIO | None:
+ # We need to figure out if the given stream is already binary.
+ # This can happen because the official docs recommend detaching
+ # the streams to get binary streams. Some code might do this, so
+ # we need to deal with this case explicitly.
+ if _is_binary_reader(stream, False):
+ return cast(BinaryIO, stream)
+
+ buf = getattr(stream, "buffer", None)
+
+ # Same situation here; this time we assume that the buffer is
+ # actually binary in case it's closed.
+ if buf is not None and _is_binary_reader(buf, True):
+ return cast(BinaryIO, buf)
+
+ return None
+
+
+def _find_binary_writer(stream: IO[Any]) -> BinaryIO | None:
+ # We need to figure out if the given stream is already binary.
+ # This can happen because the official docs recommend detaching
+ # the streams to get binary streams. Some code might do this, so
+ # we need to deal with this case explicitly.
+ if _is_binary_writer(stream, False):
+ return cast(BinaryIO, stream)
+
+ buf = getattr(stream, "buffer", None)
+
+ # Same situation here; this time we assume that the buffer is
+ # actually binary in case it's closed.
+ if buf is not None and _is_binary_writer(buf, True):
+ return cast(BinaryIO, buf)
+
+ return None
+
+
+def _stream_is_misconfigured(stream: TextIO) -> bool:
+ """A stream is misconfigured if its encoding is ASCII."""
+ # If the stream does not have an encoding set, we assume it's set
+ # to ASCII. This appears to happen in certain unittest
+ # environments. It's not quite clear what the correct behavior is
+ # but this at least will force Click to recover somehow.
+ return is_ascii_encoding(getattr(stream, "encoding", None) or "ascii")
+
+
+def _is_compat_stream_attr(stream: TextIO, attr: str, value: str | None) -> bool:
+ """A stream attribute is compatible if it is equal to the
+ desired value or the desired value is unset and the attribute
+ has a value.
+ """
+ stream_value = getattr(stream, attr, None)
+ return stream_value == value or (value is None and stream_value is not None)
+
+
+def _is_compatible_text_stream(
+ stream: TextIO, encoding: str | None, errors: str | None
+) -> bool:
+ """Check if a stream's encoding and errors attributes are
+ compatible with the desired values.
+ """
+ return _is_compat_stream_attr(
+ stream, "encoding", encoding
+ ) and _is_compat_stream_attr(stream, "errors", errors)
+
+
+def _force_correct_text_stream(
+ text_stream: IO[Any],
+ encoding: str | None,
+ errors: str | None,
+ is_binary: Callable[[IO[Any], bool], bool],
+ find_binary: Callable[[IO[Any]], BinaryIO | None],
+) -> TextIO:
+ if is_binary(text_stream, False):
+ binary_reader = cast(BinaryIO, text_stream)
+ else:
+ text_stream = cast(TextIO, text_stream)
+ # If the stream looks compatible, and won't default to a
+ # misconfigured ascii encoding, return it as-is.
+ if _is_compatible_text_stream(text_stream, encoding, errors) and not (
+ encoding is None and _stream_is_misconfigured(text_stream)
+ ):
+ return text_stream
+
+ # Otherwise, get the underlying binary reader.
+ possible_binary_reader = find_binary(text_stream)
+
+ # If that's not possible, silently use the original reader
+ # and get mojibake instead of exceptions.
+ if possible_binary_reader is None:
+ return text_stream
+
+ binary_reader = possible_binary_reader
+
+ # Default errors to replace instead of strict in order to get
+ # something that works.
+ if errors is None:
+ errors = "replace"
+
+ # Wrap the binary stream in a text stream with the correct
+ # encoding parameters.
+ return _make_text_stream(
+ binary_reader,
+ encoding,
+ errors,
+ )
+
+
+def _force_correct_text_reader(
+ text_reader: IO[Any],
+ encoding: str | None,
+ errors: str | None,
+) -> TextIO:
+ return _force_correct_text_stream(
+ text_reader,
+ encoding,
+ errors,
+ _is_binary_reader,
+ _find_binary_reader,
+ )
+
+
+def _force_correct_text_writer(
+ text_writer: IO[Any],
+ encoding: str | None,
+ errors: str | None,
+) -> TextIO:
+ return _force_correct_text_stream(
+ text_writer,
+ encoding,
+ errors,
+ _is_binary_writer,
+ _find_binary_writer,
+ )
+
+
+def get_binary_stdin() -> BinaryIO:
+ reader = _find_binary_reader(sys.stdin)
+ if reader is None: # pragma: no cover
+ raise RuntimeError("Was not able to determine binary stream for sys.stdin.")
+ return reader
+
+
+def get_binary_stdout() -> BinaryIO:
+ writer = _find_binary_writer(sys.stdout)
+ if writer is None: # pragma: no cover
+ raise RuntimeError("Was not able to determine binary stream for sys.stdout.")
+ return writer
+
+
+def get_binary_stderr() -> BinaryIO:
+ writer = _find_binary_writer(sys.stderr)
+ if writer is None: # pragma: no cover
+ raise RuntimeError("Was not able to determine binary stream for sys.stderr.")
+ return writer
+
+
+def get_text_stdin(encoding: str | None = None, errors: str | None = None) -> TextIO:
+ rv = _get_windows_console_stream(sys.stdin, encoding, errors)
+ if rv is not None:
+ return rv
+ return _force_correct_text_reader(sys.stdin, encoding, errors)
+
+
+def get_text_stdout(encoding: str | None = None, errors: str | None = None) -> TextIO:
+ rv = _get_windows_console_stream(sys.stdout, encoding, errors)
+ if rv is not None:
+ return rv
+ return _force_correct_text_writer(sys.stdout, encoding, errors)
+
+
+def get_text_stderr(encoding: str | None = None, errors: str | None = None) -> TextIO:
+ rv = _get_windows_console_stream(sys.stderr, encoding, errors)
+ if rv is not None:
+ return rv
+ return _force_correct_text_writer(sys.stderr, encoding, errors)
+
+
+def _wrap_io_open(
+ file: str | os.PathLike[str] | int,
+ mode: str,
+ encoding: str | None,
+ errors: str | None,
+) -> IO[Any]:
+ """Handles not passing ``encoding`` and ``errors`` in binary mode."""
+ if "b" in mode:
+ return open(file, mode)
+
+ return open(file, mode, encoding=encoding, errors=errors)
+
+
+def open_stream(
+ filename: str | os.PathLike[str],
+ mode: str = "r",
+ encoding: str | None = None,
+ errors: str | None = "strict",
+ atomic: bool = False,
+) -> tuple[IO[Any], bool]:
+ binary = "b" in mode
+ filename = os.fspath(filename)
+
+ # Standard streams first, ignoring the atomic flag.
+ if os.fsdecode(filename) == "-":
+ if any(m in mode for m in ["w", "a", "x"]):
+ if binary:
+ return get_binary_stdout(), False
+ return get_text_stdout(encoding=encoding, errors=errors), False
+ if binary:
+ return get_binary_stdin(), False
+ return get_text_stdin(encoding=encoding, errors=errors), False
+
+ # Non-atomic writes directly go out through the regular open functions.
+ if not atomic:
+ return _wrap_io_open(filename, mode, encoding, errors), True
+
+ # Some usability stuff for atomic writes
+ if "a" in mode:
+ raise ValueError(
+ "Appending to an existing file is not supported, because that"
+ " would involve an expensive `copy`-operation to a temporary"
+ " file. Open the file in normal `w`-mode and copy explicitly"
+ " if that's what you're after."
+ )
+ if "x" in mode:
+ raise ValueError("Use the `overwrite`-parameter instead.")
+ if "w" not in mode:
+ raise ValueError("Atomic writes only make sense with `w`-mode.")
+
+ # Atomic writes are more complicated. They work by opening a file
+ # as a proxy in the same folder and then using the fdopen
+ # functionality to wrap it in a Python file. Then we wrap it in an
+ # atomic file that moves the file over on close.
+ import errno
+ import random
+
+ try:
+ perm: int | None = os.stat(filename).st_mode
+ except OSError: # pragma: no cover
+ perm = None
+
+ flags = os.O_RDWR | os.O_CREAT | os.O_EXCL
+
+ if binary:
+ flags |= getattr(os, "O_BINARY", 0)
+
+ while True:
+ tmp_filename = os.path.join(
+ os.path.dirname(filename),
+ f".__atomic-write{random.randrange(1 << 32):08x}",
+ )
+ try:
+ fd = os.open(tmp_filename, flags, 0o666 if perm is None else perm)
+ break
+ except OSError as e: # pragma: no cover
+ if e.errno == errno.EEXIST or (
+ os.name == "nt"
+ and e.errno == errno.EACCES
+ and os.path.isdir(e.filename)
+ and os.access(e.filename, os.W_OK)
+ ):
+ continue
+ raise
+
+ if perm is not None:
+ os.chmod(tmp_filename, perm) # in case perm includes bits in umask
+
+ f = _wrap_io_open(fd, mode, encoding, errors)
+ af = _AtomicFile(f, tmp_filename, os.path.realpath(filename))
+ return cast(IO[Any], af), True
+
+
+class _AtomicFile:
+ def __init__(self, f: IO[Any], tmp_filename: str, real_filename: str) -> None:
+ self._f = f
+ self._tmp_filename = tmp_filename
+ self._real_filename = real_filename
+ self.closed = False
+
+ @property
+ def name(self) -> str:
+ return self._real_filename
+
+ def close(self, delete: bool = False) -> None:
+ if self.closed:
+ return # pragma: no cover
+ self._f.close()
+ os.replace(self._tmp_filename, self._real_filename)
+ self.closed = True
+
+ def __getattr__(self, name: str) -> Any:
+ return getattr(self._f, name)
+
+ def __enter__(self) -> "_AtomicFile":
+ return self
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ tb: TracebackType | None,
+ ) -> None:
+ self.close(delete=exc_type is not None)
+
+ def __repr__(self) -> str:
+ return repr(self._f)
+
+
+def strip_ansi(value: str) -> str:
+ return _ansi_re.sub("", value)
+
+
+def _is_jupyter_kernel_output(stream: IO[Any]) -> bool:
+ while isinstance(stream, (_FixupStream, _NonClosingTextIOWrapper)):
+ stream = stream._stream
+
+ return stream.__class__.__module__.startswith("ipykernel.")
+
+
+def should_strip_ansi(stream: IO[Any] | None = None, color: bool | None = None) -> bool:
+ if color is None:
+ if stream is None:
+ stream = sys.stdin
+ return not isatty(stream) and not _is_jupyter_kernel_output(stream)
+ return not color
+
+
+# On Windows, wrap the output streams with colorama to support ANSI
+# color codes.
+# NOTE: double check is needed so mypy does not analyze this on Linux
+if sys.platform.startswith("win") and WIN:
+ from ._winconsole import _get_windows_console_stream
+
+ def _get_argv_encoding() -> str:
+ import locale
+
+ return locale.getpreferredencoding()
+
+ _ansi_stream_wrappers: MutableMapping[TextIO, TextIO] = WeakKeyDictionary()
+
+ def auto_wrap_for_ansi(stream: TextIO, color: bool | None = None) -> TextIO:
+ """Support ANSI color and style codes on Windows by wrapping a
+ stream with colorama.
+ """
+ try:
+ cached = _ansi_stream_wrappers.get(stream)
+ except Exception: # pragma: no cover
+ cached = None
+
+ if cached is not None:
+ return cached
+
+ import colorama
+
+ strip = should_strip_ansi(stream, color)
+ ansi_wrapper = colorama.AnsiToWin32(stream, strip=strip)
+ rv = cast(TextIO, ansi_wrapper.stream)
+ _write = rv.write
+
+ def _safe_write(s: str) -> int:
+ try:
+ return _write(s)
+ except BaseException: # pragma: no cover
+ ansi_wrapper.reset_all()
+ raise
+
+ rv.write = _safe_write # type: ignore[method-assign] # ty: ignore[invalid-assignment]
+
+ try:
+ _ansi_stream_wrappers[stream] = rv
+ except Exception: # pragma: no cover
+ pass
+
+ return rv
+
+else:
+
+ def _get_argv_encoding() -> str:
+ return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding()
+
+ def _get_windows_console_stream(
+ f: TextIO, encoding: str | None, errors: str | None
+ ) -> TextIO | None:
+ return None
+
+
+def term_len(x: str) -> int:
+ return len(strip_ansi(x))
+
+
+def isatty(stream: IO[Any]) -> bool:
+ try:
+ return stream.isatty()
+ except Exception: # pragma: no cover
+ return False
+
+
+def _make_cached_stream_func(
+ src_func: Callable[[], TextIO],
+ wrapper_func: Callable[[], TextIO],
+) -> Callable[[], TextIO]:
+ cache: MutableMapping[TextIO, TextIO] = WeakKeyDictionary()
+
+ def func() -> TextIO:
+ stream = src_func()
+
+ try:
+ rv = cache.get(stream)
+ except Exception: # pragma: no cover
+ rv = None
+ if rv is not None:
+ return rv
+ rv = wrapper_func()
+ try:
+ cache[stream] = rv
+ except Exception: # pragma: no cover
+ pass
+ return rv
+
+ return func
+
+
+_default_text_stdin = _make_cached_stream_func(lambda: sys.stdin, get_text_stdin)
+_default_text_stdout = _make_cached_stream_func(lambda: sys.stdout, get_text_stdout)
+_default_text_stderr = _make_cached_stream_func(lambda: sys.stderr, get_text_stderr)
+
+
+binary_streams: Mapping[str, Callable[[], BinaryIO]] = {
+ "stdin": get_binary_stdin,
+ "stdout": get_binary_stdout,
+ "stderr": get_binary_stderr,
+}
+
+text_streams: Mapping[str, Callable[[str | None, str | None], TextIO]] = {
+ "stdin": get_text_stdin,
+ "stdout": get_text_stdout,
+ "stderr": get_text_stderr,
+}
diff --git a/typer/_click/_termui_impl.py b/typer/_click/_termui_impl.py
new file mode 100644
index 0000000000..c621810cbf
--- /dev/null
+++ b/typer/_click/_termui_impl.py
@@ -0,0 +1,522 @@
+"""
+To keep the import times down, some infrequently used termui functionality
+is placed here and only imported as needed.
+"""
+
+import contextlib
+import math
+import os
+import sys
+import time
+from collections.abc import Callable, Iterable, Iterator
+from io import StringIO
+from types import TracebackType
+from typing import Generic, TextIO, TypeVar, cast
+
+from ._compat import (
+ CYGWIN,
+ WIN,
+ _default_text_stdout,
+ get_best_encoding,
+ isatty,
+ term_len,
+)
+from .utils import echo
+
+V = TypeVar("V")
+
+if os.name == "nt":
+ BEFORE_BAR = "\r"
+ AFTER_BAR = "\n"
+else:
+ BEFORE_BAR = "\r\033[?25l"
+ AFTER_BAR = "\033[?25h\n"
+
+
+class ProgressBar(Generic[V]):
+ def __init__(
+ self,
+ iterable: Iterable[V] | None,
+ length: int | None = None,
+ fill_char: str = "#",
+ empty_char: str = " ",
+ bar_template: str = "%(bar)s",
+ info_sep: str = " ",
+ hidden: bool = False,
+ show_eta: bool = True,
+ show_percent: bool | None = None,
+ show_pos: bool = False,
+ item_show_func: Callable[[V | None], str | None] | None = None,
+ label: str | None = None,
+ file: TextIO | None = None,
+ color: bool | None = None,
+ update_min_steps: int = 1,
+ width: int = 30,
+ ) -> None:
+ self.fill_char = fill_char
+ self.empty_char = empty_char
+ self.bar_template = bar_template
+ self.info_sep = info_sep
+ self.hidden = hidden
+ self.show_eta = show_eta
+ self.show_percent = show_percent
+ self.show_pos = show_pos
+ self.item_show_func = item_show_func
+ self.label: str = label or ""
+
+ if file is None:
+ file = _default_text_stdout()
+
+ # There are no standard streams attached to write to. For example,
+ # pythonw on Windows.
+ if file is None: # pragma: no cover
+ file = StringIO()
+
+ self.file = file
+ self.color = color
+ self.update_min_steps = update_min_steps
+ self._completed_intervals = 0
+ self.width: int = width
+ self.autowidth: bool = width == 0
+
+ if length is None:
+ from operator import length_hint
+
+ length = length_hint(iterable, -1)
+
+ if length == -1: # pragma: no cover
+ length = None
+ if iterable is None:
+ if length is None: # pragma: no cover
+ raise TypeError("iterable or length is required")
+ iterable = cast("Iterable[V]", range(length))
+ self.iter: Iterable[V] = iter(iterable)
+ self.length = length
+ self.pos: int = 0
+ self.avg: list[float] = []
+ self.last_eta: float
+ self.start: float
+ self.start = self.last_eta = time.time()
+ self.eta_known: bool = False
+ self.finished: bool = False
+ self.max_width: int | None = None
+ self.entered: bool = False
+ self.current_item: V | None = None
+ self._is_atty = isatty(self.file)
+ self._last_line: str | None = None
+
+ def __enter__(self) -> "ProgressBar[V]":
+ self.entered = True
+ self.render_progress()
+ return self
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ tb: TracebackType | None,
+ ) -> None:
+ self.render_finish()
+
+ def __iter__(self) -> Iterator[V]:
+ if not self.entered:
+ raise RuntimeError("You need to use progress bars in a with block.")
+ self.render_progress()
+ return self.generator()
+
+ def __next__(self) -> V:
+ # Iteration is defined in terms of a generator function,
+ # returned by iter(self); use that to define next(). This works
+ # because `self.iter` is an iterable consumed by that generator,
+ # so it is re-entry safe. Calling `next(self.generator())`
+ # twice works and does "what you want".
+ return next(iter(self))
+
+ def render_finish(self) -> None:
+ if self.hidden or not self._is_atty:
+ return
+ self.file.write(AFTER_BAR)
+ self.file.flush()
+
+ @property
+ def pct(self) -> float:
+ if self.finished:
+ return 1.0
+ return min(self.pos / (float(self.length or 1) or 1), 1.0)
+
+ @property
+ def time_per_iteration(self) -> float:
+ if not self.avg:
+ return 0.0
+ return sum(self.avg) / float(len(self.avg))
+
+ @property
+ def eta(self) -> float:
+ if self.length is not None and not self.finished:
+ return self.time_per_iteration * (self.length - self.pos)
+ return 0.0
+
+ def format_eta(self) -> str:
+ if self.eta_known:
+ t = int(self.eta)
+ seconds = t % 60
+ t //= 60
+ minutes = t % 60
+ t //= 60
+ hours = t % 24
+ t //= 24
+ if t > 0:
+ return f"{t}d {hours:02}:{minutes:02}:{seconds:02}"
+ else:
+ return f"{hours:02}:{minutes:02}:{seconds:02}"
+ return ""
+
+ def format_pos(self) -> str:
+ pos = str(self.pos)
+ if self.length is not None:
+ pos += f"/{self.length}"
+ return pos
+
+ def format_pct(self) -> str:
+ return f"{int(self.pct * 100): 4}%"[1:]
+
+ def format_bar(self) -> str:
+ if self.length is not None:
+ bar_length = int(self.pct * self.width)
+ bar = self.fill_char * bar_length
+ bar += self.empty_char * (self.width - bar_length)
+ elif self.finished:
+ bar = self.fill_char * self.width
+ else:
+ chars = list(self.empty_char * (self.width or 1))
+ if self.time_per_iteration != 0:
+ chars[
+ int(
+ (math.cos(self.pos * self.time_per_iteration) / 2.0 + 0.5)
+ * self.width
+ )
+ ] = self.fill_char
+ bar = "".join(chars)
+ return bar
+
+ def format_progress_line(self) -> str:
+ show_percent = self.show_percent
+
+ info_bits = []
+ if self.length is not None and show_percent is None:
+ show_percent = not self.show_pos
+
+ if self.show_pos:
+ info_bits.append(self.format_pos())
+ if show_percent:
+ info_bits.append(self.format_pct())
+ if self.show_eta and self.eta_known and not self.finished:
+ info_bits.append(self.format_eta())
+ if self.item_show_func is not None:
+ item_info = self.item_show_func(self.current_item)
+ if item_info is not None:
+ info_bits.append(item_info)
+
+ return (
+ self.bar_template
+ % {
+ "label": self.label,
+ "bar": self.format_bar(),
+ "info": self.info_sep.join(info_bits),
+ }
+ ).rstrip()
+
+ def render_progress(self) -> None:
+ if self.hidden:
+ return
+
+ if not self._is_atty:
+ # Only output the label once if the output is not a TTY.
+ if self._last_line != self.label:
+ self._last_line = self.label
+ echo(self.label, file=self.file, color=self.color)
+ return
+
+ buf = []
+ # Update width in case the terminal has been resized
+ if self.autowidth:
+ import shutil
+
+ old_width = self.width
+ self.width = 0
+ clutter_length = term_len(self.format_progress_line())
+ new_width = max(0, shutil.get_terminal_size().columns - clutter_length)
+ if new_width < old_width and self.max_width is not None:
+ buf.append(BEFORE_BAR)
+ buf.append(" " * self.max_width)
+ self.max_width = new_width
+ self.width = new_width
+
+ clear_width = self.width
+ if self.max_width is not None:
+ clear_width = self.max_width
+
+ buf.append(BEFORE_BAR)
+ line = self.format_progress_line()
+ line_len = term_len(line)
+ if self.max_width is None or self.max_width < line_len:
+ self.max_width = line_len
+
+ buf.append(line)
+ buf.append(" " * (clear_width - line_len))
+ line = "".join(buf)
+ # Render the line only if it changed.
+
+ if line != self._last_line:
+ self._last_line = line
+ echo(line, file=self.file, color=self.color, nl=False)
+ self.file.flush()
+
+ def make_step(self, n_steps: int) -> None:
+ self.pos += n_steps
+ if self.length is not None and self.pos >= self.length:
+ self.finished = True
+
+ if (time.time() - self.last_eta) < 1.0:
+ return
+
+ self.last_eta = time.time()
+
+ # self.avg is a rolling list of length <= 7 of steps where steps are
+ # defined as time elapsed divided by the total progress through
+ # self.length.
+ if self.pos:
+ step = (time.time() - self.start) / self.pos
+ else:
+ step = time.time() - self.start
+
+ self.avg = self.avg[-6:] + [step]
+
+ self.eta_known = self.length is not None
+
+ def update(self, n_steps: int) -> None:
+ """Update the progress bar by advancing a specified number of steps."""
+ self._completed_intervals += n_steps
+
+ if self._completed_intervals >= self.update_min_steps:
+ self.make_step(self._completed_intervals)
+ self.render_progress()
+ self._completed_intervals = 0
+
+ def finish(self) -> None:
+ self.eta_known = False
+ self.current_item = None
+ self.finished = True
+
+ def generator(self) -> Iterator[V]:
+ """Return a generator which yields the items added to the bar
+ during construction, and updates the progress bar *after* the
+ yielded block returns.
+ """
+ # WARNING: the iterator interface for `ProgressBar` relies on
+ # this and only works because this is a simple generator which
+ # doesn't create or manage additional state. If this function
+ # changes, the impact should be evaluated both against
+ # `iter(bar)` and `next(bar)`. `next()` in particular may call
+ # `self.generator()` repeatedly, and this must remain safe in
+ # order for that interface to work.
+ if not self.entered: # pragma: no cover
+ raise RuntimeError("You need to use progress bars in a with block.")
+
+ if not self._is_atty:
+ yield from self.iter
+ else:
+ for rv in self.iter:
+ self.current_item = rv
+
+ # This allows show_item_func to be updated before the
+ # item is processed. Only trigger at the beginning of
+ # the update interval.
+ if self._completed_intervals == 0:
+ self.render_progress()
+
+ yield rv
+ self.update(1)
+
+ self.finish()
+ self.render_progress()
+
+
+def open_url(url: str, wait: bool = False, locate: bool = False) -> int:
+ import subprocess
+
+ def _unquote_file(url: str) -> str:
+ from urllib.parse import unquote
+
+ if url.startswith("file://"):
+ url = unquote(url[7:])
+
+ return url
+
+ if sys.platform == "darwin":
+ args = ["open"]
+ if wait:
+ args.append("-W")
+ if locate:
+ args.append("-R")
+ args.append(_unquote_file(url))
+ null = open("/dev/null", "w")
+ try:
+ return subprocess.Popen(args, stderr=null).wait()
+ finally:
+ null.close()
+ elif WIN:
+ if locate:
+ url = _unquote_file(url)
+ args = ["explorer", f"/select,{url}"]
+ else:
+ args = ["start"]
+ if wait:
+ args.append("/WAIT")
+ args.append("")
+ args.append(url)
+ try:
+ return subprocess.call(args)
+ except OSError:
+ # Command not found
+ return 127
+ elif CYGWIN: # pragma: no cover
+ if locate:
+ url = _unquote_file(url)
+ args = ["cygstart", os.path.dirname(url)]
+ else:
+ args = ["cygstart"]
+ if wait:
+ args.append("-w")
+ args.append(url)
+ try:
+ return subprocess.call(args)
+ except OSError:
+ # Command not found
+ return 127
+
+ try:
+ if locate:
+ url = os.path.dirname(_unquote_file(url)) or "."
+ else:
+ url = _unquote_file(url)
+ c = subprocess.Popen(["xdg-open", url])
+ if wait:
+ return c.wait()
+ return 0
+ except OSError: # pragma: no cover
+ # TODO: remove this part, doesn't get hit by Typer code paths?
+ if url.startswith(("http://", "https://")) and not locate and not wait:
+ import webbrowser
+
+ webbrowser.open(url)
+ return 0
+ return 1
+
+
+def _translate_ch_to_exc(ch: str) -> None:
+ if ch == "\x03":
+ raise KeyboardInterrupt()
+
+ if ch == "\x04" and not WIN: # Unix-like, Ctrl+D
+ raise EOFError()
+
+ if ch == "\x1a" and WIN: # Windows, Ctrl+Z
+ raise EOFError()
+
+ return None
+
+
+if sys.platform == "win32":
+ import msvcrt
+
+ @contextlib.contextmanager
+ def raw_terminal() -> Iterator[int]:
+ yield -1
+
+ def getchar(echo: bool) -> str:
+ # The function `getch` will return a bytes object corresponding to
+ # the pressed character. Since Windows 10 build 1803, it will also
+ # return \x00 when called a second time after pressing a regular key.
+ #
+ # `getwch` does not share this probably-bugged behavior. Moreover, it
+ # returns a Unicode object by default, which is what we want.
+ #
+ # Either of these functions will return \x00 or \xe0 to indicate
+ # a special key, and you need to call the same function again to get
+ # the "rest" of the code. The fun part is that \u00e0 is
+ # "latin small letter a with grave", so if you type that on a French
+ # keyboard, you _also_ get a \xe0.
+ # E.g., consider the Up arrow. This returns \xe0 and then \x48. The
+ # resulting Unicode string reads as "a with grave" + "capital H".
+ # This is indistinguishable from when the user actually types
+ # "a with grave" and then "capital H".
+ #
+ # When \xe0 is returned, we assume it's part of a special-key sequence
+ # and call `getwch` again, but that means that when the user types
+ # the \u00e0 character, `getchar` doesn't return until a second
+ # character is typed.
+ # The alternative is returning immediately, but that would mess up
+ # cross-platform handling of arrow keys and others that start with
+ # \xe0. Another option is using `getch`, but then we can't reliably
+ # read non-ASCII characters, because return values of `getch` are
+ # limited to the current 8-bit codepage.
+ #
+ # Anyway, Click doesn't claim to do this Right(tm), and using `getwch`
+ # is doing the right thing in more situations than with `getch`.
+
+ if echo:
+ func = cast(Callable[[], str], msvcrt.getwche)
+ else:
+ func = cast(Callable[[], str], msvcrt.getwch)
+
+ rv = func()
+
+ if rv in ("\x00", "\xe0"):
+ # \x00 and \xe0 are control characters that indicate special key,
+ # see above.
+ rv += func()
+
+ _translate_ch_to_exc(rv)
+ return rv
+
+else:
+ import termios
+ import tty
+
+ @contextlib.contextmanager
+ def raw_terminal() -> Iterator[int]:
+ f: TextIO | None
+ fd: int
+
+ if not isatty(sys.stdin):
+ f = open("/dev/tty")
+ fd = f.fileno()
+ else:
+ fd = sys.stdin.fileno()
+ f = None
+
+ try:
+ old_settings = termios.tcgetattr(fd)
+
+ try:
+ tty.setraw(fd)
+ yield fd
+ finally:
+ termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
+ sys.stdout.flush()
+
+ if f is not None:
+ f.close()
+ except termios.error: # pragma: no cover
+ pass
+
+ def getchar(echo: bool) -> str:
+ with raw_terminal() as fd:
+ ch = os.read(fd, 32).decode(get_best_encoding(sys.stdin), "replace")
+
+ if echo and isatty(sys.stdout): # pragma: no cover
+ sys.stdout.write(ch)
+
+ _translate_ch_to_exc(ch)
+ return ch
diff --git a/typer/_click/_textwrap.py b/typer/_click/_textwrap.py
new file mode 100644
index 0000000000..9f9636b31f
--- /dev/null
+++ b/typer/_click/_textwrap.py
@@ -0,0 +1,46 @@
+import textwrap
+from collections.abc import Iterator
+from contextlib import contextmanager
+
+
+class TextWrapper(textwrap.TextWrapper):
+ def _handle_long_word(
+ self,
+ reversed_chunks: list[str],
+ cur_line: list[str],
+ cur_len: int,
+ width: int,
+ ) -> None:
+ space_left = max(width - cur_len, 1)
+
+ last = reversed_chunks[-1]
+ cut = last[:space_left]
+ res = last[space_left:]
+ cur_line.append(cut)
+ reversed_chunks[-1] = res
+
+ @contextmanager
+ def extra_indent(self, indent: str) -> Iterator[None]:
+ old_initial_indent = self.initial_indent
+ old_subsequent_indent = self.subsequent_indent
+ self.initial_indent += indent
+ self.subsequent_indent += indent
+
+ try:
+ yield
+ finally:
+ self.initial_indent = old_initial_indent
+ self.subsequent_indent = old_subsequent_indent
+
+ def indent_only(self, text: str) -> str:
+ rv = []
+
+ for idx, line in enumerate(text.splitlines()):
+ indent = self.initial_indent
+
+ if idx > 0:
+ indent = self.subsequent_indent
+
+ rv.append(f"{indent}{line}")
+
+ return "\n".join(rv)
diff --git a/typer/_click/_winconsole.py b/typer/_click/_winconsole.py
new file mode 100644
index 0000000000..f6dedbfd6e
--- /dev/null
+++ b/typer/_click/_winconsole.py
@@ -0,0 +1,300 @@
+# This module is based on the excellent work by Adam Bartoš who
+# provided a lot of what went into the implementation here in
+# the discussion to issue1602 in the Python bug tracker.
+#
+# There are some general differences in regards to how this works
+# compared to the original patches as we do not need to patch
+# the entire interpreter but just work in our little world of
+# echo and prompt.
+import io
+import sys
+import time
+from collections.abc import Callable, Iterable, Mapping
+from ctypes import (
+ POINTER,
+ Array,
+ Structure,
+ byref,
+ c_char,
+ c_char_p,
+ c_int,
+ c_ssize_t,
+ c_ulong,
+ c_void_p,
+ py_object,
+)
+from ctypes.wintypes import DWORD, HANDLE, LPCWSTR, LPWSTR
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ AnyStr,
+ BinaryIO,
+ Literal,
+ TextIO,
+ cast,
+)
+
+from ._compat import _NonClosingTextIOWrapper
+
+assert sys.platform == "win32"
+import msvcrt # noqa: E402
+from ctypes import WINFUNCTYPE, windll # noqa: E402
+
+c_ssize_p = POINTER(c_ssize_t)
+
+kernel32 = windll.kernel32
+GetStdHandle = kernel32.GetStdHandle
+ReadConsoleW = kernel32.ReadConsoleW
+WriteConsoleW = kernel32.WriteConsoleW
+GetConsoleMode = kernel32.GetConsoleMode
+GetLastError = kernel32.GetLastError
+GetCommandLineW = WINFUNCTYPE(LPWSTR)(("GetCommandLineW", windll.kernel32))
+CommandLineToArgvW = WINFUNCTYPE(POINTER(LPWSTR), LPCWSTR, POINTER(c_int))(
+ ("CommandLineToArgvW", windll.shell32)
+)
+LocalFree = WINFUNCTYPE(c_void_p, c_void_p)(("LocalFree", windll.kernel32))
+
+STDIN_HANDLE = GetStdHandle(-10)
+STDOUT_HANDLE = GetStdHandle(-11)
+STDERR_HANDLE = GetStdHandle(-12)
+
+PyBUF_SIMPLE = 0
+PyBUF_WRITABLE = 1
+
+ERROR_SUCCESS = 0
+ERROR_NOT_ENOUGH_MEMORY = 8
+ERROR_OPERATION_ABORTED = 995
+
+STDIN_FILENO = 0
+STDOUT_FILENO = 1
+STDERR_FILENO = 2
+
+EOF = b"\x1a"
+MAX_BYTES_WRITTEN = 32767
+
+if TYPE_CHECKING:
+ try:
+ # Using `typing_extensions.Buffer` instead of `collections.abc`
+ # on Windows for some reason does not have `Sized` implemented.
+ from collections.abc import Buffer # type: ignore
+ except ImportError:
+ from typing_extensions import Buffer
+
+try:
+ from ctypes import pythonapi
+except ImportError: # pragma: no cover
+ # On PyPy we cannot get buffers so our ability to operate here is
+ # severely limited.
+ get_buffer = None
+else:
+
+ class Py_buffer(Structure):
+ _fields_ = [ # noqa: RUF012
+ ("buf", c_void_p),
+ ("obj", py_object),
+ ("len", c_ssize_t),
+ ("itemsize", c_ssize_t),
+ ("readonly", c_int),
+ ("ndim", c_int),
+ ("format", c_char_p),
+ ("shape", c_ssize_p),
+ ("strides", c_ssize_p),
+ ("suboffsets", c_ssize_p),
+ ("internal", c_void_p),
+ ]
+
+ PyObject_GetBuffer = pythonapi.PyObject_GetBuffer
+ PyBuffer_Release = pythonapi.PyBuffer_Release
+
+ def get_buffer(obj: "Buffer", writable: bool = False) -> Array[c_char]:
+ buf = Py_buffer()
+ flags: int = PyBUF_WRITABLE if writable else PyBUF_SIMPLE
+ PyObject_GetBuffer(py_object(obj), byref(buf), flags)
+
+ try:
+ buffer_type = c_char * buf.len
+ out: Array[c_char] = buffer_type.from_address(buf.buf)
+ return out
+ finally:
+ PyBuffer_Release(byref(buf))
+
+
+class _WindowsConsoleRawIOBase(io.RawIOBase):
+ def __init__(self, handle: int | None) -> None:
+ self.handle = handle
+
+ def isatty(self) -> Literal[True]:
+ super().isatty()
+ return True
+
+
+class _WindowsConsoleReader(_WindowsConsoleRawIOBase):
+ def readable(self) -> Literal[True]:
+ return True
+
+ def readinto(self, b: "Buffer") -> int:
+ bytes_to_be_read = len(b)
+ if not bytes_to_be_read:
+ return 0
+ elif bytes_to_be_read % 2:
+ raise ValueError(
+ "cannot read odd number of bytes from UTF-16-LE encoded console"
+ )
+
+ buffer = get_buffer(b, writable=True)
+ code_units_to_be_read = bytes_to_be_read // 2
+ code_units_read = c_ulong()
+
+ rv = ReadConsoleW(
+ HANDLE(self.handle),
+ buffer,
+ code_units_to_be_read,
+ byref(code_units_read),
+ None,
+ )
+ if GetLastError() == ERROR_OPERATION_ABORTED:
+ # wait for KeyboardInterrupt
+ time.sleep(0.1)
+ if not rv:
+ raise OSError(f"Windows error: {GetLastError()}")
+
+ if buffer[0] == EOF:
+ return 0
+ return 2 * code_units_read.value
+
+
+class _WindowsConsoleWriter(_WindowsConsoleRawIOBase):
+ def writable(self) -> Literal[True]:
+ return True
+
+ @staticmethod
+ def _get_error_message(errno: int) -> str:
+ if errno == ERROR_SUCCESS:
+ return "ERROR_SUCCESS"
+ elif errno == ERROR_NOT_ENOUGH_MEMORY:
+ return "ERROR_NOT_ENOUGH_MEMORY"
+ return f"Windows error {errno}"
+
+ def write(self, b: "Buffer") -> int:
+ bytes_to_be_written = len(b)
+ buf = get_buffer(b)
+ code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2
+ code_units_written = c_ulong()
+
+ WriteConsoleW(
+ HANDLE(self.handle),
+ buf,
+ code_units_to_be_written,
+ byref(code_units_written),
+ None,
+ )
+ bytes_written = 2 * code_units_written.value
+
+ if bytes_written == 0 and bytes_to_be_written > 0:
+ raise OSError(self._get_error_message(GetLastError())) # pragma: no cover
+ return bytes_written
+
+
+class ConsoleStream:
+ def __init__(self, text_stream: TextIO, byte_stream: BinaryIO) -> None:
+ self._text_stream = text_stream
+ self.buffer = byte_stream
+
+ @property
+ def name(self) -> str:
+ return self.buffer.name
+
+ def write(self, x: AnyStr) -> int:
+ if isinstance(x, str):
+ return self._text_stream.write(x)
+ try:
+ self.flush()
+ except Exception: # pragma: no cover
+ pass
+ return self.buffer.write(x)
+
+ def writelines(self, lines: Iterable[AnyStr]) -> None:
+ for line in lines:
+ self.write(line)
+
+ def __getattr__(self, name: str) -> Any:
+ return getattr(self._text_stream, name)
+
+ def isatty(self) -> bool:
+ return self.buffer.isatty()
+
+ def __repr__(self) -> str:
+ return f""
+
+
+def _get_text_stdin(buffer_stream: BinaryIO) -> TextIO:
+ text_stream = _NonClosingTextIOWrapper(
+ io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)),
+ "utf-16-le",
+ "strict",
+ line_buffering=True,
+ )
+ return cast(TextIO, ConsoleStream(text_stream, buffer_stream))
+
+
+def _get_text_stdout(buffer_stream: BinaryIO) -> TextIO:
+ text_stream = _NonClosingTextIOWrapper(
+ io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)),
+ "utf-16-le",
+ "strict",
+ line_buffering=True,
+ )
+ return cast(TextIO, ConsoleStream(text_stream, buffer_stream))
+
+
+def _get_text_stderr(buffer_stream: BinaryIO) -> TextIO:
+ text_stream = _NonClosingTextIOWrapper(
+ io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)),
+ "utf-16-le",
+ "strict",
+ line_buffering=True,
+ )
+ return cast(TextIO, ConsoleStream(text_stream, buffer_stream))
+
+
+_stream_factories: Mapping[int, Callable[[BinaryIO], TextIO]] = {
+ 0: _get_text_stdin,
+ 1: _get_text_stdout,
+ 2: _get_text_stderr,
+}
+
+
+def _is_console(f: TextIO) -> bool:
+ if not hasattr(f, "fileno"):
+ return False
+
+ try:
+ fileno = f.fileno()
+ except (OSError, io.UnsupportedOperation):
+ return False
+
+ handle = msvcrt.get_osfhandle(fileno)
+ return bool(GetConsoleMode(handle, byref(DWORD())))
+
+
+def _get_windows_console_stream(
+ f: TextIO, encoding: str | None, errors: str | None
+) -> TextIO | None:
+ if (
+ get_buffer is None
+ or encoding not in {"utf-16-le", None}
+ or errors not in {"strict", None}
+ or not _is_console(f)
+ ):
+ return None
+
+ func = _stream_factories.get(f.fileno())
+ if func is None:
+ return None
+
+ b = getattr(f, "buffer", None)
+
+ if b is None:
+ return None
+
+ return func(b)
diff --git a/typer/_click/core.py b/typer/_click/core.py
new file mode 100644
index 0000000000..580b558b9f
--- /dev/null
+++ b/typer/_click/core.py
@@ -0,0 +1,1111 @@
+import enum
+import inspect
+import os
+from abc import ABC, abstractmethod
+from collections.abc import Callable, Iterator, Mapping, MutableMapping, Sequence
+from contextlib import AbstractContextManager, ExitStack, contextmanager
+from types import TracebackType
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Literal,
+ NoReturn,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+)
+
+from . import types
+from .exceptions import (
+ Abort,
+ BadParameter,
+ Exit,
+ MissingParameter,
+ NoArgsIsHelpError,
+ UsageError,
+)
+from .formatting import HelpFormatter
+from .globals import pop_context, push_context
+from .parser import _OptionParser
+from .termui import style
+from .utils import echo, make_default_short_help
+
+if TYPE_CHECKING:
+ from ..core import TyperOption
+ from .shell_completion import CompletionItem
+
+F = TypeVar("F", bound="Callable[..., Any]")
+V = TypeVar("V")
+
+
+def _complete_visible_commands(
+ ctx: "Context", incomplete: str
+) -> Iterator[tuple[str, "Command"]]:
+ """List all the subcommands of a group that start with the
+ incomplete value and aren't hidden.
+ """
+ # avoid circular imports
+ from ..core import TyperGroup
+
+ multi = cast(TyperGroup, ctx.command)
+
+ for name in multi.list_commands(ctx):
+ if name.startswith(incomplete):
+ command = multi.get_command(ctx, name)
+
+ if command is not None and not command.hidden:
+ yield name, command
+
+
+@contextmanager
+def augment_usage_errors(
+ ctx: "Context", param: Union["Parameter", None] = None
+) -> Iterator[None]:
+ """Context manager that attaches extra information to exceptions."""
+ try:
+ yield
+ except BadParameter as e:
+ if e.ctx is None:
+ e.ctx = ctx
+ if param is not None and e.param is None:
+ e.param = param
+ raise
+ except UsageError as e: # pragma: no cover
+ if e.ctx is None:
+ e.ctx = ctx
+ raise
+
+
+def iter_params_for_processing(
+ invocation_order: Sequence["Parameter"],
+ declaration_order: Sequence["Parameter"],
+) -> list["Parameter"]:
+ """Returns all declared parameters in the order they should be processed.
+
+ The declared parameters are re-shuffled depending on the order in which
+ they were invoked, as well as the eagerness of each parameters.
+
+ The invocation order takes precedence over the declaration order. I.e. the
+ order in which the user provided them to the CLI is respected.
+
+ This behavior and its effect on callback evaluation is detailed at:
+ https://click.palletsprojects.com/en/stable/advanced/#callback-evaluation-order
+ """
+
+ def sort_key(item: Parameter) -> tuple[bool, float]:
+ try:
+ idx: float = invocation_order.index(item)
+ except ValueError:
+ idx = float("inf")
+
+ return not item.is_eager, idx
+
+ return sorted(declaration_order, key=sort_key)
+
+
+class ParameterSource(enum.Enum):
+ """This is an `Enum` that indicates the source of a
+ parameter's value.
+ """
+
+ COMMANDLINE = enum.auto()
+ """The value was provided by the command line args."""
+ ENVIRONMENT = enum.auto()
+ """The value was provided with an environment variable."""
+ DEFAULT = enum.auto()
+ """Used the default specified by the parameter."""
+ DEFAULT_MAP = enum.auto()
+ """Used a default provided by `Context.default_map`."""
+ PROMPT = enum.auto()
+ """Used a prompt to confirm a default or provide a value."""
+
+
+class Context:
+ """The context is a special internal object that holds state relevant
+ for the script execution at every single level. It's normally invisible
+ to commands unless they opt-in to getting access to it.
+
+ The context is useful as it can pass internal objects around and can
+ control special execution features such as reading data from
+ environment variables.
+
+ A context can be used as context manager in which case it will call
+ `close` on teardown.
+ """
+
+ formatter_class: type[HelpFormatter] = HelpFormatter
+
+ def __init__(
+ self,
+ command: "Command",
+ parent: Union["Context", None] = None,
+ info_name: str | None = None,
+ obj: Any | None = None,
+ auto_envvar_prefix: str | None = None,
+ default_map: MutableMapping[str, Any] | None = None,
+ terminal_width: int | None = None,
+ max_content_width: int | None = None,
+ resilient_parsing: bool = False,
+ allow_extra_args: bool | None = None,
+ allow_interspersed_args: bool | None = None,
+ ignore_unknown_options: bool | None = None,
+ help_option_names: list[str] | None = None,
+ token_normalize_func: Callable[[str], str] | None = None,
+ color: bool | None = None,
+ show_default: bool | None = None,
+ ) -> None:
+ self.parent = parent
+ self.command = command
+ self.info_name = info_name
+ # Map of parameter names to their parsed values.
+ self.params: dict[str, Any] = {}
+ # the leftover arguments.
+ self.args: list[str] = []
+ # protected arguments. used to implement nested parsing.
+ self._protected_args: list[str] = []
+ # the collected prefixes of the command's options.
+ self._opt_prefixes: set[str] = set(parent._opt_prefixes) if parent else set()
+
+ if obj is None and parent is not None:
+ obj = parent.obj
+
+ self.obj: Any = obj
+ self._meta: dict[str, Any] = getattr(parent, "meta", {})
+
+ # A dictionary (-like object) with defaults for parameters.
+ if (
+ default_map is None
+ and info_name is not None
+ and parent is not None
+ and parent.default_map is not None
+ ):
+ default_map = parent.default_map.get(info_name)
+
+ self.default_map: MutableMapping[str, Any] | None = default_map
+
+ # This flag indicates if a subcommand is going to be executed.
+ self.invoked_subcommand: str | None = None
+
+ if terminal_width is None and parent is not None:
+ terminal_width = parent.terminal_width
+
+ # The width of the terminal (None is autodetection).
+ self.terminal_width: int | None = terminal_width
+
+ if max_content_width is None and parent is not None:
+ max_content_width = parent.max_content_width
+
+ self.max_content_width: int | None = max_content_width
+
+ if allow_extra_args is None:
+ allow_extra_args = command.allow_extra_args
+
+ self.allow_extra_args = allow_extra_args
+
+ if allow_interspersed_args is None:
+ allow_interspersed_args = command.allow_interspersed_args
+
+ self.allow_interspersed_args: bool = allow_interspersed_args
+
+ if ignore_unknown_options is None:
+ ignore_unknown_options = command.ignore_unknown_options
+
+ self.ignore_unknown_options: bool = ignore_unknown_options
+
+ if help_option_names is None:
+ if parent is not None:
+ help_option_names = parent.help_option_names
+ else:
+ help_option_names = ["--help"]
+
+ self.help_option_names: list[str] = help_option_names
+
+ if token_normalize_func is None and parent is not None:
+ token_normalize_func = parent.token_normalize_func
+
+ # An optional normalization function for tokens. (options, choices, commands etc.)
+ self.token_normalize_func: Callable[[str], str] | None = token_normalize_func
+
+ # Indicates if resilient parsing is enabled.
+ self.resilient_parsing: bool = resilient_parsing
+
+ # If there is no envvar prefix yet, but the parent has one and
+ # the command on this level has a name, we can expand the envvar
+ # prefix automatically.
+ if auto_envvar_prefix is None:
+ if (
+ parent is not None
+ and parent.auto_envvar_prefix is not None
+ and self.info_name is not None
+ ):
+ auto_envvar_prefix = (
+ f"{parent.auto_envvar_prefix}_{self.info_name.upper()}"
+ )
+ else:
+ auto_envvar_prefix = auto_envvar_prefix.upper()
+
+ if auto_envvar_prefix is not None:
+ auto_envvar_prefix = auto_envvar_prefix.replace("-", "_")
+
+ self.auto_envvar_prefix: str | None = auto_envvar_prefix
+
+ if color is None and parent is not None:
+ color = parent.color
+
+ # Controls if styling output is wanted or not.
+ self.color: bool | None = color
+
+ if show_default is None and parent is not None:
+ show_default = parent.show_default
+
+ # Show option default values when formatting help text.
+ self.show_default: bool | None = show_default
+
+ self._close_callbacks: list[Callable[[], Any]] = []
+ self._depth = 0
+ self._parameter_source: dict[str, ParameterSource] = {}
+ self._exit_stack = ExitStack()
+
+ def __enter__(self) -> "Context":
+ self._depth += 1
+ push_context(self)
+ return self
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ tb: TracebackType | None,
+ ) -> bool | None:
+ self._depth -= 1
+ exit_result: bool | None = None
+ if self._depth == 0:
+ exit_result = self._close_with_exception_info(exc_type, exc_value, tb)
+ pop_context()
+
+ return exit_result
+
+ @contextmanager
+ def scope(self, cleanup: bool = True) -> Iterator["Context"]:
+ """This helper method can be used with the context object to promote
+ it to the current thread local (see `get_current_context`).
+ The default behavior of this is to invoke the cleanup functions which
+ can be disabled by setting `cleanup` to `False`. The cleanup
+ functions are typically used for things such as closing file handles.
+
+ If the cleanup is intended the context object can also be directly
+ used as a context manager.
+ """
+ if not cleanup:
+ self._depth += 1
+ try:
+ with self as rv:
+ yield rv
+ finally:
+ if not cleanup:
+ self._depth -= 1
+
+ @property
+ def meta(self) -> dict[str, Any]:
+ """This is a dictionary which is shared with all the contexts
+ that are nested. It exists so that click utilities can store some
+ state here if they need to. It is however the responsibility of
+ that code to manage this dictionary well.
+
+ The keys are supposed to be unique dotted strings. For instance
+ module paths are a good choice for it. What is stored in there is
+ irrelevant for the operation of click. However what is important is
+ that code that places data here adheres to the general semantics of
+ the system.
+ """
+ return self._meta
+
+ def make_formatter(self) -> HelpFormatter:
+ """Creates the HelpFormatter for the help and
+ usage output.
+ """
+ return self.formatter_class(
+ width=self.terminal_width, max_width=self.max_content_width
+ )
+
+ def with_resource(self, context_manager: AbstractContextManager[V]) -> V:
+ """Register a resource as if it were used in a ``with``
+ statement. The resource will be cleaned up when the context is
+ popped.
+
+ Uses `contextlib.ExitStack.enter_context`. It calls the
+ resource's ``__enter__()`` method and returns the result. When
+ the context is popped, it closes the stack, which calls the
+ resource's ``__exit__()`` method.
+
+ To register a cleanup function for something that isn't a
+ context manager, use `call_on_close`. Or use something
+ from `contextlib` to turn it into a context manager first.
+ """
+ return self._exit_stack.enter_context(context_manager)
+
+ def call_on_close(self, f: Callable[..., Any]) -> Callable[..., Any]:
+ """Register a function to be called when the context tears down.
+
+ This can be used to close resources opened during the script
+ execution. Resources that support Python's context manager
+ protocol which would be used in a ``with`` statement should be
+ registered with `with_resource` instead.
+ """
+ return self._exit_stack.callback(f)
+
+ def close(self) -> None:
+ """Invoke all close callbacks registered with `call_on_close`,
+ and exit all context managers entered with `with_resource`.
+ """
+ self._close_with_exception_info(None, None, None)
+
+ def _close_with_exception_info(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ tb: TracebackType | None,
+ ) -> bool | None:
+ """Unwind the exit stack by calling its `__exit__` providing the exception
+ information to allow for exception handling by the various resources registered
+ using `with_resource`
+ """
+ exit_result = self._exit_stack.__exit__(exc_type, exc_value, tb)
+ # In case the context is reused, create a new exit stack.
+ self._exit_stack = ExitStack()
+
+ return exit_result
+
+ @property
+ def command_path(self) -> str:
+ """The computed command path. This is used for the ``usage``
+ information on the help page. It's automatically created by
+ combining the info names of the chain of contexts to the root.
+ """
+ rv = ""
+ if self.info_name is not None:
+ rv = self.info_name
+ if self.parent is not None:
+ parent_command_path = [self.parent.command_path]
+
+ if isinstance(self.parent.command, Command):
+ for param in self.parent.command.get_params(self):
+ parent_command_path.extend(param.get_usage_pieces(self))
+
+ rv = f"{' '.join(parent_command_path)} {rv}"
+ return rv.lstrip()
+
+ def find_root(self) -> "Context":
+ """Finds the outermost context."""
+ node = self
+ while node.parent is not None:
+ node = node.parent
+ return node
+
+ def find_object(self, object_type: type[V]) -> V | None:
+ """Finds the closest object of a given type."""
+ node: Context | None = self
+
+ while node is not None:
+ if isinstance(node.obj, object_type):
+ return node.obj
+
+ node = node.parent
+
+ return None
+
+ def ensure_object(self, object_type: type[V]) -> V:
+ """Like `find_object` but sets the innermost object to a
+ new instance of `object_type` if it does not exist.
+ """
+ rv = self.find_object(object_type)
+ if rv is None:
+ self.obj = rv = object_type()
+ return rv
+
+ @overload
+ def lookup_default(self, name: str, call: Literal[True] = True) -> Any | None: ...
+
+ @overload
+ def lookup_default(
+ self, name: str, call: Literal[False] = ...
+ ) -> Any | Callable[[], Any] | None: ...
+
+ def lookup_default(self, name: str, call: bool = True) -> Any | None:
+ """Get the default for a parameter from `default_map`."""
+ if self.default_map is not None:
+ value = self.default_map.get(name)
+
+ if call and callable(value):
+ return value()
+
+ return value
+
+ return None
+
+ def fail(self, message: str) -> NoReturn:
+ """Aborts the execution of the program with a specific error
+ message.
+ """
+ raise UsageError(message, self)
+
+ def abort(self) -> NoReturn:
+ """Aborts the script."""
+ raise Abort()
+
+ def exit(self, code: int = 0) -> NoReturn:
+ """Exits the application with a given exit code."""
+ self.close()
+ raise Exit(code)
+
+ def get_usage(self) -> str:
+ """Helper method to get formatted usage string for the current
+ context and command.
+ """
+ return self.command.get_usage(self)
+
+ def get_help(self) -> str:
+ """Helper method to get formatted help page for the current
+ context and command.
+ """
+ return self.command.get_help(self)
+
+ def invoke(self, callback: Callable[..., V], /, *args: Any, **kwargs: Any) -> V:
+ """Invokes a command callback in exactly the way it expects. There
+ are two ways to invoke this method:
+
+ 1. the first argument can be a callback and all other arguments and
+ keyword arguments are forwarded directly to the function.
+ 2. the first argument is a click command object. In that case all
+ arguments are forwarded as well but proper click parameters
+ (options and click arguments) must be keyword arguments and Click
+ will fill in defaults.
+ """
+ ctx = self
+
+ with augment_usage_errors(self):
+ with ctx:
+ return callback(*args, **kwargs)
+
+ def set_parameter_source(self, name: str, source: ParameterSource) -> None:
+ """Set the source of a parameter. This indicates the location
+ from which the value of the parameter was obtained.
+ """
+ self._parameter_source[name] = source
+
+ def get_parameter_source(self, name: str) -> ParameterSource | None:
+ """Get the source of a parameter. This indicates the location
+ from which the value of the parameter was obtained.
+
+ This can be useful for determining when a user specified a value
+ on the command line that is the same as the default value. It
+ will be `ParameterSource.DEFAULT` only if the
+ value was actually taken from the default.
+ """
+ return self._parameter_source.get(name)
+
+
+class Command(ABC):
+ """Commands are the basic building block of command line interfaces in
+ Click. A basic command handles command line parsing and might dispatch
+ more parsing to commands nested below it.
+ """
+
+ context_class: type[Context] = Context
+ allow_extra_args = False
+ allow_interspersed_args = True
+ ignore_unknown_options = False
+
+ def __init__(
+ self,
+ name: str | None,
+ context_settings: MutableMapping[str, Any] | None = None,
+ callback: Callable[..., Any] | None = None,
+ params: list["Parameter"] | None = None,
+ help: str | None = None,
+ epilog: str | None = None,
+ short_help: str | None = None,
+ options_metavar: str | None = "[OPTIONS]",
+ add_help_option: bool = True,
+ no_args_is_help: bool = False,
+ hidden: bool = False,
+ deprecated: bool | str = False,
+ ) -> None:
+ self.name = name
+
+ if context_settings is None:
+ context_settings = {}
+
+ self.context_settings: MutableMapping[str, Any] = context_settings
+
+ self.callback = callback
+ self.params: list[Parameter] = params or []
+ self.help = help
+ self.epilog = epilog
+ self.options_metavar = options_metavar
+ self.short_help = short_help
+ self.add_help_option = add_help_option
+ self._help_option: TyperOption | None = None
+ self.no_args_is_help = no_args_is_help
+ self.hidden = hidden
+ self.deprecated = deprecated
+
+ def __repr__(self) -> str:
+ return f"<{self.__class__.__name__} {self.name}>"
+
+ def get_usage(self, ctx: Context) -> str:
+ """Formats the usage line into a string and returns it."""
+ formatter = ctx.make_formatter()
+ self.format_usage(ctx, formatter)
+ return formatter.getvalue().rstrip("\n")
+
+ def get_params(self, ctx: Context) -> list["Parameter"]:
+ params = self.params
+ help_option = self.get_help_option(ctx)
+
+ if help_option is not None:
+ params = [*params, help_option]
+
+ return params
+
+ def format_usage(self, ctx: Context, formatter: HelpFormatter) -> None:
+ """Writes the usage line into the formatter."""
+ pieces = self.collect_usage_pieces(ctx)
+ formatter.write_usage(ctx.command_path, " ".join(pieces))
+
+ def collect_usage_pieces(self, ctx: Context) -> list[str]:
+ """Returns all the pieces that go into the usage line and returns
+ it as a list of strings.
+ """
+ rv = [self.options_metavar] if self.options_metavar else []
+
+ for param in self.get_params(ctx):
+ rv.extend(param.get_usage_pieces(ctx))
+
+ return rv
+
+ def get_help_option_names(self, ctx: Context) -> list[str]:
+ """Returns the names for the help option."""
+ all_names = set(ctx.help_option_names)
+ for param in self.params:
+ all_names.difference_update(param.opts)
+ all_names.difference_update(param.secondary_opts)
+ return list(all_names)
+
+ def get_help_option(self, ctx: Context) -> Union["TyperOption", None]:
+ """Returns the help option object."""
+ help_option_names = self.get_help_option_names(ctx)
+
+ if not help_option_names or not self.add_help_option:
+ return None
+
+ # Cache the help option object in private _help_option attribute to
+ # avoid creating it multiple times. Not doing this will break the
+ # callback ordering by iter_params_for_processing(), which relies on
+ # object comparison.
+ if self._help_option is None:
+ # Avoid circular import.
+ from .decorators import help_option
+
+ # Apply help_option decorator and pop resulting option
+ help_option(help_option_names)(self)
+ self._help_option = cast("TyperOption", self.params.pop())
+
+ return self._help_option
+
+ def make_parser(self, ctx: Context) -> _OptionParser:
+ """Creates the underlying option parser for this command."""
+ parser = _OptionParser(ctx)
+ for param in self.get_params(ctx):
+ param.add_to_parser(parser, ctx)
+ return parser
+
+ def get_help(self, ctx: Context) -> str:
+ """Formats the help into a string and returns it."""
+ formatter = ctx.make_formatter()
+ self.format_help(ctx, formatter)
+ return formatter.getvalue().rstrip("\n")
+
+ def get_short_help_str(self, limit: int = 45) -> str:
+ """Gets short help for the command or makes it by shortening the
+ long help string.
+ """
+ if self.short_help:
+ text = inspect.cleandoc(self.short_help)
+ elif self.help:
+ text = make_default_short_help(self.help, limit)
+ else:
+ text = ""
+
+ if self.deprecated:
+ deprecated_message = (
+ f"(DEPRECATED: {self.deprecated})"
+ if isinstance(self.deprecated, str)
+ else "(DEPRECATED)"
+ )
+ text = f"{text} {deprecated_message}"
+
+ return text.strip()
+
+ def format_help(self, ctx: Context, formatter: HelpFormatter) -> None:
+ """Writes the help into the formatter if it exists."""
+ self.format_usage(ctx, formatter)
+ self.format_help_text(ctx, formatter)
+ self.format_options(ctx, formatter)
+ self.format_epilog(ctx, formatter)
+
+ def format_help_text(self, ctx: Context, formatter: HelpFormatter) -> None:
+ """Writes the help text to the formatter if it exists."""
+ if self.help is not None:
+ # truncate the help text to the first form feed
+ text = inspect.cleandoc(self.help).partition("\f")[0]
+ else:
+ text = ""
+
+ if self.deprecated:
+ deprecated_message = (
+ f"(DEPRECATED: {self.deprecated})"
+ if isinstance(self.deprecated, str)
+ else "(DEPRECATED)"
+ )
+ text = f"{text} {deprecated_message}"
+
+ if text:
+ formatter.write_paragraph()
+
+ with formatter.indentation():
+ formatter.write_text(text)
+
+ @abstractmethod
+ def format_options(self, ctx: Context, formatter: HelpFormatter) -> None:
+ pass # pragma: no cover
+
+ def format_epilog(self, ctx: Context, formatter: HelpFormatter) -> None:
+ """Writes the epilog into the formatter if it exists."""
+ if self.epilog:
+ epilog = inspect.cleandoc(self.epilog)
+ formatter.write_paragraph()
+
+ with formatter.indentation():
+ formatter.write_text(epilog)
+
+ def make_context(
+ self,
+ info_name: str | None,
+ args: list[str],
+ parent: Context | None = None,
+ **extra: Any,
+ ) -> Context:
+ """This function when given an info name and arguments will kick
+ off the parsing and create a new `Context`. It does not
+ invoke the actual command callback though.
+
+ To quickly customize the context class used without overriding
+ this method, set the `context_class` attribute.
+ """
+ for key, value in self.context_settings.items():
+ if key not in extra:
+ extra[key] = value
+
+ ctx = self.context_class(self, info_name=info_name, parent=parent, **extra)
+
+ with ctx.scope(cleanup=False):
+ self.parse_args(ctx, args)
+ return ctx
+
+ def parse_args(self, ctx: Context, args: list[str]) -> list[str]:
+ if not args and self.no_args_is_help and not ctx.resilient_parsing:
+ raise NoArgsIsHelpError(ctx) # pragma: no cover
+
+ parser = self.make_parser(ctx)
+ opts, args, param_order = parser.parse_args(args=args)
+
+ for param in iter_params_for_processing(param_order, self.get_params(ctx)):
+ _, args = param.handle_parse_result(ctx, opts, args)
+
+ if args and not ctx.allow_extra_args and not ctx.resilient_parsing:
+ ctx.fail(f"Got unexpected extra argument(s) ({' '.join(map(str, args))})")
+
+ ctx.args = args
+ ctx._opt_prefixes.update(parser._opt_prefixes)
+ return args
+
+ def invoke(self, ctx: Context) -> Any:
+ """Given a context, this invokes the attached callback (if it exists)
+ in the right way.
+ """
+ if self.deprecated:
+ extra_message = (
+ f" {self.deprecated}" if isinstance(self.deprecated, str) else ""
+ )
+ message = f"DeprecationWarning: The command {self.name!r} is deprecated.{extra_message}"
+ echo(style(message, fg="red"), err=True)
+
+ if self.callback is not None:
+ return ctx.invoke(self.callback, **ctx.params)
+
+ def shell_complete(self, ctx: Context, incomplete: str) -> list["CompletionItem"]:
+ """Return a list of completions for the incomplete value. Looks
+ at the names of options and chained multi-commands.
+
+ Any command could be part of a chained multi-command, so sibling
+ commands are valid at any point during command completion.
+ """
+ # avoid circular imports
+ from .shell_completion import CompletionItem
+
+ results: list[CompletionItem] = []
+
+ if incomplete and not incomplete[0].isalnum():
+ # avoid circular imports
+ from ..core import TyperOption
+
+ for param in self.get_params(ctx):
+ if (
+ not isinstance(param, TyperOption)
+ or param.hidden
+ or (
+ not param.multiple
+ and ctx.get_parameter_source(param.name) # type: ignore
+ is ParameterSource.COMMANDLINE
+ )
+ ):
+ continue
+
+ results.extend(
+ CompletionItem(name, help=param.help)
+ for name in [*param.opts, *param.secondary_opts]
+ if name.startswith(incomplete)
+ )
+
+ return results
+
+ @abstractmethod
+ def main(
+ self,
+ args: Sequence[str] | None = None,
+ prog_name: str | None = None,
+ complete_var: str | None = None,
+ standalone_mode: bool = True,
+ windows_expand_args: bool = True,
+ **extra: Any,
+ ) -> Any:
+ pass # pragma: no cover
+
+ @abstractmethod
+ def _main_shell_completion(
+ self,
+ ctx_args: MutableMapping[str, Any],
+ prog_name: str,
+ complete_var: str | None = None,
+ ) -> None:
+ pass # pragma: no cover
+
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
+ """Alias for self.main"""
+ return self.main(*args, **kwargs)
+
+
+class Parameter(ABC):
+ r"""A parameter to a command comes in two versions: they are either
+ `Option`\s or `Argument`\s.
+
+ Some settings are supported by both options and arguments.
+ """
+
+ param_type_name = "parameter"
+
+ def __init__(
+ self,
+ param_decls: Sequence[str] | None = None,
+ type: types.ParamType | Any | None = None,
+ required: bool = False,
+ default: Any | Callable[[], Any] | None = None,
+ callback: Callable[[Context, "Parameter", Any], Any] | None = None,
+ nargs: int | None = None,
+ multiple: bool = False,
+ metavar: str | None = None,
+ expose_value: bool = True,
+ is_eager: bool = False,
+ envvar: str | Sequence[str] | None = None,
+ shell_complete: Callable[
+ [Context, "Parameter", str], list["CompletionItem"] | list[str]
+ ]
+ | None = None,
+ ) -> None:
+ self.name: str | None
+ self.opts: list[str]
+ self.secondary_opts: list[str]
+ self.name, self.opts, self.secondary_opts = self._parse_decls(
+ param_decls or (), expose_value
+ )
+ self.type: types.ParamType = types.convert_type(type, default)
+
+ # Default nargs to what the type tells us if we have that
+ # information available.
+ if nargs is None:
+ if self.type.is_composite:
+ nargs = self.type.arity
+ else:
+ nargs = 1
+
+ self.required = required
+ self.callback = callback
+ self.nargs = nargs
+ self.multiple = multiple
+ self.expose_value = expose_value
+ self.default: Any | Callable[[], Any] | None = default
+ self.is_eager = is_eager
+ self.metavar = metavar
+ self.envvar = envvar
+ self._custom_shell_complete = shell_complete
+
+ def __repr__(self) -> str:
+ return f"<{self.__class__.__name__} {self.name}>"
+
+ @abstractmethod
+ def _parse_decls(
+ self, decls: Sequence[str], expose_value: bool
+ ) -> tuple[str | None, list[str], list[str]]:
+ pass # pragma: no cover
+
+ @property
+ def human_readable_name(self) -> str:
+ """Returns the human readable name of this parameter. This is the
+ same as the name for options, but the metavar for arguments.
+ """
+ assert self.name is not None, "self.name should be set"
+ return self.name
+
+ def make_metavar(self, ctx: Context) -> str:
+ if self.metavar is not None:
+ return self.metavar
+
+ metavar = self.type.get_metavar(param=self, ctx=ctx)
+
+ if metavar is None:
+ metavar = self.type.name.upper()
+
+ if self.nargs != 1:
+ metavar += "..."
+
+ return metavar
+
+ @overload
+ def get_default(self, ctx: Context, call: Literal[True] = True) -> Any | None: ...
+
+ @overload
+ def get_default(
+ self, ctx: Context, call: bool = ...
+ ) -> Any | Callable[[], Any] | None: ...
+
+ def get_default(
+ self, ctx: Context, call: bool = True
+ ) -> Any | Callable[[], Any] | None:
+ """Get the default for the parameter"""
+ value = ctx.lookup_default(self.name, call=False) # type: ignore
+
+ if value is None:
+ value = self.default
+
+ if call and callable(value):
+ value = value()
+
+ return value
+
+ @abstractmethod
+ def add_to_parser(self, parser: _OptionParser, ctx: Context) -> None:
+ pass # pragma: no cover
+
+ def consume_value(
+ self, ctx: Context, opts: Mapping[str, Any]
+ ) -> tuple[Any, ParameterSource]:
+ value = opts.get(self.name) # type: ignore
+ source = ParameterSource.COMMANDLINE
+
+ if value is None:
+ value = self.value_from_envvar(ctx)
+ source = ParameterSource.ENVIRONMENT
+
+ if value is None:
+ value = ctx.lookup_default(self.name) # type: ignore
+ source = ParameterSource.DEFAULT_MAP
+
+ if value is None:
+ value = self.get_default(ctx)
+ source = ParameterSource.DEFAULT
+
+ return value, source
+
+ def type_cast_value(self, ctx: Context, value: Any) -> Any:
+ """Convert and validate a value against the parameter's
+ `type`, `multiple`, and `nargs`.
+ """
+ if value is None:
+ return () if self.multiple or self.nargs == -1 else None
+
+ def check_iter(value: Any) -> Iterator[Any]:
+ if isinstance(value, str):
+ raise BadParameter("Value must be an iterable.", ctx=ctx, param=self)
+ else:
+ return iter(value)
+
+ # Define the conversion function based on nargs and type.
+ if self.nargs == 1 or self.type.is_composite:
+
+ def convert(value: Any) -> Any:
+ return self.type(value, param=self, ctx=ctx)
+
+ elif self.nargs == -1:
+
+ def convert(value: Any) -> Any: # tuple[t.Any, ...]
+ return tuple(self.type(x, self, ctx) for x in check_iter(value))
+
+ # TODO: evaluate whether we need to keep this in Typer
+ else: # nargs > 1
+
+ def convert(value: Any) -> Any: # tuple[t.Any, ...]
+ value = tuple(check_iter(value))
+
+ if len(value) != self.nargs:
+ raise BadParameter(
+ f"Takes {self.nargs} values but {len(value)} given.",
+ ctx=ctx,
+ param=self,
+ )
+
+ return tuple(self.type(x, self, ctx) for x in value)
+
+ if self.multiple:
+ return tuple(convert(x) for x in check_iter(value))
+
+ return convert(value)
+
+ @abstractmethod
+ def value_is_missing(self, value: Any) -> bool:
+ pass # pragma: no cover
+
+ def process_value(self, ctx: Context, value: Any) -> Any:
+ """Process the value of this parameter"""
+ value = self.type_cast_value(ctx, value)
+
+ if self.required and self.value_is_missing(value):
+ raise MissingParameter(ctx=ctx, param=self)
+
+ if self.callback is not None:
+ value = self.callback(ctx, self, value)
+
+ return value
+
+ def resolve_envvar_value(self, ctx: Context) -> str | None:
+ """Returns the value found in the environment variable(s) attached to this
+ parameter.
+
+ Environment variables values are `always returned as strings
+ `_.
+
+ This method returns ``None`` if:
+
+ - the `envvar` property is not set on `Parameter`,
+ - the environment variable is not found in the environment,
+ - the variable is found in the environment but its value is empty (i.e. the
+ environment variable is present but has an empty string).
+
+ If `envvar` is setup with multiple environment variables,
+ then only the first non-empty value is returned.
+ """
+ if self.envvar is None:
+ return None
+
+ if isinstance(self.envvar, str):
+ rv = os.environ.get(self.envvar)
+
+ if rv:
+ return rv
+ else:
+ for envvar in self.envvar:
+ rv = os.environ.get(envvar)
+
+ # Return the first non-empty value of the list of environment variables.
+ if rv:
+ return rv
+ # Else, absence of value is interpreted as an environment variable that
+ # is not set, so proceed to the next one.
+
+ return None
+
+ def value_from_envvar(self, ctx: Context) -> str | Sequence[str] | None:
+ """Process the raw environment variable string for this parameter.
+
+ Returns the string as-is or splits it into a sequence of strings if the
+ parameter is expecting multiple values (i.e. its `nargs` property is set
+ to a value other than ``1``).
+ """
+ rv: Any | None = self.resolve_envvar_value(ctx)
+
+ if rv is not None and self.nargs != 1:
+ rv = self.type.split_envvar_value(rv)
+
+ return rv
+
+ def handle_parse_result(
+ self, ctx: Context, opts: Mapping[str, Any], args: list[str]
+ ) -> tuple[Any, list[str]]:
+ """Process the value produced by the parser from user input.
+
+ Always process the value through the Parameter's `type`, wherever it
+ comes from.
+
+ If the parameter is deprecated, this method warn the user about it. But only if
+ the value has been explicitly set by the user (and as such, is not coming from
+ a default).
+ """
+ with augment_usage_errors(ctx, param=self):
+ value, source = self.consume_value(ctx, opts)
+
+ ctx.set_parameter_source(self.name, source) # type: ignore
+
+ # Process the value through the parameter's type.
+ try:
+ value = self.process_value(ctx, value)
+ except Exception:
+ if not ctx.resilient_parsing:
+ raise
+ value = None
+
+ if self.expose_value:
+ ctx.params[self.name] = value # type: ignore
+
+ return value, args
+
+ @abstractmethod
+ def get_help_record(self, ctx: Context) -> tuple[str, str] | None:
+ pass # pragma: no cover
+
+ def get_usage_pieces(self, ctx: Context) -> list[str]:
+ return []
+
+ def get_error_hint(self, ctx: Context) -> str:
+ """Get a stringified version of the param for use in error messages to
+ indicate which param caused the error.
+ """
+ hint_list = self.opts or [self.human_readable_name]
+ return " / ".join(f"'{x}'" for x in hint_list)
+
+ def shell_complete(self, ctx: Context, incomplete: str) -> list["CompletionItem"]:
+ """Return a list of completions for the incomplete value. If a
+ ``shell_complete`` function was given during init, it is used.
+ Otherwise, the `type` `ParamType.shell_complete` function is used.
+ """
+ if self._custom_shell_complete is not None:
+ results = self._custom_shell_complete(ctx, self, incomplete)
+
+ if results and isinstance(results[0], str):
+ from .shell_completion import CompletionItem
+
+ results = [CompletionItem(c) for c in results]
+
+ return cast("list[CompletionItem]", results)
+
+ return self.type.shell_complete(ctx, self, incomplete)
diff --git a/typer/_click/decorators.py b/typer/_click/decorators.py
new file mode 100644
index 0000000000..28ad656a8c
--- /dev/null
+++ b/typer/_click/decorators.py
@@ -0,0 +1,60 @@
+from collections.abc import Callable
+from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
+
+from .core import Command, Context, Parameter
+from .utils import echo
+
+if TYPE_CHECKING:
+ from ..core import TyperGroup, TyperOption
+
+ GrpType = TypeVar("GrpType", bound=TyperGroup)
+
+
+P = ParamSpec("P")
+
+R = TypeVar("R")
+T = TypeVar("T")
+_AnyCallable = Callable[..., Any]
+
+
+CmdType = TypeVar("CmdType", bound=Command)
+
+
+def option(
+ param_decls: list[str], cls: type["TyperOption"] | None = None, **attrs: Any
+) -> Callable[[Command], Command]:
+ """Attaches an option to the command."""
+ if cls is None:
+ # avoid circular imports
+ from ..core import TyperOption
+
+ cls = TyperOption
+
+ def decorator(f: Command) -> Command:
+ param = cls(param_decls=param_decls, **attrs)
+ f.params.append(param)
+ return f
+
+ return decorator
+
+
+def help_option(param_decls: list[str]) -> Callable[[Command], Command]:
+ """Help option which prints the help page and exits the program."""
+
+ def show_help(ctx: Context, param: Parameter, value: bool) -> None:
+ """Callback that print the help page on ```` and exits."""
+ if value and not ctx.resilient_parsing:
+ echo(ctx.get_help(), color=ctx.color)
+ ctx.exit()
+
+ assert len(param_decls) > 0, "At least one help option should be provided"
+
+ return option(
+ param_decls,
+ is_flag=True,
+ expose_value=False,
+ is_eager=True,
+ help="Show this message and exit.",
+ callback=show_help,
+ required=False,
+ )
diff --git a/typer/_click/exceptions.py b/typer/_click/exceptions.py
new file mode 100644
index 0000000000..af2af260a6
--- /dev/null
+++ b/typer/_click/exceptions.py
@@ -0,0 +1,260 @@
+from collections.abc import Sequence
+from typing import IO, TYPE_CHECKING, Any, Union
+
+from ._compat import get_text_stderr
+from .globals import resolve_color_default
+from .utils import echo, format_filename
+
+if TYPE_CHECKING:
+ from .core import Command, Context, Parameter
+
+
+def _join_param_hints(param_hint: Sequence[str] | str | None) -> str | None:
+ if param_hint is not None and not isinstance(param_hint, str):
+ return " / ".join(repr(x) for x in param_hint)
+
+ return param_hint
+
+
+class ClickException(Exception):
+ """An exception that Click can handle and show to the user."""
+
+ exit_code = 1
+
+ def __init__(self, message: str) -> None:
+ super().__init__(message)
+ # The context will be removed by the time we print the message, so cache
+ # the color settings here to be used later on (in `show`)
+ self.show_color: bool | None = resolve_color_default()
+ self.message = message
+
+ def format_message(self) -> str:
+ return self.message
+
+ def __str__(self) -> str:
+ return self.message
+
+ def show(self, file: IO[Any] | None = None) -> None:
+ if file is None:
+ file = get_text_stderr()
+
+ echo(
+ f"Error: {self.format_message()}",
+ file=file,
+ color=self.show_color,
+ )
+
+
+class UsageError(ClickException):
+ """An internal exception that signals a usage error. This typically
+ aborts any further handling.
+ """
+
+ exit_code = 2
+
+ def __init__(self, message: str, ctx: Union["Context", None] = None) -> None:
+ super().__init__(message)
+ self.ctx = ctx
+ self.cmd: Command | None = self.ctx.command if self.ctx else None
+
+ def show(self, file: IO[Any] | None = None) -> None:
+ if file is None:
+ file = get_text_stderr()
+ color = None
+ hint = ""
+ if (
+ self.ctx is not None
+ and self.ctx.command.get_help_option(self.ctx) is not None
+ ):
+ command = self.ctx.command_path
+ option = self.ctx.help_option_names[0]
+ hint = f"Try '{command} {option}' for help.\n"
+ if self.ctx is not None:
+ color = self.ctx.color
+ echo(f"{self.ctx.get_usage()}\n{hint}", file=file, color=color)
+ echo(
+ f"Error: {self.format_message()}",
+ file=file,
+ color=color,
+ )
+
+
+class BadParameter(UsageError):
+ """An exception that formats out a standardized error message for a
+ bad parameter. This is useful when thrown from a callback or type as
+ Click will attach contextual information to it (for instance, which
+ parameter it is).
+ """
+
+ def __init__(
+ self,
+ message: str,
+ ctx: Union["Context", None] = None,
+ param: Union["Parameter", None] = None,
+ param_hint: Sequence[str] | str | None = None,
+ ) -> None:
+ super().__init__(message, ctx)
+ self.param = param
+ self.param_hint = param_hint
+
+ def format_message(self) -> str:
+ if self.param_hint is not None:
+ param_hint = self.param_hint
+ elif self.param is not None:
+ param_hint = self.param.get_error_hint(self.ctx) # type: ignore
+ else:
+ return f"Invalid value: {self.message}"
+
+ hint = _join_param_hints(param_hint)
+ return f"Invalid value for {hint}: {self.message}"
+
+
+class MissingParameter(BadParameter):
+ """Raised if click required an option or argument but it was not
+ provided when invoking the script.
+ """
+
+ def __init__(
+ self,
+ message: str | None = None,
+ ctx: Union["Context", None] = None,
+ param: Union["Parameter", None] = None,
+ param_hint: Sequence[str] | str | None = None,
+ param_type: str | None = None,
+ ) -> None:
+ super().__init__(message or "", ctx, param, param_hint)
+ self.param_type = param_type
+
+ def format_message(self) -> str:
+ if self.param_hint is not None:
+ param_hint: Sequence[str] | str | None = self.param_hint
+ elif self.param is not None:
+ param_hint = self.param.get_error_hint(self.ctx) # type: ignore
+ else:
+ param_hint = None
+
+ param_hint = _join_param_hints(param_hint)
+ param_hint = f" {param_hint}" if param_hint else ""
+
+ param_type = self.param_type
+ if param_type is None and self.param is not None:
+ param_type = self.param.param_type_name
+
+ msg = self.message
+ if self.param is not None:
+ msg_extra = self.param.type.get_missing_message(
+ param=self.param, ctx=self.ctx
+ )
+ if msg_extra:
+ if msg:
+ msg += f". {msg_extra}"
+ else:
+ msg = msg_extra
+
+ msg = f" {msg}" if msg else ""
+
+ # Translate param_type for known types.
+ if param_type == "argument":
+ missing = "Missing argument"
+ elif param_type == "option":
+ missing = "Missing option"
+ elif param_type == "parameter":
+ missing = "Missing parameter"
+ else:
+ missing = f"Missing {param_type}"
+
+ return f"{missing}{param_hint}.{msg}"
+
+ def __str__(self) -> str:
+ if not self.message:
+ param_name = self.param.name if self.param else None
+ return f"Missing parameter: {param_name}"
+ else:
+ return self.message
+
+
+class NoSuchOption(UsageError):
+ """Raised if click attempted to handle an option that does not
+ exist.
+ """
+
+ def __init__(
+ self,
+ option_name: str,
+ message: str | None = None,
+ possibilities: Sequence[str] | None = None,
+ ctx: Union["Context", None] = None,
+ ) -> None:
+ if message is None:
+ message = f"No such option: {option_name}"
+
+ super().__init__(message, ctx)
+ self.option_name = option_name
+ self.possibilities = possibilities
+
+ def format_message(self) -> str:
+ if not self.possibilities:
+ return self.message
+
+ possibility_str = ", ".join(sorted(self.possibilities))
+ suggest = (f"(Possible options: {possibility_str})",)
+ return f"{self.message} {suggest}"
+
+
+class BadOptionUsage(UsageError):
+ """Raised if an option is generally supplied but the use of the option
+ was incorrect. This is for instance raised if the number of arguments
+ for an option is not correct.
+ """
+
+ def __init__(
+ self, option_name: str, message: str, ctx: Union["Context", None] = None
+ ) -> None:
+ super().__init__(message, ctx)
+ self.option_name = option_name
+
+
+class BadArgumentUsage(UsageError):
+ """Raised if an argument is generally supplied but the use of the argument
+ was incorrect. This is for instance raised if the number of values
+ for an argument is not correct.
+ """
+
+
+class NoArgsIsHelpError(UsageError):
+ def __init__(self, ctx: "Context") -> None:
+ self.ctx: Context
+ super().__init__(ctx.get_help(), ctx=ctx)
+
+ def show(self, file: IO[Any] | None = None) -> None:
+ echo(self.format_message(), file=file, err=True, color=self.ctx.color)
+
+
+class FileError(ClickException):
+ """Raised if a file cannot be opened."""
+
+ def __init__(self, filename: str, hint: str | None = None) -> None:
+ if hint is None:
+ hint = "unknown error"
+
+ super().__init__(hint)
+ self.ui_filename: str = format_filename(filename)
+ self.filename = filename
+
+ def format_message(self) -> str:
+ return f"Could not open file {self.ui_filename!r}: {self.message}"
+
+
+class Abort(RuntimeError):
+ """An internal signalling exception that signals Click to abort."""
+
+
+class Exit(RuntimeError):
+ """An exception that indicates that the application should exit with some
+ status code.
+ """
+
+ __slots__ = ("exit_code",)
+
+ def __init__(self, code: int = 0) -> None:
+ self.exit_code: int = code
diff --git a/typer/_click/formatting.py b/typer/_click/formatting.py
new file mode 100644
index 0000000000..b5eaab3bd0
--- /dev/null
+++ b/typer/_click/formatting.py
@@ -0,0 +1,272 @@
+from collections.abc import Iterable, Iterator, Sequence
+from contextlib import contextmanager
+
+from ._compat import term_len
+from .parser import _split_opt
+
+# Can force a width. This is used by the test system
+FORCED_WIDTH: int | None = None
+
+
+def measure_table(rows: Iterable[tuple[str, str]]) -> tuple[int, ...]:
+ widths: dict[int, int] = {}
+
+ for row in rows:
+ for idx, col in enumerate(row):
+ widths[idx] = max(widths.get(idx, 0), term_len(col))
+
+ return tuple(y for x, y in sorted(widths.items()))
+
+
+def iter_rows(
+ rows: Iterable[tuple[str, str]], col_count: int
+) -> Iterator[tuple[str, ...]]:
+ for row in rows:
+ yield row + ("",) * (col_count - len(row))
+
+
+def wrap_text(
+ text: str,
+ width: int = 78,
+ initial_indent: str = "",
+ subsequent_indent: str = "",
+ preserve_paragraphs: bool = False,
+) -> str:
+ """A helper function that intelligently wraps text. By default, it
+ assumes that it operates on a single paragraph of text but if the
+ `preserve_paragraphs` parameter is provided it will intelligently
+ handle paragraphs (defined by two empty lines).
+
+ If paragraphs are handled, a paragraph can be prefixed with an empty
+ line containing the ``\\b`` character (``\\x08``) to indicate that
+ no rewrapping should happen in that block.
+ """
+ from ._textwrap import TextWrapper
+
+ text = text.expandtabs()
+ wrapper = TextWrapper(
+ width,
+ initial_indent=initial_indent,
+ subsequent_indent=subsequent_indent,
+ replace_whitespace=False,
+ )
+ if not preserve_paragraphs:
+ return wrapper.fill(text)
+
+ p: list[tuple[int, bool, str]] = []
+ buf: list[str] = []
+ indent = None
+
+ def _flush_par() -> None:
+ if not buf:
+ return
+ if buf[0].strip() == "\b":
+ p.append((indent or 0, True, "\n".join(buf[1:])))
+ else:
+ p.append((indent or 0, False, " ".join(buf)))
+ del buf[:]
+
+ for line in text.splitlines():
+ if not line:
+ _flush_par()
+ indent = None
+ else:
+ if indent is None:
+ orig_len = term_len(line)
+ line = line.lstrip()
+ indent = orig_len - term_len(line)
+ buf.append(line)
+ _flush_par()
+
+ rv = []
+ for indent, raw, text in p:
+ with wrapper.extra_indent(" " * indent):
+ if raw:
+ rv.append(wrapper.indent_only(text))
+ else:
+ rv.append(wrapper.fill(text))
+
+ return "\n\n".join(rv)
+
+
+class HelpFormatter:
+ """This class helps with formatting text-based help pages. It's
+ usually just needed for very special internal cases, but it's also
+ exposed so that developers can write their own fancy outputs.
+
+ At present, it always writes into memory.
+ """
+
+ def __init__(
+ self,
+ indent_increment: int = 2,
+ width: int | None = None,
+ max_width: int | None = None,
+ ) -> None:
+ self.indent_increment = indent_increment
+ if max_width is None:
+ max_width = 80
+ if width is None:
+ import shutil
+
+ width = FORCED_WIDTH
+ if width is None:
+ width = max(min(shutil.get_terminal_size().columns, max_width) - 2, 50)
+ self.width = width
+ self.current_indent: int = 0
+ self.buffer: list[str] = []
+
+ def write(self, string: str) -> None:
+ """Writes a unicode string into the internal buffer."""
+ self.buffer.append(string)
+
+ def indent(self) -> None:
+ """Increases the indentation."""
+ self.current_indent += self.indent_increment
+
+ def dedent(self) -> None:
+ """Decreases the indentation."""
+ self.current_indent -= self.indent_increment
+
+ def write_usage(self, prog: str, args: str = "", prefix: str | None = None) -> None:
+ """Writes a usage line into the buffer."""
+ if prefix is None:
+ prefix = "Usage: "
+
+ usage_prefix = f"{prefix:>{self.current_indent}}{prog} "
+ text_width = self.width - self.current_indent
+
+ if text_width >= (term_len(usage_prefix) + 20):
+ # The arguments will fit to the right of the prefix.
+ indent = " " * term_len(usage_prefix)
+ self.write(
+ wrap_text(
+ args,
+ text_width,
+ initial_indent=usage_prefix,
+ subsequent_indent=indent,
+ )
+ )
+ else:
+ # The prefix is too long, put the arguments on the next line.
+ self.write(usage_prefix)
+ self.write("\n")
+ indent = " " * (max(self.current_indent, term_len(prefix)) + 4)
+ self.write(
+ wrap_text(
+ args, text_width, initial_indent=indent, subsequent_indent=indent
+ )
+ )
+
+ self.write("\n")
+
+ def write_heading(self, heading: str) -> None:
+ """Writes a heading into the buffer."""
+ self.write(f"{'':>{self.current_indent}}{heading}:\n")
+
+ def write_paragraph(self) -> None:
+ """Writes a paragraph into the buffer."""
+ if self.buffer:
+ self.write("\n")
+
+ def write_text(self, text: str) -> None:
+ """Writes re-indented text into the buffer. This rewraps and
+ preserves paragraphs.
+ """
+ indent = " " * self.current_indent
+ self.write(
+ wrap_text(
+ text,
+ self.width,
+ initial_indent=indent,
+ subsequent_indent=indent,
+ preserve_paragraphs=True,
+ )
+ )
+ self.write("\n")
+
+ def write_dl(
+ self,
+ rows: Sequence[tuple[str, str]],
+ col_max: int = 30,
+ col_spacing: int = 2,
+ ) -> None:
+ """Writes a definition list into the buffer. This is how options
+ and commands are usually formatted.
+ """
+ rows = list(rows)
+ widths = measure_table(rows)
+ if len(widths) != 2: # pragma: no cover
+ raise TypeError("Expected two columns for definition list")
+
+ first_col = min(widths[0], col_max) + col_spacing
+
+ for first, second in iter_rows(rows, len(widths)):
+ self.write(f"{'':>{self.current_indent}}{first}")
+ if not second:
+ self.write("\n")
+ continue
+ if term_len(first) <= first_col - col_spacing:
+ self.write(" " * (first_col - term_len(first)))
+ else:
+ self.write("\n")
+ self.write(" " * (first_col + self.current_indent))
+
+ text_width = max(self.width - first_col - 2, 10)
+ wrapped_text = wrap_text(second, text_width, preserve_paragraphs=True)
+ lines = wrapped_text.splitlines()
+
+ if lines:
+ self.write(f"{lines[0]}\n")
+
+ for line in lines[1:]:
+ self.write(f"{'':>{first_col + self.current_indent}}{line}\n")
+ else: # pragma: no cover
+ self.write("\n")
+
+ @contextmanager
+ def section(self, name: str) -> Iterator[None]:
+ """Helpful context manager that writes a paragraph, a heading,
+ and the indents.
+ """
+ self.write_paragraph()
+ self.write_heading(name)
+ self.indent()
+ try:
+ yield
+ finally:
+ self.dedent()
+
+ @contextmanager
+ def indentation(self) -> Iterator[None]:
+ """A context manager that increases the indentation."""
+ self.indent()
+ try:
+ yield
+ finally:
+ self.dedent()
+
+ def getvalue(self) -> str:
+ """Returns the buffer contents."""
+ return "".join(self.buffer)
+
+
+def join_options(options: Sequence[str]) -> tuple[str, bool]:
+ """Given a list of option strings this joins them in the most appropriate
+ way and returns them in the form ``(formatted_string,
+ any_prefix_is_slash)`` where the second item in the tuple is a flag that
+ indicates if any of the option prefixes was a slash.
+ """
+ rv = []
+ any_prefix_is_slash = False
+
+ for opt in options:
+ prefix = _split_opt(opt)[0]
+
+ if prefix == "/":
+ any_prefix_is_slash = True
+
+ rv.append((len(prefix), opt))
+
+ rv.sort(key=lambda x: x[0])
+ return ", ".join(x[1] for x in rv), any_prefix_is_slash
diff --git a/typer/_click/globals.py b/typer/_click/globals.py
new file mode 100644
index 0000000000..372dc40749
--- /dev/null
+++ b/typer/_click/globals.py
@@ -0,0 +1,61 @@
+from threading import local
+from typing import TYPE_CHECKING, Literal, Union, cast, overload
+
+if TYPE_CHECKING:
+ from .core import Context
+
+_local = local()
+
+
+@overload
+def get_current_context(silent: Literal[False] = False) -> "Context": ...
+
+
+@overload
+def get_current_context(silent: bool = ...) -> Union["Context", None]: ...
+
+
+def get_current_context(silent: bool = False) -> Union["Context", None]:
+ """Returns the current click context. This can be used as a way to
+ access the current context object from anywhere. This is a more implicit
+ alternative to the `pass_context` decorator. This function is
+ primarily useful for helpers such as `echo` which might be
+ interested in changing its behavior based on the current context.
+
+ To push the current context, `Context.scope` can be used.
+ """
+ try:
+ return cast("Context", _local.stack[-1])
+ except (AttributeError, IndexError) as e:
+ if not silent:
+ raise RuntimeError(
+ "There is no active click context."
+ ) from e # pragma: no cover
+
+ return None
+
+
+def push_context(ctx: "Context") -> None:
+ """Pushes a new context to the current stack."""
+ _local.__dict__.setdefault("stack", []).append(ctx)
+
+
+def pop_context() -> None:
+ """Removes the top level from the stack."""
+ _local.stack.pop()
+
+
+def resolve_color_default(color: bool | None = None) -> bool | None:
+ """Internal helper to get the default value of the color flag. If a
+ value is passed it's returned unchanged, otherwise it's looked up from
+ the current context.
+ """
+ if color is not None:
+ return color
+
+ ctx = get_current_context(silent=True)
+
+ if ctx is not None:
+ return ctx.color
+
+ return None
diff --git a/typer/_click/parser.py b/typer/_click/parser.py
new file mode 100644
index 0000000000..71eb3003cc
--- /dev/null
+++ b/typer/_click/parser.py
@@ -0,0 +1,459 @@
+"""
+This module started out as largely a copy paste from the stdlib's
+optparse module with the features removed that we do not need from
+optparse because we implement them in Click on a higher level (for
+instance type handling, help formatting and a lot more).
+
+The plan is to remove more and more from here over time.
+
+The reason this is a different module and not optparse from the stdlib
+is that there are differences in 2.x and 3.x about the error messages
+generated and optparse in the stdlib uses gettext for no good reason
+and might cause us issues.
+
+Click uses parts of optparse written by Gregory P. Ward and maintained
+by the Python Software Foundation. This is limited to code in parser.py.
+
+Copyright 2001-2006 Gregory P. Ward. All rights reserved.
+Copyright 2002-2006 Python Software Foundation. All rights reserved.
+"""
+
+# This code uses parts of optparse written by Gregory P. Ward and
+# maintained by the Python Software Foundation.
+# Copyright 2001-2006 Gregory P. Ward
+# Copyright 2002-2006 Python Software Foundation
+from collections import deque
+from collections.abc import Sequence
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ TypeVar,
+ Union,
+)
+
+from .exceptions import BadArgumentUsage, BadOptionUsage, NoSuchOption, UsageError
+
+if TYPE_CHECKING:
+ from typer.core import TyperArgument as CoreArgument
+ from typer.core import TyperOption as CoreOption
+
+ from .core import Context
+ from .core import Parameter as CoreParameter
+
+V = TypeVar("V")
+
+
+def _unpack_args(
+ args: Sequence[str], nargs_spec: Sequence[int]
+) -> tuple[Sequence[str | Sequence[str | None] | None], list[str]]:
+ """Given an iterable of arguments and an iterable of nargs specifications,
+ it returns a tuple with all the unpacked arguments at the first index
+ and all remaining arguments as the second.
+
+ The nargs specification is the number of arguments that should be consumed
+ or `-1` to indicate that this position should eat up all the remainders.
+ """
+ args = deque(args)
+ nargs_spec = deque(nargs_spec)
+ rv: list[str | tuple[str | None, ...] | None] = []
+ spos: int | None = None
+
+ def _fetch(c: deque[V]) -> V | None:
+ try:
+ if spos is None:
+ return c.popleft()
+ else:
+ return c.pop()
+ except IndexError:
+ return None
+
+ while nargs_spec:
+ nargs = _fetch(nargs_spec)
+ assert nargs is not None
+
+ if nargs == 1:
+ rv.append(_fetch(args))
+ elif nargs > 1:
+ x = [_fetch(args) for _ in range(nargs)]
+
+ # If we're reversed, we're pulling in the arguments in reverse,
+ # so we need to turn them around.
+ if spos is not None:
+ x.reverse()
+
+ rv.append(tuple(x))
+ elif nargs < 0:
+ if spos is not None: # pragma: no cover
+ raise TypeError("Cannot have two nargs < 0")
+
+ spos = len(rv)
+ rv.append(None)
+
+ # spos is the position of the wildcard (star). If it's not `None`,
+ # we fill it with the remainder.
+ if spos is not None:
+ rv[spos] = tuple(args)
+ args = []
+ rv[spos + 1 :] = reversed(rv[spos + 1 :])
+
+ return tuple(rv), list(args)
+
+
+def _split_opt(opt: str) -> tuple[str, str]:
+ first = opt[:1]
+ if first.isalnum():
+ return "", opt
+ if opt[1:2] == first:
+ return opt[:2], opt[2:]
+ return first, opt[1:]
+
+
+def _normalize_opt(opt: str, ctx: Union["Context", None]) -> str:
+ if ctx is None or ctx.token_normalize_func is None:
+ return opt
+ prefix, opt = _split_opt(opt)
+ return f"{prefix}{ctx.token_normalize_func(opt)}"
+
+
+class _Option:
+ def __init__(
+ self,
+ obj: "CoreOption",
+ opts: Sequence[str],
+ dest: str | None,
+ action: str = "store",
+ nargs: int = 1,
+ const: Any | None = None,
+ ):
+ self._short_opts = []
+ self._long_opts = []
+ self.prefixes: set[str] = set()
+
+ for opt in opts:
+ prefix, value = _split_opt(opt)
+ if not prefix: # pragma: no cover
+ raise ValueError(f"Invalid start character for option ({opt})")
+ self.prefixes.add(prefix[0])
+ if len(prefix) == 1 and len(value) == 1:
+ self._short_opts.append(opt)
+ else:
+ self._long_opts.append(opt)
+ self.prefixes.add(prefix)
+
+ self.dest = dest
+ self.action = action
+ self.nargs = nargs
+ self.const = const
+ self.obj = obj
+
+ @property
+ def takes_value(self) -> bool:
+ return self.action in ("store", "append")
+
+ def process(self, value: Any, state: "_ParsingState") -> None:
+ if self.action == "store":
+ state.opts[self.dest] = value # type: ignore
+ elif self.action == "store_const":
+ state.opts[self.dest] = self.const # type: ignore
+ elif self.action == "append":
+ state.opts.setdefault(self.dest, []).append(value) # type: ignore
+ elif self.action == "append_const":
+ state.opts.setdefault(self.dest, []).append(self.const) # type: ignore
+ elif self.action == "count":
+ state.opts[self.dest] = state.opts.get(self.dest, 0) + 1 # type: ignore
+ else: # pragma: no cover
+ raise ValueError(f"unknown action '{self.action}'")
+ state.order.append(self.obj)
+
+
+class _Argument:
+ def __init__(self, obj: "CoreArgument", dest: str | None, nargs: int = 1):
+ self.dest = dest
+ self.nargs = nargs
+ self.obj = obj
+
+ def process(
+ self,
+ value: str | Sequence[str | None] | None,
+ state: "_ParsingState",
+ ) -> None:
+ if self.nargs > 1:
+ assert value is not None
+ holes = sum(1 for x in value if x is None)
+ if holes == len(value):
+ value = None
+ elif holes != 0:
+ raise BadArgumentUsage(
+ f"Argument {self.dest!r} takes {self.nargs} values."
+ )
+
+ if self.nargs == -1 and self.obj.envvar is not None and value == ():
+ # Replace empty tuple with None so that a value from the
+ # environment may be tried.
+ value = None
+
+ state.opts[self.dest] = value # type: ignore
+ state.order.append(self.obj)
+
+
+class _ParsingState:
+ def __init__(self, rargs: list[str]) -> None:
+ self.opts: dict[str, Any] = {}
+ self.largs: list[str] = []
+ self.rargs = rargs
+ self.order: list[CoreParameter] = []
+
+
+class _OptionParser:
+ """The option parser is an internal class that is ultimately used to
+ parse options and arguments. It's modelled after optparse and brings
+ a similar but vastly simplified API. It should generally not be used
+ directly as the high level Click classes wrap it for you.
+
+ It's not nearly as extensible as optparse or argparse as it does not
+ implement features that are implemented on a higher level (such as
+ types or defaults).
+ """
+
+ def __init__(self, ctx: Union["Context", None] = None) -> None:
+ self.ctx = ctx
+ # This controls how the parser deals with interspersed arguments.
+ # If this is set to `False`, the parser will stop on the first
+ # non-option. Click uses this to implement nested subcommands
+ # safely.
+ self.allow_interspersed_args: bool = True
+ # This tells the parser how to deal with unknown options. By
+ # default it will error out (which is sensible), but there is a
+ # second mode where it will ignore it and continue processing
+ # after shifting all the unknown options into the resulting args.
+ self.ignore_unknown_options: bool = False
+
+ if ctx is not None:
+ self.allow_interspersed_args = ctx.allow_interspersed_args
+ self.ignore_unknown_options = ctx.ignore_unknown_options
+
+ self._short_opt: dict[str, _Option] = {}
+ self._long_opt: dict[str, _Option] = {}
+ self._opt_prefixes = {"-", "--"}
+ self._args: list[_Argument] = []
+
+ def add_option(
+ self,
+ obj: "CoreOption",
+ opts: Sequence[str],
+ dest: str | None,
+ action: str = "store",
+ nargs: int = 1,
+ const: Any | None = None,
+ ) -> None:
+ """Adds a new option named `dest` to the parser. The destination
+ is not inferred (unlike with optparse) and needs to be explicitly
+ provided. Action can be any of ``store``, ``store_const``,
+ ``append``, ``append_const`` or ``count``.
+
+ The `obj` can be used to identify the option in the order list
+ that is returned from the parser.
+ """
+ opts = [_normalize_opt(opt, self.ctx) for opt in opts]
+ option = _Option(obj, opts, dest, action=action, nargs=nargs, const=const)
+ self._opt_prefixes.update(option.prefixes)
+ for opt in option._short_opts:
+ self._short_opt[opt] = option
+ for opt in option._long_opts:
+ self._long_opt[opt] = option
+
+ def add_argument(
+ self, obj: "CoreArgument", dest: str | None, nargs: int = 1
+ ) -> None:
+ """Adds a positional argument named `dest` to the parser.
+
+ The `obj` can be used to identify the option in the order list
+ that is returned from the parser.
+ """
+ self._args.append(_Argument(obj, dest=dest, nargs=nargs))
+
+ def parse_args(
+ self, args: list[str]
+ ) -> tuple[dict[str, Any], list[str], list["CoreParameter"]]:
+ """Parses positional arguments and returns ``(values, args, order)``
+ for the parsed options and arguments as well as the leftover
+ arguments if there are any. The order is a list of objects as they
+ appear on the command line. If arguments appear multiple times they
+ will be memorized multiple times as well.
+ """
+ state = _ParsingState(args)
+ try:
+ self._process_args_for_options(state)
+ self._process_args_for_args(state)
+ except UsageError:
+ if self.ctx is None or not self.ctx.resilient_parsing:
+ raise
+ return state.opts, state.largs, state.order
+
+ def _process_args_for_args(self, state: _ParsingState) -> None:
+ pargs, args = _unpack_args(
+ state.largs + state.rargs, [x.nargs for x in self._args]
+ )
+
+ for idx, arg in enumerate(self._args):
+ arg.process(pargs[idx], state)
+
+ state.largs = args
+ state.rargs = []
+
+ def _process_args_for_options(self, state: _ParsingState) -> None:
+ while state.rargs:
+ arg = state.rargs.pop(0)
+ arglen = len(arg)
+ # Double dashes always handled explicitly regardless of what
+ # prefixes are valid.
+ if arg == "--":
+ return
+ elif arg[:1] in self._opt_prefixes and arglen > 1:
+ self._process_opts(arg, state)
+ elif self.allow_interspersed_args:
+ state.largs.append(arg)
+ else:
+ state.rargs.insert(0, arg)
+ return
+
+ # Say this is the original argument list:
+ # [arg0, arg1, ..., arg(i-1), arg(i), arg(i+1), ..., arg(N-1)]
+ # ^
+ # (we are about to process arg(i)).
+ #
+ # Then rargs is [arg(i), ..., arg(N-1)] and largs is a *subset* of
+ # [arg0, ..., arg(i-1)] (any options and their arguments will have
+ # been removed from largs).
+ #
+ # The while loop will usually consume 1 or more arguments per pass.
+ # If it consumes 1 (eg. arg is an option that takes no arguments),
+ # then after _process_arg() is done the situation is:
+ #
+ # largs = subset of [arg0, ..., arg(i)]
+ # rargs = [arg(i+1), ..., arg(N-1)]
+ #
+ # If allow_interspersed_args is false, largs will always be
+ # *empty* -- still a subset of [arg0, ..., arg(i-1)], but
+ # not a very interesting subset!
+
+ def _match_long_opt(
+ self, opt: str, explicit_value: str | None, state: _ParsingState
+ ) -> None:
+ if opt not in self._long_opt:
+ from difflib import get_close_matches
+
+ possibilities = get_close_matches(opt, self._long_opt)
+ raise NoSuchOption(opt, possibilities=possibilities, ctx=self.ctx)
+
+ option = self._long_opt[opt]
+ if option.takes_value:
+ # At this point it's safe to modify rargs by injecting the
+ # explicit value, because no exception is raised in this
+ # branch. This means that the inserted value will be fully
+ # consumed.
+ if explicit_value is not None:
+ state.rargs.insert(0, explicit_value)
+
+ value = self._get_value_from_state(opt, option, state)
+
+ elif explicit_value is not None: # pragma: no cover
+ raise BadOptionUsage(opt, f"Option {opt!r} does not take a value.")
+
+ else:
+ value = None
+
+ option.process(value, state)
+
+ def _match_short_opt(self, arg: str, state: _ParsingState) -> None:
+ stop = False
+ i = 1
+ prefix = arg[0]
+ unknown_options = []
+
+ for ch in arg[1:]:
+ opt = _normalize_opt(f"{prefix}{ch}", self.ctx)
+ option = self._short_opt.get(opt)
+ i += 1
+
+ if not option:
+ if self.ignore_unknown_options:
+ unknown_options.append(ch)
+ continue
+ raise NoSuchOption(opt, ctx=self.ctx)
+ if option.takes_value:
+ # Any characters left in arg? Pretend they're the
+ # next arg, and stop consuming characters of arg.
+ if i < len(arg):
+ state.rargs.insert(0, arg[i:])
+ stop = True
+
+ value = self._get_value_from_state(opt, option, state)
+
+ else:
+ value = None
+
+ option.process(value, state)
+
+ if stop:
+ break
+
+ # If we got any unknown options we recombine the string of the
+ # remaining options and re-attach the prefix, then report that
+ # to the state as new 'largs'. This way there is basic combinatorics
+ # that can be achieved while still ignoring unknown arguments.
+ if self.ignore_unknown_options and unknown_options:
+ state.largs.append(f"{prefix}{''.join(unknown_options)}")
+
+ def _get_value_from_state(
+ self, option_name: str, option: _Option, state: _ParsingState
+ ) -> str | Sequence[str]:
+ nargs = option.nargs
+
+ value: str | Sequence[str]
+
+ if len(state.rargs) < nargs:
+ msg = "an argument." if nargs == 1 else f"{nargs} arguments."
+ raise BadOptionUsage(
+ option_name,
+ f"Option {option_name!r} requires {msg}",
+ )
+ elif nargs == 1:
+ value = state.rargs.pop(0)
+ else:
+ value = tuple(state.rargs[:nargs])
+ del state.rargs[:nargs]
+
+ return value
+
+ def _process_opts(self, arg: str, state: _ParsingState) -> None:
+ explicit_value = None
+ # Long option handling happens in two parts. The first part is
+ # supporting explicitly attached values. In any case, we will try
+ # to long match the option first.
+ if "=" in arg:
+ long_opt, explicit_value = arg.split("=", 1)
+ else:
+ long_opt = arg
+ norm_long_opt = _normalize_opt(long_opt, self.ctx)
+
+ # At this point we will match the (assumed) long option through
+ # the long option matching code. Note that this allows options
+ # like "-foo" to be matched as long options.
+ try:
+ self._match_long_opt(norm_long_opt, explicit_value, state)
+ except NoSuchOption:
+ # At this point the long option matching failed, and we need
+ # to try with short options. However there is a special rule
+ # which says, that if we have a two character options prefix
+ # (applies to "--foo" for instance), we do not dispatch to the
+ # short option code and will instead raise the no option
+ # error.
+ if arg[:2] not in self._opt_prefixes:
+ self._match_short_opt(arg, state)
+ return
+
+ if not self.ignore_unknown_options:
+ raise
+
+ state.largs.append(arg)
diff --git a/typer/_click/py.typed b/typer/_click/py.typed
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/typer/_click/shell_completion.py b/typer/_click/shell_completion.py
new file mode 100644
index 0000000000..4490f18810
--- /dev/null
+++ b/typer/_click/shell_completion.py
@@ -0,0 +1,306 @@
+import re
+from abc import ABC, abstractmethod
+from collections.abc import MutableMapping
+from typing import Any, ClassVar, TypeVar
+
+from .core import Command, Context, Parameter, ParameterSource
+
+
+class CompletionItem:
+ """Represents a completion value and metadata about the value. The
+ default metadata is ``type`` to indicate special shell handling,
+ and ``help`` if a shell supports showing a help string next to the
+ value.
+
+ Arbitrary parameters can be passed when creating the object, and
+ accessed using ``item.attr``. If an attribute wasn't passed,
+ accessing it returns ``None``.
+ """
+
+ __slots__ = ("value", "type", "help", "_info")
+
+ def __init__(
+ self,
+ value: Any,
+ type: str = "plain",
+ help: str | None = None,
+ **kwargs: Any,
+ ) -> None:
+ self.value: Any = value
+ self.type: str = type
+ self.help: str | None = help
+ self._info = kwargs
+
+ def __getattr__(self, name: str) -> Any:
+ return self._info.get(name)
+
+
+class ShellComplete(ABC):
+ """Base class for providing shell completion support. A subclass for
+ a given shell will override attributes and methods to implement the
+ completion instructions (``source`` and ``complete``).
+ """
+
+ name: ClassVar[str]
+ """Name to register the shell as with `add_completion_class`.
+ This is used in completion instructions (``{name}_source`` and
+ ``{name}_complete``).
+ """
+
+ source_template: ClassVar[str]
+ """Completion script template formatted by `source`. This must
+ be provided by subclasses.
+ """
+
+ def __init__(
+ self,
+ cli: Command,
+ ctx_args: MutableMapping[str, Any],
+ prog_name: str,
+ complete_var: str,
+ ) -> None:
+ self.cli = cli
+ self.ctx_args = ctx_args
+ self.prog_name = prog_name
+ self.complete_var = complete_var
+
+ @property
+ def func_name(self) -> str:
+ """The name of the shell function defined by the completion
+ script.
+ """
+ safe_name = re.sub(r"\W*", "", self.prog_name.replace("-", "_"), flags=re.ASCII)
+ return f"_{safe_name}_completion"
+
+ @abstractmethod
+ def source_vars(self) -> dict[str, Any]:
+ """Vars for formatting `source_template`."""
+ pass # pragma: no cover
+
+ def source(self) -> str:
+ """Produce the shell script that defines the completion
+ function. By default this ``%``-style formats
+ `source_template` with the dict returned by `source_vars`.
+ """
+ return self.source_template % self.source_vars()
+
+ @abstractmethod
+ def get_completion_args(self) -> tuple[list[str], str]:
+ """Use the env vars defined by the shell script to return a
+ tuple of ``args, incomplete``. This must be implemented by
+ subclasses.
+ """
+ pass # pragma: no cover
+
+ def get_completions(self, args: list[str], incomplete: str) -> list[CompletionItem]:
+ """Determine the context and last complete command or parameter
+ from the complete args. Call that object's ``shell_complete``
+ method to get the completions for the incomplete value.
+ """
+ ctx = _resolve_context(self.cli, self.ctx_args, self.prog_name, args)
+ obj, incomplete = _resolve_incomplete(ctx, args, incomplete)
+ return obj.shell_complete(ctx, incomplete)
+
+ @abstractmethod
+ def format_completion(self, item: CompletionItem) -> str:
+ """Format a completion item into the form recognized by the
+ shell script. This must be implemented by subclasses.
+ """
+ pass # pragma: no cover
+
+ def complete(self) -> str:
+ """Produce the completion data to send back to the shell.
+
+ By default this calls `get_completion_args`, gets the
+ completions, then calls `format_completion` for each
+ completion.
+ """
+ args, incomplete = self.get_completion_args()
+ completions = self.get_completions(args, incomplete)
+ out = [self.format_completion(item) for item in completions]
+ return "\n".join(out)
+
+
+ShellCompleteType = TypeVar("ShellCompleteType", bound="type[ShellComplete]")
+
+
+_available_shells: dict[str, type[ShellComplete]] = {}
+
+
+def add_completion_class(cls: ShellCompleteType, name: str) -> ShellCompleteType:
+ """Register a `ShellComplete` subclass under the given name.
+ The name will be provided by the completion instruction environment
+ variable during completion.
+ """
+ _available_shells[name] = cls
+
+ return cls
+
+
+def get_completion_class(shell: str) -> type[ShellComplete] | None:
+ """Look up a registered `ShellComplete` subclass by the name
+ provided by the completion instruction environment variable. If the
+ name isn't registered, returns ``None``.
+ """
+ return _available_shells.get(shell)
+
+
+def split_arg_string(string: str) -> list[str]:
+ """Split an argument string as with `shlex.split`, but don't
+ fail if the string is incomplete. Ignores a missing closing quote or
+ incomplete escape sequence and uses the partial token as-is.
+ """
+ import shlex
+
+ lex = shlex.shlex(string, posix=True)
+ lex.whitespace_split = True
+ lex.commenters = ""
+ out = []
+
+ try:
+ for token in lex:
+ out.append(token)
+ except ValueError: # pragma: no cover
+ # Raised when end-of-string is reached in an invalid state. Use
+ # the partial token as-is. The quote or escape character is in
+ # lex.state, not lex.token.
+ out.append(lex.token)
+
+ return out
+
+
+def _is_incomplete_argument(ctx: Context, param: Parameter) -> bool:
+ """Determine if the given parameter is an argument that can still
+ accept values.
+ """
+ # avoid circular imports
+ from ..core import TyperArgument
+
+ if not isinstance(param, TyperArgument):
+ return False
+
+ assert param.name is not None
+ # Will be None if expose_value is False.
+ value = ctx.params.get(param.name)
+ return (
+ param.nargs == -1
+ or ctx.get_parameter_source(param.name) is not ParameterSource.COMMANDLINE
+ or (
+ param.nargs > 1
+ and isinstance(value, (tuple, list))
+ and len(value) < param.nargs
+ )
+ )
+
+
+def _start_of_option(ctx: Context, value: str) -> bool:
+ """Check if the value looks like the start of an option."""
+ if not value:
+ return False
+
+ c = value[0]
+ return c in ctx._opt_prefixes
+
+
+def _is_incomplete_option(ctx: Context, args: list[str], param: Parameter) -> bool:
+ """Determine if the given parameter is an option that needs a value."""
+ # avoid circular imports
+ from ..core import TyperOption
+
+ if not isinstance(param, TyperOption):
+ return False
+
+ if param.is_flag or param.count:
+ return False
+
+ last_option = None
+
+ for index, arg in enumerate(reversed(args)):
+ if index + 1 > param.nargs:
+ break
+
+ if _start_of_option(ctx, arg):
+ last_option = arg
+ break
+
+ return last_option is not None and last_option in param.opts
+
+
+def _resolve_context(
+ cli: Command,
+ ctx_args: MutableMapping[str, Any],
+ prog_name: str,
+ args: list[str],
+) -> Context:
+ """Produce the context hierarchy starting with the command and
+ traversing the complete arguments. This only follows the commands,
+ it doesn't trigger input prompts or callbacks.
+ """
+ # avoid circular imports
+ from ..core import TyperGroup
+
+ ctx_args["resilient_parsing"] = True
+ with cli.make_context(prog_name, args.copy(), **ctx_args) as ctx:
+ args = ctx._protected_args + ctx.args
+
+ while args:
+ command = ctx.command
+
+ if isinstance(command, TyperGroup):
+ # if not command.chain:
+ name, cmd, args = command.resolve_command(ctx, args)
+
+ if cmd is None:
+ return ctx
+
+ with cmd.make_context(
+ name, args, parent=ctx, resilient_parsing=True
+ ) as sub_ctx:
+ ctx = sub_ctx
+ args = ctx._protected_args + ctx.args
+ else: # pragma: no cover
+ break
+
+ return ctx
+
+
+def _resolve_incomplete(
+ ctx: Context, args: list[str], incomplete: str
+) -> tuple[Command | Parameter, str]:
+ """Find the Click object that will handle the completion of the
+ incomplete value. Return the object and the incomplete value.
+ """
+ # Different shells treat an "=" between a long option name and
+ # value differently. Might keep the value joined, return the "="
+ # as a separate item, or return the split name and value. Always
+ # split and discard the "=" to make completion easier.
+ if incomplete == "=":
+ incomplete = ""
+ elif "=" in incomplete and _start_of_option(ctx, incomplete):
+ name, _, incomplete = incomplete.partition("=")
+ args.append(name)
+
+ # The "--" marker tells Click to stop treating values as options
+ # even if they start with the option character. If it hasn't been
+ # given and the incomplete arg looks like an option, the current
+ # command will provide option name completions.
+ if "--" not in args and _start_of_option(ctx, incomplete):
+ return ctx.command, incomplete
+
+ params = ctx.command.get_params(ctx)
+
+ # If the last complete arg is an option name with an incomplete
+ # value, the option will provide value completions.
+ for param in params:
+ if _is_incomplete_option(ctx, args, param):
+ return param, incomplete
+
+ # It's not an option name or value. The first argument without a
+ # parsed value will provide value completions.
+ for param in params:
+ if _is_incomplete_argument(ctx, param):
+ return param, incomplete
+
+ # There were no unparsed arguments, the command may be a group that
+ # will provide command name completions.
+ return ctx.command, incomplete
diff --git a/typer/_click/termui.py b/typer/_click/termui.py
new file mode 100644
index 0000000000..0a8c82574d
--- /dev/null
+++ b/typer/_click/termui.py
@@ -0,0 +1,430 @@
+import io
+from collections.abc import Callable, Iterable
+from contextlib import AbstractContextManager
+from typing import IO, TYPE_CHECKING, Any, AnyStr, TextIO, TypeVar, overload
+
+from .exceptions import Abort, UsageError
+from .globals import resolve_color_default
+from .types import ParamType, convert_type
+from .utils import LazyFile, echo
+
+if TYPE_CHECKING:
+ from ._termui_impl import ProgressBar
+
+V = TypeVar("V")
+
+# The prompt functions to use. The doc tools currently override these
+# functions to customize how they work.
+visible_prompt_func: Callable[[str], str] = input
+
+_ansi_colors = {
+ "black": 30,
+ "red": 31,
+ "green": 32,
+ "yellow": 33,
+ "blue": 34,
+ "magenta": 35,
+ "cyan": 36,
+ "white": 37,
+ "reset": 39,
+ "bright_black": 90,
+ "bright_red": 91,
+ "bright_green": 92,
+ "bright_yellow": 93,
+ "bright_blue": 94,
+ "bright_magenta": 95,
+ "bright_cyan": 96,
+ "bright_white": 97,
+}
+_ansi_reset_all = "\033[0m"
+
+
+def hidden_prompt_func(prompt: str) -> str:
+ import getpass
+
+ return getpass.getpass(prompt)
+
+
+def _build_prompt(
+ text: str,
+ suffix: str,
+ show_default: bool = False,
+ default: Any | None = None,
+ show_choices: bool = True,
+ type: ParamType | None = None,
+) -> str:
+ # prevent circular imports
+ from .._types import TyperChoice
+
+ prompt = text
+ if type is not None and show_choices and isinstance(type, TyperChoice):
+ prompt += f" ({', '.join(map(str, type.choices))})"
+ if default is not None and show_default:
+ prompt = f"{prompt} [{_format_default(default)}]"
+ return f"{prompt}{suffix}"
+
+
+def _format_default(default: Any) -> Any:
+ if isinstance(default, (io.IOBase, LazyFile)) and hasattr(default, "name"):
+ return default.name
+
+ return default
+
+
+def prompt(
+ text: str,
+ default: Any | None = None,
+ hide_input: bool = False,
+ confirmation_prompt: bool | str = False,
+ type: ParamType | Any | None = None,
+ value_proc: Callable[[str], Any] | None = None,
+ prompt_suffix: str = ": ",
+ show_default: bool = True,
+ err: bool = False,
+ show_choices: bool = True,
+) -> Any:
+ """Prompts a user for input. This is a convenience function that can
+ be used to prompt a user for input later.
+
+ If the user aborts the input by sending an interrupt signal, this
+ function will catch it and raise an `Abort` exception.
+ """
+
+ def prompt_func(text: str) -> str:
+ f = hidden_prompt_func if hide_input else visible_prompt_func
+ try:
+ # Write the prompt separately so that we get nice
+ # coloring through colorama on Windows
+ echo(text[:-1], nl=False, err=err)
+ # Echo the last character to stdout to work around an issue where
+ # readline causes backspace to clear the whole line.
+ return f(text[-1:])
+ except (KeyboardInterrupt, EOFError): # pragma: no cover
+ # getpass doesn't print a newline if the user aborts input with ^C.
+ # Allegedly this behavior is inherited from getpass(3).
+ # A doc bug has been filed at https://bugs.python.org/issue24711
+ if hide_input:
+ echo(None, err=err)
+ raise Abort() from None
+
+ if value_proc is None:
+ value_proc = convert_type(type, default)
+
+ prompt = _build_prompt(
+ text, prompt_suffix, show_default, default, show_choices, type
+ )
+
+ if confirmation_prompt:
+ if confirmation_prompt is True:
+ confirmation_prompt = "Repeat for confirmation"
+
+ confirmation_prompt = _build_prompt(confirmation_prompt, prompt_suffix)
+
+ while True:
+ while True:
+ value = prompt_func(prompt)
+ if value:
+ break
+ elif default is not None:
+ value = default
+ break
+ try:
+ result = value_proc(value)
+ except UsageError as e: # pragma: no cover
+ if hide_input:
+ echo("Error: The value you entered was invalid.", err=err)
+ else:
+ echo(f"Error: {e.message}", err=err)
+ continue
+ if not confirmation_prompt:
+ return result
+ while True:
+ value2 = prompt_func(confirmation_prompt)
+ is_empty = not value and not value2
+ if value2 or is_empty:
+ break
+ if value == value2:
+ return result
+ echo("Error: The two entered values do not match.", err=err)
+
+
+def confirm(
+ text: str,
+ default: bool | None = False,
+ abort: bool = False,
+ prompt_suffix: str = ": ",
+ show_default: bool = True,
+ err: bool = False,
+) -> bool:
+ """Prompts for confirmation (yes/no question).
+
+ If the user aborts the input by sending a interrupt signal this
+ function will catch it and raise an `Abort` exception.
+ """
+ prompt = _build_prompt(
+ text,
+ prompt_suffix,
+ show_default,
+ "y/n" if default is None else ("Y/n" if default else "y/N"),
+ )
+
+ while True:
+ try:
+ # Write the prompt separately so that we get nice
+ # coloring through colorama on Windows
+ echo(prompt[:-1], nl=False, err=err)
+ # Echo the last character to stdout to work around an issue where
+ # readline causes backspace to clear the whole line.
+ value = visible_prompt_func(prompt[-1:]).lower().strip()
+ except (KeyboardInterrupt, EOFError): # pragma: no cover
+ raise Abort() from None
+ if value in ("y", "yes"):
+ rv = True
+ elif value in ("n", "no"):
+ rv = False
+ elif default is not None and value == "":
+ rv = default
+ else: # pragma: no cover
+ echo("Error: invalid input", err=err)
+ continue
+ break
+ if abort and not rv:
+ raise Abort()
+ return rv
+
+
+@overload
+def progressbar(
+ *,
+ length: int,
+ label: str | None = None,
+ hidden: bool = False,
+ show_eta: bool = True,
+ show_percent: bool | None = None,
+ show_pos: bool = False,
+ fill_char: str = "#",
+ empty_char: str = "-",
+ bar_template: str = "%(label)s [%(bar)s] %(info)s",
+ info_sep: str = " ",
+ width: int = 36,
+ file: TextIO | None = None,
+ color: bool | None = None,
+ update_min_steps: int = 1,
+) -> "ProgressBar[int]": ...
+
+
+@overload
+def progressbar(
+ iterable: Iterable[V] | None = None,
+ length: int | None = None,
+ label: str | None = None,
+ hidden: bool = False,
+ show_eta: bool = True,
+ show_percent: bool | None = None,
+ show_pos: bool = False,
+ item_show_func: Callable[[V | None], str | None] | None = None,
+ fill_char: str = "#",
+ empty_char: str = "-",
+ bar_template: str = "%(label)s [%(bar)s] %(info)s",
+ info_sep: str = " ",
+ width: int = 36,
+ file: TextIO | None = None,
+ color: bool | None = None,
+ update_min_steps: int = 1,
+) -> "ProgressBar[V]": ...
+
+
+def progressbar(
+ iterable: Iterable[V] | None = None,
+ length: int | None = None,
+ label: str | None = None,
+ hidden: bool = False,
+ show_eta: bool = True,
+ show_percent: bool | None = None,
+ show_pos: bool = False,
+ item_show_func: Callable[[V | None], str | None] | None = None,
+ fill_char: str = "#",
+ empty_char: str = "-",
+ bar_template: str = "%(label)s [%(bar)s] %(info)s",
+ info_sep: str = " ",
+ width: int = 36,
+ file: TextIO | None = None,
+ color: bool | None = None,
+ update_min_steps: int = 1,
+) -> "ProgressBar[V]":
+ """This function creates an iterable context manager that can be used
+ to iterate over something while showing a progress bar. It will
+ either iterate over the `iterable` or `length` items (that are counted
+ up). While iteration happens, this function will print a rendered
+ progress bar to the given `file` (defaults to stdout) and will attempt
+ to calculate remaining time and more. By default, this progress bar
+ will not be rendered if the file is not a terminal.
+
+ The context manager creates the progress bar. When the context
+ manager is entered the progress bar is already created. With every
+ iteration over the progress bar, the iterable passed to the bar is
+ advanced and the bar is updated. When the context manager exits,
+ a newline is printed and the progress bar is finalized on screen.
+
+ Note: The progress bar is currently designed for use cases where the
+ total progress can be expected to take at least several seconds.
+ Because of this, the ProgressBar class object won't display
+ progress that is considered too fast, and progress where the time
+ between steps is less than a second.
+
+ No printing must happen or the progress bar will be unintentionally
+ destroyed.
+ """
+ from ._termui_impl import ProgressBar
+
+ color = resolve_color_default(color)
+ return ProgressBar(
+ iterable=iterable,
+ length=length,
+ hidden=hidden,
+ show_eta=show_eta,
+ show_percent=show_percent,
+ show_pos=show_pos,
+ item_show_func=item_show_func,
+ fill_char=fill_char,
+ empty_char=empty_char,
+ bar_template=bar_template,
+ info_sep=info_sep,
+ file=file,
+ label=label,
+ width=width,
+ color=color,
+ update_min_steps=update_min_steps,
+ )
+
+
+def _interpret_color(color: int | tuple[int, int, int] | str, offset: int = 0) -> str:
+ if isinstance(color, int):
+ return f"{38 + offset};5;{color:d}"
+
+ if isinstance(color, (tuple, list)):
+ r, g, b = color
+ return f"{38 + offset};2;{r:d};{g:d};{b:d}"
+
+ return str(_ansi_colors[color] + offset)
+
+
+def style(
+ text: Any,
+ fg: int | tuple[int, int, int] | str | None = None,
+ bg: int | tuple[int, int, int] | str | None = None,
+ bold: bool | None = None,
+ dim: bool | None = None,
+ underline: bool | None = None,
+ overline: bool | None = None,
+ italic: bool | None = None,
+ blink: bool | None = None,
+ reverse: bool | None = None,
+ strikethrough: bool | None = None,
+ reset: bool = True,
+) -> str:
+ """Styles a text with ANSI styles and returns the new string. By
+ default the styling is self contained which means that at the end
+ of the string a reset code is issued. This can be prevented by
+ passing ``reset=False``.
+ """
+ if not isinstance(text, str):
+ text = str(text)
+
+ bits = []
+
+ if fg:
+ try:
+ bits.append(f"\033[{_interpret_color(fg)}m")
+ except KeyError:
+ raise TypeError(f"Unknown color {fg!r}") from None
+
+ if bg:
+ try:
+ bits.append(f"\033[{_interpret_color(bg, 10)}m")
+ except KeyError:
+ raise TypeError(f"Unknown color {bg!r}") from None
+
+ if bold is not None:
+ bits.append(f"\033[{1 if bold else 22}m")
+ if dim is not None:
+ bits.append(f"\033[{2 if dim else 22}m")
+ if underline is not None:
+ bits.append(f"\033[{4 if underline else 24}m")
+ if overline is not None:
+ bits.append(f"\033[{53 if overline else 55}m")
+ if italic is not None:
+ bits.append(f"\033[{3 if italic else 23}m")
+ if blink is not None:
+ bits.append(f"\033[{5 if blink else 25}m")
+ if reverse is not None:
+ bits.append(f"\033[{7 if reverse else 27}m")
+ if strikethrough is not None:
+ bits.append(f"\033[{9 if strikethrough else 29}m")
+ bits.append(text)
+ if reset:
+ bits.append(_ansi_reset_all)
+ return "".join(bits)
+
+
+def secho(
+ message: Any | None = None,
+ file: IO[AnyStr] | None = None,
+ nl: bool = True,
+ err: bool = False,
+ color: bool | None = None,
+ **styles: Any,
+) -> None:
+ """This function combines `echo` and `style` into one call."""
+ if message is not None and not isinstance(message, (bytes, bytearray)):
+ message = style(message, **styles)
+
+ return echo(message, file=file, nl=nl, err=err, color=color)
+
+
+def launch(url: str, wait: bool = False, locate: bool = False) -> int:
+ """This function launches the given URL (or filename) in the default
+ viewer application for this file type. If this is an executable, it
+ might launch the executable in a new session. The return value is
+ the exit code of the launched application. Usually, ``0`` indicates
+ success.
+ """
+ from ._termui_impl import open_url
+
+ return open_url(url, wait=wait, locate=locate)
+
+
+# If this is provided, getchar() calls into this instead. This is used
+# for unittesting purposes.
+_getchar: Callable[[bool], str] | None = None
+
+
+def getchar(echo: bool = False) -> str:
+ """Fetches a single character from the terminal and returns it. This
+ will always return a unicode character and under certain rare
+ circumstances this might return more than one character. The
+ situations which more than one character is returned is when for
+ whatever reason multiple characters end up in the terminal buffer or
+ standard input was not actually a terminal.
+
+ Note that this will always read from the terminal, even if something
+ is piped into the standard input.
+
+ Note for Windows: in rare cases when typing non-ASCII characters, this
+ function might wait for a second character and then return both at once.
+ This is because certain Unicode characters look like special-key markers.
+ """
+ global _getchar
+
+ if _getchar is None:
+ from ._termui_impl import getchar as f
+
+ _getchar = f
+
+ return _getchar(echo)
+
+
+def raw_terminal() -> AbstractContextManager[int]:
+ from ._termui_impl import raw_terminal as f
+
+ return f()
diff --git a/typer/_click/testing.py b/typer/_click/testing.py
new file mode 100644
index 0000000000..0d8c03a790
--- /dev/null
+++ b/typer/_click/testing.py
@@ -0,0 +1,366 @@
+import contextlib
+import io
+import os
+import shlex
+import sys
+from collections.abc import Iterator, Mapping, Sequence
+from types import TracebackType
+from typing import IO, TYPE_CHECKING, Any, BinaryIO, cast
+
+from . import _compat, formatting, termui, utils
+from ._compat import _find_binary_reader
+from .core import Command
+
+if TYPE_CHECKING:
+ from _typeshed import ReadableBuffer
+
+
+class BytesIOCopy(io.BytesIO):
+ """Patch ``io.BytesIO`` to let the written stream be copied to another."""
+
+ def __init__(self, copy_to: io.BytesIO) -> None:
+ super().__init__()
+ self.copy_to = copy_to
+
+ def flush(self) -> None:
+ super().flush()
+ self.copy_to.flush()
+
+ def write(self, b: "ReadableBuffer") -> int:
+ self.copy_to.write(b)
+ return super().write(b)
+
+
+class StreamMixer:
+ """Mixes `` and `` streams.
+
+ The result is available in the ``output`` attribute.
+ """
+
+ def __init__(self) -> None:
+ self.output: io.BytesIO = io.BytesIO()
+ self.stdout: io.BytesIO = BytesIOCopy(copy_to=self.output)
+ self.stderr: io.BytesIO = BytesIOCopy(copy_to=self.output)
+
+ def __del__(self) -> None:
+ """
+ Guarantee that embedded file-like objects are closed in a
+ predictable order, protecting against races between
+ self.output being closed and other streams being flushed on close
+ """
+ self.stderr.close()
+ self.stdout.close()
+ self.output.close()
+
+
+class _NamedTextIOWrapper(io.TextIOWrapper):
+ def __init__(self, buffer: BinaryIO, name: str, mode: str, **kwargs: Any) -> None:
+ super().__init__(buffer, **kwargs)
+ self._name = name
+ self._mode = mode
+
+ @property
+ def name(self) -> str:
+ return self._name # pragma: no cover
+
+ @property
+ def mode(self) -> str:
+ return self._mode # pragma: no cover
+
+
+def make_input_stream(input: str | bytes | IO[Any] | None, charset: str) -> BinaryIO:
+ # Is already an input stream.
+ if hasattr(input, "read"):
+ rv = _find_binary_reader(cast("IO[Any]", input))
+
+ if rv is not None:
+ return rv
+
+ raise TypeError(
+ "Could not find binary reader for input stream."
+ ) # pragma: no cover
+
+ if input is None:
+ input = b""
+ elif isinstance(input, str):
+ input = input.encode(charset)
+
+ return io.BytesIO(input)
+
+
+class Result:
+ """Holds the captured result of an invoked CLI script."""
+
+ def __init__(
+ self,
+ runner: "CliRunner",
+ stdout_bytes: bytes,
+ stderr_bytes: bytes,
+ output_bytes: bytes,
+ return_value: Any,
+ exit_code: int,
+ exception: BaseException | None,
+ exc_info: tuple[type[BaseException], BaseException, TracebackType]
+ | None = None,
+ ):
+ self.runner = runner
+ self.stdout_bytes = stdout_bytes
+ self.stderr_bytes = stderr_bytes
+ self.output_bytes = output_bytes
+ self.return_value = return_value
+ self.exit_code = exit_code
+ self.exception = exception
+ self.exc_info = exc_info
+
+ @property
+ def output(self) -> str:
+ """The terminal output as unicode string, as the user would see it."""
+ return self.output_bytes.decode(self.runner.charset, "replace").replace(
+ "\r\n", "\n"
+ )
+
+ @property
+ def stdout(self) -> str:
+ """The standard output as unicode string."""
+ return self.stdout_bytes.decode(self.runner.charset, "replace").replace(
+ "\r\n", "\n"
+ )
+
+ @property
+ def stderr(self) -> str:
+ """The standard error as unicode string."""
+ return self.stderr_bytes.decode(self.runner.charset, "replace").replace(
+ "\r\n", "\n"
+ )
+
+ def __repr__(self) -> str:
+ exc_str = repr(self.exception) if self.exception else "okay"
+ return f"<{type(self).__name__} {exc_str}>"
+
+
+class CliRunner:
+ """The CLI runner provides functionality to invoke a Click command line
+ script for unittesting purposes in a isolated environment. This only
+ works in single-threaded systems without any concurrency as it changes the
+ global interpreter state.
+ """
+
+ def __init__(
+ self,
+ charset: str = "utf-8",
+ env: Mapping[str, str | None] | None = None,
+ catch_exceptions: bool = True,
+ ) -> None:
+ self.charset = charset
+ self.env: Mapping[str, str | None] = env or {}
+ self.catch_exceptions = catch_exceptions
+
+ def get_default_prog_name(self, cli: Command) -> str:
+ """Given a command object it will return the default program name
+ for it. The default is the `name` attribute or ``"root"`` if not
+ set.
+ """
+ return cli.name or "root"
+
+ def make_env(
+ self, overrides: Mapping[str, str | None] | None = None
+ ) -> Mapping[str, str | None]:
+ """Returns the environment overrides for invoking a script."""
+ rv = dict(self.env)
+ if overrides:
+ rv.update(overrides)
+ return rv
+
+ @contextlib.contextmanager
+ def isolation(
+ self,
+ input: str | bytes | IO[Any] | None = None,
+ env: Mapping[str, str | None] | None = None,
+ color: bool = False,
+ ) -> Iterator[tuple[io.BytesIO, io.BytesIO, io.BytesIO]]:
+ """A context manager that sets up the isolation for invoking of a
+ command line tool. This sets up `` with the given input data
+ and `os.environ` with the overrides from the given dictionary.
+ This also rebinds some internals in Click to be mocked (like the
+ prompt functionality).
+ """
+ bytes_input = make_input_stream(input, self.charset)
+
+ old_stdin = sys.stdin
+ old_stdout = sys.stdout
+ old_stderr = sys.stderr
+ old_forced_width = formatting.FORCED_WIDTH
+ formatting.FORCED_WIDTH = 80
+
+ env = self.make_env(env)
+
+ stream_mixer = StreamMixer()
+
+ sys.stdin = text_input = _NamedTextIOWrapper(
+ bytes_input, encoding=self.charset, name="", mode="r"
+ )
+
+ sys.stdout = _NamedTextIOWrapper(
+ stream_mixer.stdout, encoding=self.charset, name="", mode="w"
+ )
+
+ sys.stderr = _NamedTextIOWrapper(
+ stream_mixer.stderr,
+ encoding=self.charset,
+ name="",
+ mode="w",
+ errors="backslashreplace",
+ )
+
+ def visible_input(prompt: str | None = None) -> str:
+ sys.stdout.write(prompt or "")
+ try:
+ val = next(text_input).rstrip("\r\n")
+ except StopIteration as e: # pragma: no cover
+ raise EOFError() from e
+ sys.stdout.write(f"{val}\n")
+ sys.stdout.flush()
+ return val
+
+ def hidden_input(prompt: str | None = None) -> str:
+ sys.stdout.write(f"{prompt or ''}\n")
+ sys.stdout.flush()
+ try:
+ return next(text_input).rstrip("\r\n")
+ except StopIteration as e: # pragma: no cover
+ raise EOFError() from e
+
+ def _getchar(echo: bool) -> str:
+ char = sys.stdin.read(1)
+
+ if echo:
+ sys.stdout.write(char)
+
+ sys.stdout.flush()
+ return char
+
+ default_color = color
+
+ def should_strip_ansi(
+ stream: IO[Any] | None = None, color: bool | None = None
+ ) -> bool:
+ if color is None:
+ return not default_color
+ return not color
+
+ old_visible_prompt_func = termui.visible_prompt_func
+ old_hidden_prompt_func = termui.hidden_prompt_func
+ old__getchar_func = termui._getchar
+ old_should_strip_ansi = utils.should_strip_ansi # type: ignore[attr-defined]
+ old__compat_should_strip_ansi = _compat.should_strip_ansi
+ termui.visible_prompt_func = visible_input
+ termui.hidden_prompt_func = hidden_input # ty: ignore[invalid-assignment]
+ termui._getchar = _getchar
+ utils.should_strip_ansi = should_strip_ansi # type: ignore
+ _compat.should_strip_ansi = should_strip_ansi # ty: ignore[invalid-assignment]
+
+ old_env = {}
+ try:
+ for key, value in env.items():
+ old_env[key] = os.environ.get(key)
+ if value is None:
+ try:
+ del os.environ[key]
+ except Exception: # pragma: no cover
+ pass
+ else:
+ os.environ[key] = value
+ yield (stream_mixer.stdout, stream_mixer.stderr, stream_mixer.output)
+ finally:
+ for key, value in old_env.items():
+ if value is None:
+ try:
+ del os.environ[key]
+ except Exception: # pragma: no cover
+ pass
+ else:
+ os.environ[key] = value
+ sys.stdout = old_stdout
+ sys.stderr = old_stderr
+ sys.stdin = old_stdin
+ termui.visible_prompt_func = old_visible_prompt_func
+ termui.hidden_prompt_func = old_hidden_prompt_func
+ termui._getchar = old__getchar_func
+ utils.should_strip_ansi = old_should_strip_ansi # type: ignore[attr-defined]
+ _compat.should_strip_ansi = old__compat_should_strip_ansi
+ formatting.FORCED_WIDTH = old_forced_width
+
+ def invoke(
+ self,
+ cli: Command,
+ args: str | Sequence[str] | None = None,
+ input: str | bytes | IO[Any] | None = None,
+ env: Mapping[str, str | None] | None = None,
+ catch_exceptions: bool | None = None,
+ color: bool = False,
+ **extra: Any,
+ ) -> Result:
+ """Invokes a command in an isolated environment. The arguments are
+ forwarded directly to the command line script, the `extra` keyword
+ arguments are passed to the `Command.main` function of
+ the command.
+ """
+ exc_info = None
+ if catch_exceptions is None:
+ catch_exceptions = self.catch_exceptions
+
+ with self.isolation(input=input, env=env, color=color) as outstreams:
+ return_value = None
+ exception: BaseException | None = None
+ exit_code = 0
+
+ if isinstance(args, str):
+ args = shlex.split(args)
+
+ try:
+ prog_name = extra.pop("prog_name")
+ except KeyError:
+ prog_name = self.get_default_prog_name(cli)
+
+ try:
+ return_value = cli.main(args=args or (), prog_name=prog_name, **extra)
+ except SystemExit as e:
+ exc_info = sys.exc_info()
+ e_code = cast("int | Any | None", e.code)
+
+ if e_code is None:
+ e_code = 0
+
+ if e_code != 0:
+ exception = e
+
+ if not isinstance(e_code, int):
+ sys.stdout.write(str(e_code))
+ sys.stdout.write("\n")
+ e_code = 1
+
+ exit_code = e_code
+
+ except Exception as e:
+ if not catch_exceptions:
+ raise
+ exception = e
+ exit_code = 1
+ exc_info = sys.exc_info()
+ finally:
+ sys.stdout.flush()
+ sys.stderr.flush()
+ stdout = outstreams[0].getvalue()
+ stderr = outstreams[1].getvalue()
+ output = outstreams[2].getvalue()
+
+ return Result(
+ runner=self,
+ stdout_bytes=stdout,
+ stderr_bytes=stderr,
+ output_bytes=output,
+ return_value=return_value,
+ exit_code=exit_code,
+ exception=exception,
+ exc_info=exc_info, # type: ignore
+ )
diff --git a/typer/_click/types.py b/typer/_click/types.py
new file mode 100644
index 0000000000..5ccf15fe1b
--- /dev/null
+++ b/typer/_click/types.py
@@ -0,0 +1,695 @@
+import os
+import sys
+from collections.abc import Callable, Sequence
+from datetime import datetime
+from typing import (
+ IO,
+ TYPE_CHECKING,
+ Any,
+ ClassVar,
+ Literal,
+ NoReturn,
+ TypedDict,
+ TypeGuard,
+ TypeVar,
+ Union,
+ cast,
+)
+
+from ._compat import _get_argv_encoding, open_stream
+from .exceptions import BadParameter
+from .utils import LazyFile, format_filename, safecall
+
+if TYPE_CHECKING:
+ from .core import Context, Parameter
+ from .shell_completion import CompletionItem
+
+ParamTypeValue = TypeVar("ParamTypeValue")
+
+
+class ParamType:
+ """Represents the type of a parameter. Validates and converts values
+ from the command line or Python into the correct type.
+
+ To implement a custom type, subclass and implement at least the
+ following:
+
+ - The `name` class attribute must be set.
+ - Calling an instance of the type with ``None`` must return
+ ``None``. This is already implemented by default.
+ - `convert` must convert string values to the correct type.
+ - `convert` must accept values that are already the correct
+ type.
+ - It must be able to convert a value if the ``ctx`` and ``param``
+ arguments are ``None``. This can occur when converting prompt
+ input.
+ """
+
+ is_composite: ClassVar[bool] = False
+ arity: ClassVar[int] = 1
+ name: str
+
+ # if a list of this type is expected and the value is pulled from a
+ # string environment variable, this is what splits it up. `None`
+ # means any whitespace. For all parameters the general rule is that
+ # whitespace splits them up. The exception are paths and files which
+ # are split by ``os.path.pathsep`` by default (":" on Unix and ";" on
+ # Windows).
+ envvar_list_splitter: ClassVar[str | None] = None
+
+ def __call__(
+ self,
+ value: Any,
+ param: Union["Parameter", None] = None,
+ ctx: Union["Context", None] = None,
+ ) -> Any:
+ if value is not None:
+ return self.convert(value, param, ctx)
+
+ def get_metavar(self, param: "Parameter", ctx: "Context") -> str | None:
+ """Returns the metavar default for this param if it provides one."""
+ pass # pragma: no cover
+
+ def get_missing_message(
+ self, param: "Parameter", ctx: Union["Context", None]
+ ) -> str | None:
+ """Optionally might return extra information about a missing
+ parameter.
+ """
+ pass # pragma: no cover
+
+ def convert(
+ self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None]
+ ) -> Any:
+ pass # pragma: no cover
+
+ def split_envvar_value(self, rv: str) -> Sequence[str]:
+ """Given a value from an environment variable this splits it up
+ into small chunks depending on the defined envvar list splitter.
+
+ If the splitter is set to `None`, which means that whitespace splits,
+ then leading and trailing whitespace is ignored. Otherwise, leading
+ and trailing splitters usually lead to empty items being included.
+ """
+ return (rv or "").split(self.envvar_list_splitter)
+
+ def fail(
+ self,
+ message: str,
+ param: Union["Parameter", None] = None,
+ ctx: Union["Context", None] = None,
+ ) -> NoReturn:
+ """Helper method to fail with an invalid value message."""
+ raise BadParameter(message, ctx=ctx, param=param)
+
+ def shell_complete(
+ self, ctx: "Context", param: "Parameter", incomplete: str
+ ) -> list["CompletionItem"]:
+ """Return a list of `CompletionItem` objects for the
+ incomplete value. Most types do not provide completions, but
+ some do, and this allows custom types to provide custom
+ completions as well.
+ """
+ return []
+
+
+class CompositeParamType(ParamType):
+ is_composite = True
+
+ @property
+ def arity(self) -> int: # type: ignore
+ raise NotImplementedError() # pragma: no cover
+
+
+class FuncParamType(ParamType):
+ def __init__(self, func: Callable[[Any], Any]) -> None:
+ self.name: str = getattr(func, "__name__", "function")
+ self.func = func
+
+ def convert(
+ self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None]
+ ) -> Any:
+ try:
+ return self.func(value)
+ except ValueError:
+ try:
+ value = str(value)
+ except UnicodeError: # pragma: no cover
+ assert isinstance(value, bytes)
+ value = value.decode("utf-8", "replace")
+
+ self.fail(value, param, ctx)
+
+
+class StringParamType(ParamType):
+ name = "text"
+
+ def convert(
+ self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None]
+ ) -> Any:
+ if isinstance(value, bytes):
+ enc = _get_argv_encoding()
+ try:
+ value = value.decode(enc)
+ except UnicodeError:
+ fs_enc = sys.getfilesystemencoding()
+ if fs_enc != enc:
+ try:
+ value = value.decode(fs_enc)
+ except UnicodeError:
+ value = value.decode("utf-8", "replace")
+ else:
+ value = value.decode("utf-8", "replace")
+ return value
+ return str(value)
+
+ def __repr__(self) -> str:
+ return "STRING"
+
+
+class DateTime(ParamType):
+ """The DateTime type converts date strings into `datetime` objects.
+
+ The format strings which are checked are configurable, but default to some
+ common (non-timezone aware) ISO 8601 formats.
+
+ When specifying *DateTime* formats, you should only pass a list or a tuple.
+ Other iterables, like generators, may lead to surprising results.
+
+ The format strings are processed using ``datetime.strptime``, and this
+ consequently defines the format strings which are allowed.
+
+ Parsing is tried using each format, in order, and the first format which
+ parses successfully is used.
+ """
+
+ name = "datetime"
+
+ def __init__(self, formats: Sequence[str] | None = None):
+ self.formats: Sequence[str] = formats or [
+ "%Y-%m-%d",
+ "%Y-%m-%dT%H:%M:%S",
+ "%Y-%m-%d %H:%M:%S",
+ ]
+
+ def get_metavar(self, param: "Parameter", ctx: "Context") -> str | None:
+ return f"[{'|'.join(self.formats)}]"
+
+ def _try_to_convert_date(self, value: Any, format: str) -> datetime | None:
+ try:
+ return datetime.strptime(value, format)
+ except ValueError:
+ return None
+
+ def convert(
+ self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None]
+ ) -> Any:
+ if isinstance(value, datetime):
+ return value
+
+ for format in self.formats:
+ converted = self._try_to_convert_date(value, format)
+
+ if converted is not None:
+ return converted
+
+ formats_str = ", ".join(map(repr, self.formats))
+ self.fail(
+ f"{value!r} does not match the formats {formats_str}.",
+ param,
+ ctx,
+ )
+
+ def __repr__(self) -> str:
+ return "DateTime"
+
+
+class _NumberParamTypeBase(ParamType):
+ _number_class: ClassVar[type[Any]]
+
+ def convert(
+ self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None]
+ ) -> Any:
+ try:
+ return self._number_class(value)
+ except ValueError:
+ self.fail(
+ f"{value!r} is not a valid {self.name}.",
+ param,
+ ctx,
+ )
+
+
+class _NumberRangeBase(_NumberParamTypeBase):
+ def __init__(
+ self,
+ min: float | None = None,
+ max: float | None = None,
+ min_open: bool = False,
+ max_open: bool = False,
+ clamp: bool = False,
+ ) -> None:
+ self.min = min
+ self.max = max
+ self.min_open = min_open
+ self.max_open = max_open
+ self.clamp = clamp
+
+ def convert(
+ self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None]
+ ) -> Any:
+ import operator
+
+ rv = super().convert(value, param, ctx)
+ lt_min: bool = self.min is not None and (
+ operator.le if self.min_open else operator.lt
+ )(rv, self.min)
+ gt_max: bool = self.max is not None and (
+ operator.ge if self.max_open else operator.gt
+ )(rv, self.max)
+
+ if self.clamp:
+ if lt_min:
+ return self._clamp(self.min, 1, self.min_open) # type: ignore[arg-type]
+
+ if gt_max:
+ return self._clamp(self.max, -1, self.max_open) # type: ignore[arg-type]
+
+ if lt_min or gt_max:
+ self.fail(
+ f"{rv} is not in the range {self._describe_range()}.",
+ param,
+ ctx,
+ )
+
+ return rv
+
+ def _clamp(self, bound: float, dir: Literal[1, -1], open: bool) -> float:
+ """Find the valid value to clamp to bound in the given
+ direction.
+ """
+ raise NotImplementedError # pragma: no cover
+
+ def _describe_range(self) -> str:
+ """Describe the range for use in help text."""
+ if self.min is None:
+ op = "<" if self.max_open else "<="
+ return f"x{op}{self.max}"
+
+ if self.max is None:
+ op = ">" if self.min_open else ">="
+ return f"x{op}{self.min}"
+
+ lop = "<" if self.min_open else "<="
+ rop = "<" if self.max_open else "<="
+ return f"{self.min}{lop}x{rop}{self.max}"
+
+ def __repr__(self) -> str:
+ clamp = " clamped" if self.clamp else ""
+ return f"<{type(self).__name__} {self._describe_range()}{clamp}>"
+
+
+class IntParamType(_NumberParamTypeBase):
+ name = "integer"
+ _number_class = int
+
+ def __repr__(self) -> str:
+ return "INT"
+
+
+class IntRange(_NumberRangeBase, IntParamType):
+ """Restrict an `INT` value to a range of accepted values. See
+
+ If ``min`` or ``max`` are not passed, any value is accepted in that
+ direction. If ``min_open`` or ``max_open`` are enabled, the
+ corresponding boundary is not included in the range.
+
+ If ``clamp`` is enabled, a value outside the range is clamped to the
+ boundary instead of failing.
+ """
+
+ name = "integer range"
+
+ def _clamp( # type: ignore
+ self, bound: int, dir: Literal[1, -1], open: bool
+ ) -> int:
+ if not open:
+ return bound
+
+ return bound + dir
+
+
+class FloatParamType(_NumberParamTypeBase):
+ name = "float"
+ _number_class = float
+
+ def __repr__(self) -> str:
+ return "FLOAT"
+
+
+class FloatRange(_NumberRangeBase, FloatParamType):
+ """Restrict a `FLOAT` value to a range of accepted
+ values. See `ranges`.
+
+ If ``min`` or ``max`` are not passed, any value is accepted in that
+ direction. If ``min_open`` or ``max_open`` are enabled, the
+ corresponding boundary is not included in the range.
+
+ If ``clamp`` is enabled, a value outside the range is clamped to the
+ boundary instead of failing. This is not supported if either
+ boundary is marked ``open``.
+ """
+
+ name = "float range"
+
+ def __init__(
+ self,
+ min: float | None = None,
+ max: float | None = None,
+ min_open: bool = False,
+ max_open: bool = False,
+ clamp: bool = False,
+ ) -> None:
+ super().__init__(
+ min=min, max=max, min_open=min_open, max_open=max_open, clamp=clamp
+ )
+
+ if (min_open or max_open) and clamp:
+ raise TypeError("Clamping is not supported for open bounds.")
+
+ def _clamp(self, bound: float, dir: Literal[1, -1], open: bool) -> float:
+ if not open:
+ return bound
+
+ # Could use math.nextafter here, but clamping an
+ # open float range doesn't seem to be particularly useful. It's
+ # left up to the user to write a callback to do it if needed.
+ raise RuntimeError(
+ "Clamping is not supported for open bounds."
+ ) # pragma: no cover
+
+
+class BoolParamType(ParamType):
+ name = "boolean"
+
+ bool_states: dict[str, bool] = {
+ "1": True,
+ "0": False,
+ "yes": True,
+ "no": False,
+ "true": True,
+ "false": False,
+ "on": True,
+ "off": False,
+ "t": True,
+ "f": False,
+ "y": True,
+ "n": False,
+ # Absence of value is considered False.
+ "": False,
+ }
+ """A mapping of string values to boolean states.
+
+ Mapping is inspired by `configparser.ConfigParser.BOOLEAN_STATES`
+ and extends it.
+ """
+
+ @staticmethod
+ def str_to_bool(value: str | bool) -> bool | None:
+ """Convert a string to a boolean value.
+
+ If the value is already a boolean, it is returned as-is. If the value is a
+ string, it is stripped of whitespaces and lower-cased, then checked against
+ the known boolean states pre-defined in the `BoolParamType.bool_states` mapping
+ above.
+
+ Returns `None` if the value does not match any known boolean state.
+ """
+ if isinstance(value, bool):
+ return value
+ return BoolParamType.bool_states.get(value.strip().lower())
+
+ def convert(
+ self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None]
+ ) -> bool:
+ normalized = self.str_to_bool(value)
+ if normalized is None:
+ states = ", ".join(sorted(self.bool_states))
+ self.fail(
+ f"{value!r} is not a valid boolean. Recognized values: {states}",
+ param,
+ ctx,
+ )
+ return normalized
+
+ def __repr__(self) -> str:
+ return "BOOL"
+
+
+class UUIDParameterType(ParamType):
+ name = "uuid"
+
+ def convert(
+ self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None]
+ ) -> Any:
+ import uuid
+
+ if isinstance(value, uuid.UUID):
+ return value
+
+ value = value.strip()
+
+ try:
+ return uuid.UUID(value)
+ except ValueError:
+ self.fail(f"{value!r} is not a valid UUID.", param, ctx)
+
+ def __repr__(self) -> str:
+ return "UUID"
+
+
+class File(ParamType):
+ """Declares a parameter to be a file for reading or writing. The file
+ is automatically closed once the context tears down (after the command
+ finished working).
+
+ Files can be opened for reading or writing. The special value ``-``
+ indicates stdin or stdout depending on the mode.
+
+ By default, the file is opened for reading text data, but it can also be
+ opened in binary mode or for writing. The encoding parameter can be used
+ to force a specific encoding.
+
+ The `lazy` flag controls if the file should be opened immediately or upon
+ first IO. The default is to be non-lazy for standard input and output
+ streams as well as files opened for reading, `lazy` otherwise. When opening a
+ file lazily for reading, it is still opened temporarily for validation, but
+ will not be held open until first IO. lazy is mainly useful when opening
+ for writing to avoid creating the file until it is needed.
+
+ Files can also be opened atomically in which case all writes go into a
+ separate file in the same folder and upon completion the file will
+ be moved over to the original location. This is useful if a file
+ regularly read by other users is modified.
+ """
+
+ name = "filename"
+ envvar_list_splitter: ClassVar[str] = os.path.pathsep
+
+ def __init__(
+ self,
+ mode: str = "r",
+ encoding: str | None = None,
+ errors: str | None = "strict",
+ lazy: bool | None = None,
+ atomic: bool = False,
+ ) -> None:
+ self.mode = mode
+ self.encoding = encoding
+ self.errors = errors
+ self.lazy = lazy
+ self.atomic = atomic
+
+ def resolve_lazy_flag(self, value: str | os.PathLike[str]) -> bool:
+ if self.lazy is not None:
+ return self.lazy
+ if os.fspath(value) == "-":
+ return False
+ elif "w" in self.mode:
+ return True
+ return False
+
+ def convert(
+ self,
+ value: str | os.PathLike[str] | IO[Any],
+ param: Union["Parameter", None],
+ ctx: Union["Context", None],
+ ) -> IO[Any]:
+ if _is_file_like(value):
+ return value
+
+ value = cast("str | os.PathLike[str]", value)
+
+ try:
+ lazy = self.resolve_lazy_flag(value)
+
+ if lazy:
+ lf = LazyFile(
+ value, self.mode, self.encoding, self.errors, atomic=self.atomic
+ )
+
+ if ctx is not None:
+ ctx.call_on_close(lf.close_intelligently)
+
+ return cast("IO[Any]", lf)
+
+ f, should_close = open_stream(
+ value, self.mode, self.encoding, self.errors, atomic=self.atomic
+ )
+
+ # If a context is provided, we automatically close the file
+ # at the end of the context execution (or flush out). If a
+ # context does not exist, it's the caller's responsibility to
+ # properly close the file. This for instance happens when the
+ # type is used with prompts.
+ if ctx is not None:
+ if should_close:
+ ctx.call_on_close(safecall(f.close))
+ else:
+ ctx.call_on_close(safecall(f.flush))
+
+ return f
+ except OSError as e: # pragma: no cover
+ self.fail(f"'{format_filename(value)}': {e.strerror}", param, ctx)
+
+ def shell_complete(
+ self, ctx: "Context", param: "Parameter", incomplete: str
+ ) -> list["CompletionItem"]:
+ """Return a special completion marker that tells the completion
+ system to use the shell to provide file path completions.
+ """
+ from .shell_completion import CompletionItem
+
+ return [CompletionItem(incomplete, type="file")]
+
+
+def _is_file_like(value: Any) -> TypeGuard[IO[Any]]:
+ return hasattr(value, "read") or hasattr(value, "write")
+
+
+class Tuple(CompositeParamType):
+ """The default behavior of Click is to apply a type on a value directly.
+ This works well in most cases, except for when `nargs` is set to a fixed
+ count and different types should be used for different items. In this
+ case the `Tuple` type can be used. This type can only be used
+ if `nargs` is set to a fixed number.
+
+ For more information see `tuple-type`.
+
+ This can be selected by using a Python tuple literal as a type.
+ """
+
+ def __init__(self, types: Sequence[type[Any] | ParamType]) -> None:
+ self.types: Sequence[ParamType] = [convert_type(ty) for ty in types]
+
+ @property
+ def name(self) -> str: # type: ignore[override]
+ return f"<{' '.join(ty.name for ty in self.types)}>"
+
+ @property
+ def arity(self) -> int: # type: ignore
+ return len(self.types)
+
+ def convert(
+ self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None]
+ ) -> Any:
+ len_type = len(self.types)
+ len_value = len(value)
+
+ if len_value != len_type:
+ self.fail(
+ f"{len_type} values are required, but {len_value} given.",
+ param=param,
+ ctx=ctx,
+ )
+
+ return tuple(
+ ty(x, param, ctx) for ty, x in zip(self.types, value, strict=False)
+ )
+
+
+def convert_type(ty: Any | None, default: Any | None = None) -> ParamType:
+ """Find the most appropriate `ParamType` for the given Python
+ type. If the type isn't provided, it can be inferred from a default
+ value.
+ """
+ guessed_type = False
+
+ if ty is None and default is not None:
+ if isinstance(default, (tuple, list)):
+ # If the default is empty, ty will remain None and will
+ # return STRING.
+ if default:
+ item = default[0]
+
+ # A tuple of tuples needs to detect the inner types.
+ # Can't call convert recursively because that would
+ # incorrectly unwind the tuple to a single type.
+ if isinstance(item, (tuple, list)):
+ ty = tuple(map(type, item))
+ else:
+ ty = type(item)
+ else:
+ ty = type(default)
+
+ guessed_type = True
+
+ if isinstance(ty, tuple):
+ return Tuple(ty)
+
+ if isinstance(ty, ParamType):
+ return ty
+
+ if ty is str or ty is None:
+ return STRING
+
+ if ty is int:
+ return INT
+
+ if ty is float:
+ return FLOAT
+
+ if ty is bool:
+ return BOOL
+
+ if guessed_type:
+ return STRING
+
+ return FuncParamType(ty)
+
+
+# A unicode string parameter type which is the implicit default. This
+# can also be selected by using ``str`` as type.
+STRING = StringParamType()
+
+# An integer parameter. This can also be selected by using ``int`` as
+# type.
+INT = IntParamType()
+
+# A floating point value parameter. This can also be selected by using
+# ``float`` as type.
+FLOAT = FloatParamType()
+
+# A boolean parameter. This is the default for boolean flags. This can
+# also be selected by using ``bool`` as a type.
+BOOL = BoolParamType()
+
+# A UUID parameter.
+UUID = UUIDParameterType()
+
+
+class OptionHelpExtra(TypedDict, total=False):
+ envvars: tuple[str, ...]
+ default: str
+ range: str
+ required: str
diff --git a/typer/_click/utils.py b/typer/_click/utils.py
new file mode 100644
index 0000000000..ac8e5ba3f2
--- /dev/null
+++ b/typer/_click/utils.py
@@ -0,0 +1,470 @@
+import os
+import re
+import sys
+from collections.abc import Callable, Iterable, Iterator
+from functools import update_wrapper
+from types import ModuleType, TracebackType
+from typing import (
+ IO,
+ Any,
+ AnyStr,
+ BinaryIO,
+ Literal,
+ ParamSpec,
+ TextIO,
+ TypeVar,
+ cast,
+)
+
+from ._compat import (
+ WIN,
+ _default_text_stderr,
+ _default_text_stdout,
+ _find_binary_writer,
+ auto_wrap_for_ansi,
+ binary_streams,
+ open_stream,
+ should_strip_ansi,
+ strip_ansi,
+ text_streams,
+)
+from .globals import resolve_color_default
+
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+def _posixify(name: str) -> str:
+ return "-".join(name.split()).lower()
+
+
+def safecall(func: Callable[P, R]) -> Callable[P, R | None]:
+ """Wraps a function so that it swallows exceptions."""
+
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | None:
+ try:
+ return func(*args, **kwargs)
+ except Exception: # pragma: no cover
+ pass
+ return None # pragma: no cover
+
+ return update_wrapper(wrapper, func)
+
+
+def make_default_short_help(help: str, max_length: int = 45) -> str:
+ """Returns a condensed version of help string."""
+ # Consider only the first paragraph.
+ paragraph_end = help.find("\n\n")
+
+ if paragraph_end != -1:
+ help = help[:paragraph_end]
+
+ # Collapse newlines, tabs, and spaces.
+ words = help.split()
+
+ if not words:
+ return ""
+
+ # The first paragraph started with a "no rewrap" marker, ignore it.
+ if words[0] == "\b":
+ words = words[1:]
+
+ total_length = 0
+ last_index = len(words) - 1
+
+ for i, word in enumerate(words):
+ total_length += len(word) + (i > 0)
+
+ if total_length > max_length: # too long, truncate
+ break
+
+ if word[-1] == ".": # sentence end, truncate without "..."
+ return " ".join(words[: i + 1])
+
+ if total_length == max_length and i != last_index:
+ break # not at sentence end, truncate with "..."
+ else:
+ return " ".join(words) # no truncation needed
+
+ # Account for the length of the suffix.
+ total_length += len("...")
+
+ # remove words until the length is short enough
+ while i > 0:
+ total_length -= len(words[i]) + (i > 0)
+
+ if total_length <= max_length:
+ break
+
+ i -= 1
+
+ return " ".join(words[:i]) + "..."
+
+
+class LazyFile:
+ """A lazy file works like a regular file but it does not fully open
+ the file but it does perform some basic checks early to see if the
+ filename parameter does make sense. This is useful for safely opening
+ files for writing.
+ """
+
+ def __init__(
+ self,
+ filename: str | os.PathLike[str],
+ mode: str = "r",
+ encoding: str | None = None,
+ errors: str | None = "strict",
+ atomic: bool = False,
+ ):
+ self.name: str = os.fspath(filename)
+ self.mode = mode
+ self.encoding = encoding
+ self.errors = errors
+ self.atomic = atomic
+ self._f: IO[Any] | None
+ self.should_close: bool
+
+ if self.name == "-":
+ self._f, self.should_close = open_stream(filename, mode, encoding, errors)
+ else:
+ if "r" in mode:
+ # Open and close the file in case we're opening it for
+ # reading so that we can catch at least some errors in
+ # some cases early.
+ open(filename, mode).close()
+ self._f = None
+ self.should_close = True
+
+ def __getattr__(self, name: str) -> Any:
+ return getattr(self.open(), name)
+
+ def __repr__(self) -> str:
+ if self._f is not None:
+ return repr(self._f)
+ return f""
+
+ def open(self) -> IO[Any]:
+ """Opens the file if it's not yet open. This call might fail with
+ a `FileError`. Not handling this error will produce an error
+ that Click shows.
+ """
+ if self._f is not None:
+ return self._f
+ try:
+ rv, self.should_close = open_stream(
+ self.name, self.mode, self.encoding, self.errors, atomic=self.atomic
+ )
+ except OSError as e:
+ from .exceptions import FileError
+
+ raise FileError(self.name, hint=e.strerror) from e
+ self._f = rv
+ return rv
+
+ def close(self) -> None:
+ """Closes the underlying file, no matter what."""
+ if self._f is not None:
+ self._f.close()
+
+ def close_intelligently(self) -> None:
+ """This function only closes the file if it was opened by the lazy
+ file wrapper. For instance this will never close stdin.
+ """
+ if self.should_close:
+ self.close()
+
+ def __enter__(self) -> "LazyFile":
+ return self
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ tb: TracebackType | None,
+ ) -> None:
+ self.close_intelligently()
+
+ def __iter__(self) -> Iterator[AnyStr]:
+ self.open()
+ return iter(self._f) # type: ignore
+
+
+def echo(
+ message: Any | None = None,
+ file: IO[Any] | None = None,
+ nl: bool = True,
+ err: bool = False,
+ color: bool | None = None,
+) -> None:
+ """Print a message and newline to stdout or a file. This should be
+ used instead of `print` because it provides better support
+ for different data, files, and environments.
+
+ Compared to `print`, this does the following:
+
+ - Ensures that the output encoding is not misconfigured on Linux.
+ - Supports Unicode in the Windows console.
+ - Supports writing to binary outputs, and supports writing bytes
+ to text outputs.
+ - Supports colors and styles on Windows.
+ - Removes ANSI color and style codes if the output does not look
+ like an interactive terminal.
+ - Always flushes the output.
+ """
+ if file is None:
+ if err:
+ file = _default_text_stderr()
+ else:
+ file = _default_text_stdout()
+
+ # There are no standard streams attached to write to. For example,
+ # pythonw on Windows.
+ if file is None:
+ return
+
+ # Convert non bytes/text into the native string type.
+ if message is not None and not isinstance(message, (str, bytes, bytearray)):
+ out: str | bytes | bytearray | None = str(message)
+ else:
+ out = message
+
+ if nl:
+ out = out or ""
+ if isinstance(out, str):
+ out += "\n"
+ else:
+ out += b"\n"
+
+ if not out:
+ file.flush()
+ return
+
+ # If there is a message and the value looks like bytes, we manually
+ # need to find the binary stream and write the message in there.
+ # This is done separately so that most stream types will work as you
+ # would expect. Eg: you can write to StringIO for other cases.
+ if isinstance(out, (bytes, bytearray)):
+ binary_file = _find_binary_writer(file)
+
+ if binary_file is not None:
+ file.flush()
+ binary_file.write(out)
+ binary_file.flush()
+ return
+
+ # ANSI style code support. For no message or bytes, nothing happens.
+ # When outputting to a file instead of a terminal, strip codes.
+ else:
+ color = resolve_color_default(color)
+
+ if should_strip_ansi(file, color):
+ out = strip_ansi(out)
+ elif WIN:
+ if auto_wrap_for_ansi is not None:
+ file = auto_wrap_for_ansi(file, color) # type: ignore[arg-type,call-arg]
+ elif not color:
+ out = strip_ansi(out)
+
+ file.write(out) # type: ignore[arg-type]
+ file.flush()
+
+
+def get_binary_stream(name: Literal["stdin", "stdout", "stderr"]) -> BinaryIO:
+ """Returns a system stream for byte processing."""
+ opener = binary_streams.get(name)
+ if opener is None:
+ raise TypeError(f"Unknown standard stream '{name}'")
+ return opener()
+
+
+def get_text_stream(
+ name: Literal["stdin", "stdout", "stderr"],
+ encoding: str | None = None,
+ errors: str | None = "strict",
+) -> TextIO:
+ """Returns a system stream for text processing. This usually returns
+ a wrapped stream around a binary stream returned from
+ `get_binary_stream` but it also can take shortcuts for already
+ correctly configured streams.
+ """
+ opener = text_streams.get(name)
+ if opener is None:
+ raise TypeError(f"Unknown standard stream '{name}'")
+ return opener(encoding, errors)
+
+
+def format_filename(
+ filename: str | bytes | os.PathLike[str] | os.PathLike[bytes],
+ shorten: bool = False,
+) -> str:
+ """Format a filename as a string for display. Ensures the filename can be
+ displayed by replacing any invalid bytes or surrogate escapes in the name
+ with the replacement character ``�``.
+
+ Invalid bytes or surrogate escapes will raise an error when written to a
+ stream with ``errors="strict"``. This will typically happen with ``stdout``
+ when the locale is something like ``en_GB.UTF-8``.
+
+ Many scenarios *are* safe to write surrogates though, due to PEP 538 and
+ PEP 540, including:
+
+ - Writing to ``stderr``, which uses ``errors="backslashreplace"``.
+ - The system has ``LANG=C.UTF-8``, ``C``, or ``POSIX``. Python opens
+ stdout and stderr with ``errors="surrogateescape"``.
+ - None of ``LANG/LC_*`` are set. Python assumes ``LANG=C.UTF-8``.
+ - Python is started in UTF-8 mode with ``PYTHONUTF8=1`` or ``-X utf8``.
+ Python opens stdout and stderr with ``errors="surrogateescape"``.
+ """
+ if shorten:
+ filename = os.path.basename(filename)
+ else:
+ filename = os.fspath(filename)
+
+ if isinstance(filename, bytes):
+ filename = filename.decode(sys.getfilesystemencoding(), "replace")
+ else:
+ filename = filename.encode("utf-8", "surrogateescape").decode(
+ "utf-8", "replace"
+ )
+
+ return filename
+
+
+def get_app_dir(app_name: str, roaming: bool = True, force_posix: bool = False) -> str:
+ r"""Returns the config folder for the application. The default behavior
+ is to return whatever is most appropriate for the operating system.
+
+ To give you an idea, for an app called ``"Foo Bar"``, something like
+ the following folders could be returned:
+
+ Mac OS X:
+ ``~/Library/Application Support/Foo Bar``
+ Mac OS X (POSIX):
+ ``~/.foo-bar``
+ Unix:
+ ``~/.config/foo-bar``
+ Unix (POSIX):
+ ``~/.foo-bar``
+ Windows (roaming):
+ ``C:\Users\\AppData\Roaming\Foo Bar``
+ Windows (not roaming):
+ ``C:\Users\\AppData\Local\Foo Bar``
+ """
+ if WIN:
+ key = "APPDATA" if roaming else "LOCALAPPDATA"
+ folder = os.environ.get(key)
+ if folder is None:
+ folder = os.path.expanduser("~")
+ return os.path.join(folder, app_name)
+ if force_posix:
+ return os.path.join(os.path.expanduser(f"~/.{_posixify(app_name)}"))
+ if sys.platform == "darwin":
+ return os.path.join(
+ os.path.expanduser("~/Library/Application Support"), app_name
+ )
+ return os.path.join(
+ os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")),
+ _posixify(app_name),
+ )
+
+
+class PacifyFlushWrapper:
+ """This wrapper is used to catch and suppress BrokenPipeErrors resulting
+ from ``.flush()`` being called on broken pipe during the shutdown/final-GC
+ of the Python interpreter. Notably ``.flush()`` is always called on
+ ``sys.stdout`` and ``sys.stderr``. So as to have minimal impact on any
+ other cleanup code, and the case where the underlying file is not a broken
+ pipe, all calls and attributes are proxied.
+ """
+
+ def __init__(self, wrapped: IO[Any]) -> None:
+ self.wrapped = wrapped
+
+ def flush(self) -> None:
+ try:
+ self.wrapped.flush()
+ except OSError as e: # pragma: no cover
+ import errno
+
+ if e.errno != errno.EPIPE:
+ raise
+
+ def __getattr__(self, attr: str) -> Any:
+ return getattr(self.wrapped, attr)
+
+
+def _detect_program_name(
+ path: str | None = None, _main: ModuleType | None = None
+) -> str:
+ """Determine the command used to run the program, for use in help
+ text. If a file or entry point was executed, the file name is
+ returned. If ``python -m`` was used to execute a module or package,
+ ``python -m name`` is returned.
+
+ This doesn't try to be too precise, the goal is to give a concise
+ name for help text. Files are only shown as their name without the
+ path. ``python`` is only shown for modules, and the full path to
+ ``sys.executable`` is not shown.
+ """
+ if _main is None:
+ _main = sys.modules["__main__"]
+
+ if not path:
+ path = sys.argv[0]
+
+ # The value of __package__ indicates how Python was called. It may
+ # not exist if a setuptools script is installed as an egg. It may be
+ # set incorrectly for entry points created with pip on Windows.
+ # It is set to "" inside a Shiv or PEX zipapp.
+ if getattr(_main, "__package__", None) in {None, ""} or (
+ os.name == "nt"
+ and _main.__package__ == ""
+ and not os.path.exists(path)
+ and os.path.exists(f"{path}.exe")
+ ):
+ # Executed a file, like "python app.py".
+ return os.path.basename(path)
+
+ # Executed a module, like "python -m example".
+ # Rewritten by Python from "-m script" to "/path/to/script.py".
+ # Need to look at main module to determine how it was executed.
+ py_module = cast(str, _main.__package__)
+ name = os.path.splitext(os.path.basename(path))[0]
+
+ # A submodule like "example.cli".
+ if name != "__main__":
+ py_module = f"{py_module}.{name}"
+
+ return f"python -m {py_module.lstrip('.')}"
+
+
+def _expand_args(
+ args: Iterable[str],
+ *,
+ user: bool = True,
+ env: bool = True,
+ glob_recursive: bool = True,
+) -> list[str]:
+ """Simulate Unix shell expansion with Python functions."""
+ from glob import glob
+
+ out = []
+
+ for arg in args:
+ if user:
+ arg = os.path.expanduser(arg)
+
+ if env:
+ arg = os.path.expandvars(arg)
+
+ try:
+ matches = glob(arg, recursive=glob_recursive)
+ except re.error: # pragma: no cover
+ matches = []
+
+ if not matches:
+ out.append(arg)
+ else:
+ out.extend(matches)
+
+ return out
diff --git a/typer/_completion_classes.py b/typer/_completion_classes.py
index 8548fb4d6a..cfae02c9bf 100644
--- a/typer/_completion_classes.py
+++ b/typer/_completion_classes.py
@@ -4,11 +4,9 @@
import sys
from typing import Any
-import click
-import click.parser
-import click.shell_completion
-from click.shell_completion import split_arg_string as click_split_arg_string
-
+from . import _click
+from ._click.shell_completion import CompletionItem, ShellComplete, add_completion_class
+from ._click.shell_completion import split_arg_string as click_split_arg_string
from ._completion_shared import (
COMPLETION_SCRIPT_BASH,
COMPLETION_SCRIPT_FISH,
@@ -27,7 +25,7 @@ def _sanitize_help_text(text: str) -> str:
return rich_utils.rich_render_text(text)
-class BashComplete(click.shell_completion.BashComplete):
+class BashComplete(ShellComplete):
name = Shells.bash.value
source_template = COMPLETION_SCRIPT_BASH
@@ -50,7 +48,7 @@ def get_completion_args(self) -> tuple[list[str], str]:
return args, incomplete
- def format_completion(self, item: click.shell_completion.CompletionItem) -> str:
+ def format_completion(self, item: CompletionItem) -> str:
# TODO: Explore replicating the new behavior from Click, with item types and
# triggering completion for files and directories
# return f"{item.type},{item.value}"
@@ -62,8 +60,42 @@ def complete(self) -> str:
out = [self.format_completion(item) for item in completions]
return "\n".join(out)
+ @staticmethod
+ def _check_version() -> None:
+ import shutil
+ import subprocess
+
+ bash_exe = shutil.which("bash")
+
+ if bash_exe is None:
+ match = None # pragma: no cover
+ else:
+ output = subprocess.run(
+ [bash_exe, "--norc", "-c", 'echo "${BASH_VERSION}"'],
+ stdout=subprocess.PIPE,
+ )
+ match = re.search(r"^(\d+)\.(\d+)\.\d+", output.stdout.decode())
+
+ if match is not None:
+ major, minor = match.groups()
+
+ if major < "4" or major == "4" and minor < "4":
+ _click.utils.echo(
+ "Shell completion is not supported for Bash versions older than 4.4.",
+ err=True,
+ )
+ else:
+ _click.utils.echo(
+ "Couldn't detect Bash version, shell completion is not supported.",
+ err=True,
+ ) # pragma: no cover
+
+ def source(self) -> str:
+ self._check_version()
+ return super().source()
+
-class ZshComplete(click.shell_completion.ZshComplete):
+class ZshComplete(ShellComplete):
name = Shells.zsh.value
source_template = COMPLETION_SCRIPT_ZSH
@@ -85,7 +117,7 @@ def get_completion_args(self) -> tuple[list[str], str]:
incomplete = ""
return args, incomplete
- def format_completion(self, item: click.shell_completion.CompletionItem) -> str:
+ def format_completion(self, item: CompletionItem) -> str:
def escape(s: str) -> str:
return (
s.replace('"', '""')
@@ -114,7 +146,7 @@ def complete(self) -> str:
return "_files"
-class FishComplete(click.shell_completion.FishComplete):
+class FishComplete(ShellComplete):
name = Shells.fish.value
source_template = COMPLETION_SCRIPT_FISH
@@ -136,7 +168,7 @@ def get_completion_args(self) -> tuple[list[str], str]:
incomplete = ""
return args, incomplete
- def format_completion(self, item: click.shell_completion.CompletionItem) -> str:
+ def format_completion(self, item: CompletionItem) -> str:
# TODO: Explore replicating the new behavior from Click, pay attention to
# the difference with and without formatted help
# if item.help:
@@ -167,7 +199,7 @@ def complete(self) -> str:
return "" # pragma: no cover
-class PowerShellComplete(click.shell_completion.ShellComplete):
+class PowerShellComplete(ShellComplete):
name = Shells.powershell.value
source_template = COMPLETION_SCRIPT_POWER_SHELL
@@ -185,15 +217,13 @@ def get_completion_args(self) -> tuple[list[str], str]:
args = cwords[1:-1] if incomplete else cwords[1:]
return args, incomplete
- def format_completion(self, item: click.shell_completion.CompletionItem) -> str:
+ def format_completion(self, item: CompletionItem) -> str:
return f"{item.value}:::{_sanitize_help_text(item.help) if item.help else ' '}"
def completion_init() -> None:
- click.shell_completion.add_completion_class(BashComplete, Shells.bash.value)
- click.shell_completion.add_completion_class(ZshComplete, Shells.zsh.value)
- click.shell_completion.add_completion_class(FishComplete, Shells.fish.value)
- click.shell_completion.add_completion_class(
- PowerShellComplete, Shells.powershell.value
- )
- click.shell_completion.add_completion_class(PowerShellComplete, Shells.pwsh.value)
+ add_completion_class(BashComplete, Shells.bash.value)
+ add_completion_class(ZshComplete, Shells.zsh.value)
+ add_completion_class(FishComplete, Shells.fish.value)
+ add_completion_class(PowerShellComplete, Shells.powershell.value)
+ add_completion_class(PowerShellComplete, Shells.pwsh.value)
diff --git a/typer/_completion_shared.py b/typer/_completion_shared.py
index 5a81dcf68c..8d2c19715c 100644
--- a/typer/_completion_shared.py
+++ b/typer/_completion_shared.py
@@ -4,9 +4,11 @@
from enum import Enum
from pathlib import Path
-import click
import shellingham
+from . import _click
+from ._click.globals import get_current_context
+
class Shells(str, Enum):
bash = "bash"
@@ -78,8 +80,8 @@ def get_completion_script(*, prog_name: str, complete_var: str, shell: str) -> s
cf_name = _invalid_ident_char_re.sub("", prog_name.replace("-", "_"))
script = _completion_scripts.get(shell)
if script is None:
- click.echo(f"Shell {shell} not supported.", err=True)
- raise click.exceptions.Exit(1)
+ _click.echo(f"Shell {shell} not supported.", err=True)
+ raise _click.exceptions.Exit(1)
return (
script
% {
@@ -172,8 +174,8 @@ def install_powershell(*, prog_name: str, complete_var: str, shell: str) -> Path
stdout=subprocess.PIPE,
)
if result.returncode != 0: # pragma: no cover
- click.echo("Couldn't get PowerShell user profile", err=True)
- raise click.exceptions.Exit(result.returncode)
+ _click.echo("Couldn't get PowerShell user profile", err=True)
+ raise _click.exceptions.Exit(result.returncode)
path_str = ""
if isinstance(result.stdout, str): # pragma: no cover
path_str = result.stdout
@@ -185,8 +187,8 @@ def install_powershell(*, prog_name: str, complete_var: str, shell: str) -> Path
except UnicodeDecodeError: # pragma: no cover
pass
if not path_str: # pragma: no cover
- click.echo("Couldn't decode the path automatically", err=True)
- raise click.exceptions.Exit(1)
+ _click.echo("Couldn't decode the path automatically", err=True)
+ raise _click.exceptions.Exit(1)
path_obj = Path(path_str.strip())
parent_dir: Path = path_obj.parent
parent_dir.mkdir(parents=True, exist_ok=True)
@@ -203,7 +205,7 @@ def install(
prog_name: str | None = None,
complete_var: str | None = None,
) -> tuple[str, Path]:
- prog_name = prog_name or click.get_current_context().find_root().info_name
+ prog_name = prog_name or get_current_context().find_root().info_name
assert prog_name
if complete_var is None:
complete_var = "_{}_COMPLETE".format(prog_name.replace("-", "_").upper())
@@ -231,8 +233,8 @@ def install(
)
return shell, installed_path
else:
- click.echo(f"Shell {shell} is not supported.")
- raise click.exceptions.Exit(1)
+ _click.echo(f"Shell {shell} is not supported.")
+ raise _click.exceptions.Exit(1)
def _get_shell_name() -> str | None:
diff --git a/typer/_types.py b/typer/_types.py
index dc9fc63220..09b38afb3f 100644
--- a/typer/_types.py
+++ b/typer/_types.py
@@ -1,21 +1,42 @@
+from collections.abc import Iterable, Mapping, Sequence
from enum import Enum
-from typing import TypeVar
+from typing import Any, Generic, TypeVar
-import click
+from . import _click
+from ._click import types
+from ._click.shell_completion import CompletionItem
ParamTypeValue = TypeVar("ParamTypeValue")
-class TyperChoice(click.Choice[ParamTypeValue]):
+class TyperChoice(types.ParamType, Generic[ParamTypeValue]):
+ # Code adapted from Click 8.3.1, with Typer using enum values in normalize_choice
+ name = "choice"
+
+ def __init__(
+ self, choices: Iterable[ParamTypeValue], case_sensitive: bool = True
+ ) -> None:
+ self.choices: Sequence[ParamTypeValue] = tuple(choices)
+ self.case_sensitive = case_sensitive
+
+ def _normalized_mapping(
+ self, ctx: _click.Context | None = None
+ ) -> Mapping[ParamTypeValue, str]:
+ """
+ Returns mapping where keys are the original choices and the values are
+ the normalized values that are accepted via the command line.
+ """
+ return {
+ choice: self.normalize_choice(
+ choice=choice,
+ ctx=ctx,
+ )
+ for choice in self.choices
+ }
+
def normalize_choice(
- self, choice: ParamTypeValue, ctx: click.Context | None
+ self, choice: ParamTypeValue, ctx: _click.Context | None
) -> str:
- # Click 8.2.0 added a new method `normalize_choice` to the `Choice` class
- # to support enums, but it uses the enum names, while Typer has always used the
- # enum values.
- # This class overrides that method to maintain the previous behavior.
- # In Click:
- # normed_value = choice.name if isinstance(choice, Enum) else str(choice)
normed_value = str(choice.value) if isinstance(choice, Enum) else str(choice)
if ctx is not None and ctx.token_normalize_func is not None:
@@ -25,3 +46,75 @@ def normalize_choice(
normed_value = normed_value.casefold()
return normed_value
+
+ def get_metavar(self, param: _click.Parameter, ctx: _click.Context) -> str | None:
+ if param.param_type_name == "option" and not param.show_choices: # type: ignore
+ choice_metavars = [
+ types.convert_type(type(choice)).name.upper() for choice in self.choices
+ ]
+ choices_str = "|".join([*dict.fromkeys(choice_metavars)])
+ else:
+ choices_str = "|".join(
+ [str(i) for i in self._normalized_mapping(ctx=ctx).values()]
+ )
+
+ # Use curly braces to indicate a required argument.
+ if param.required and param.param_type_name == "argument":
+ return f"{{{choices_str}}}"
+
+ # Use square braces to indicate an option or optional argument.
+ return f"[{choices_str}]"
+
+ def get_missing_message(
+ self, param: _click.Parameter, ctx: _click.Context | None
+ ) -> str:
+ """Message shown when no choice is passed."""
+ choices = ",\n\t".join(self._normalized_mapping(ctx=ctx).values())
+ return f"Choose from:\n\t{choices}"
+
+ def convert(
+ self, value: Any, param: _click.Parameter | None, ctx: _click.Context | None
+ ) -> ParamTypeValue:
+ """
+ For a given value from the parser, normalize it and find its
+ matching normalized value in the list of choices. Then return the
+ matched "original" choice.
+ """
+ normed_value = self.normalize_choice(choice=value, ctx=ctx)
+ normalized_mapping = self._normalized_mapping(ctx=ctx)
+
+ try:
+ return next(
+ original
+ for original, normalized in normalized_mapping.items()
+ if normalized == normed_value
+ )
+ except StopIteration:
+ self.fail(
+ self.get_invalid_choice_message(value=value, ctx=ctx),
+ param=param,
+ ctx=ctx,
+ )
+
+ def get_invalid_choice_message(self, value: Any, ctx: _click.Context | None) -> str:
+ """Get the error message when the given choice is invalid."""
+ choices_str = ", ".join(map(repr, self._normalized_mapping(ctx=ctx).values()))
+ return f"{value!r} is not one of {choices_str}."
+
+ def __repr__(self) -> str:
+ return f"Choice({list(self.choices)})"
+
+ def shell_complete(
+ self, ctx: _click.Context, param: _click.Parameter, incomplete: str
+ ) -> list[CompletionItem]:
+ """Complete choices that start with the incomplete value."""
+
+ str_choices = map(str, self.choices)
+
+ if self.case_sensitive:
+ matched = (c for c in str_choices if c.startswith(incomplete))
+ else:
+ incomplete = incomplete.lower()
+ matched = (c for c in str_choices if c.lower().startswith(incomplete))
+
+ return [CompletionItem(c) for c in matched]
diff --git a/typer/cli.py b/typer/cli.py
index 2a7d78c3a4..665bcf5a59 100644
--- a/typer/cli.py
+++ b/typer/cli.py
@@ -4,13 +4,12 @@
from pathlib import Path
from typing import Any
-import click
import typer
import typer.core
-from click import Command, Group, Option
-from . import __version__
-from .core import HAS_RICH, MARKUP_MODE_KEY
+from . import __version__, _click
+from ._click import Command
+from .core import HAS_RICH, MARKUP_MODE_KEY, TyperGroup, TyperOption
default_app_names = ("app", "cli", "main")
default_func_names = ("main", "cli", "app")
@@ -31,7 +30,7 @@ def __init__(self) -> None:
state = State()
-def maybe_update_state(ctx: click.Context) -> None:
+def maybe_update_state(ctx: _click.Context) -> None:
path_or_module = ctx.params.get("path_or_module")
if path_or_module:
file_path = Path(path_or_module)
@@ -53,19 +52,19 @@ def maybe_update_state(ctx: click.Context) -> None:
class TyperCLIGroup(typer.core.TyperGroup):
- def list_commands(self, ctx: click.Context) -> list[str]:
+ def list_commands(self, ctx: _click.Context) -> list[str]:
self.maybe_add_run(ctx)
return super().list_commands(ctx)
- def get_command(self, ctx: click.Context, name: str) -> Command | None: # ty: ignore[invalid-method-override]
+ def get_command(self, ctx: _click.Context, name: str) -> Command | None: # ty: ignore[invalid-method-override]
self.maybe_add_run(ctx)
return super().get_command(ctx, name)
- def invoke(self, ctx: click.Context) -> Any:
+ def invoke(self, ctx: _click.Context) -> Any:
self.maybe_add_run(ctx)
return super().invoke(ctx)
- def maybe_add_run(self, ctx: click.Context) -> None:
+ def maybe_add_run(self, ctx: _click.Context) -> None:
maybe_update_state(ctx)
maybe_add_run_to_cli(self)
@@ -138,7 +137,7 @@ def get_typer_from_state() -> typer.Typer | None:
return obj
-def maybe_add_run_to_cli(cli: click.Group) -> None:
+def maybe_add_run_to_cli(cli: TyperGroup) -> None:
if "run" not in cli.commands:
if state.file or state.module:
obj = get_typer_from_state()
@@ -151,7 +150,7 @@ def maybe_add_run_to_cli(cli: click.Group) -> None:
cli.add_command(click_obj)
-def print_version(ctx: click.Context, param: Option, value: bool) -> None:
+def print_version(ctx: _click.Context, param: TyperOption, value: bool) -> None:
if not value or ctx.resilient_parsing:
return
typer.echo(f"Typer version: {__version__}")
@@ -242,7 +241,7 @@ def get_docs_for_click(
docs += "\n"
if obj.epilog:
docs += f"{obj.epilog}\n\n"
- if isinstance(obj, Group):
+ if isinstance(obj, TyperGroup):
group = obj
commands = group.list_commands(ctx)
if commands:
diff --git a/typer/completion.py b/typer/completion.py
index 0d621e411d..f63692ddf3 100644
--- a/typer/completion.py
+++ b/typer/completion.py
@@ -3,8 +3,8 @@
from collections.abc import MutableMapping
from typing import Any
-import click
-
+from . import _click
+from ._click import shell_completion
from ._completion_classes import completion_init
from ._completion_shared import Shells, _get_shell_name, get_completion_script, install
from .models import ParamMeta
@@ -27,19 +27,19 @@ def get_completion_inspect_parameters() -> tuple[ParamMeta, ParamMeta]:
return install_param, show_param
-def install_callback(ctx: click.Context, param: click.Parameter, value: Any) -> Any:
+def install_callback(ctx: _click.Context, param: _click.Parameter, value: Any) -> Any:
if not value or ctx.resilient_parsing:
return value # pragma: no cover
if isinstance(value, str):
shell, path = install(shell=value)
else:
shell, path = install()
- click.secho(f"{shell} completion installed in {path}", fg="green")
- click.echo("Completion will take effect once you restart the terminal")
+ _click.termui.secho(f"{shell} completion installed in {path}", fg="green")
+ _click.echo("Completion will take effect once you restart the terminal")
sys.exit(0)
-def show_callback(ctx: click.Context, param: click.Parameter, value: Any) -> Any:
+def show_callback(ctx: _click.Context, param: _click.Parameter, value: Any) -> Any:
if not value or ctx.resilient_parsing:
return value # pragma: no cover
prog_name = ctx.find_root().info_name
@@ -56,7 +56,7 @@ def show_callback(ctx: click.Context, param: click.Parameter, value: Any) -> Any
script_content = get_completion_script(
prog_name=prog_name, complete_var=complete_var, shell=shell
)
- click.echo(script_content)
+ _click.echo(script_content)
sys.exit(0)
@@ -103,17 +103,16 @@ def _install_completion_no_auto_placeholder_function(
# And to add extra error messages, for compatibility with Typer in previous versions
# This is only called in new Command method, only used by Click 8.x+
def shell_complete(
- cli: click.Command,
+ cli: _click.Command,
ctx_args: MutableMapping[str, Any],
prog_name: str,
complete_var: str,
instruction: str,
) -> int:
- import click
- import click.shell_completion
+ from . import _click
if "_" not in instruction:
- click.echo("Invalid completion instruction.", err=True)
+ _click.echo("Invalid completion instruction.", err=True)
return 1
# Click 8 changed the order/style of shell instructions from e.g.
@@ -124,23 +123,23 @@ def shell_complete(
instruction, _, shell = instruction.partition("_")
# Typer override end
- comp_cls = click.shell_completion.get_completion_class(shell)
+ comp_cls = shell_completion.get_completion_class(shell)
if comp_cls is None:
- click.echo(f"Shell {shell} not supported.", err=True)
+ _click.echo(f"Shell {shell} not supported.", err=True)
return 1
comp = comp_cls(cli, ctx_args, prog_name, complete_var)
if instruction == "source":
- click.echo(comp.source())
+ _click.echo(comp.source())
return 0
# Typer override to print the completion help msg with Rich
if instruction == "complete":
- click.echo(comp.complete())
+ _click.echo(comp.complete())
return 0
# Typer override end
- click.echo(f'Completion instruction "{instruction}" not supported.', err=True)
+ _click.echo(f'Completion instruction "{instruction}" not supported.', err=True)
return 1
diff --git a/typer/core.py b/typer/core.py
index 48fee64e34..edd9370eb2 100644
--- a/typer/core.py
+++ b/typer/core.py
@@ -2,7 +2,7 @@
import inspect
import os
import sys
-from collections.abc import Callable, MutableMapping, Sequence
+from collections.abc import Callable, Mapping, MutableMapping, Sequence
from difflib import get_close_matches
from enum import Enum
from gettext import gettext as _
@@ -13,13 +13,10 @@
cast,
)
-import click
-import click.core
-import click.formatting
-import click.shell_completion
-import click.types
-import click.utils
-
+from . import _click
+from ._click import types
+from ._click.parser import _OptionParser
+from ._click.shell_completion import CompletionItem
from ._typing import Literal
from .utils import parse_boolean_env_var
@@ -34,7 +31,7 @@
DEFAULT_MARKUP_MODE = None
-# Copy from click.parser._split_opt
+# Copy from _click.parser._split_opt
def _split_opt(opt: str) -> tuple[str, str]:
first = opt[:1]
if first.isalnum():
@@ -45,10 +42,10 @@ def _split_opt(opt: str) -> tuple[str, str]:
def _typer_param_setup_autocompletion_compat(
- self: click.Parameter,
+ self: _click.Parameter,
*,
autocompletion: Callable[
- [click.Context, list[str], str], list[tuple[str, str] | str]
+ [_click.Context, list[str], str], list[tuple[str, str] | str]
]
| None = None,
) -> None:
@@ -65,10 +62,8 @@ def _typer_param_setup_autocompletion_compat(
if autocompletion is not None:
def compat_autocompletion(
- ctx: click.Context, param: click.core.Parameter, incomplete: str
- ) -> list["click.shell_completion.CompletionItem"]:
- from click.shell_completion import CompletionItem
-
+ ctx: _click.Context, param: _click.core.Parameter, incomplete: str
+ ) -> list[CompletionItem]:
out = []
for c in autocompletion(ctx, [], incomplete):
@@ -89,11 +84,11 @@ def compat_autocompletion(
def _get_default_string(
obj: Union["TyperArgument", "TyperOption"],
*,
- ctx: click.Context,
+ ctx: _click.Context,
show_default_is_str: bool,
default_value: list[Any] | tuple[Any, ...] | str | Callable[..., Any] | Any,
) -> str:
- # Extracted from click.core.Option.get_help_record() to be reused by
+ # Extracted from _click.core.Option.get_help_record() to be reused by
# rich_utils avoiding RegEx hacks
if show_default_is_str:
default_string = f"({obj.show_default})"
@@ -112,7 +107,7 @@ def _get_default_string(
# For boolean flags that have distinct True/False opts,
# use the opt without prefix instead of the value.
# Typer override, original commented
- # default_string = click.parser.split_opt(
+ # default_string = _click.parser.split_opt(
# (self.opts if self.default else self.secondary_opts)[0]
# )[1]
if obj.default:
@@ -136,9 +131,9 @@ def _get_default_string(
def _extract_default_help_str(
- obj: Union["TyperArgument", "TyperOption"], *, ctx: click.Context
+ obj: Union["TyperArgument", "TyperOption"], *, ctx: _click.Context
) -> Any | Callable[[], Any] | None:
- # Extracted from click.core.Option.get_help_record() to be reused by
+ # Extracted from _click.core.Option.get_help_record() to be reused by
# rich_utils avoiding RegEx hacks
# Temporarily enable resilient parsing to avoid type casting
# failing for the default. Might be possible to extend this to
@@ -154,7 +149,7 @@ def _extract_default_help_str(
def _main(
- self: click.Command,
+ self: _click.Command,
*,
args: Sequence[str] | None = None,
prog_name: str | None = None,
@@ -164,7 +159,7 @@ def _main(
rich_markup_mode: MarkupMode = DEFAULT_MARKUP_MODE,
**extra: Any,
) -> Any:
- # Typer override, duplicated from click.main() to handle custom rich exceptions
+ # Typer override, duplicated from _click.main() to handle custom rich exceptions
# Verify that the environment is configured correctly, or reject
# further execution to avoid a broken script.
if args is None:
@@ -172,12 +167,12 @@ def _main(
# Covered in Click tests
if os.name == "nt" and windows_expand_args: # pragma: no cover
- args = click.utils._expand_args(args)
+ args = _click.utils._expand_args(args)
else:
args = list(args)
if prog_name is None:
- prog_name = click.utils._detect_program_name()
+ prog_name = _click.utils._detect_program_name()
# Process shell completion requests and exit early.
self._main_shell_completion(extra, prog_name, complete_var)
@@ -197,11 +192,11 @@ def _main(
# by its truthiness/falsiness
ctx.exit()
except EOFError as e:
- click.echo(file=sys.stderr)
- raise click.Abort() from e
+ _click.echo(file=sys.stderr)
+ raise _click.exceptions.Abort() from e
except KeyboardInterrupt as e:
- raise click.exceptions.Exit(130) from e
- except click.ClickException as e:
+ raise _click.exceptions.Exit(130) from e
+ except _click.exceptions.ClickException as e:
if not standalone_mode:
raise
# Typer override
@@ -215,12 +210,12 @@ def _main(
sys.exit(e.exit_code)
except OSError as e:
if e.errno == errno.EPIPE:
- sys.stdout = cast(TextIO, click.utils.PacifyFlushWrapper(sys.stdout))
- sys.stderr = cast(TextIO, click.utils.PacifyFlushWrapper(sys.stderr))
+ sys.stdout = cast(TextIO, _click.utils.PacifyFlushWrapper(sys.stdout))
+ sys.stderr = cast(TextIO, _click.utils.PacifyFlushWrapper(sys.stderr))
sys.exit(1)
else:
raise
- except click.exceptions.Exit as e:
+ except _click.exceptions.Exit as e:
if standalone_mode:
sys.exit(e.exit_code)
else:
@@ -233,7 +228,7 @@ def _main(
# `ctx.exit(1)` and to `return 1`, the caller won't be able to
# tell the difference between the two
return e.exit_code
- except click.Abort:
+ except _click.exceptions.Abort:
if not standalone_mode:
raise
# Typer override
@@ -242,19 +237,21 @@ def _main(
rich_utils.rich_abort_error()
else:
- click.echo(_("Aborted!"), file=sys.stderr)
+ _click.echo(_("Aborted!"), file=sys.stderr)
# Typer override end
sys.exit(1)
-class TyperArgument(click.core.Argument):
+class TyperArgument(_click.core.Parameter):
+ param_type_name = "argument"
+
def __init__(
self,
*,
# Parameter
param_decls: list[str],
type: Any | None = None,
- required: bool | None = None,
+ required: bool = False,
default: Any | None = None,
callback: Callable[..., Any] | None = None,
nargs: int | None = None,
@@ -265,8 +262,8 @@ def __init__(
# Note that shell_complete is not fully supported and will be removed in future versions
# TODO: Remove shell_complete in a future version (after 0.16.0)
shell_complete: Callable[
- [click.Context, click.Parameter, str],
- list["click.shell_completion.CompletionItem"] | list[str],
+ [_click.Context, _click.Parameter, str],
+ list[CompletionItem] | list[str],
]
| None = None,
autocompletion: Callable[..., Any] | None = None,
@@ -301,10 +298,17 @@ def __init__(
)
_typer_param_setup_autocompletion_compat(self, autocompletion=autocompletion)
+ @property
+ def human_readable_name(self) -> str:
+ if self.metavar is not None:
+ return self.metavar
+ assert self.name is not None, "self.name or self.metavar should be set"
+ return self.name.upper()
+
def _get_default_string(
self,
*,
- ctx: click.Context,
+ ctx: _click.Context,
show_default_is_str: bool,
default_value: list[Any] | tuple[Any, ...] | str | Callable[..., Any] | Any,
) -> str:
@@ -316,12 +320,12 @@ def _get_default_string(
)
def _extract_default_help_str(
- self, *, ctx: click.Context
+ self, *, ctx: _click.Context
) -> Any | Callable[[], Any] | None:
return _extract_default_help_str(self, ctx=ctx)
- def get_help_record(self, ctx: click.Context) -> tuple[str, str] | None:
- # Modified version of click.core.Option.get_help_record()
+ def get_help_record(self, ctx: _click.Context) -> tuple[str, str] | None:
+ # Modified version of _click.core.Option.get_help_record()
# to support Arguments
if self.hidden:
return None
@@ -376,8 +380,8 @@ def get_help_record(self, ctx: click.Context) -> tuple[str, str] | None:
help = f"{help} {extra_str}" if help else f"{extra_str}"
return name, help
- def make_metavar(self, ctx: click.Context) -> str:
- # Modified version of click.core.Argument.make_metavar()
+ def make_metavar(self, ctx: _click.Context) -> str:
+ # Modified version of _click.core.Argument.make_metavar()
# to include Argument name
if self.metavar is not None:
var = self.metavar
@@ -397,15 +401,45 @@ def make_metavar(self, ctx: click.Context) -> str:
def value_is_missing(self, value: Any) -> bool:
return _value_is_missing(self, value)
+ def _parse_decls(
+ self, decls: Sequence[str], expose_value: bool
+ ) -> tuple[str | None, list[str], list[str]]:
+ if not decls:
+ if not expose_value:
+ return None, [], []
+ raise TypeError("Argument is marked as exposed, but does not have a name.")
+ if len(decls) == 1:
+ name = arg = decls[0]
+ name = name.replace("-", "_").lower()
+ else:
+ raise TypeError(
+ "Arguments take exactly one parameter declaration, got"
+ f" {len(decls)}: {decls}."
+ )
+ return name, [arg], []
+
+ def get_usage_pieces(self, ctx: _click.Context) -> list[str]:
+ return [self.make_metavar(ctx)]
+
+ def get_error_hint(self, ctx: _click.Context) -> str:
+ return f"'{self.make_metavar(ctx)}'"
+
+ def add_to_parser(self, parser: _OptionParser, ctx: _click.Context) -> None:
+ parser.add_argument(dest=self.name, nargs=self.nargs, obj=self)
+
+
+class TyperOption(_click.Parameter):
+ param_type_name = "option"
+
+ _depr_flag_value: bool | None
-class TyperOption(click.core.Option):
def __init__(
self,
*,
# Parameter
param_decls: list[str],
- type: click.types.ParamType | Any | None = None,
- required: bool | None = None,
+ type: types.ParamType | Any | None = None,
+ required: bool = False,
default: Any | None = None,
callback: Callable[..., Any] | None = None,
nargs: int | None = None,
@@ -416,8 +450,8 @@ def __init__(
# Note that shell_complete is not fully supported and will be removed in future versions
# TODO: Remove shell_complete in a future version (after 0.16.0)
shell_complete: Callable[
- [click.Context, click.Parameter, str],
- list["click.shell_completion.CompletionItem"] | list[str],
+ [_click.Context, _click.Parameter, str],
+ list[CompletionItem] | list[str],
]
| None = None,
autocompletion: Callable[..., Any] | None = None,
@@ -438,9 +472,13 @@ def __init__(
# Rich settings
rich_help_panel: str | None = None,
):
+ if help:
+ help = inspect.cleandoc(help)
+
super().__init__(
- param_decls=param_decls,
+ param_decls,
type=type,
+ multiple=multiple,
required=required,
default=default,
callback=callback,
@@ -449,28 +487,240 @@ def __init__(
expose_value=expose_value,
is_eager=is_eager,
envvar=envvar,
- show_default=show_default,
- prompt=prompt,
- confirmation_prompt=confirmation_prompt,
- hide_input=hide_input,
- is_flag=is_flag,
- multiple=multiple,
- count=count,
- allow_from_autoenv=allow_from_autoenv,
- help=help,
- hidden=hidden,
- show_choices=show_choices,
- show_envvar=show_envvar,
- prompt_required=prompt_required,
shell_complete=shell_complete,
)
+
+ if prompt is True:
+ if self.name is None:
+ raise TypeError("'name' is required with 'prompt=True'.")
+
+ prompt_text: str | None = self.name.replace("_", " ").capitalize()
+ elif prompt is False:
+ prompt_text = None
+ else:
+ prompt_text = prompt
+
+ self.prompt = prompt_text
+ self.confirmation_prompt = confirmation_prompt
+ self.prompt_required = prompt_required
+ self.hide_input = hide_input
+ self.hidden = hidden
+
+ # TODO: revisit all of this flag stuff
+ if is_flag and type is None:
+ self.type: types.ParamType = types.BoolParamType()
+
+ self.is_flag: bool = bool(is_flag)
+ self.is_bool_flag: bool = bool(
+ is_flag and isinstance(self.type, types.BoolParamType)
+ )
+
+ if self.is_flag:
+ self._depr_flag_value = True
+ else:
+ self._depr_flag_value = None
+
+ # Counting. TODO: test or remove? Not currently in coverage.
+ self.count = count
+ if count and type is None:
+ self.type = types.IntRange(min=0)
+
+ self.allow_from_autoenv = allow_from_autoenv
+ self.help = help
+ self.show_default = show_default
+ self.show_choices = show_choices
+ self.show_envvar = show_envvar
+
_typer_param_setup_autocompletion_compat(self, autocompletion=autocompletion)
self.rich_help_panel = rich_help_panel
+ def get_error_hint(self, ctx: _click.Context) -> str:
+ result = super().get_error_hint(ctx)
+ if self.show_envvar and self.envvar is not None:
+ result += f" (env var: '{self.envvar}')"
+ return result
+
+ def _parse_decls(
+ self, decls: Sequence[str], expose_value: bool
+ ) -> tuple[str | None, list[str], list[str]]:
+ opts = []
+ secondary_opts = []
+ name = None
+ possible_names = []
+
+ for decl in decls:
+ if decl.isidentifier():
+ if name is not None:
+ raise TypeError(f"Name '{name}' defined twice")
+ name = decl
+ else:
+ split_char = ";" if decl[:1] == "/" else "/"
+ if split_char in decl:
+ first, second = decl.split(split_char, 1)
+ first = first.rstrip()
+ if first:
+ possible_names.append(_split_opt(first))
+ opts.append(first)
+ second = second.lstrip()
+ if second:
+ secondary_opts.append(second.lstrip())
+ if first == second:
+ raise ValueError(
+ f"Boolean option {decl!r} cannot use the"
+ " same flag for true/false."
+ )
+ else:
+ possible_names.append(_split_opt(decl))
+ opts.append(decl)
+
+ if name is None and possible_names:
+ possible_names.sort(key=lambda x: -len(x[0])) # group long options first
+ name = possible_names[0][1].replace("-", "_").lower()
+ if not name.isidentifier():
+ name = None
+
+ return name, opts, secondary_opts
+
+ def add_to_parser(self, parser: _OptionParser, ctx: _click.Context) -> None:
+ if self.multiple:
+ action = "append"
+ elif self.count:
+ action = "count"
+ else:
+ action = "store"
+
+ if self.is_flag:
+ action = f"{action}_const"
+
+ if self.is_bool_flag and self.secondary_opts:
+ parser.add_option(
+ obj=self, opts=self.opts, dest=self.name, action=action, const=True
+ )
+ parser.add_option(
+ obj=self,
+ opts=self.secondary_opts,
+ dest=self.name,
+ action=action,
+ const=False,
+ )
+ else:
+ parser.add_option(
+ obj=self,
+ opts=self.opts,
+ dest=self.name,
+ action=action,
+ const=self._depr_flag_value,
+ )
+ else:
+ parser.add_option(
+ obj=self,
+ opts=self.opts,
+ dest=self.name,
+ action=action,
+ nargs=self.nargs,
+ )
+
+ def prompt_for_value(self, ctx: _click.Context) -> Any:
+ """This is an alternative flow that can be activated in the full
+ value processing if a value does not exist. It will prompt the
+ user until a valid value exists and then returns the processed
+ value as result.
+ """
+ assert self.prompt is not None
+
+ # Calculate the default before prompting anything to lock in the value before
+ # attempting any user interaction.
+ default = self.get_default(ctx)
+
+ # A boolean flag can use a simplified [y/n] confirmation prompt.
+ if self.is_bool_flag:
+ # Nothing prevent you to declare an option that is simultaneously:
+ # 1) auto-detected as a boolean flag,
+ # 2) allowed to prompt, and
+ # 3) still declare a non-boolean default.
+ # This forced casting into a boolean is necessary to align any non-boolean
+ # default to the prompt, which is going to be a [y/n]-style confirmation
+ # because the option is still a boolean flag. That way, instead of [y/n],
+ # we get [Y/n] or [y/N] depending on the truthy value of the default.
+ # Refs: https://github.com/pallets/click/pull/3030#discussion_r2289180249
+ if default is not None:
+ default = bool(default)
+ return _click.termui.confirm(self.prompt, default)
+
+ # If show_default is set to True/False, provide this to `prompt` as well. For
+ # non-bool values of `show_default`, we use `prompt`'s default behavior
+ prompt_kwargs: Any = {}
+ if isinstance(self.show_default, bool):
+ prompt_kwargs["show_default"] = self.show_default
+
+ return _click.termui.prompt(
+ self.prompt,
+ # Use ``None`` to inform the prompt() function to reiterate until a valid
+ # value is provided by the user if we have no default.
+ default=default,
+ type=self.type,
+ hide_input=self.hide_input,
+ show_choices=self.show_choices,
+ confirmation_prompt=self.confirmation_prompt,
+ value_proc=lambda x: self.process_value(ctx, x),
+ **prompt_kwargs,
+ )
+
+ def value_from_envvar(self, ctx: _click.Context) -> Any:
+ rv = self.resolve_envvar_value(ctx)
+
+ # Absent environment variable or an empty string is interpreted as unset.
+ if rv is None:
+ return None
+
+ def resolve_envvar_value(self, ctx: _click.Context) -> str | None:
+ rv = super().resolve_envvar_value(ctx)
+
+ if rv is not None:
+ return rv
+
+ if (
+ self.allow_from_autoenv
+ and ctx.auto_envvar_prefix is not None
+ and self.name is not None
+ ):
+ envvar = f"{ctx.auto_envvar_prefix}_{self.name.upper()}"
+ rv = os.environ.get(envvar)
+
+ if rv:
+ return rv
+
+ return None
+
+ def consume_value(
+ self, ctx: _click.Context, opts: Mapping[str, _click.Parameter]
+ ) -> tuple[Any, _click.core.ParameterSource]:
+ """For `Option`, the value can be collected from an interactive prompt
+ if the option is a flag that needs a value (and the `prompt` property is
+ set).
+
+ Additionally, this method handles flag option that are activated without a
+ value, in which case the `flag_value` is returned.
+ """
+ value, source = super().consume_value(ctx, opts)
+
+ # The value wasn't set, or used the param's default, prompt for one to the user
+ # if prompting is enabled.
+ if (
+ source in {None, _click.core.ParameterSource.DEFAULT}
+ and self.prompt is not None
+ and (self.required or self.prompt_required)
+ and not ctx.resilient_parsing
+ ):
+ value = self.prompt_for_value(ctx)
+ source = _click.core.ParameterSource.PROMPT
+
+ return value, source
+
def _get_default_string(
self,
*,
- ctx: click.Context,
+ ctx: _click.Context,
show_default_is_str: bool,
default_value: list[Any] | tuple[Any, ...] | str | Callable[..., Any] | Any,
) -> str:
@@ -482,14 +732,14 @@ def _get_default_string(
)
def _extract_default_help_str(
- self, *, ctx: click.Context
+ self, *, ctx: _click.Context
) -> Any | Callable[[], Any] | None:
return _extract_default_help_str(self, ctx=ctx)
- def make_metavar(self, ctx: click.Context) -> str:
+ def make_metavar(self, ctx: _click.Context) -> str:
return super().make_metavar(ctx=ctx)
- def get_help_record(self, ctx: click.Context) -> tuple[str, str] | None:
+ def get_help_record(self, ctx: _click.Context) -> tuple[str, str] | None:
# Duplicate all of Click's logic only to modify a single line, to allow boolean
# flags with only names for False values as it's currently supported by Typer
# Ref: https://typer.tiangolo.com/tutorial/parameter-types/bool/#only-names-for-false
@@ -501,7 +751,7 @@ def get_help_record(self, ctx: click.Context) -> tuple[str, str] | None:
def _write_opts(opts: Sequence[str]) -> str:
nonlocal any_prefix_is_slash
- rv, any_slashes = click.formatting.join_options(opts)
+ rv, any_slashes = _click.formatting.join_options(opts)
if any_slashes:
any_prefix_is_slash = True
@@ -559,7 +809,7 @@ def _write_opts(opts: Sequence[str]) -> str:
if default_string:
extra.append(_("default: {default}").format(default=default_string))
- if isinstance(self.type, click.types._NumberRangeBase):
+ if isinstance(self.type, types._NumberRangeBase):
range_str = self.type._describe_range()
if range_str:
@@ -588,14 +838,10 @@ def value_is_missing(self, value: Any) -> bool:
return _value_is_missing(self, value)
-def _value_is_missing(param: click.Parameter, value: Any) -> bool:
+def _value_is_missing(param: _click.Parameter, value: Any) -> bool:
if value is None:
return True
- # Click 8.3 and beyond
- # if value is UNSET:
- # return True
-
if (param.nargs != 1 or param.multiple) and value == ():
return True # pragma: no cover
@@ -603,7 +849,7 @@ def _value_is_missing(param: click.Parameter, value: Any) -> bool:
def _typer_format_options(
- self: click.core.Command, *, ctx: click.Context, formatter: click.HelpFormatter
+ self: _click.core.Command, *, ctx: _click.Context, formatter: _click.HelpFormatter
) -> None:
args = []
opts = []
@@ -624,7 +870,7 @@ def _typer_format_options(
def _typer_main_shell_completion(
- self: click.core.Command,
+ self: _click.core.Command,
*,
ctx_args: MutableMapping[str, Any],
prog_name: str,
@@ -644,14 +890,14 @@ def _typer_main_shell_completion(
sys.exit(rv)
-class TyperCommand(click.core.Command):
+class TyperCommand(_click.core.Command):
def __init__(
self,
name: str | None,
*,
context_settings: dict[str, Any] | None = None,
callback: Callable[..., Any] | None = None,
- params: list[click.Parameter] | None = None,
+ params: list[_click.Parameter] | None = None,
help: str | None = None,
epilog: str | None = None,
short_help: str | None = None,
@@ -682,7 +928,7 @@ def __init__(
self.rich_help_panel = rich_help_panel
def format_options(
- self, ctx: click.Context, formatter: click.HelpFormatter
+ self, ctx: _click.Context, formatter: _click.HelpFormatter
) -> None:
_typer_format_options(self, ctx=ctx, formatter=formatter)
@@ -716,7 +962,7 @@ def main(
**extra,
)
- def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
+ def format_help(self, ctx: _click.Context, formatter: _click.HelpFormatter) -> None:
if not HAS_RICH or self.rich_markup_mode is None:
if not hasattr(ctx, "obj") or ctx.obj is None:
ctx.ensure_object(dict)
@@ -732,25 +978,151 @@ def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> Non
)
-class TyperGroup(click.core.Group):
+class TyperGroup(_click.Command):
+ allow_extra_args = True
+ allow_interspersed_args = False
+ command_class: type[_click.Command] | None = None
+ group_class: type["TyperGroup"] | type[type] | None = None
+
def __init__(
self,
*,
name: str | None = None,
- commands: dict[str, click.Command] | Sequence[click.Command] | None = None,
+ commands: dict[str, _click.Command] | Sequence[_click.Command] | None = None,
# Rich settings
rich_markup_mode: MarkupMode = DEFAULT_MARKUP_MODE,
rich_help_panel: str | None = None,
suggest_commands: bool = True,
+ # Click settings
+ invoke_without_command: bool = False,
+ no_args_is_help: bool = False,
+ subcommand_metavar: str | None = None,
+ result_callback: Callable[..., Any] | None = None,
**attrs: Any,
) -> None:
- super().__init__(name=name, commands=commands, **attrs)
+ super().__init__(name=name, **attrs)
self.rich_markup_mode: MarkupMode = rich_markup_mode
self.rich_help_panel = rich_help_panel
self.suggest_commands = suggest_commands
+ # copied from Click's init
+ if commands is None:
+ commands = {}
+ elif isinstance(commands, Sequence):
+ commands = {
+ c.name: c
+ for c in commands
+ if isinstance(c, _click.Command) and c.name is not None
+ }
+
+ self.commands = cast(MutableMapping[str, _click.Command], commands)
+ self.no_args_is_help = no_args_is_help
+ self.invoke_without_command = invoke_without_command
+
+ if subcommand_metavar is None:
+ subcommand_metavar = "COMMAND [ARGS]..."
+
+ self.subcommand_metavar = subcommand_metavar
+ self._result_callback = result_callback
+
+ def add_command(self, cmd: _click.Command, name: str | None = None) -> None:
+ name = name or cmd.name
+ if name is None:
+ raise TypeError("Command has no name.")
+ self.commands[name] = cmd
+
+ def get_command(self, ctx: _click.Context, cmd_name: str) -> _click.Command | None:
+ return self.commands.get(cmd_name)
+
+ def collect_usage_pieces(self, ctx: _click.Context) -> list[str]:
+ rv = super().collect_usage_pieces(ctx)
+ rv.append(self.subcommand_metavar)
+ return rv
+
+ def format_commands(
+ self, ctx: _click.Context, formatter: _click.HelpFormatter
+ ) -> None:
+ commands = []
+ for subcommand in self.list_commands(ctx):
+ cmd = self.get_command(ctx, subcommand)
+
+ commands.append((subcommand, cmd))
+
+ # allow for 3 times the default spacing
+ if len(commands):
+ limit = formatter.width - 6 - max(len(cmd[0]) for cmd in commands)
+
+ rows = []
+ for subcommand, cmd in commands:
+ assert cmd is not None
+ help = cmd.get_short_help_str(limit)
+ rows.append((subcommand, help))
+
+ if rows:
+ with formatter.section(_("Commands")):
+ formatter.write_dl(rows)
+
+ def parse_args(self, ctx: _click.Context, args: list[str]) -> list[str]:
+ if not args and self.no_args_is_help and not ctx.resilient_parsing:
+ raise _click.exceptions.NoArgsIsHelpError(ctx)
+
+ rest = super().parse_args(ctx, args)
+
+ if rest:
+ ctx._protected_args, ctx.args = rest[:1], rest[1:]
+
+ return ctx.args
+
+ def invoke(self, ctx: _click.Context) -> Any:
+ def _process_result(value: Any) -> Any:
+ if self._result_callback is not None:
+ value = ctx.invoke(self._result_callback, value, **ctx.params)
+ return value
+
+ if not ctx._protected_args:
+ if self.invoke_without_command:
+ # No subcommand was invoked, so the result callback is
+ # invoked with the group return value for regular
+ # groups, or an empty list for chained groups.
+ with ctx:
+ rv = super().invoke(ctx)
+ # return _process_result([] if self.chain else rv)
+ return _process_result(rv)
+ ctx.fail(_("Missing command."))
+
+ # Fetch args back out
+ args = [*ctx._protected_args, *ctx.args]
+ ctx.args = []
+ ctx._protected_args = []
+
+ # Make sure the context is entered so we do not clean up
+ # resources until the result processor has worked.
+ with ctx:
+ cmd_name, cmd, args = self.resolve_command(ctx, args)
+ assert cmd is not None
+ ctx.invoked_subcommand = cmd_name
+ super().invoke(ctx)
+ sub_ctx = cmd.make_context(cmd_name, args, parent=ctx)
+ with sub_ctx:
+ return _process_result(sub_ctx.command.invoke(sub_ctx))
+
+ def shell_complete(
+ self, ctx: _click.Context, incomplete: str
+ ) -> list[CompletionItem]:
+ """Return a list of completions for the incomplete value. Looks
+ at the names of options, subcommands, and chained
+ multi-commands.
+ """
+
+ results = [
+ CompletionItem(name, help=command.get_short_help_str())
+ for name, command in _click.core._complete_visible_commands(ctx, incomplete)
+ ]
+ results.extend(super().shell_complete(ctx, incomplete))
+ return results
+
def format_options(
- self, ctx: click.Context, formatter: click.HelpFormatter
+ self, ctx: _click.Context, formatter: _click.HelpFormatter
) -> None:
_typer_format_options(self, ctx=ctx, formatter=formatter)
self.format_commands(ctx, formatter)
@@ -765,12 +1137,31 @@ def _main_shell_completion(
self, ctx_args=ctx_args, prog_name=prog_name, complete_var=complete_var
)
+ def _click_resolve_command(
+ self, ctx: _click.Context, args: list[str]
+ ) -> tuple[str | None, _click.Command | None, list[str]]:
+ cmd_name = args[0]
+ original_cmd_name = cmd_name
+
+ # Get the command
+ cmd = self.get_command(ctx, cmd_name)
+
+ if cmd is None and ctx.token_normalize_func is not None:
+ cmd_name = ctx.token_normalize_func(cmd_name)
+ cmd = self.get_command(ctx, cmd_name)
+
+ if cmd is None and not ctx.resilient_parsing:
+ if _split_opt(cmd_name)[0]:
+ self.parse_args(ctx, args)
+ ctx.fail(_("No such command {name!r}.").format(name=original_cmd_name))
+ return cmd_name if cmd else None, cmd, args[1:]
+
def resolve_command(
- self, ctx: click.Context, args: list[str]
- ) -> tuple[str | None, click.Command | None, list[str]]:
+ self, ctx: _click.Context, args: list[str]
+ ) -> tuple[str | None, _click.Command | None, list[str]]:
try:
- return super().resolve_command(ctx, args)
- except click.UsageError as e:
+ return self._click_resolve_command(ctx, args)
+ except _click.exceptions.UsageError as e:
if self.suggest_commands:
available_commands = list(self.commands.keys())
if available_commands and args:
@@ -802,7 +1193,7 @@ def main(
**extra,
)
- def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
+ def format_help(self, ctx: _click.Context, formatter: _click.HelpFormatter) -> None:
if not HAS_RICH or self.rich_markup_mode is None:
return super().format_help(ctx, formatter)
from . import rich_utils
@@ -813,8 +1204,6 @@ def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> Non
markup_mode=self.rich_markup_mode,
)
- def list_commands(self, ctx: click.Context) -> list[str]:
- """Returns a list of subcommand names.
- Note that in Click's Group class, these are sorted.
- In Typer, we wish to maintain the original order of creation (cf Issue #933)"""
+ def list_commands(self, ctx: _click.Context) -> list[str]:
+ """Returns a list of subcommand names, maintaining the original order of creation (cf Issue #933)"""
return [n for n, c in self.commands.items()]
diff --git a/typer/main.py b/typer/main.py
index ebcf639a2c..43c7a921f7 100644
--- a/typer/main.py
+++ b/typer/main.py
@@ -15,10 +15,12 @@
from typing import Annotated, Any
from uuid import UUID
-import click
from annotated_doc import Doc
from typer._types import TyperChoice
+from . import _click
+from ._click import types
+from ._click.globals import get_current_context
from ._typing import get_args, get_origin, is_literal_type, is_union, literal_values
from .completion import get_completion_inspect_parameters
from .core import (
@@ -73,7 +75,7 @@ def except_hook(
_original_except_hook(exc_type, exc_value, tb)
return
typer_path = os.path.dirname(__file__)
- click_path = os.path.dirname(click.__file__)
+ click_path = os.path.dirname(_click.__file__)
internal_dir_names = [typer_path, click_path]
exc = exc_value
if HAS_RICH:
@@ -107,7 +109,7 @@ def except_hook(
return
-def get_install_completion_arguments() -> tuple[click.Parameter, click.Parameter]:
+def get_install_completion_arguments() -> tuple[_click.Parameter, _click.Parameter]:
install_param, show_param = get_completion_inspect_parameters()
click_install_param, _ = get_click_param(install_param)
click_show_param, _ = get_click_param(show_param)
@@ -1168,7 +1170,7 @@ def get_group(typer_instance: Typer) -> TyperGroup:
return group
-def get_command(typer_instance: Typer) -> click.Command:
+def get_command(typer_instance: Typer) -> _click.Command:
if typer_instance._add_completion:
click_install_param, click_show_param = get_install_completion_arguments()
if (
@@ -1178,7 +1180,7 @@ def get_command(typer_instance: Typer) -> click.Command:
or len(typer_instance.registered_commands) > 1
):
# Create a Group
- click_command: click.Command = get_group(typer_instance)
+ click_command: _click.Command = get_group(typer_instance)
if typer_instance._add_completion:
click_command.params.append(click_install_param)
click_command.params.append(click_show_param)
@@ -1288,7 +1290,7 @@ def get_group_from_info(
assert group_info.typer_instance, (
"A Typer instance is needed to generate a Click Group"
)
- commands: dict[str, click.Command] = {}
+ commands: dict[str, _click.Command] = {}
for command_info in group_info.typer_instance.registered_commands:
command = get_command_from_info(
command_info=command_info,
@@ -1330,7 +1332,6 @@ def get_group_from_info(
invoke_without_command=solved_info.invoke_without_command,
no_args_is_help=solved_info.no_args_is_help,
subcommand_metavar=solved_info.subcommand_metavar,
- chain=solved_info.chain,
result_callback=solved_info.result_callback,
context_settings=solved_info.context_settings,
callback=get_callback(
@@ -1362,14 +1363,14 @@ def get_command_name(name: str) -> str:
def get_params_convertors_ctx_param_name_from_function(
callback: Callable[..., Any] | None,
-) -> tuple[list[click.Argument | click.Option], dict[str, Any], str | None]:
+) -> tuple[list[TyperArgument | TyperOption], dict[str, Any], str | None]:
params = []
convertors = {}
context_param_name = None
if callback:
parameters = get_params_from_function(callback)
for param_name, param in parameters.items():
- if lenient_issubclass(param.annotation, click.Context):
+ if lenient_issubclass(param.annotation, _click.Context):
context_param_name = param_name
continue
click_param, convertor = get_click_param(param)
@@ -1384,7 +1385,7 @@ def get_command_from_info(
*,
pretty_exceptions_short: bool,
rich_markup_mode: MarkupMode,
-) -> click.Command:
+) -> _click.Command:
assert command_info.callback, "A command must have a callback function"
name = command_info.name or get_command_name(command_info.callback.__name__) # ty: ignore
use_help = command_info.help
@@ -1486,7 +1487,7 @@ def internal_convertor(
def get_callback(
*,
callback: Callable[..., Any] | None = None,
- params: Sequence[click.Parameter] = [],
+ params: Sequence[_click.Parameter] = [],
convertors: dict[str, Callable[[str], Any]] | None = None,
context_param_name: str | None = None,
pretty_exceptions_short: bool,
@@ -1510,7 +1511,7 @@ def wrapper(**kwargs: Any) -> Any:
else:
use_params[k] = v
if context_param_name:
- use_params[context_param_name] = click.get_current_context()
+ use_params[context_param_name] = get_current_context()
return callback(**use_params)
update_wrapper(wrapper, callback)
@@ -1519,15 +1520,15 @@ def wrapper(**kwargs: Any) -> Any:
def get_click_type(
*, annotation: Any, parameter_info: ParameterInfo
-) -> click.ParamType:
+) -> types.ParamType:
if parameter_info.click_type is not None:
return parameter_info.click_type
elif parameter_info.parser is not None:
- return click.types.FuncParamType(parameter_info.parser)
+ return types.FuncParamType(parameter_info.parser)
elif annotation is str:
- return click.STRING
+ return types.STRING
elif annotation is int:
if parameter_info.min is not None or parameter_info.max is not None:
min_ = None
@@ -1536,24 +1537,24 @@ def get_click_type(
min_ = int(parameter_info.min)
if parameter_info.max is not None:
max_ = int(parameter_info.max)
- return click.IntRange(min=min_, max=max_, clamp=parameter_info.clamp)
+ return types.IntRange(min=min_, max=max_, clamp=parameter_info.clamp)
else:
- return click.INT
+ return types.INT
elif annotation is float:
if parameter_info.min is not None or parameter_info.max is not None:
- return click.FloatRange(
+ return types.FloatRange(
min=parameter_info.min,
max=parameter_info.max,
clamp=parameter_info.clamp,
)
else:
- return click.FLOAT
+ return types.FLOAT
elif annotation is bool:
- return click.BOOL
+ return types.BOOL
elif annotation == UUID:
- return click.UUID
+ return types.UUID
elif annotation == datetime:
- return click.DateTime(formats=parameter_info.formats)
+ return types.DateTime(formats=parameter_info.formats)
elif (
annotation == Path
or parameter_info.allow_dash
@@ -1571,7 +1572,7 @@ def get_click_type(
path_type=parameter_info.path_type,
)
elif lenient_issubclass(annotation, FileTextWrite):
- return click.File(
+ return types.File(
mode=parameter_info.mode or "w",
encoding=parameter_info.encoding,
errors=parameter_info.errors,
@@ -1579,7 +1580,7 @@ def get_click_type(
atomic=parameter_info.atomic,
)
elif lenient_issubclass(annotation, FileText):
- return click.File(
+ return types.File(
mode=parameter_info.mode or "r",
encoding=parameter_info.encoding,
errors=parameter_info.errors,
@@ -1587,7 +1588,7 @@ def get_click_type(
atomic=parameter_info.atomic,
)
elif lenient_issubclass(annotation, FileBinaryRead):
- return click.File(
+ return types.File(
mode=parameter_info.mode or "rb",
encoding=parameter_info.encoding,
errors=parameter_info.errors,
@@ -1595,7 +1596,7 @@ def get_click_type(
atomic=parameter_info.atomic,
)
elif lenient_issubclass(annotation, FileBinaryWrite):
- return click.File(
+ return types.File(
mode=parameter_info.mode or "wb",
encoding=parameter_info.encoding,
errors=parameter_info.errors,
@@ -1603,17 +1604,12 @@ def get_click_type(
atomic=parameter_info.atomic,
)
elif lenient_issubclass(annotation, Enum):
- # The custom TyperChoice is only needed for Click < 8.2.0, to parse the
- # command line values matching them to the enum values. Click 8.2.0 added
- # support for enum values but reading enum names.
- # Passing here the list of enum values (instead of just the enum) accounts for
- # Click < 8.2.0.
return TyperChoice(
[item.value for item in annotation],
case_sensitive=parameter_info.case_sensitive,
)
elif is_literal_type(annotation):
- return click.Choice(
+ return TyperChoice(
literal_values(annotation),
case_sensitive=parameter_info.case_sensitive,
)
@@ -1626,7 +1622,7 @@ def lenient_issubclass(cls: Any, class_or_tuple: AnyType | tuple[AnyType, ...])
def get_click_param(
param: ParamMeta,
-) -> tuple[click.Argument | click.Option, Any]:
+) -> tuple[TyperArgument | TyperOption, Any]:
# First, find out what will be:
# * ParamInfo (ArgumentInfo or OptionInfo)
# * default_value
@@ -1784,7 +1780,7 @@ def get_click_param(
),
convertor,
)
- raise AssertionError("A click.Parameter should be returned") # pragma: no cover
+ raise AssertionError("A _click.Parameter should be returned") # pragma: no cover
def get_param_callback(
@@ -1800,9 +1796,9 @@ def get_param_callback(
value_name = None
untyped_names: list[str] = []
for param_name, param_sig in parameters.items():
- if lenient_issubclass(param_sig.annotation, click.Context):
+ if lenient_issubclass(param_sig.annotation, _click.Context):
ctx_name = param_name
- elif lenient_issubclass(param_sig.annotation, click.Parameter):
+ elif lenient_issubclass(param_sig.annotation, _click.Parameter):
click_param_name = param_name
else:
untyped_names.append(param_name)
@@ -1817,11 +1813,11 @@ def get_param_callback(
if untyped_names:
click_param_name = untyped_names.pop(0)
if untyped_names:
- raise click.ClickException(
+ raise _click.ClickException(
"Too many CLI parameter callback function parameters"
)
- def wrapper(ctx: click.Context, param: click.Parameter, value: Any) -> Any:
+ def wrapper(ctx: _click.Context, param: _click.Parameter, value: Any) -> Any:
use_params: dict[str, Any] = {}
if ctx_name:
use_params[ctx_name] = ctx
@@ -1851,7 +1847,7 @@ def get_param_completion(
unassigned_params = list(parameters.values())
for param_sig in unassigned_params[:]:
origin = get_origin(param_sig.annotation)
- if lenient_issubclass(param_sig.annotation, click.Context):
+ if lenient_issubclass(param_sig.annotation, _click.Context):
ctx_name = param_sig.name
unassigned_params.remove(param_sig)
elif lenient_issubclass(origin, list):
@@ -1874,11 +1870,11 @@ def get_param_completion(
# Extract value param name first
if unassigned_params:
show_params = " ".join([param.name for param in unassigned_params])
- raise click.ClickException(
+ raise _click.ClickException(
f"Invalid autocompletion callback parameters: {show_params}"
)
- def wrapper(ctx: click.Context, args: list[str], incomplete: str | None) -> Any:
+ def wrapper(ctx: _click.Context, args: list[str], incomplete: str | None) -> Any:
use_params: dict[str, Any] = {}
if ctx_name:
use_params[ctx_name] = ctx
@@ -2010,4 +2006,4 @@ def launch(
return 0
else:
- return click.launch(url, wait=wait, locate=locate)
+ return _click.launch(url, wait=wait, locate=locate)
diff --git a/typer/models.py b/typer/models.py
index 3285a96a24..00385c38ce 100644
--- a/typer/models.py
+++ b/typer/models.py
@@ -1,15 +1,20 @@
import inspect
import io
+import os
+import stat
from collections.abc import Callable, Sequence
from typing import (
TYPE_CHECKING,
Any,
+ ClassVar,
Optional,
TypeVar,
+ cast,
)
-import click
-import click.shell_completion
+from . import _click
+from ._click import types
+from ._click.shell_completion import CompletionItem
if TYPE_CHECKING: # pragma: no cover
from .core import TyperCommand, TyperGroup
@@ -23,7 +28,7 @@
Required = ...
-class Context(click.Context):
+class Context(_click.Context):
"""
The [`Context`](https://click.palletsprojects.com/en/stable/api/#click.Context) has some additional data about the current execution of your program.
When declaring it in a [callback](https://typer.tiangolo.com/tutorial/options/callback-and-context/) function,
@@ -153,7 +158,7 @@ def main(file: Annotated[typer.FileBinaryWrite, typer.Option()]):
pass
-class CallbackParam(click.Parameter):
+class CallbackParam(_click.Parameter):
"""
In a callback function, you can declare a function parameter with type `CallbackParam`
to access the specific Click [`Parameter`](https://click.palletsprojects.com/en/stable/api/#click.Parameter) object.
@@ -286,15 +291,15 @@ def __init__(
# Note that shell_complete is not fully supported and will be removed in future versions
# TODO: Remove shell_complete in a future version (after 0.16.0)
shell_complete: Callable[
- [click.Context, click.Parameter, str],
- list["click.shell_completion.CompletionItem"] | list[str],
+ [_click.Context, _click.Parameter, str],
+ list["CompletionItem"] | list[str],
]
| None = None,
autocompletion: Callable[..., Any] | None = None,
default_factory: Callable[[], Any] | None = None,
# Custom type
parser: Callable[[str], Any] | None = None,
- click_type: click.ParamType | None = None,
+ click_type: types.ParamType | None = None,
# TyperArgument
show_default: bool | str = True,
show_choices: bool = True,
@@ -395,15 +400,15 @@ def __init__(
# Note that shell_complete is not fully supported and will be removed in future versions
# TODO: Remove shell_complete in a future version (after 0.16.0)
shell_complete: Callable[
- [click.Context, click.Parameter, str],
- list["click.shell_completion.CompletionItem"] | list[str],
+ [_click.Context, _click.Parameter, str],
+ list["CompletionItem"] | list[str],
]
| None = None,
autocompletion: Callable[..., Any] | None = None,
default_factory: Callable[[], Any] | None = None,
# Custom type
parser: Callable[[str], Any] | None = None,
- click_type: click.ParamType | None = None,
+ click_type: types.ParamType | None = None,
# Option
show_default: bool | str = True,
prompt: bool | str = False,
@@ -523,15 +528,15 @@ def __init__(
# Note that shell_complete is not fully supported and will be removed in future versions
# TODO: Remove shell_complete in a future version (after 0.16.0)
shell_complete: Callable[
- [click.Context, click.Parameter, str],
- list["click.shell_completion.CompletionItem"] | list[str],
+ [_click.Context, _click.Parameter, str],
+ list["CompletionItem"] | list[str],
]
| None = None,
autocompletion: Callable[..., Any] | None = None,
default_factory: Callable[[], Any] | None = None,
# Custom type
parser: Callable[[str], Any] | None = None,
- click_type: click.ParamType | None = None,
+ click_type: types.ParamType | None = None,
# TyperArgument
show_default: bool | str = True,
show_choices: bool = True,
@@ -640,11 +645,98 @@ def __init__(
self.pretty_exceptions_short = pretty_exceptions_short
-class TyperPath(click.Path):
- # Overwrite Click's behaviour to be compatible with Typer's autocompletion system
+class TyperPath(types.ParamType):
+ # Based originally on code from Click 8.3.1
+ # Partly rewritten and added an override for shell_complete
+
+ envvar_list_splitter: ClassVar[str] = os.path.pathsep
+
+ def __init__(
+ self,
+ exists: bool = False,
+ file_okay: bool = True,
+ dir_okay: bool = True,
+ writable: bool = False,
+ readable: bool = True,
+ resolve_path: bool = False,
+ allow_dash: bool = False,
+ path_type: type[Any] | None = None,
+ ):
+ self.exists = exists
+ self.file_okay = file_okay
+ self.dir_okay = dir_okay
+ self.readable = readable
+ self.writable = writable
+ self.resolve_path = resolve_path
+ self.allow_dash = allow_dash
+ self.type = path_type
+
+ if self.file_okay and not self.dir_okay:
+ self.name = "file"
+ elif self.dir_okay and not self.file_okay:
+ self.name = "directory"
+ else:
+ self.name = "path"
+
+ def coerce_path_result(
+ self, value: str | os.PathLike[str]
+ ) -> str | bytes | os.PathLike[str]:
+ if self.type is not None and not isinstance(value, self.type):
+ if (
+ self.type is str
+ ): # pragma: no cover # TODO: perhaps this branch can't be hit and should be removed
+ return os.fsdecode(value)
+ elif self.type is bytes:
+ return os.fsencode(value)
+ else:
+ return cast("os.PathLike[str]", self.type(value))
+
+ return value
+
+ def convert( # ty: ignore[invalid-method-override]
+ self,
+ value: str | os.PathLike[str],
+ param: _click.Parameter | None,
+ ctx: Context | None, # type: ignore[override]
+ ) -> str | bytes | os.PathLike[str]:
+ rv = value
+
+ is_dash = self.file_okay and self.allow_dash and rv in (b"-", "-")
+
+ if not is_dash:
+ if self.resolve_path:
+ rv = os.path.realpath(rv)
+
+ try:
+ st = os.stat(rv)
+ except OSError:
+ if not self.exists:
+ return self.coerce_path_result(rv)
+ self.fail(
+ f"{self.name.title()} {_click.utils.format_filename(value)!r} does not exist.",
+ param,
+ ctx,
+ )
+
+ name = self.name.title()
+ loc = repr(_click.utils.format_filename(value))
+ if not self.file_okay and stat.S_ISREG(st.st_mode):
+ self.fail(f"{name} {loc} is a file.", param, ctx)
+
+ if not self.dir_okay and stat.S_ISDIR(st.st_mode):
+ self.fail(f"{name} {loc} is a directory.", param, ctx)
+
+ if self.readable and not os.access(rv, os.R_OK):
+ self.fail(f"{name} {loc} is not readable.", param, ctx)
+
+ if self.writable and not os.access(rv, os.W_OK):
+ self.fail(f"{name} {loc} is not writable.", param, ctx)
+
+ return self.coerce_path_result(rv)
+
def shell_complete(
- self, ctx: click.Context, param: click.Parameter, incomplete: str
- ) -> list[click.shell_completion.CompletionItem]:
+ self, ctx: _click.Context, param: _click.Parameter, incomplete: str
+ ) -> list[CompletionItem]:
"""Return an empty list so that the autocompletion functionality
will work properly from the commandline.
"""
diff --git a/typer/params.py b/typer/params.py
index b325b273c4..833461fa78 100644
--- a/typer/params.py
+++ b/typer/params.py
@@ -1,13 +1,15 @@
from collections.abc import Callable
from typing import TYPE_CHECKING, Annotated, Any, overload
-import click
from annotated_doc import Doc
+from . import _click
+from ._click import types
+from ._click.shell_completion import CompletionItem
from .models import ArgumentInfo, OptionInfo
if TYPE_CHECKING: # pragma: no cover
- import click.shell_completion
+ pass
# Overload for Option created with custom type 'parser'
@@ -24,8 +26,8 @@ def Option(
# Note that shell_complete is not fully supported and will be removed in future versions
# TODO: Remove shell_complete in a future version (after 0.16.0)
shell_complete: Callable[
- [click.Context, click.Parameter, str],
- list["click.shell_completion.CompletionItem"] | list[str],
+ [_click.Context, _click.Parameter, str],
+ list["CompletionItem"] | list[str],
]
| None = None,
autocompletion: Callable[..., Any] | None = None,
@@ -89,14 +91,14 @@ def Option(
# Note that shell_complete is not fully supported and will be removed in future versions
# TODO: Remove shell_complete in a future version (after 0.16.0)
shell_complete: Callable[
- [click.Context, click.Parameter, str],
- list["click.shell_completion.CompletionItem"] | list[str],
+ [_click.Context, _click.Parameter, str],
+ list["CompletionItem"] | list[str],
]
| None = None,
autocompletion: Callable[..., Any] | None = None,
default_factory: Callable[[], Any] | None = None,
# Custom type
- click_type: click.ParamType | None = None,
+ click_type: types.ParamType | None = None,
# Option
show_default: bool | str = True,
prompt: bool | str = False,
@@ -265,8 +267,8 @@ def main(user: Annotated[str, typer.Option(envvar="ME")]):
# TODO: Remove shell_complete in a future version (after 0.16.0)
shell_complete: Annotated[
Callable[
- [click.Context, click.Parameter, str],
- list["click.shell_completion.CompletionItem"] | list[str],
+ [_click.Context, _click.Parameter, str],
+ list["CompletionItem"] | list[str],
]
| None,
Doc(
@@ -343,7 +345,7 @@ def main(opt: Annotated[CustomClass, typer.Option(parser=my_parser)] = "Foo"):
),
] = None,
click_type: Annotated[
- click.ParamType | None,
+ types.ParamType | None,
Doc(
"""
Define this parameter to use a [custom Click type](https://click.palletsprojects.com/en/stable/parameters/#implementing-custom-types) in your Typer applications.
@@ -1014,8 +1016,8 @@ def Argument(
# Note that shell_complete is not fully supported and will be removed in future versions
# TODO: Remove shell_complete in a future version (after 0.16.0)
shell_complete: Callable[
- [click.Context, click.Parameter, str],
- list["click.shell_completion.CompletionItem"] | list[str],
+ [_click.Context, _click.Parameter, str],
+ list["CompletionItem"] | list[str],
]
| None = None,
autocompletion: Callable[..., Any] | None = None,
@@ -1070,14 +1072,14 @@ def Argument(
# Note that shell_complete is not fully supported and will be removed in future versions
# TODO: Remove shell_complete in a future version (after 0.16.0)
shell_complete: Callable[
- [click.Context, click.Parameter, str],
- list["click.shell_completion.CompletionItem"] | list[str],
+ [_click.Context, _click.Parameter, str],
+ list["CompletionItem"] | list[str],
]
| None = None,
autocompletion: Callable[..., Any] | None = None,
default_factory: Callable[[], Any] | None = None,
# Custom type
- click_type: click.ParamType | None = None,
+ click_type: types.ParamType | None = None,
# TyperArgument
show_default: bool | str = True,
show_choices: bool = True,
@@ -1219,8 +1221,8 @@ def main(name: Annotated[str, typer.Argument(envvar="ME")]):
# TODO: Remove shell_complete in a future version (after 0.16.0)
shell_complete: Annotated[
Callable[
- [click.Context, click.Parameter, str],
- list["click.shell_completion.CompletionItem"] | list[str],
+ [_click.Context, _click.Parameter, str],
+ list["CompletionItem"] | list[str],
]
| None,
Doc(
@@ -1297,7 +1299,7 @@ def main(arg: Annotated[CustomClass, typer.Argument(parser=my_parser):
),
] = None,
click_type: Annotated[
- click.ParamType | None,
+ types.ParamType | None,
Doc(
"""
Define this parameter to use a [custom Click type](https://click.palletsprojects.com/en/stable/parameters/#implementing-custom-types) in your Typer applications.
diff --git a/typer/rich_utils.py b/typer/rich_utils.py
index 69be631207..e5b106126e 100644
--- a/typer/rich_utils.py
+++ b/typer/rich_utils.py
@@ -8,7 +8,6 @@
from os import getenv
from typing import Any, Literal
-import click
from rich import box
from rich.align import Align
from rich.columns import Columns
@@ -25,6 +24,10 @@
from rich.traceback import Traceback
from typer.models import DeveloperExceptionConfig
+from . import _click
+from ._click import types
+from .core import TyperArgument, TyperGroup, TyperOption
+
# Default styles
STYLE_OPTION = "bold cyan"
STYLE_SWITCH = "bold green"
@@ -184,7 +187,7 @@ def _make_rich_text(
@group()
def _get_help_text(
*,
- obj: click.Command | click.Group,
+ obj: _click.Command | TyperGroup,
markup_mode: MarkupModeStrict,
) -> Iterable[Markdown | Text]:
"""Build primary help text for a click command or group.
@@ -231,8 +234,8 @@ def _get_help_text(
def _get_parameter_help(
*,
- param: click.Option | click.Argument | click.Parameter,
- ctx: click.Context,
+ param: TyperOption | TyperArgument | _click.Parameter,
+ ctx: _click.Context,
markup_mode: MarkupModeStrict,
) -> Columns:
"""Build primary help text for a click option or argument.
@@ -348,8 +351,8 @@ def _make_command_help(
def _print_options_panel(
*,
name: str,
- params: list[click.Option] | list[click.Argument],
- ctx: click.Context,
+ params: list[TyperOption] | list[TyperArgument],
+ ctx: _click.Context,
markup_mode: MarkupModeStrict,
console: Console,
) -> None:
@@ -377,7 +380,7 @@ def _print_options_panel(
metavar_str = param.make_metavar(ctx=ctx)
# Do it ourselves if this is a positional argument
if (
- isinstance(param, click.Argument)
+ isinstance(param, TyperArgument)
and param.name
and metavar_str == param.name.upper()
):
@@ -391,8 +394,8 @@ def _print_options_panel(
# https://github.com/pallets/click/blob/c63c70dabd3f86ca68678b4f00951f78f52d0270/src/click/core.py#L2698-L2706 # noqa: E501
# skip count with default range type
if (
- isinstance(param.type, click.types._NumberRangeBase)
- and isinstance(param, click.Option)
+ isinstance(param.type, types._NumberRangeBase)
+ and isinstance(param, TyperOption)
and not (param.count and param.type.min == 0 and param.type.max is None)
):
range_str = param.type._describe_range()
@@ -459,7 +462,7 @@ def _print_options_panel(
def _print_commands_panel(
*,
name: str,
- commands: list[click.Command],
+ commands: list[_click.Command],
markup_mode: MarkupModeStrict,
console: Console,
cmd_len: int,
@@ -534,8 +537,8 @@ def _print_commands_panel(
def rich_format_help(
*,
- obj: click.Command | click.Group,
- ctx: click.Context,
+ obj: _click.Command | TyperGroup,
+ ctx: _click.Context,
markup_mode: MarkupModeStrict,
) -> None:
"""Print nicely formatted help text using rich.
@@ -568,18 +571,18 @@ def rich_format_help(
(0, 1, 1, 1),
)
)
- panel_to_arguments: defaultdict[str, list[click.Argument]] = defaultdict(list)
- panel_to_options: defaultdict[str, list[click.Option]] = defaultdict(list)
+ panel_to_arguments: defaultdict[str, list[TyperArgument]] = defaultdict(list)
+ panel_to_options: defaultdict[str, list[TyperOption]] = defaultdict(list)
for param in obj.get_params(ctx):
# Skip if option is hidden
if getattr(param, "hidden", False):
continue
- if isinstance(param, click.Argument):
+ if isinstance(param, TyperArgument):
panel_name = (
getattr(param, _RICH_HELP_PANEL_NAME, None) or ARGUMENTS_PANEL_TITLE
)
panel_to_arguments[panel_name].append(param)
- elif isinstance(param, click.Option):
+ elif isinstance(param, TyperOption):
panel_name = (
getattr(param, _RICH_HELP_PANEL_NAME, None) or OPTIONS_PANEL_TITLE
)
@@ -623,8 +626,8 @@ def rich_format_help(
console=console,
)
- if isinstance(obj, click.Group):
- panel_to_commands: defaultdict[str, list[click.Command]] = defaultdict(list)
+ if isinstance(obj, TyperGroup):
+ panel_to_commands: defaultdict[str, list[_click.Command]] = defaultdict(list)
for command_name in obj.list_commands(ctx):
command = obj.get_command(ctx, command_name)
if command and not command.hidden:
@@ -674,18 +677,18 @@ def rich_format_help(
console.print(Padding(Align(epilogue_text, pad=False), 1))
-def rich_format_error(self: click.ClickException) -> None:
+def rich_format_error(self: _click.ClickException) -> None:
"""Print richly formatted click errors.
Called by custom exception handler to print richly formatted click errors.
- Mimics original click.ClickException.echo() function but with rich formatting.
+ Mimics original _click.ClickException.echo() function but with rich formatting.
"""
# Don't do anything when it's a NoArgsIsHelpError (without importing it, cf. #1278)
if self.__class__.__name__ == "NoArgsIsHelpError":
return
console = _get_rich_console(stderr=True)
- ctx: click.Context | None = getattr(self, "ctx", None)
+ ctx: _click.Context | None = getattr(self, "ctx", None)
if ctx is not None:
console.print(ctx.get_usage())
diff --git a/typer/testing.py b/typer/testing.py
index 09711e66fd..6035867662 100644
--- a/typer/testing.py
+++ b/typer/testing.py
@@ -1,11 +1,12 @@
from collections.abc import Mapping, Sequence
from typing import IO, Any
-from click.testing import CliRunner as ClickCliRunner # noqa
-from click.testing import Result
from typer.main import Typer
from typer.main import get_command as _get_command
+from ._click.testing import CliRunner as ClickCliRunner # noqa
+from ._click.testing import Result
+
class CliRunner(ClickCliRunner):
def invoke( # type: ignore
diff --git a/uv.lock b/uv.lock
index cf893201e8..207311edeb 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1787,7 +1787,7 @@ name = "typer"
source = { editable = "." }
dependencies = [
{ name = "annotated-doc" },
- { name = "click" },
+ { name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "rich" },
{ name = "shellingham" },
]
@@ -1850,7 +1850,7 @@ tests = [
[package.metadata]
requires-dist = [
{ name = "annotated-doc", specifier = ">=0.0.2" },
- { name = "click", specifier = ">=8.2.1,<8.4" },
+ { name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "rich", specifier = ">=13.8.0" },
{ name = "shellingham", specifier = ">=1.3.0" },
]