diff --git a/.changes/next-release/enhancement-Performance-97002.json b/.changes/next-release/enhancement-Performance-97002.json new file mode 100644 index 000000000000..c1d5f0a59b63 --- /dev/null +++ b/.changes/next-release/enhancement-Performance-97002.json @@ -0,0 +1,5 @@ +{ + "type": "enhancement", + "category": "Performance", + "description": "Defer some imports (e.g. ``prompt_toolkit``) until they are needed to reduce command initialization time (e.g. loading all imported modules)." +} diff --git a/awscli/autoprompt/core.py b/awscli/autoprompt/core.py index 47d9b5d54a78..29e6dc465317 100644 --- a/awscli/autoprompt/core.py +++ b/awscli/autoprompt/core.py @@ -10,20 +10,13 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from botocore.exceptions import ProfileNotFound - from awscli.autocomplete.filters import fuzzy_filter from awscli.autocomplete.main import create_autocompleter from awscli.autoprompt.prompttoolkit import PromptToolkitPrompter -from awscli.customizations.exceptions import ParamValidationError from awscli.errorhandler import SilenceParamValidationMsgErrorHandler class AutoPromptDriver: - _NO_PROMPT_ARGS = ['help', '--version'] - _CLI_AUTO_PROMPT_OPTION = '--cli-auto-prompt' - _NO_CLI_AUTO_PROMPT_OPTION = '--no-cli-auto-prompt' - def __init__(self, driver, completion_source=None, prompter=None): self._completion_source = completion_source self._prompter = prompter @@ -42,34 +35,6 @@ def prompter(self): ) return self._prompter - def validate_auto_prompt_args_are_mutually_exclusive(self, args): - no_cli_auto_prompt = self._NO_CLI_AUTO_PROMPT_OPTION in args - cli_auto_prompt = self._CLI_AUTO_PROMPT_OPTION in args - if cli_auto_prompt and no_cli_auto_prompt: - raise ParamValidationError( - 'Both --cli-auto-prompt and --no-cli-auto-prompt cannot be ' - 'specified at the same time.' - ) - - def resolve_mode(self, args): - # Order of precedence to check: - # - check if any arg rom NO_PROMPT_ARGS in args - # - check if '--no-cli-auto-prompt' was specified - # - check if '--cli-auto-prompt' was specified - # - check configuration chain - self.validate_auto_prompt_args_are_mutually_exclusive(args) - if any(arg in args for arg in self._NO_PROMPT_ARGS): - return 'off' - if self._NO_CLI_AUTO_PROMPT_OPTION in args: - return 'off' - if self._CLI_AUTO_PROMPT_OPTION in args: - return 'on' - try: - config = self._session.get_config_variable('cli_auto_prompt') - return config.lower() - except ProfileNotFound: - return 'off' - def inject_silence_param_error_msg_handler(self, driver): driver.error_handler.inject_handler( 0, SilenceParamValidationMsgErrorHandler() diff --git a/awscli/autoprompt/exceptions.py b/awscli/autoprompt/exceptions.py new file mode 100644 index 000000000000..08d6c6cc7f13 --- /dev/null +++ b/awscli/autoprompt/exceptions.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +class PrompterKeyboardInterrupt(KeyboardInterrupt): + pass diff --git a/awscli/autoprompt/factory.py b/awscli/autoprompt/factory.py index 59cffbc6daea..3d037d7dbfc8 100644 --- a/awscli/autoprompt/factory.py +++ b/awscli/autoprompt/factory.py @@ -28,6 +28,7 @@ from prompt_toolkit.layout.processors import BeforeInput from prompt_toolkit.widgets import SearchToolbar, VerticalLine +from awscli.autoprompt.exceptions import PrompterKeyboardInterrupt from awscli.autoprompt.filters import ( doc_section_visible, doc_window_has_focus, @@ -46,10 +47,6 @@ ) -class PrompterKeyboardInterrupt(KeyboardInterrupt): - pass - - class CLIPromptBuffer(Buffer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/awscli/clidriver.py b/awscli/clidriver.py index 7a5e95ebaeab..643331c50d46 100644 --- a/awscli/clidriver.py +++ b/awscli/clidriver.py @@ -30,6 +30,7 @@ ScopedConfigProvider, ) from botocore.context import start_as_current_context +from botocore.exceptions import ProfileNotFound from botocore.history import get_global_history_recorder from awscli import __version__ @@ -49,7 +50,6 @@ ListArgument, UnknownArgumentError, ) -from awscli.autoprompt.core import AutoPromptDriver from awscli.commands import CLICommand from awscli.compat import ( default_pager, @@ -58,6 +58,7 @@ validate_preferred_output_encoding, ) from awscli.constants import PARAM_VALIDATION_ERROR_RC +from awscli.customizations.exceptions import ParamValidationError from awscli.errorhandler import ( construct_cli_error_handlers_chain, construct_entry_point_handlers_chain, @@ -90,6 +91,9 @@ ) HISTORY_RECORDER = get_global_history_recorder() METADATA_FILENAME = 'metadata.json' +_NO_AUTO_PROMPT_ARGS = ['help', '--version'] +_CLI_AUTO_PROMPT_OPTION = '--cli-auto-prompt' +_NO_CLI_AUTO_PROMPT_OPTION = '--no-cli-auto-prompt' # Don't remove this line. The idna encoding # is used by getaddrinfo when dealing with unicode hostnames, # and in some cases, there appears to be a race condition @@ -126,6 +130,36 @@ def create_clidriver(args=None): return driver +def validate_auto_prompt_args_are_mutually_exclusive(args): + no_cli_auto_prompt = _NO_CLI_AUTO_PROMPT_OPTION in args + cli_auto_prompt = _CLI_AUTO_PROMPT_OPTION in args + if cli_auto_prompt and no_cli_auto_prompt: + raise ParamValidationError( + 'Both --cli-auto-prompt and --no-cli-auto-prompt cannot be ' + 'specified at the same time.' + ) + + +def resolve_auto_prompt_mode(args, session): + # Order of precedence to check: + # - check if any arg from _NO_AUTO_PROMPT_ARGS in args + # - check if '--no-cli-auto-prompt' was specified + # - check if '--cli-auto-prompt' was specified + # - check configuration chain + validate_auto_prompt_args_are_mutually_exclusive(args) + if any(arg in args for arg in _NO_AUTO_PROMPT_ARGS): + return 'off' + if _NO_CLI_AUTO_PROMPT_OPTION in args: + return 'off' + if _CLI_AUTO_PROMPT_OPTION in args: + return 'on' + try: + config = session.get_config_variable('cli_auto_prompt') + return config.lower() + except ProfileNotFound: + return 'off' + + def _get_distribution_source(): metadata_file = os.path.join( os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data'), @@ -220,12 +254,17 @@ def _do_main(self, args): driver = self._driver if driver is None: driver = create_clidriver(args) - autoprompt_driver = AutoPromptDriver(driver) - auto_prompt_mode = autoprompt_driver.resolve_mode(args) + auto_prompt_mode = resolve_auto_prompt_mode(args, driver.session) if auto_prompt_mode == 'on': + from awscli.autoprompt.core import AutoPromptDriver + + autoprompt_driver = AutoPromptDriver(driver) args = autoprompt_driver.prompt_for_args(args) rc = self._run_driver(driver, args, prompt_mode='on') elif auto_prompt_mode == 'on-partial': + from awscli.autoprompt.core import AutoPromptDriver + + autoprompt_driver = AutoPromptDriver(driver) autoprompt_driver.inject_silence_param_error_msg_handler(driver) rc = self._run_driver(driver, args, prompt_mode='off') if rc == PARAM_VALIDATION_ERROR_RC: diff --git a/awscli/customizations/configure/configure.py b/awscli/customizations/configure/configure.py index 1d39832b07b0..b62ce34be34c 100644 --- a/awscli/customizations/configure/configure.py +++ b/awscli/customizations/configure/configure.py @@ -28,7 +28,7 @@ from awscli.customizations.configure.listprofiles import ListProfilesCommand from awscli.customizations.configure.mfalogin import ConfigureMFALoginCommand from awscli.customizations.configure.set import ConfigureSetCommand -from awscli.customizations.configure.sso import ( +from awscli.customizations.configure.sso_commands import ( ConfigureSSOCommand, ConfigureSSOSessionCommand, ) diff --git a/awscli/customizations/configure/sso.py b/awscli/customizations/configure/sso.py index c09ff0084c3a..673c5ce8ae55 100644 --- a/awscli/customizations/configure/sso.py +++ b/awscli/customizations/configure/sso.py @@ -14,14 +14,8 @@ import itertools import json import logging -import os import re -import colorama -from botocore import UNSIGNED -from botocore.config import Config -from botocore.configprovider import ConstantProvider -from botocore.exceptions import ProfileNotFound from botocore.utils import is_valid_endpoint_url from prompt_toolkit import prompt as ptk_prompt from prompt_toolkit.application import get_app @@ -30,42 +24,16 @@ from prompt_toolkit.styles import Style from prompt_toolkit.validation import ValidationError, Validator -from awscli.customizations.configure import ( - get_section_header, - profile_to_section, -) -from awscli.customizations.configure.writer import ConfigFileWriter from awscli.customizations.sso.utils import ( - LOGIN_ARGS, - BaseSSOCommand, - PrintOnlyHandler, - do_sso_login, parse_sso_registration_scopes, ) -from awscli.customizations.utils import uni_print -from awscli.customizations.wizard.ui.selectmenu import select_menu -from awscli.formatter import CLI_OUTPUT_FORMATS logger = logging.getLogger(__name__) -_CMD_PROMPT_USAGE = ( - 'To keep an existing value, hit enter when prompted for the value. When ' - 'you are prompted for information, the current value will be displayed in ' - '[brackets]. If the config item has no value, it is displayed as ' - '[None] or omitted entirely.\n\n' -) -_CONFIG_EXTRA_INFO = ( - 'Note: The configuration is saved in the shared configuration file. ' - 'By default, ``~/.aws/config``. For more information, see the ' - '"Configuring the AWS CLI to use AWS IAM Identity Center" section in the ' - 'AWS CLI User Guide:' - '\n\nhttps://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sso.html' -) - class ValidatorWithDefault(Validator): def __init__(self, default=None): - super(ValidatorWithDefault, self).__init__() + super().__init__() self._default = default def _raise_validation_error(self, document, message): @@ -167,34 +135,6 @@ def _get_prompt_style(self): ) -def display_account(account): - """Converts an SSO account response into a display string. - - All fields should be present in the API response but we've seen some - cases where the account name or email address is not set. Considering - we only need the account name and email address for display purposes - we should be defensive just in case they don't come back. - """ - if 'accountName' not in account and 'emailAddress' not in account: - account_template = '{accountId}' - elif 'emailAddress' not in account: - account_template = '{accountName} ({accountId})' - elif 'accountName' not in account: - account_template = '{emailAddress} ({accountId})' - else: - account_template = '{accountName}, {emailAddress} ({accountId})' - return account_template.format(**account) - - -def get_account_sorting_key(account): - only_account_id = ('accountName' not in account and 'emailAddress' not in account) - for key in ('accountName', 'emailAddress', 'accountId'): - value = account.get(key, None) - if value is not None: - return (only_account_id, value.lower()) - return (only_account_id, None) - - class SSOSessionConfigurationPrompter: _DEFAULT_SSO_SCOPE = 'sso:account:access' _KNOWN_SSO_SCOPES = { @@ -357,400 +297,3 @@ def _get_previously_used_scopes_to_sso_sessions(self): for parsed_scope in parsed_scopes: scopes_to_sessions[parsed_scope].append(sso_session) return scopes_to_sessions - - -class BaseSSOConfigurationCommand(BaseSSOCommand): - def __init__(self, session, prompter=None, config_writer=None): - super(BaseSSOConfigurationCommand, self).__init__(session) - if prompter is None: - prompter = PTKPrompt() - self._prompter = prompter - if config_writer is None: - config_writer = ConfigFileWriter() - self._config_writer = config_writer - self._sso_sessions = self._session.full_config.get('sso_sessions', {}) - self._sso_session_prompter = SSOSessionConfigurationPrompter( - botocore_session=session, - prompter=self._prompter, - ) - - def _write_sso_configuration(self): - self._update_section( - section_header=get_section_header( - 'sso-session', self._sso_session_prompter.sso_session - ), - new_values=self._sso_session_prompter.sso_session_config, - ) - - def _update_section(self, section_header, new_values): - config_path = self._session.get_config_variable('config_file') - config_path = os.path.expanduser(config_path) - new_values['__section__'] = section_header - self._config_writer.update_config(new_values, config_path) - - -class ConfigureSSOCommand(BaseSSOConfigurationCommand): - NAME = 'sso' - SYNOPSIS = 'aws configure sso [--profile profile-name]' - DESCRIPTION = ( - 'The ``aws configure sso`` command interactively prompts for the ' - 'configuration values required to create a profile that sources ' - 'temporary AWS credentials from AWS IAM Identity Center.\n\n' - f'{_CMD_PROMPT_USAGE}' - 'When providing the ``--profile`` parameter the named profile ' - 'will be created or updated. When a profile is not explicitly set ' - 'the profile name will be prompted for.\n\n' - f'{_CONFIG_EXTRA_INFO}' - ) - # TODO: Add CLI parameters to skip prompted values, --start-url, etc. - ARG_TABLE = LOGIN_ARGS - - def __init__( - self, - session, - prompter=None, - selector=None, - config_writer=None, - sso_token_cache=None, - sso_login=None, - ): - super(ConfigureSSOCommand, self).__init__( - session, prompter=prompter, config_writer=config_writer - ) - if selector is None: - selector = select_menu - self._selector = selector - if sso_login is None: - sso_login = do_sso_login - self._sso_login = sso_login - self._sso_token_cache = sso_token_cache - - self._new_profile_config_values = {} - self._original_profile_name = self._session.profile - try: - self._profile_config = self._session.get_scoped_config() - except ProfileNotFound: - self._profile_config = {} - self._set_sso_session_if_configured_in_profile() - - def _set_sso_session_if_configured_in_profile(self): - if 'sso_session' in self._profile_config: - self._sso_session_prompter.sso_session = self._profile_config[ - 'sso_session' - ] - - def _handle_single_account(self, accounts): - sso_account_id = accounts[0]['accountId'] - single_account_msg = 'The only AWS account available to you is: {}\n' - uni_print(single_account_msg.format(sso_account_id)) - return sso_account_id - - def _handle_multiple_accounts(self, accounts): - available_accounts_msg = ( - 'There are {} AWS accounts available to you.\n' - ) - uni_print(available_accounts_msg.format(len(accounts))) - sorted_accounts = sorted(accounts, key=get_account_sorting_key) - selected_account = self._selector( - sorted_accounts, display_format=display_account - ) - sso_account_id = selected_account['accountId'] - return sso_account_id - - def _get_all_accounts(self, sso, sso_token): - paginator = sso.get_paginator('list_accounts') - results = paginator.paginate(accessToken=sso_token['accessToken']) - return results.build_full_result() - - def _prompt_for_account(self, sso, sso_token): - accounts = self._get_all_accounts(sso, sso_token)['accountList'] - if not accounts: - raise RuntimeError('No AWS accounts are available to you.') - if len(accounts) == 1: - sso_account_id = self._handle_single_account(accounts) - else: - sso_account_id = self._handle_multiple_accounts(accounts) - uni_print(f'Using the account ID {sso_account_id}\n') - self._new_profile_config_values['sso_account_id'] = sso_account_id - return sso_account_id - - def _handle_single_role(self, roles): - sso_role_name = roles[0]['roleName'] - available_roles_msg = 'The only role available to you is: {}\n' - uni_print(available_roles_msg.format(sso_role_name)) - return sso_role_name - - def _handle_multiple_roles(self, roles): - available_roles_msg = 'There are {} roles available to you.\n' - uni_print(available_roles_msg.format(len(roles))) - sorted_roles = sorted(roles, key=lambda x: x['roleName'].lower()) - role_names = [r['roleName'] for r in sorted_roles] - sso_role_name = self._selector(role_names) - return sso_role_name - - def _get_all_roles(self, sso, sso_token, sso_account_id): - paginator = sso.get_paginator('list_account_roles') - results = paginator.paginate( - accountId=sso_account_id, accessToken=sso_token['accessToken'] - ) - return results.build_full_result() - - def _prompt_for_role(self, sso, sso_token, sso_account_id): - roles = self._get_all_roles(sso, sso_token, sso_account_id)['roleList'] - if not roles: - error_msg = 'No roles are available for the account {}' - raise RuntimeError(error_msg.format(sso_account_id)) - if len(roles) == 1: - sso_role_name = self._handle_single_role(roles) - else: - sso_role_name = self._handle_multiple_roles(roles) - uni_print(f'Using the role name "{sso_role_name}"\n') - self._new_profile_config_values['sso_role_name'] = sso_role_name - return sso_role_name - - def _prompt_for_profile(self, sso_account_id=None, sso_role_name=None): - if self._original_profile_name: - profile_name = self._original_profile_name - else: - text = 'Profile name' - default_profile = None - if sso_account_id and sso_role_name: - default_profile = f'{sso_role_name}-{sso_account_id}' - validator = RequiredInputValidator(default_profile) - profile_name = self._prompter.get_value( - default_profile, text, validator=validator - ) - return profile_name - - def _prompt_for_cli_default_region(self): - # TODO: figure out a way to get a list of reasonable client regions - return self._prompt_for_profile_config( - 'region', 'Default client Region' - ) - - def _prompt_for_cli_output_format(self): - return self._prompt_for_profile_config( - 'output', - 'CLI default output format (json if not specified)', - completions=list(CLI_OUTPUT_FORMATS.keys()), - ) - - def _prompt_for_profile_config(self, config_name, text, completions=None): - current_value = self._profile_config.get(config_name) - - new_value = self._prompter.get_value( - current_value, - text, - completions=completions, - ) - if new_value: - self._new_profile_config_values[config_name] = new_value - return new_value - - def _unset_session_profile(self): - # The profile provided to the CLI as --profile may not exist. - # This means we cannot use the session as is to create clients. - # By overriding the profile provider we ensure that a non-existant - # profile won't cause us to fail to create clients. - # No configuration from the profile is needed for the SSO APIs. - # It might be good to see if we can address this in a better way - # in botocore. - config_store = self._session.get_component('config_store') - config_store.set_config_provider('profile', ConstantProvider(None)) - - def _run_main(self, parsed_args, parsed_globals): - self._unset_session_profile() - on_pending_authorization = None - if parsed_args.no_browser: - on_pending_authorization = PrintOnlyHandler() - sso_registration_args = self._prompt_for_sso_registration_args() - sso_token = self._sso_login( - self._session, - parsed_globals=parsed_globals, - token_cache=self._sso_token_cache, - on_pending_authorization=on_pending_authorization, - use_device_code=parsed_args.use_device_code, - **sso_registration_args, - ) - - # Construct an SSO client to explore the accounts / roles - client_config = Config( - signature_version=UNSIGNED, - region_name=sso_registration_args['sso_region'], - ) - sso = self._session.create_client('sso', config=client_config) - - sso_account_id, sso_role_name = self._prompt_for_sso_account_and_role( - sso, sso_token - ) - configured_for_aws_credentials = all((sso_account_id, sso_role_name)) - - # General CLI configuration - self._prompt_for_cli_default_region() - self._prompt_for_cli_output_format() - - profile_name = self._prompt_for_profile(sso_account_id, sso_role_name) - - self._write_new_config(profile_name) - self._print_conclusion(configured_for_aws_credentials, profile_name) - return 0 - - def _prompt_for_sso_registration_args(self): - sso_session = self._sso_session_prompter.prompt_for_sso_session( - required=False - ) - if sso_session is None: - self._warn_configuring_using_legacy_format() - return self._prompt_for_registration_args_with_legacy_format() - else: - self._set_sso_session_in_profile_config(sso_session) - if sso_session in self._sso_sessions: - return self._get_sso_registration_args_from_sso_config( - sso_session - ) - else: - return self._prompt_for_registration_args_for_new_sso_session( - sso_session=sso_session - ) - - def _prompt_for_registration_args_with_legacy_format(self): - self._store_sso_session_prompter_answers_to_profile_config() - self._set_sso_session_defaults_from_profile_config() - start_url, sso_region = self._prompt_for_sso_start_url_and_sso_region() - return {'start_url': start_url, 'sso_region': sso_region} - - def _get_sso_registration_args_from_sso_config(self, sso_session): - sso_config = self._get_sso_session_config(sso_session) - return { - 'session_name': sso_session, - 'start_url': sso_config['sso_start_url'], - 'sso_region': sso_config['sso_region'], - 'registration_scopes': sso_config.get('registration_scopes'), - } - - def _prompt_for_registration_args_for_new_sso_session(self, sso_session): - self._set_sso_session_defaults_from_profile_config() - start_url, sso_region = self._prompt_for_sso_start_url_and_sso_region() - scopes = ( - self._sso_session_prompter.prompt_for_sso_registration_scopes() - ) - return { - 'session_name': sso_session, - 'start_url': start_url, - 'sso_region': sso_region, - 'registration_scopes': scopes, - # We force refresh for any new SSO sessions to ensure we are not - # using any cached tokens from any previous of attempts to - # create/authenticate a new SSO session as part of the configure - # sso flow. - 'force_refresh': True, - } - - def _store_sso_session_prompter_answers_to_profile_config(self): - # Wire the SSO session prompter to set config values to the - # dictionary used for writing to the profile section - self._sso_session_prompter.sso_session_config = ( - self._new_profile_config_values - ) - - def _set_sso_session_in_profile_config(self, sso_session): - self._new_profile_config_values['sso_session'] = sso_session - - def _set_sso_session_defaults_from_profile_config(self): - # This is to ensure the SSO session prompter pulls in existing - # SSO configuration as part of the prompt if a profile was explicitly - # provided that already had SSO configuration - if 'sso_start_url' in self._profile_config: - self._sso_session_prompter.sso_session_config['sso_start_url'] = ( - self._profile_config['sso_start_url'] - ) - if 'sso_region' in self._profile_config: - self._sso_session_prompter.sso_session_config['sso_region'] = ( - self._profile_config['sso_region'] - ) - - def _prompt_for_sso_start_url_and_sso_region(self): - start_url = self._sso_session_prompter.prompt_for_sso_start_url() - sso_region = self._sso_session_prompter.prompt_for_sso_region() - return start_url, sso_region - - def _warn_configuring_using_legacy_format(self): - uni_print( - f'{colorama.Style.BRIGHT}WARNING: Configuring using legacy format ' - f'(e.g. without an SSO session).\n' - f'Consider re-running "configure sso" command and providing ' - f'a session name.\n{colorama.Style.RESET_ALL}' - ) - - def _prompt_for_sso_account_and_role(self, sso, sso_token): - sso_account_id = None - sso_role_name = None - try: - sso_account_id = self._prompt_for_account(sso, sso_token) - sso_role_name = self._prompt_for_role( - sso, sso_token, sso_account_id - ) - except sso.exceptions.UnauthorizedException: - uni_print( - 'Unable to list AWS accounts and/or roles. ' - 'Skipping configuring AWS credential provider for profile.\n' - ) - return sso_account_id, sso_role_name - - def _write_new_config(self, profile): - if self._new_profile_config_values: - profile_section = profile_to_section(profile) - self._update_section( - profile_section, self._new_profile_config_values - ) - if self._sso_session_prompter.sso_session: - self._write_sso_configuration() - - def _print_conclusion(self, configured_for_aws_credentials, profile_name): - if configured_for_aws_credentials: - if profile_name.lower() == 'default': - msg = ( - 'The AWS CLI is now configured to use the default profile.\n' - 'Run the following command to verify your configuration:\n\n' - 'aws sts get-caller-identity\n' - ) - else: - msg = ( - 'To use this profile, specify the profile name using ' - '--profile, as shown:\n\n' - 'aws sts get-caller-identity --profile {}\n' - ) - else: - msg = 'Successfully configured SSO for profile: {}\n' - uni_print(msg.format(profile_name)) - - -class ConfigureSSOSessionCommand(BaseSSOConfigurationCommand): - NAME = 'sso-session' - SYNOPSIS = 'aws configure sso-session' - DESCRIPTION = ( - 'The ``aws configure sso-session`` command interactively prompts for ' - 'the configuration values required to create a SSO session. ' - 'The SSO session can then be associated to a profile to retrieve ' - 'SSO access tokens and AWS credentials.\n\n' - f'{_CMD_PROMPT_USAGE}' - f'{_CONFIG_EXTRA_INFO}' - ) - - def _run_main(self, parsed_args, parsed_globals): - self._sso_session_prompter.prompt_for_sso_session() - self._sso_session_prompter.prompt_for_sso_start_url() - self._sso_session_prompter.prompt_for_sso_region() - self._sso_session_prompter.prompt_for_sso_registration_scopes() - self._write_sso_configuration() - self._print_configuration_success() - return 0 - - def _print_configuration_success(self): - sso_session = self._sso_session_prompter.sso_session - uni_print( - f'\nCompleted configuring SSO session: {sso_session}\n' - f'Run the following to login and refresh access token for ' - f'this session:\n\n' - f'aws sso login --sso-session {sso_session}\n' - ) diff --git a/awscli/customizations/configure/sso_commands.py b/awscli/customizations/configure/sso_commands.py new file mode 100644 index 000000000000..e9d80cc548d3 --- /dev/null +++ b/awscli/customizations/configure/sso_commands.py @@ -0,0 +1,496 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +""" +Defines SSO-specific configure commands, to be used with the `aws configure` +command. + +The main reason it lives in its own separate module instead of in +`awscli/customizations/configure/sso.py` is so that these commands can be +referenced without importing `awscli/customizations/configure/sso.py`, +which imports from `prompt_toolkit`. Importing from `prompt_toolkit` has +historically increased command execution times. + +This separation helps us limit our imports from `prompt_toolkit` to when it +is actually needed, improving execution time across most commands. +""" + +import logging +import os + +import colorama +from botocore import UNSIGNED +from botocore.config import Config +from botocore.configprovider import ConstantProvider +from botocore.exceptions import ProfileNotFound + +from awscli.customizations.configure import ( + get_section_header, + profile_to_section, +) +from awscli.customizations.configure.writer import ConfigFileWriter +from awscli.customizations.sso.utils import ( + LOGIN_ARGS, + BaseSSOCommand, + PrintOnlyHandler, + do_sso_login, +) +from awscli.customizations.utils import uni_print +from awscli.formatter import CLI_OUTPUT_FORMATS + +logger = logging.getLogger(__name__) + +_CMD_PROMPT_USAGE = ( + 'To keep an existing value, hit enter when prompted for the value. When ' + 'you are prompted for information, the current value will be displayed in ' + '[brackets]. If the config item has no value, it is displayed as ' + '[None] or omitted entirely.\n\n' +) +_CONFIG_EXTRA_INFO = ( + 'Note: The configuration is saved in the shared configuration file. ' + 'By default, ``~/.aws/config``. For more information, see the ' + '"Configuring the AWS CLI to use AWS IAM Identity Center" section in the ' + 'AWS CLI User Guide:' + '\n\nhttps://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sso.html' +) + + +def display_account(account): + """Converts an SSO account response into a display string. + + All fields should be present in the API response but we've seen some + cases where the account name or email address is not set. Considering + we only need the account name and email address for display purposes + we should be defensive just in case they don't come back. + """ + if 'accountName' not in account and 'emailAddress' not in account: + account_template = '{accountId}' + elif 'emailAddress' not in account: + account_template = '{accountName} ({accountId})' + elif 'accountName' not in account: + account_template = '{emailAddress} ({accountId})' + else: + account_template = '{accountName}, {emailAddress} ({accountId})' + return account_template.format(**account) + + +def get_account_sorting_key(account): + only_account_id = ( + 'accountName' not in account and 'emailAddress' not in account + ) + for key in ('accountName', 'emailAddress', 'accountId'): + value = account.get(key, None) + if value is not None: + return (only_account_id, value.lower()) + return (only_account_id, None) + + +class BaseSSOConfigurationCommand(BaseSSOCommand): + def __init__(self, session, prompter=None, config_writer=None): + super().__init__(session) + self._prompter = prompter + if config_writer is None: + config_writer = ConfigFileWriter() + self._config_writer = config_writer + self._sso_sessions = self._session.full_config.get('sso_sessions', {}) + # Initialize self._sso_session_prompter to None. It will be + # initialized lazily during command execution. + self._sso_session_prompter = None + + def _init_prompt_toolkit(self): + from awscli.customizations.configure.sso import ( + PTKPrompt, + SSOSessionConfigurationPrompter, + ) + + if self._prompter is None: + self._prompter = PTKPrompt() + + self._sso_session_prompter = SSOSessionConfigurationPrompter( + botocore_session=self._session, + prompter=self._prompter, + ) + + def _write_sso_configuration(self): + self._update_section( + section_header=get_section_header( + 'sso-session', self._sso_session_prompter.sso_session + ), + new_values=self._sso_session_prompter.sso_session_config, + ) + + def _update_section(self, section_header, new_values): + config_path = self._session.get_config_variable('config_file') + config_path = os.path.expanduser(config_path) + new_values['__section__'] = section_header + self._config_writer.update_config(new_values, config_path) + + def _run_main(self, parsed_args, parsed_globals): + self._init_prompt_toolkit() + + +class ConfigureSSOCommand(BaseSSOConfigurationCommand): + NAME = 'sso' + SYNOPSIS = 'aws configure sso [--profile profile-name]' + DESCRIPTION = ( + 'The ``aws configure sso`` command interactively prompts for the ' + 'configuration values required to create a profile that sources ' + 'temporary AWS credentials from AWS IAM Identity Center.\n\n' + f'{_CMD_PROMPT_USAGE}' + 'When providing the ``--profile`` parameter the named profile ' + 'will be created or updated. When a profile is not explicitly set ' + 'the profile name will be prompted for.\n\n' + f'{_CONFIG_EXTRA_INFO}' + ) + # TODO: Add CLI parameters to skip prompted values, --start-url, etc. + ARG_TABLE = LOGIN_ARGS + + def __init__( + self, + session, + prompter=None, + selector=None, + config_writer=None, + sso_token_cache=None, + sso_login=None, + ): + super().__init__( + session, prompter=prompter, config_writer=config_writer + ) + self._selector = selector + if sso_login is None: + sso_login = do_sso_login + self._sso_login = sso_login + self._sso_token_cache = sso_token_cache + + self._new_profile_config_values = {} + self._original_profile_name = self._session.profile + try: + self._profile_config = self._session.get_scoped_config() + except ProfileNotFound: + self._profile_config = {} + + def _init_prompt_toolkit(self): + super()._init_prompt_toolkit() + if self._selector is None: + from awscli.customizations.wizard.ui.selectmenu import select_menu + + self._selector = select_menu + self._set_sso_session_if_configured_in_profile() + + def _set_sso_session_if_configured_in_profile(self): + if 'sso_session' in self._profile_config: + self._sso_session_prompter.sso_session = self._profile_config[ + 'sso_session' + ] + + def _handle_single_account(self, accounts): + sso_account_id = accounts[0]['accountId'] + single_account_msg = 'The only AWS account available to you is: {}\n' + uni_print(single_account_msg.format(sso_account_id)) + return sso_account_id + + def _handle_multiple_accounts(self, accounts): + available_accounts_msg = ( + 'There are {} AWS accounts available to you.\n' + ) + uni_print(available_accounts_msg.format(len(accounts))) + sorted_accounts = sorted(accounts, key=get_account_sorting_key) + selected_account = self._selector( + sorted_accounts, display_format=display_account + ) + sso_account_id = selected_account['accountId'] + return sso_account_id + + def _get_all_accounts(self, sso, sso_token): + paginator = sso.get_paginator('list_accounts') + results = paginator.paginate(accessToken=sso_token['accessToken']) + return results.build_full_result() + + def _prompt_for_account(self, sso, sso_token): + accounts = self._get_all_accounts(sso, sso_token)['accountList'] + if not accounts: + raise RuntimeError('No AWS accounts are available to you.') + if len(accounts) == 1: + sso_account_id = self._handle_single_account(accounts) + else: + sso_account_id = self._handle_multiple_accounts(accounts) + uni_print(f'Using the account ID {sso_account_id}\n') + self._new_profile_config_values['sso_account_id'] = sso_account_id + return sso_account_id + + def _handle_single_role(self, roles): + sso_role_name = roles[0]['roleName'] + available_roles_msg = 'The only role available to you is: {}\n' + uni_print(available_roles_msg.format(sso_role_name)) + return sso_role_name + + def _handle_multiple_roles(self, roles): + available_roles_msg = 'There are {} roles available to you.\n' + uni_print(available_roles_msg.format(len(roles))) + sorted_roles = sorted(roles, key=lambda x: x['roleName'].lower()) + role_names = [r['roleName'] for r in sorted_roles] + sso_role_name = self._selector(role_names) + return sso_role_name + + def _get_all_roles(self, sso, sso_token, sso_account_id): + paginator = sso.get_paginator('list_account_roles') + results = paginator.paginate( + accountId=sso_account_id, accessToken=sso_token['accessToken'] + ) + return results.build_full_result() + + def _prompt_for_role(self, sso, sso_token, sso_account_id): + roles = self._get_all_roles(sso, sso_token, sso_account_id)['roleList'] + if not roles: + error_msg = 'No roles are available for the account {}' + raise RuntimeError(error_msg.format(sso_account_id)) + if len(roles) == 1: + sso_role_name = self._handle_single_role(roles) + else: + sso_role_name = self._handle_multiple_roles(roles) + uni_print(f'Using the role name "{sso_role_name}"\n') + self._new_profile_config_values['sso_role_name'] = sso_role_name + return sso_role_name + + def _prompt_for_profile(self, sso_account_id=None, sso_role_name=None): + from awscli.customizations.configure.sso import RequiredInputValidator + + if self._original_profile_name: + profile_name = self._original_profile_name + else: + text = 'Profile name' + default_profile = None + if sso_account_id and sso_role_name: + default_profile = f'{sso_role_name}-{sso_account_id}' + validator = RequiredInputValidator(default_profile) + profile_name = self._prompter.get_value( + default_profile, text, validator=validator + ) + return profile_name + + def _prompt_for_cli_default_region(self): + # TODO: figure out a way to get a list of reasonable client regions + return self._prompt_for_profile_config( + 'region', 'Default client Region' + ) + + def _prompt_for_cli_output_format(self): + return self._prompt_for_profile_config( + 'output', + 'CLI default output format (json if not specified)', + completions=list(CLI_OUTPUT_FORMATS.keys()), + ) + + def _prompt_for_profile_config(self, config_name, text, completions=None): + current_value = self._profile_config.get(config_name) + + new_value = self._prompter.get_value( + current_value, + text, + completions=completions, + ) + if new_value: + self._new_profile_config_values[config_name] = new_value + return new_value + + def _unset_session_profile(self): + config_store = self._session.get_component('config_store') + config_store.set_config_provider('profile', ConstantProvider(None)) + + def _run_main(self, parsed_args, parsed_globals): + super()._run_main(parsed_args, parsed_globals) + self._unset_session_profile() + on_pending_authorization = None + if parsed_args.no_browser: + on_pending_authorization = PrintOnlyHandler() + sso_registration_args = self._prompt_for_sso_registration_args() + sso_token = self._sso_login( + self._session, + parsed_globals=parsed_globals, + token_cache=self._sso_token_cache, + on_pending_authorization=on_pending_authorization, + use_device_code=parsed_args.use_device_code, + **sso_registration_args, + ) + + client_config = Config( + signature_version=UNSIGNED, + region_name=sso_registration_args['sso_region'], + ) + sso = self._session.create_client('sso', config=client_config) + + sso_account_id, sso_role_name = self._prompt_for_sso_account_and_role( + sso, sso_token + ) + configured_for_aws_credentials = all((sso_account_id, sso_role_name)) + + self._prompt_for_cli_default_region() + self._prompt_for_cli_output_format() + + profile_name = self._prompt_for_profile(sso_account_id, sso_role_name) + + self._write_new_config(profile_name) + self._print_conclusion(configured_for_aws_credentials, profile_name) + return 0 + + def _prompt_for_sso_registration_args(self): + sso_session = self._sso_session_prompter.prompt_for_sso_session( + required=False + ) + if sso_session is None: + self._warn_configuring_using_legacy_format() + return self._prompt_for_registration_args_with_legacy_format() + else: + self._set_sso_session_in_profile_config(sso_session) + if sso_session in self._sso_sessions: + return self._get_sso_registration_args_from_sso_config( + sso_session + ) + else: + return self._prompt_for_registration_args_for_new_sso_session( + sso_session=sso_session + ) + + def _prompt_for_registration_args_with_legacy_format(self): + self._store_sso_session_prompter_answers_to_profile_config() + self._set_sso_session_defaults_from_profile_config() + start_url, sso_region = self._prompt_for_sso_start_url_and_sso_region() + return {'start_url': start_url, 'sso_region': sso_region} + + def _get_sso_registration_args_from_sso_config(self, sso_session): + sso_config = self._get_sso_session_config(sso_session) + return { + 'session_name': sso_session, + 'start_url': sso_config['sso_start_url'], + 'sso_region': sso_config['sso_region'], + 'registration_scopes': sso_config.get('registration_scopes'), + } + + def _prompt_for_registration_args_for_new_sso_session(self, sso_session): + self._set_sso_session_defaults_from_profile_config() + start_url, sso_region = self._prompt_for_sso_start_url_and_sso_region() + scopes = ( + self._sso_session_prompter.prompt_for_sso_registration_scopes() + ) + return { + 'session_name': sso_session, + 'start_url': start_url, + 'sso_region': sso_region, + 'registration_scopes': scopes, + 'force_refresh': True, + } + + def _store_sso_session_prompter_answers_to_profile_config(self): + self._sso_session_prompter.sso_session_config = ( + self._new_profile_config_values + ) + + def _set_sso_session_in_profile_config(self, sso_session): + self._new_profile_config_values['sso_session'] = sso_session + + def _set_sso_session_defaults_from_profile_config(self): + if 'sso_start_url' in self._profile_config: + self._sso_session_prompter.sso_session_config['sso_start_url'] = ( + self._profile_config['sso_start_url'] + ) + if 'sso_region' in self._profile_config: + self._sso_session_prompter.sso_session_config['sso_region'] = ( + self._profile_config['sso_region'] + ) + + def _prompt_for_sso_start_url_and_sso_region(self): + start_url = self._sso_session_prompter.prompt_for_sso_start_url() + sso_region = self._sso_session_prompter.prompt_for_sso_region() + return start_url, sso_region + + def _warn_configuring_using_legacy_format(self): + uni_print( + f'{colorama.Style.BRIGHT}WARNING: Configuring using legacy format ' + f'(e.g. without an SSO session).\n' + f'Consider re-running "configure sso" command and providing ' + f'a session name.\n{colorama.Style.RESET_ALL}' + ) + + def _prompt_for_sso_account_and_role(self, sso, sso_token): + sso_account_id = None + sso_role_name = None + try: + sso_account_id = self._prompt_for_account(sso, sso_token) + sso_role_name = self._prompt_for_role( + sso, sso_token, sso_account_id + ) + except sso.exceptions.UnauthorizedException: + uni_print( + 'Unable to list AWS accounts and/or roles. ' + 'Skipping configuring AWS credential provider for profile.\n' + ) + return sso_account_id, sso_role_name + + def _write_new_config(self, profile): + if self._new_profile_config_values: + profile_section = profile_to_section(profile) + self._update_section( + profile_section, self._new_profile_config_values + ) + if self._sso_session_prompter.sso_session: + self._write_sso_configuration() + + def _print_conclusion(self, configured_for_aws_credentials, profile_name): + if configured_for_aws_credentials: + if profile_name.lower() == 'default': + msg = ( + 'The AWS CLI is now configured to use the default profile.\n' + 'Run the following command to verify your configuration:\n\n' + 'aws sts get-caller-identity\n' + ) + else: + msg = ( + 'To use this profile, specify the profile name using ' + '--profile, as shown:\n\n' + 'aws sts get-caller-identity --profile {}\n' + ) + else: + msg = 'Successfully configured SSO for profile: {}\n' + uni_print(msg.format(profile_name)) + + +class ConfigureSSOSessionCommand(BaseSSOConfigurationCommand): + NAME = 'sso-session' + SYNOPSIS = 'aws configure sso-session' + DESCRIPTION = ( + 'The ``aws configure sso-session`` command interactively prompts for ' + 'the configuration values required to create a SSO session. ' + 'The SSO session can then be associated to a profile to retrieve ' + 'SSO access tokens and AWS credentials.\n\n' + f'{_CMD_PROMPT_USAGE}' + f'{_CONFIG_EXTRA_INFO}' + ) + + def _run_main(self, parsed_args, parsed_globals): + super()._run_main(parsed_args, parsed_globals) + self._sso_session_prompter.prompt_for_sso_session() + self._sso_session_prompter.prompt_for_sso_start_url() + self._sso_session_prompter.prompt_for_sso_region() + self._sso_session_prompter.prompt_for_sso_registration_scopes() + self._write_sso_configuration() + self._print_configuration_success() + return 0 + + def _print_configuration_success(self): + sso_session = self._sso_session_prompter.sso_session + uni_print( + f'\nCompleted configuring SSO session: {sso_session}\n' + f'Run the following to login and refresh access token for ' + f'this session:\n\n' + f'aws sso login --sso-session {sso_session}\n' + ) diff --git a/awscli/customizations/ecs/monitorexpressgatewayservice.py b/awscli/customizations/ecs/monitorexpressgatewayservice.py index 66128f619b5e..c089807230d0 100644 --- a/awscli/customizations/ecs/monitorexpressgatewayservice.py +++ b/awscli/customizations/ecs/monitorexpressgatewayservice.py @@ -54,7 +54,6 @@ InteractiveDisplayStrategy, TextOnlyDisplayStrategy, ) -from awscli.customizations.ecs.prompt_toolkit_display import Display from awscli.customizations.ecs.serviceviewcollector import ServiceViewCollector from awscli.customizations.utils import uni_print @@ -281,6 +280,10 @@ def _create_display_strategy(self): if self.display_mode == 'TEXT-ONLY': return TextOnlyDisplayStrategy(use_color=self.use_color) elif self.display_mode == 'INTERACTIVE': + from awscli.customizations.ecs.prompt_toolkit_display import ( + Display, + ) + return InteractiveDisplayStrategy( display=Display(), use_color=self.use_color ) diff --git a/awscli/customizations/login/login.py b/awscli/customizations/login/login.py index 9a1695b16f8e..171b76c69871 100644 --- a/awscli/customizations/login/login.py +++ b/awscli/customizations/login/login.py @@ -15,10 +15,6 @@ from awscli.compat import compat_input from awscli.customizations.commands import BasicCommand -from awscli.customizations.configure.sso import ( - PTKPrompt, - RequiredInputValidator, -) from awscli.customizations.configure.writer import ConfigFileWriter from awscli.customizations.exceptions import ConfigurationError from awscli.customizations.login.utils import ( @@ -237,6 +233,11 @@ def _resolve_region(self, parsed_globals): return self._prompt_for_region() def _prompt_for_region(self): + from awscli.customizations.configure.sso import ( + PTKPrompt, + RequiredInputValidator, + ) + prompter = PTKPrompt() self._prompted_for_region = True uni_print( diff --git a/awscli/customizations/logs/__init__.py b/awscli/customizations/logs/__init__.py index 6cdb6292c1c5..d8a5d00cdbc1 100644 --- a/awscli/customizations/logs/__init__.py +++ b/awscli/customizations/logs/__init__.py @@ -10,10 +10,6 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from awscli.customizations.logs.startlivetail import StartLiveTailCommand -from awscli.customizations.logs.tail import TailCommand - - def register_logs_commands(event_emitter): event_emitter.register('building-command-table.logs', inject_tail_command) event_emitter.register( @@ -22,8 +18,12 @@ def register_logs_commands(event_emitter): def inject_tail_command(command_table, session, **kwargs): + from awscli.customizations.logs.tail import TailCommand + command_table['tail'] = TailCommand(session) def inject_start_live_tail_command(command_table, session, **kwargs): + from awscli.customizations.logs.startlivetail import StartLiveTailCommand + command_table['start-live-tail'] = StartLiveTailCommand(session) diff --git a/awscli/customizations/logs/startlivetail.py b/awscli/customizations/logs/startlivetail.py index 1875c7e96d0e..480b709358e3 100644 --- a/awscli/customizations/logs/startlivetail.py +++ b/awscli/customizations/logs/startlivetail.py @@ -10,33 +10,6 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -import asyncio -import contextlib -import json -import re -import signal -import sys -import time -from enum import Enum -from functools import partial -from threading import Thread - -import colorama -from prompt_toolkit.application import Application, get_app -from prompt_toolkit.buffer import Buffer -from prompt_toolkit.filters import Condition -from prompt_toolkit.formatted_text import ( - ANSI, - fragment_list_to_text, - to_formatted_text, -) -from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent -from prompt_toolkit.layout import Layout, Window, WindowAlign -from prompt_toolkit.layout.containers import HSplit, VSplit -from prompt_toolkit.layout.controls import BufferControl, FormattedTextControl -from prompt_toolkit.layout.dimension import Dimension -from prompt_toolkit.layout.processors import Processor, Transformation - from awscli.compat import get_stdout_text_writer from awscli.customizations.commands import BasicCommand from awscli.customizations.exceptions import ParamValidationError @@ -160,766 +133,10 @@ } -def signal_handler(printer, signum, frame): - printer.interrupt_session = True - - -@contextlib.contextmanager -def handle_signal(printer): - signal_list = [signal.SIGINT, signal.SIGTERM] - if sys.platform != "win32": - signal_list.append(signal.SIGPIPE) - actual_signals = [] - for user_signal in signal_list: - actual_signals.append( - signal.signal(user_signal, partial(signal_handler, printer)) - ) - try: - yield - finally: - for sig, user_signal in enumerate(signal_list): - signal.signal(user_signal, actual_signals[sig]) - - -MAX_KEYWORDS_ALLOWED = 5 -COLOR_LIST = [ - colorama.Fore.CYAN, - colorama.Fore.MAGENTA, - colorama.Fore.BLUE, - colorama.Fore.LIGHTYELLOW_EX, - colorama.Fore.LIGHTGREEN_EX, -] - - class ColorNotAllowedInInteractiveModeError(ParamValidationError): pass -class OutputFormat(Enum): - JSON = "JSON" - PLAIN_TEXT = "Plain text" - - -class InputState(Enum): - HIGHLIGHT = "highlight" - CLEAR = "clear" - DISABLED = "disabled" - - -class LiveTailSessionMetadata: - def __init__(self) -> None: - self._session_start_time = time.time() - self._is_sampled = False - - @property - def session_start_time(self): - return self._session_start_time - - @property - def is_sampled(self): - return self._is_sampled - - def update_metadata(self, session_metadata): - self._is_sampled = session_metadata["sampled"] - - -class Keyword: - def __init__(self, text: str, color=None): - self._text = text.lower().strip() - self._color = color or self._get_color() - self._occurrence_count = 0 - self._regex_pattern = re.compile(re.escape(self._text)) - self._removal_regex_pattern = re.compile( - re.escape(self._color + self._text + colorama.Style.RESET_ALL) - ) - - @property - def text(self): - return self._text - - @property - def color(self): - return self._color - - def _get_color(self): - if self._text == "error": - return colorama.Fore.RED - elif self._text == "info": - return colorama.Fore.GREEN - elif self._text == "warn": - return colorama.Fore.YELLOW - - def get_string_to_print(self): - text_to_print = self._text - if len(self._text) > 15: - text_to_print = self._text[:10] + "..." - return ( - self._color - + text_to_print - + colorama.Style.RESET_ALL - + ": " - + str(self._occurrence_count) - ) - - def _add_color_to_string(self, log_event: str, start, end): - string_to_replace = log_event[start:end] - return ( - log_event[:start] - + self._color - + string_to_replace - + colorama.Style.RESET_ALL - + log_event[end:] - ) - - def _remove_color_from_string(self, log_event: str, start, end): - string_with_color = log_event[start:end] - string_without_color = string_with_color.split(self._color)[-1].split( - colorama.Style.RESET_ALL - )[0] - return log_event[:start] + string_without_color + log_event[end:] - - def highlight(self, log_event: str): - matchings = list(self._regex_pattern.finditer(log_event.lower()))[::-1] - self._occurrence_count += len(matchings) - for match in matchings: - log_event = self._add_color_to_string( - log_event, match.start(), match.end() - ) - - return log_event - - def remove_highlighting(self, log_event: str): - matchings = list( - self._removal_regex_pattern.finditer(log_event.lower()) - )[::-1] - self._occurrence_count -= len(matchings) - for match in matchings: - log_event = self._remove_color_from_string( - log_event, match.start(), match.end() - ) - - return log_event - - -class LiveTailKeyBindings(KeyBindings): - def __init__( - self, - ui, - prompt_buffer: Buffer, - output_buffer: Buffer, - keywords_to_highlight: dict, - ) -> None: - super().__init__() - self._ui = ui - self._input_buffer = prompt_buffer - self._input_state = InputState.DISABLED - self._log_output_buffer = output_buffer - self._available_colors = COLOR_LIST.copy() - self._keywords_to_highlight = keywords_to_highlight - self._is_exit_set = False - self._attach_keybindings() - - @property - def input_state(self): - return self._input_state - - def _attach_keybindings(self): - @Condition - def is_input_disabled(): - return self._input_state == InputState.DISABLED - - @Condition - def is_input_highlight(): - return self._input_state == InputState.HIGHLIGHT - - @Condition - def is_input_clear(): - return self._input_state == InputState.CLEAR - - @Condition - def is_prompt_active(): - return ( - get_app().layout.current_control - == self._ui.prompt_buffer_control - ) - - @Condition - def is_cursor_at_bottom(): - return self._log_output_buffer.cursor_position == len( - self._log_output_buffer.text - ) - - @Condition - def is_keyword_addition_allowed(): - return len(self._keywords_to_highlight) < MAX_KEYWORDS_ALLOWED - - @Condition - def is_clear_keyword_allowed(): - return len(self._keywords_to_highlight) > 0 - - @Condition - def is_exit_set(): - return self._is_exit_set - - @self.add("", filter=is_input_disabled) - def _(event: KeyPressEvent): - pass - - @self.add("", filter=is_input_highlight) - def _(event: KeyPressEvent): - self._input_buffer.insert_text(event.data) - - @self.add("", filter=is_input_clear) - def _(event: KeyPressEvent): - if event.data.isdigit(): - keyword_idx = int(event.data) - 1 - keywords = list(self._keywords_to_highlight.keys()) - if 0 <= keyword_idx < len(keywords): - removed_keyword = self._keywords_to_highlight.pop( - keywords[keyword_idx] - ) - self._available_colors.insert(0, removed_keyword.color) - event.app.create_background_task( - self._ui.remove_term_from_buffer(removed_keyword) - ) - self._reset(event) - - @self.add("h", filter=is_input_disabled & is_keyword_addition_allowed) - def _(event: KeyPressEvent): - self._input_state = InputState.HIGHLIGHT - self._ui.update_bottom_toolbar(self._ui.HIGHLIGHT_INSTRUCTIONS) - self._ui.update_quit_button(self._ui.EXIT_HIGHLIGHT) - event.app.invalidate() - - @self.add("c", filter=is_input_disabled & is_clear_keyword_allowed) - def _(event: KeyPressEvent): - self._input_state = InputState.CLEAR - - clear_intructions = "" - keyword_counter = 1 - for keyword in self._keywords_to_highlight.keys(): - clear_intructions += ( - str(keyword_counter) + ": " + keyword + " " - ) - keyword_counter += 1 - clear_intructions += "ENTER: All" - - self._ui.update_bottom_toolbar(clear_intructions) - self._ui.update_quit_button(self._ui.EXIT_CLEAR) - event.app.invalidate() - - @self.add("t", filter=is_input_disabled) - def _(event: KeyPressEvent): - self._ui.toggle_formatting() - self._ui.update_bottom_toolbar(self._ui.get_instructions()) - - @self.add("q", filter=is_input_disabled & ~is_prompt_active) - def _(event: KeyPressEvent): - self._ui.handle_scrolling(False) - event.app.layout.focus(self._input_buffer) - - @self.add("enter", filter=is_input_highlight) - def _(event: KeyPressEvent): - if ( - self._input_buffer.text.lower() - not in self._keywords_to_highlight.keys() - and len(self._input_buffer.text) > 0 - ): - keyword_color = self._available_colors.pop(0) - keyword = Keyword(self._input_buffer.text, keyword_color) - self._keywords_to_highlight[keyword.text] = keyword - event.app.create_background_task( - self._ui.highlight_term_in_buffer(keyword) - ) - - self._reset(event) - - @self.add("enter", filter=is_input_clear) - def _(event: KeyPressEvent): - for keyword in self._keywords_to_highlight.values(): - event.app.create_background_task( - self._ui.remove_term_from_buffer(keyword) - ) - - self._keywords_to_highlight.clear() - self._available_colors = COLOR_LIST.copy() - self._reset(event) - - @self.add("backspace") - def _(event: KeyPressEvent): - self._input_buffer.text = self._input_buffer.text[:-1] - - @self.add("escape", filter=~is_input_disabled) - def _(event: KeyPressEvent): - self._reset(event) - - @self.add("c-c", filter=is_prompt_active & ~is_exit_set) - @self.add("escape", filter=is_input_disabled & ~is_exit_set) - def _(event: KeyPressEvent): - self._is_exit_set = True - event.app.exit() - - @self.add("up", filter=is_prompt_active) - def _(event: KeyPressEvent): - event.app.layout.focus(self._log_output_buffer) - self._ui.handle_scrolling(True) - - @self.add("c-u") - def _(event: KeyPressEvent): - if is_prompt_active(): - event.app.layout.focus(self._log_output_buffer) - self._ui.handle_scrolling(True) - else: - self._log_output_buffer.cursor_up(20) - - @self.add("down", filter=is_cursor_at_bottom) - def _(event: KeyPressEvent): - self._ui.handle_scrolling(False) - event.app.layout.focus(self._input_buffer) - - @self.add("c-d") - def _(event: KeyPressEvent): - if is_cursor_at_bottom(): - self._ui.handle_scrolling(False) - event.app.layout.focus(self._input_buffer) - else: - self._log_output_buffer.cursor_down(20) - - def _reset(self, event): - self._input_state = InputState.DISABLED - self._ui.update_bottom_toolbar(self._ui.get_instructions()) - self._ui.update_quit_button(self._ui.EXIT_SESSION) - self._input_buffer.reset() - self._ui.update_metadata() - event.app.invalidate() - - -class BaseLiveTailPrinter: - def run(self): - raise NotImplementedError() - - -class BaseLiveTailUI: - def exit(self): - raise NotImplementedError() - - def run(self): - raise NotImplementedError() - - -class LiveTailBuffer(Buffer): - def __init__(self): - super().__init__() - self._pause_buffer = Buffer() - - @property - def pause_buffer(self): - return self._pause_buffer - - def add_text(self, data: str) -> bool: - if not get_app().layout.has_focus(self): - self.text += data - self.cursor_position = len(self.text) - else: - self._pause_buffer.text += data - - -class InteractivePrinter(BaseLiveTailPrinter): - _PROTECTED_KEYWORDS = [Keyword("ERROR"), Keyword("INFO"), Keyword("WARN")] - - def __init__( - self, - ui, - output: LiveTailBuffer, - log_events: list, - session_metadata: LiveTailSessionMetadata, - keywords_to_highlight: dict, - ) -> None: - self._ui = ui - self._output = output - self._log_events = log_events - self._session_metadata = session_metadata - self._log_events_displayed = 0 - self._is_sampled = False - self._keywords_to_highlight = keywords_to_highlight - self._format = OutputFormat.JSON - colorama.init(autoreset=True, strip=False) # noqa - - @property - def log_events_displayed(self): - return self._log_events_displayed - - @property - def is_sampled(self): - return self._is_sampled - - @property - def output_format(self): - return self._format - - def toggle_formatting(self): - self._format = ( - OutputFormat.PLAIN_TEXT - if self._format == OutputFormat.JSON - else OutputFormat.JSON - ) - - def _color_log_event(self, log_event: str): - for keyword in ( - list(self._keywords_to_highlight.values()) - + self._PROTECTED_KEYWORDS - ): - log_event = keyword.highlight(log_event) - - return log_event - - def _format_log_event(self, log_event: str): - if self._format == OutputFormat.JSON: - try: - log_event = json.loads(log_event) - return json.dumps(log_event, indent=4) - except json.decoder.JSONDecodeError: - pass - - return log_event - - def _print_log_events(self): - self._log_events_displayed = len(self._log_events) - self._is_sampled = self._session_metadata.is_sampled - self._ui.update_metadata() - for log_event in self._log_events: - log_event = self._format_log_event(log_event) - self._output.add_text(self._color_log_event(log_event) + "\n") - - self._log_events.clear() - - async def run(self): - while True: - self._print_log_events() - await asyncio.sleep(1) - - -class BufferControlColorProcessor(Processor): - def apply_transformation(self, transformation_input): - fragments = to_formatted_text( - ANSI(fragment_list_to_text(transformation_input.fragments)) - ) - return Transformation(fragments) - - -class InteractiveUI(BaseLiveTailUI): - EXIT_SESSION = "Esc: Exit" - EXIT_HIGHLIGHT = "Esc: Exit Highlight" - EXIT_CLEAR = "Ecs: Exit Clear" - INSTRUCTIONS = "h: Highlight Terms (MAX 5) c: Clear Highlighted Terms t: Toggle Formatting ({}/{}) up/down: Scroll ctrl+u/ctrl+d: Fast Scroll" - HIGHLIGHT_INSTRUCTIONS = "Type Term and press ENTER" - _MAX_LINE_COUNT = 1000 - - def __init__( - self, - log_events, - session_metadata: LiveTailSessionMetadata, - app_output=None, - app_input=None, - ) -> None: - self._log_events = log_events - self._session_metadata = session_metadata - self._keywords_to_highlight = {} - self._output = LiveTailBuffer() - self._is_scroll_active = False - self._log_events_printer = InteractivePrinter( - self, - self._output, - self._log_events, - self._session_metadata, - self._keywords_to_highlight, - ) - self._create_ui(app_output, app_input) - - def _create_ui(self, app_output, app_input): - prompt_buffer = Buffer() - self._prompt_buffer_control = BufferControl(prompt_buffer) - prompt_buffer_window = Window(self._prompt_buffer_control) - prompt_text_window = Window( - FormattedTextControl(">"), dont_extend_width=True, width=1 - ) - self._key_bindings = LiveTailKeyBindings( - self, prompt_buffer, self._output, self._keywords_to_highlight - ) - - output_buffer_control = BufferControl( - self._output, input_processors=[BufferControlColorProcessor()] - ) - log_output_container = Window(output_buffer_control, wrap_lines=True) - - dashed_line_container = Window(height=1, char="-") - metadata_container = self._create_metadata() - bottom_toolbar_container = self._create_bottom_toolbar() - - containers = HSplit( - [ - log_output_container, - dashed_line_container, - VSplit( - [ - prompt_text_window, - prompt_buffer_window, - ], - height=1, - ), - metadata_container, - bottom_toolbar_container, - ] - ) - layout = Layout(containers, prompt_buffer_window) - - self._application = Application( - layout, - key_bindings=self._key_bindings, - refresh_interval=1, - output=app_output, - input=app_input, - ) - - @property - def prompt_buffer_control(self): - return self._prompt_buffer_control - - def _create_bottom_toolbar(self): - self._quit_button = FormattedTextControl(self.EXIT_SESSION) - self._bottom_toolbar = FormattedTextControl(self.get_instructions()) - - return HSplit( - [ - Window( - self._bottom_toolbar, - wrap_lines=True, - dont_extend_height=True, - height=Dimension(min=1), - char=" ", - style="class:bottom-toolbar.text", - ), - Window( - self._quit_button, - wrap_lines=True, - dont_extend_height=True, - height=Dimension(min=1), - style="class:bottom-toolbar.text", - align=WindowAlign.RIGHT, - ), - ] - ) - - def update_bottom_toolbar(self, new_text): - self._bottom_toolbar.text = new_text - - def update_quit_button(self, new_text): - self._quit_button.text = new_text - - def toggle_formatting(self): - self._log_events_printer.toggle_formatting() - - def get_instructions(self): - if self._log_events_printer.output_format == OutputFormat.JSON: - instructions = self.INSTRUCTIONS.format( - colorama.Fore.GREEN - + OutputFormat.JSON.value - + colorama.Style.RESET_ALL, - OutputFormat.PLAIN_TEXT.value, - ) - else: - instructions = self.INSTRUCTIONS.format( - OutputFormat.JSON.value, - colorama.Fore.GREEN - + OutputFormat.PLAIN_TEXT.value - + colorama.Style.RESET_ALL, - ) - - if self._is_scroll_active: - instructions += " q: Scroll to latest" - - return ANSI(instructions) - - def handle_scrolling(self, is_scroll_active): - self._is_scroll_active = is_scroll_active - - if not self._is_scroll_active: - self._output.text += self._output.pause_buffer.text - self._output.cursor_position = len(self._output.text) - self._output.pause_buffer.reset() - self._application.create_background_task( - self._trim_buffer(self._output) - ) - - if self._key_bindings.input_state == InputState.DISABLED: - self.update_bottom_toolbar(self.get_instructions()) - self._application.invalidate() - - def _create_metadata(self): - self._metadata = FormattedTextControl( - text="Highlighted Terms: {}, 0 events/sec, Sampled: No | 00:00:00" - ) - return Window( - self._metadata, - wrap_lines=True, - dont_extend_height=True, - align=WindowAlign.RIGHT, - ) - - def update_metadata(self): - current_time = time.time() - elapsed_time = int( - current_time - self._session_metadata.session_start_time - ) - hours = f"{elapsed_time // 3600:02d}" - minutes = f"{(elapsed_time // 60) % 60:02d}" - seconds = f"{elapsed_time % 60:02d}" - keyword_count_map = ", ".join( - [ - value.get_string_to_print() - for value in self._keywords_to_highlight.values() - ] - ) - events_per_second = self._log_events_printer.log_events_displayed - is_sampled = "Yes" if self._log_events_printer.is_sampled else "No" - - self._metadata.text = ANSI( - f"Highlighted Terms: {{{keyword_count_map}}}, {events_per_second} events/sec, Sampled: {is_sampled} | {hours}:{minutes}:{seconds}" - ) - - async def highlight_term_in_buffer(self, keyword: Keyword): - self._output.text = keyword.highlight(self._output.text) - self._output.pause_buffer.text = keyword.highlight( - self._output.pause_buffer.text - ) - - async def remove_term_from_buffer(self, keyword: Keyword): - self._output.text = keyword.remove_highlighting(self._output.text) - self._output.pause_buffer.text = keyword.remove_highlighting( - self._output.pause_buffer.text - ) - - async def _trim_buffer(self, buffer: Buffer): - lines_to_be_removed = max( - buffer.document.line_count - self._MAX_LINE_COUNT - 1, 0 - ) - return buffer.text.split("\n", lines_to_be_removed)[-1] - - async def _trim_buffers(self): - while True: - if self._is_scroll_active: - self._output.pause_buffer.text = await self._trim_buffer( - self._output.pause_buffer - ) - else: - self._output.text = await self._trim_buffer(self._output) - - await asyncio.sleep(2) - - async def _render_metadata(self): - while True: - self.update_metadata() - await asyncio.sleep(1) - - def exit(self): - self._application.exit() - - async def _run_ui(self): - self._application.create_background_task( - self._log_events_printer.run() - ) - self._application.create_background_task(self._render_metadata()) - self._application.create_background_task(self._trim_buffers()) - - await self._application.run_async() - - def run(self): - asyncio.get_event_loop().run_until_complete(self._run_ui()) - - -class PrintOnlyPrinter(BaseLiveTailPrinter): - def __init__(self, output, log_events) -> None: - self._output = output - self._log_events = log_events - self.interrupt_session = False - - def _print_log_events(self): - for log_event in self._log_events: - self._output.write(log_event + "\n") - self._output.flush() - - self._log_events.clear() - - def run(self): - try: - while True: - self._print_log_events() - - if self.interrupt_session: - break - - time.sleep(1) - except (BrokenPipeError, KeyboardInterrupt): - pass - - -class PrintOnlyUI(BaseLiveTailUI): - def __init__(self, output, log_events) -> None: - self._log_events = log_events - self._printer = PrintOnlyPrinter(output, self._log_events) - - def exit(self): - self._printer.interrupt_session = True - - def run(self): - with handle_signal(self._printer): - self._printer.run() - - -class LiveTailLogEventsCollector(Thread): - def __init__( - self, - output, - ui, - response_stream, - log_events: list, - session_metadata: LiveTailSessionMetadata, - ) -> None: - super().__init__() - self._output = output - self._ui = ui - self._response_stream = response_stream - self._log_events = log_events - self._session_metadata = session_metadata - self._exception = None - - def _collect_log_events(self): - try: - for event in self._response_stream: - if "sessionUpdate" not in event: - continue - - session_update = event["sessionUpdate"] - self._session_metadata.update_metadata( - session_update["sessionMetadata"] - ) - logEvents = session_update["sessionResults"] - for logEvent in logEvents: - self._log_events.append(logEvent["message"]) - except Exception as e: - self._exception = e - - self._ui.exit() - - def stop(self): - if self._exception is not None: - self._output.write(str(self._exception) + "\n") - self._output.flush() - - def run(self): - self._collect_log_events() - - class StartLiveTailCommand(BasicCommand): NAME = "start-live-tail" DESCRIPTION = DESCRIPTION @@ -932,7 +149,7 @@ class StartLiveTailCommand(BasicCommand): ] def __init__(self, session): - super(StartLiveTailCommand, self).__init__(session) + super().__init__(session) self._output = get_stdout_text_writer() def _get_client(self, parsed_globals): @@ -967,6 +184,13 @@ def _is_color_allowed(self, color): return is_a_tty() def _run_main(self, parsed_args, parsed_globals): + from awscli.customizations.logs.ui import ( + InteractiveUI, + LiveTailLogEventsCollector, + LiveTailSessionMetadata, + PrintOnlyUI, + ) + self._client = self._get_client(parsed_globals) start_live_tail_kwargs = self._get_start_live_tail_kwargs(parsed_args) diff --git a/awscli/customizations/logs/ui.py b/awscli/customizations/logs/ui.py new file mode 100644 index 000000000000..c4cd0558391f --- /dev/null +++ b/awscli/customizations/logs/ui.py @@ -0,0 +1,793 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import asyncio +import contextlib +import json +import re +import signal +import sys +import time +from enum import Enum +from functools import partial +from threading import Thread + +import colorama +from prompt_toolkit.application import Application, get_app +from prompt_toolkit.buffer import Buffer +from prompt_toolkit.filters import Condition +from prompt_toolkit.formatted_text import ( + ANSI, + fragment_list_to_text, + to_formatted_text, +) +from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent +from prompt_toolkit.layout import Layout, Window, WindowAlign +from prompt_toolkit.layout.containers import HSplit, VSplit +from prompt_toolkit.layout.controls import BufferControl, FormattedTextControl +from prompt_toolkit.layout.dimension import Dimension +from prompt_toolkit.layout.processors import Processor, Transformation + + +def signal_handler(printer, signum, frame): + printer.interrupt_session = True + + +@contextlib.contextmanager +def handle_signal(printer): + signal_list = [signal.SIGINT, signal.SIGTERM] + if sys.platform != "win32": + signal_list.append(signal.SIGPIPE) + actual_signals = [] + for user_signal in signal_list: + actual_signals.append( + signal.signal(user_signal, partial(signal_handler, printer)) + ) + try: + yield + finally: + for sig, user_signal in enumerate(signal_list): + signal.signal(user_signal, actual_signals[sig]) + + +MAX_KEYWORDS_ALLOWED = 5 +COLOR_LIST = [ + colorama.Fore.CYAN, + colorama.Fore.MAGENTA, + colorama.Fore.BLUE, + colorama.Fore.LIGHTYELLOW_EX, + colorama.Fore.LIGHTGREEN_EX, +] + + +class OutputFormat(Enum): + JSON = "JSON" + PLAIN_TEXT = "Plain text" + + +class InputState(Enum): + HIGHLIGHT = "highlight" + CLEAR = "clear" + DISABLED = "disabled" + + +class LiveTailSessionMetadata: + def __init__(self) -> None: + self._session_start_time = time.time() + self._is_sampled = False + + @property + def session_start_time(self): + return self._session_start_time + + @property + def is_sampled(self): + return self._is_sampled + + def update_metadata(self, session_metadata): + self._is_sampled = session_metadata["sampled"] + + +class Keyword: + def __init__(self, text: str, color=None): + self._text = text.lower().strip() + self._color = color or self._get_color() + self._occurrence_count = 0 + self._regex_pattern = re.compile(re.escape(self._text)) + self._removal_regex_pattern = re.compile( + re.escape(self._color + self._text + colorama.Style.RESET_ALL) + ) + + @property + def text(self): + return self._text + + @property + def color(self): + return self._color + + def _get_color(self): + if self._text == "error": + return colorama.Fore.RED + elif self._text == "info": + return colorama.Fore.GREEN + elif self._text == "warn": + return colorama.Fore.YELLOW + + def get_string_to_print(self): + text_to_print = self._text + if len(self._text) > 15: + text_to_print = self._text[:10] + "..." + return ( + self._color + + text_to_print + + colorama.Style.RESET_ALL + + ": " + + str(self._occurrence_count) + ) + + def _add_color_to_string(self, log_event: str, start, end): + string_to_replace = log_event[start:end] + return ( + log_event[:start] + + self._color + + string_to_replace + + colorama.Style.RESET_ALL + + log_event[end:] + ) + + def _remove_color_from_string(self, log_event: str, start, end): + string_with_color = log_event[start:end] + string_without_color = string_with_color.split(self._color)[-1].split( + colorama.Style.RESET_ALL + )[0] + return log_event[:start] + string_without_color + log_event[end:] + + def highlight(self, log_event: str): + matchings = list(self._regex_pattern.finditer(log_event.lower()))[::-1] + self._occurrence_count += len(matchings) + for match in matchings: + log_event = self._add_color_to_string( + log_event, match.start(), match.end() + ) + + return log_event + + def remove_highlighting(self, log_event: str): + matchings = list( + self._removal_regex_pattern.finditer(log_event.lower()) + )[::-1] + self._occurrence_count -= len(matchings) + for match in matchings: + log_event = self._remove_color_from_string( + log_event, match.start(), match.end() + ) + + return log_event + + +class LiveTailKeyBindings(KeyBindings): + def __init__( + self, + ui, + prompt_buffer: Buffer, + output_buffer: Buffer, + keywords_to_highlight: dict, + ) -> None: + super().__init__() + self._ui = ui + self._input_buffer = prompt_buffer + self._input_state = InputState.DISABLED + self._log_output_buffer = output_buffer + self._available_colors = COLOR_LIST.copy() + self._keywords_to_highlight = keywords_to_highlight + self._is_exit_set = False + self._attach_keybindings() + + @property + def input_state(self): + return self._input_state + + def _attach_keybindings(self): + @Condition + def is_input_disabled(): + return self._input_state == InputState.DISABLED + + @Condition + def is_input_highlight(): + return self._input_state == InputState.HIGHLIGHT + + @Condition + def is_input_clear(): + return self._input_state == InputState.CLEAR + + @Condition + def is_prompt_active(): + return ( + get_app().layout.current_control + == self._ui.prompt_buffer_control + ) + + @Condition + def is_cursor_at_bottom(): + return self._log_output_buffer.cursor_position == len( + self._log_output_buffer.text + ) + + @Condition + def is_keyword_addition_allowed(): + return len(self._keywords_to_highlight) < MAX_KEYWORDS_ALLOWED + + @Condition + def is_clear_keyword_allowed(): + return len(self._keywords_to_highlight) > 0 + + @Condition + def is_exit_set(): + return self._is_exit_set + + @self.add("", filter=is_input_disabled) + def _(event: KeyPressEvent): + pass + + @self.add("", filter=is_input_highlight) + def _(event: KeyPressEvent): + self._input_buffer.insert_text(event.data) + + @self.add("", filter=is_input_clear) + def _(event: KeyPressEvent): + if event.data.isdigit(): + keyword_idx = int(event.data) - 1 + keywords = list(self._keywords_to_highlight.keys()) + if 0 <= keyword_idx < len(keywords): + removed_keyword = self._keywords_to_highlight.pop( + keywords[keyword_idx] + ) + self._available_colors.insert(0, removed_keyword.color) + event.app.create_background_task( + self._ui.remove_term_from_buffer(removed_keyword) + ) + self._reset(event) + + @self.add("h", filter=is_input_disabled & is_keyword_addition_allowed) + def _(event: KeyPressEvent): + self._input_state = InputState.HIGHLIGHT + self._ui.update_bottom_toolbar(self._ui.HIGHLIGHT_INSTRUCTIONS) + self._ui.update_quit_button(self._ui.EXIT_HIGHLIGHT) + event.app.invalidate() + + @self.add("c", filter=is_input_disabled & is_clear_keyword_allowed) + def _(event: KeyPressEvent): + self._input_state = InputState.CLEAR + + clear_intructions = "" + keyword_counter = 1 + for keyword in self._keywords_to_highlight.keys(): + clear_intructions += ( + str(keyword_counter) + ": " + keyword + " " + ) + keyword_counter += 1 + clear_intructions += "ENTER: All" + + self._ui.update_bottom_toolbar(clear_intructions) + self._ui.update_quit_button(self._ui.EXIT_CLEAR) + event.app.invalidate() + + @self.add("t", filter=is_input_disabled) + def _(event: KeyPressEvent): + self._ui.toggle_formatting() + self._ui.update_bottom_toolbar(self._ui.get_instructions()) + + @self.add("q", filter=is_input_disabled & ~is_prompt_active) + def _(event: KeyPressEvent): + self._ui.handle_scrolling(False) + event.app.layout.focus(self._input_buffer) + + @self.add("enter", filter=is_input_highlight) + def _(event: KeyPressEvent): + if ( + self._input_buffer.text.lower() + not in self._keywords_to_highlight.keys() + and len(self._input_buffer.text) > 0 + ): + keyword_color = self._available_colors.pop(0) + keyword = Keyword(self._input_buffer.text, keyword_color) + self._keywords_to_highlight[keyword.text] = keyword + event.app.create_background_task( + self._ui.highlight_term_in_buffer(keyword) + ) + + self._reset(event) + + @self.add("enter", filter=is_input_clear) + def _(event: KeyPressEvent): + for keyword in self._keywords_to_highlight.values(): + event.app.create_background_task( + self._ui.remove_term_from_buffer(keyword) + ) + + self._keywords_to_highlight.clear() + self._available_colors = COLOR_LIST.copy() + self._reset(event) + + @self.add("backspace") + def _(event: KeyPressEvent): + self._input_buffer.text = self._input_buffer.text[:-1] + + @self.add("escape", filter=~is_input_disabled) + def _(event: KeyPressEvent): + self._reset(event) + + @self.add("c-c", filter=is_prompt_active & ~is_exit_set) + @self.add("escape", filter=is_input_disabled & ~is_exit_set) + def _(event: KeyPressEvent): + self._is_exit_set = True + event.app.exit() + + @self.add("up", filter=is_prompt_active) + def _(event: KeyPressEvent): + event.app.layout.focus(self._log_output_buffer) + self._ui.handle_scrolling(True) + + @self.add("c-u") + def _(event: KeyPressEvent): + if is_prompt_active(): + event.app.layout.focus(self._log_output_buffer) + self._ui.handle_scrolling(True) + else: + self._log_output_buffer.cursor_up(20) + + @self.add("down", filter=is_cursor_at_bottom) + def _(event: KeyPressEvent): + self._ui.handle_scrolling(False) + event.app.layout.focus(self._input_buffer) + + @self.add("c-d") + def _(event: KeyPressEvent): + if is_cursor_at_bottom(): + self._ui.handle_scrolling(False) + event.app.layout.focus(self._input_buffer) + else: + self._log_output_buffer.cursor_down(20) + + def _reset(self, event): + self._input_state = InputState.DISABLED + self._ui.update_bottom_toolbar(self._ui.get_instructions()) + self._ui.update_quit_button(self._ui.EXIT_SESSION) + self._input_buffer.reset() + self._ui.update_metadata() + event.app.invalidate() + + +class BaseLiveTailPrinter: + def run(self): + raise NotImplementedError() + + +class BaseLiveTailUI: + def exit(self): + raise NotImplementedError() + + def run(self): + raise NotImplementedError() + + +class LiveTailBuffer(Buffer): + def __init__(self): + super().__init__() + self._pause_buffer = Buffer() + + @property + def pause_buffer(self): + return self._pause_buffer + + def add_text(self, data: str) -> bool: + if not get_app().layout.has_focus(self): + self.text += data + self.cursor_position = len(self.text) + else: + self._pause_buffer.text += data + + +class InteractivePrinter(BaseLiveTailPrinter): + _PROTECTED_KEYWORDS = [Keyword("ERROR"), Keyword("INFO"), Keyword("WARN")] + + def __init__( + self, + ui, + output: LiveTailBuffer, + log_events: list, + session_metadata: LiveTailSessionMetadata, + keywords_to_highlight: dict, + ) -> None: + self._ui = ui + self._output = output + self._log_events = log_events + self._session_metadata = session_metadata + self._log_events_displayed = 0 + self._is_sampled = False + self._keywords_to_highlight = keywords_to_highlight + self._format = OutputFormat.JSON + colorama.init(autoreset=True, strip=False) # noqa + + @property + def log_events_displayed(self): + return self._log_events_displayed + + @property + def is_sampled(self): + return self._is_sampled + + @property + def output_format(self): + return self._format + + def toggle_formatting(self): + self._format = ( + OutputFormat.PLAIN_TEXT + if self._format == OutputFormat.JSON + else OutputFormat.JSON + ) + + def _color_log_event(self, log_event: str): + for keyword in ( + list(self._keywords_to_highlight.values()) + + self._PROTECTED_KEYWORDS + ): + log_event = keyword.highlight(log_event) + + return log_event + + def _format_log_event(self, log_event: str): + if self._format == OutputFormat.JSON: + try: + log_event = json.loads(log_event) + return json.dumps(log_event, indent=4) + except json.decoder.JSONDecodeError: + pass + + return log_event + + def _print_log_events(self): + self._log_events_displayed = len(self._log_events) + self._is_sampled = self._session_metadata.is_sampled + self._ui.update_metadata() + for log_event in self._log_events: + log_event = self._format_log_event(log_event) + self._output.add_text(self._color_log_event(log_event) + "\n") + + self._log_events.clear() + + async def run(self): + while True: + self._print_log_events() + await asyncio.sleep(1) + + +class BufferControlColorProcessor(Processor): + def apply_transformation(self, transformation_input): + fragments = to_formatted_text( + ANSI(fragment_list_to_text(transformation_input.fragments)) + ) + return Transformation(fragments) + + +class InteractiveUI(BaseLiveTailUI): + EXIT_SESSION = "Esc: Exit" + EXIT_HIGHLIGHT = "Esc: Exit Highlight" + EXIT_CLEAR = "Ecs: Exit Clear" + INSTRUCTIONS = "h: Highlight Terms (MAX 5) c: Clear Highlighted Terms t: Toggle Formatting ({}/{}) up/down: Scroll ctrl+u/ctrl+d: Fast Scroll" + HIGHLIGHT_INSTRUCTIONS = "Type Term and press ENTER" + _MAX_LINE_COUNT = 1000 + + def __init__( + self, + log_events, + session_metadata: LiveTailSessionMetadata, + app_output=None, + app_input=None, + ) -> None: + self._log_events = log_events + self._session_metadata = session_metadata + self._keywords_to_highlight = {} + self._output = LiveTailBuffer() + self._is_scroll_active = False + self._log_events_printer = InteractivePrinter( + self, + self._output, + self._log_events, + self._session_metadata, + self._keywords_to_highlight, + ) + self._create_ui(app_output, app_input) + + def _create_ui(self, app_output, app_input): + prompt_buffer = Buffer() + self._prompt_buffer_control = BufferControl(prompt_buffer) + prompt_buffer_window = Window(self._prompt_buffer_control) + prompt_text_window = Window( + FormattedTextControl(">"), dont_extend_width=True, width=1 + ) + self._key_bindings = LiveTailKeyBindings( + self, prompt_buffer, self._output, self._keywords_to_highlight + ) + + output_buffer_control = BufferControl( + self._output, input_processors=[BufferControlColorProcessor()] + ) + log_output_container = Window(output_buffer_control, wrap_lines=True) + + dashed_line_container = Window(height=1, char="-") + metadata_container = self._create_metadata() + bottom_toolbar_container = self._create_bottom_toolbar() + + containers = HSplit( + [ + log_output_container, + dashed_line_container, + VSplit( + [ + prompt_text_window, + prompt_buffer_window, + ], + height=1, + ), + metadata_container, + bottom_toolbar_container, + ] + ) + layout = Layout(containers, prompt_buffer_window) + self._application = Application( + layout, + key_bindings=self._key_bindings, + refresh_interval=1, + output=app_output, + input=app_input, + ) + + @property + def prompt_buffer_control(self): + return self._prompt_buffer_control + + def _create_bottom_toolbar(self): + self._quit_button = FormattedTextControl(self.EXIT_SESSION) + self._bottom_toolbar = FormattedTextControl(self.get_instructions()) + + return HSplit( + [ + Window( + self._bottom_toolbar, + wrap_lines=True, + dont_extend_height=True, + height=Dimension(min=1), + char=" ", + style="class:bottom-toolbar.text", + ), + Window( + self._quit_button, + wrap_lines=True, + dont_extend_height=True, + height=Dimension(min=1), + style="class:bottom-toolbar.text", + align=WindowAlign.RIGHT, + ), + ] + ) + + def update_bottom_toolbar(self, new_text): + self._bottom_toolbar.text = new_text + + def update_quit_button(self, new_text): + self._quit_button.text = new_text + + def toggle_formatting(self): + self._log_events_printer.toggle_formatting() + + def get_instructions(self): + if self._log_events_printer.output_format == OutputFormat.JSON: + instructions = self.INSTRUCTIONS.format( + colorama.Fore.GREEN + + OutputFormat.JSON.value + + colorama.Style.RESET_ALL, + OutputFormat.PLAIN_TEXT.value, + ) + else: + instructions = self.INSTRUCTIONS.format( + OutputFormat.JSON.value, + colorama.Fore.GREEN + + OutputFormat.PLAIN_TEXT.value + + colorama.Style.RESET_ALL, + ) + + if self._is_scroll_active: + instructions += " q: Scroll to latest" + + return ANSI(instructions) + + def handle_scrolling(self, is_scroll_active): + self._is_scroll_active = is_scroll_active + + if not self._is_scroll_active: + self._output.text += self._output.pause_buffer.text + self._output.cursor_position = len(self._output.text) + self._output.pause_buffer.reset() + self._application.create_background_task( + self._trim_buffer(self._output) + ) + + if self._key_bindings.input_state == InputState.DISABLED: + self.update_bottom_toolbar(self.get_instructions()) + self._application.invalidate() + + def _create_metadata(self): + self._metadata = FormattedTextControl( + text="Highlighted Terms: {}, 0 events/sec, Sampled: No | 00:00:00" + ) + return Window( + self._metadata, + wrap_lines=True, + dont_extend_height=True, + align=WindowAlign.RIGHT, + ) + + def update_metadata(self): + current_time = time.time() + elapsed_time = int( + current_time - self._session_metadata.session_start_time + ) + hours = f"{elapsed_time // 3600:02d}" + minutes = f"{(elapsed_time // 60) % 60:02d}" + seconds = f"{elapsed_time % 60:02d}" + keyword_count_map = ", ".join( + [ + value.get_string_to_print() + for value in self._keywords_to_highlight.values() + ] + ) + events_per_second = self._log_events_printer.log_events_displayed + is_sampled = "Yes" if self._log_events_printer.is_sampled else "No" + + self._metadata.text = ANSI( + f"Highlighted Terms: {{{keyword_count_map}}}, {events_per_second} events/sec, Sampled: {is_sampled} | {hours}:{minutes}:{seconds}" + ) + + async def highlight_term_in_buffer(self, keyword: Keyword): + self._output.text = keyword.highlight(self._output.text) + self._output.pause_buffer.text = keyword.highlight( + self._output.pause_buffer.text + ) + + async def remove_term_from_buffer(self, keyword: Keyword): + self._output.text = keyword.remove_highlighting(self._output.text) + self._output.pause_buffer.text = keyword.remove_highlighting( + self._output.pause_buffer.text + ) + + async def _trim_buffer(self, buffer: Buffer): + lines_to_be_removed = max( + buffer.document.line_count - self._MAX_LINE_COUNT - 1, 0 + ) + return buffer.text.split("\n", lines_to_be_removed)[-1] + + async def _trim_buffers(self): + while True: + if self._is_scroll_active: + self._output.pause_buffer.text = await self._trim_buffer( + self._output.pause_buffer + ) + else: + self._output.text = await self._trim_buffer(self._output) + + await asyncio.sleep(2) + + async def _render_metadata(self): + while True: + self.update_metadata() + await asyncio.sleep(1) + + def exit(self): + self._application.exit() + + async def _run_ui(self): + self._application.create_background_task( + self._log_events_printer.run() + ) + self._application.create_background_task(self._render_metadata()) + self._application.create_background_task(self._trim_buffers()) + + await self._application.run_async() + + def run(self): + asyncio.get_event_loop().run_until_complete(self._run_ui()) + + +class PrintOnlyPrinter(BaseLiveTailPrinter): + def __init__(self, output, log_events) -> None: + self._output = output + self._log_events = log_events + self.interrupt_session = False + + def _print_log_events(self): + for log_event in self._log_events: + self._output.write(log_event + "\n") + self._output.flush() + + self._log_events.clear() + + def run(self): + try: + while True: + self._print_log_events() + + if self.interrupt_session: + break + + time.sleep(1) + except (BrokenPipeError, KeyboardInterrupt): + pass + + +class PrintOnlyUI(BaseLiveTailUI): + def __init__(self, output, log_events) -> None: + self._log_events = log_events + self._printer = PrintOnlyPrinter(output, self._log_events) + + def exit(self): + self._printer.interrupt_session = True + + def run(self): + with handle_signal(self._printer): + self._printer.run() + + +class LiveTailLogEventsCollector(Thread): + def __init__( + self, + output, + ui, + response_stream, + log_events: list, + session_metadata: LiveTailSessionMetadata, + ) -> None: + super().__init__() + self._output = output + self._ui = ui + self._response_stream = response_stream + self._log_events = log_events + self._session_metadata = session_metadata + self._exception = None + + def _collect_log_events(self): + try: + for event in self._response_stream: + if "sessionUpdate" not in event: + continue + + session_update = event["sessionUpdate"] + self._session_metadata.update_metadata( + session_update["sessionMetadata"] + ) + logEvents = session_update["sessionResults"] + for logEvent in logEvents: + self._log_events.append(logEvent["message"]) + except Exception as e: + self._exception = e + + self._ui.exit() + + def stop(self): + if self._exception is not None: + self._output.write(str(self._exception) + "\n") + self._output.flush() + + def run(self): + self._collect_log_events() diff --git a/awscli/customizations/wizard/commands.py b/awscli/customizations/wizard/commands.py index 1ab72137aaa2..06071d0e90c9 100644 --- a/awscli/customizations/wizard/commands.py +++ b/awscli/customizations/wizard/commands.py @@ -12,11 +12,12 @@ # language governing permissions and limitations under the License. from awscli.customizations.commands import BasicCommand, BasicHelp from awscli.customizations.exceptions import ParamValidationError -from awscli.customizations.wizard import devcommands, factory from awscli.customizations.wizard.loader import WizardLoader def register_wizard_commands(event_handlers): + from awscli.customizations.wizard import devcommands + devcommands.register_dev_commands(event_handlers) loader = WizardLoader() commands = loader.list_commands_with_wizards() @@ -26,18 +27,15 @@ def register_wizard_commands(event_handlers): def _register_wizards_for_commands(commands, event_handlers): for command in commands: event_handlers.register( - 'building-command-table.%s' % command, _add_wizard_command + f'building-command-table.{command}', _add_wizard_command ) def _add_wizard_command(session, command_object, command_table, **kwargs): - v1_runner = factory.create_default_wizard_v1_runner(session) - v2_runner = factory.create_default_wizard_v2_runner(session) cmd = TopLevelWizardCommand( session=session, loader=WizardLoader(), parent_command=command_object.name, - runner={'0.1': v1_runner, '0.2': v2_runner}, ) command_table['wizard'] = cmd @@ -49,26 +47,42 @@ class TopLevelWizardCommand(BasicCommand): ) def __init__( - self, session, loader, parent_command, runner, wizard_name='_main' + self, session, loader, parent_command, runner=None, wizard_name='_main' ): - super(TopLevelWizardCommand, self).__init__(session) + super().__init__(session) self._session = session self._loader = loader self._parent_command = parent_command self._runner = runner self._wizard_name = wizard_name + def _get_runner(self): + # If a runner was not provided during initialization, compute the + # default. This defers default computation to runtime, when the + # wizard command is actually invoked. The benefit of this is so that + # we defer importing `awscli.customizations.wizard.factory` until + # it's actually needed. The wizard factory module imports from + # `prompt_toolkit`, and importing `prompt_toolkit` while executing + # commands that don't actually use it has historically led to + # unnecessarily higher command execution time and wasted compute. + if self._runner is None: + from awscli.customizations.wizard import factory + + self._runner = { + '0.1': factory.create_default_wizard_v1_runner(self._session), + '0.2': factory.create_default_wizard_v2_runner(self._session), + } + return self._runner + def _build_subcommand_table(self): - subcommand_table = super( - TopLevelWizardCommand, self - )._build_subcommand_table() + subcommand_table = super()._build_subcommand_table() wizards = self._get_available_wizards() for name in wizards: cmd = SingleWizardCommand( self._session, self._loader, self._parent_command, - self._runner, + runner=self._runner, wizard_name=name, ) subcommand_table[name] = cmd @@ -96,11 +110,12 @@ def _run_wizard(self): self._parent_command, self._wizard_name ) version = loaded.get('version') - if version in self._runner: - self._runner[version].run(loaded) + runner = self._get_runner() + if version in runner: + runner[version].run(loaded) else: raise ParamValidationError( - 'Definition file has unsupported version %s ' % version + f'Definition file has unsupported version {version} ' ) def create_help_command(self): @@ -114,8 +129,12 @@ def create_help_command(self): class SingleWizardCommand(TopLevelWizardCommand): def __init__(self, session, loader, parent_command, runner, wizard_name): - super(SingleWizardCommand, self).__init__( - session, loader, parent_command, runner, wizard_name + super().__init__( + session, + loader, + parent_command, + runner=runner, + wizard_name=wizard_name, ) self._session = session self._loader = loader @@ -144,7 +163,5 @@ class WizardHelpCommand(BasicHelp): def __init__( self, session, command_object, command_table, arg_table, loaded_wizard ): - super(WizardHelpCommand, self).__init__( - session, command_object, command_table, arg_table - ) + super().__init__(session, command_object, command_table, arg_table) self._description = loaded_wizard.get('description', '') diff --git a/awscli/customizations/wizard/devcommands.py b/awscli/customizations/wizard/devcommands.py index d02ca7c101ba..aa0245fc181c 100644 --- a/awscli/customizations/wizard/devcommands.py +++ b/awscli/customizations/wizard/devcommands.py @@ -13,7 +13,6 @@ from ruamel.yaml import YAML from awscli.customizations.commands import BasicCommand -from awscli.customizations.wizard.factory import create_wizard_app def register_dev_commands(event_handlers): @@ -43,6 +42,8 @@ def __init__(self, wizard_loader, session): def run_wizard(self, wizard_contents): """Run a single wizard given the contents as a string.""" + from awscli.customizations.wizard.factory import create_wizard_app + loaded = self._wizard_loader.load(wizard_contents) app = create_wizard_app(loaded, self._session) app.run() @@ -66,7 +67,7 @@ class WizardDev(BasicCommand): ] def __init__(self, session, dev_runner=None): - super(WizardDev, self).__init__(session) + super().__init__(session) if dev_runner is None: dev_runner = create_default_wizard_dev_runner(session) self._dev_runner = dev_runner diff --git a/awscli/errorhandler.py b/awscli/errorhandler.py index 91a9f1f45dfd..e09d4c740220 100644 --- a/awscli/errorhandler.py +++ b/awscli/errorhandler.py @@ -26,7 +26,7 @@ from awscli.argparser import USAGE, ArgParseException from awscli.argprocess import ParamError, ParamSyntaxError from awscli.arguments import UnknownArgumentError -from awscli.autoprompt.factory import PrompterKeyboardInterrupt +from awscli.autoprompt.exceptions import PrompterKeyboardInterrupt from awscli.constants import ( CLIENT_ERROR_RC, CONFIGURATION_ERROR_RC, diff --git a/tests/functional/configure/test_sso.py b/tests/functional/configure/test_sso.py index f0ada15d4650..2a9afc2b33c4 100644 --- a/tests/functional/configure/test_sso.py +++ b/tests/functional/configure/test_sso.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from unittest.mock import patch -from awscli.customizations.configure.sso import ConfigureSSOCommand +from awscli.customizations.configure.sso_commands import ConfigureSSOCommand from awscli.testutils import BaseAWSCommandParamsTest, mock from tests.functional.sso import BaseSSOTest diff --git a/tests/functional/dependencies/test_lazy_imports.py b/tests/functional/dependencies/test_lazy_imports.py new file mode 100644 index 000000000000..549e9b2e5d81 --- /dev/null +++ b/tests/functional/dependencies/test_lazy_imports.py @@ -0,0 +1,86 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import subprocess +import sys + +import pytest + +# Commands that should never trigger a prompt_toolkit import. +BASIC_COMMANDS = [ + 's3 ls', + 'configure list', + 'logs describe-log-groups', + 'ecs describe-services', +] + +_CHECK_SCRIPT = """\ +import os +import sys + +from unittest.mock import patch +from awscli.botocore.awsrequest import AWSResponse +from awscli.clidriver import create_clidriver + +env = {{ + 'AWS_DATA_PATH': os.environ.get('AWS_DATA_PATH', ''), + 'AWS_DEFAULT_REGION': 'us-east-1', + 'AWS_ACCESS_KEY_ID': 'testing', + 'AWS_SECRET_ACCESS_KEY': 'testing', + 'AWS_CONFIG_FILE': '', + 'AWS_SHARED_CREDENTIALS_FILE': '', +}} +if os.environ.get('ComSpec'): + env['ComSpec'] = os.environ['ComSpec'] + +http_response = AWSResponse(None, 200, {{}}, None) + +with patch('os.environ', env), \\ + patch('awscli.botocore.endpoint.Endpoint.make_request', + return_value=(http_response, {{}})): + driver = create_clidriver() + try: + driver.main({args}) + except SystemExit: + pass + +mods = [m for m in sys.modules if m.startswith('prompt_toolkit')] +if mods: + print('FAIL: prompt_toolkit modules loaded: ' + ', '.join(sorted(mods))) + sys.exit(1) +else: + print('OK') +""" + + +@pytest.mark.parametrize('cmd', BASIC_COMMANDS) +def test_prompt_toolkit_not_imported(cmd): + # Historically prompt_toolkit has contributed to significant + # unnecessary initialization time. This test verifies it + # is not imported for a handful of commands for which we know + # it is not needed. + script = _CHECK_SCRIPT.format(args=repr(cmd.split())) + # Since prompt_toolkit might be imported during test-running, we execute + # the commands in a subprocess to ensure we are testing only the modules + # loaded for the commands under test. + # We apply a timeout to prevent the subprocess for silently stalling our + # scripts. + result = subprocess.run( + [sys.executable, '-c', script], + capture_output=True, + text=True, + timeout=60, + ) + assert result.returncode == 0, ( + f"prompt_toolkit was unexpectedly imported for 'aws {cmd}':\n" + f"stdout: {result.stdout}\nstderr: {result.stderr}" + ) diff --git a/tests/unit/autoprompt/test_core.py b/tests/unit/autoprompt/test_core.py index 176d0cc4f40c..64e6598eb3c7 100644 --- a/tests/unit/autoprompt/test_core.py +++ b/tests/unit/autoprompt/test_core.py @@ -28,70 +28,6 @@ def setUp(self): self.driver, prompter=self.prompter ) - def test_throw_error_if_both_args_specified(self): - args = ['--cli-auto-prompt', '--no-cli-auto-prompt'] - self.assertRaises( - ParamValidationError, - self.prompt_driver.validate_auto_prompt_args_are_mutually_exclusive, - args, - ) - - -def _generate_auto_prompt_resolve_cases(): - # Each case is a 5-namedtuple with the following meaning: - # "args" is a list of arguments that command got as input from - # command line - # "config_variable" is the result from get_config_variable - # This takes a value of either 'on' , 'off' or 'on-partial' - # "expected_result" is a boolean indicating whether auto-prompt - # should be used or not. - # - # Note: This set of tests assumes that only one of --no-cli-auto-prompt - # or --cli-auto-prompt overrides can be specified. - # TestCLIAutoPrompt.test_throw_error_if_both_args_specified tests - # that these command line overrides are mutually exclusive. - Case = namedtuple( - 'Case', - [ - 'args', - 'config_variable', - 'expected_result', - ], - ) - return [ - Case([], 'off', 'off'), - Case([], 'on', 'on'), - Case(['--cli-auto-prompt'], 'off', 'on'), - Case(['--cli-auto-prompt'], 'on', 'on'), - Case(['--no-cli-auto-prompt'], 'off', 'off'), - Case(['--no-cli-auto-prompt'], 'on', 'off'), - Case([], 'on', 'on'), - Case([], 'on-partial', 'on-partial'), - Case(['--cli-auto-prompt'], 'on-partial', 'on'), - Case(['--no-cli-auto-prompt'], 'on-partial', 'off'), - Case(['--version'], 'on', 'off'), - Case(['help'], 'on', 'off'), - ] - - -@pytest.mark.parametrize('case', _generate_auto_prompt_resolve_cases()) -def test_auto_prompt_resolve_mode(case): - driver = create_clidriver() - driver.session.set_config_variable('cli_auto_prompt', case.config_variable) - prompter = mock.Mock(spec=core.AutoPrompter) - prompt_driver = core.AutoPromptDriver(driver, prompter=prompter) - result = prompt_driver.resolve_mode(args=case.args) - assert result == case.expected_result - - -def test_auto_prompt_resolve_mode_on_non_existing_profile(): - driver = create_clidriver() - driver.session.set_config_variable('profile', 'not_exist') - prompter = mock.Mock(spec=core.AutoPrompter) - prompt_driver = core.AutoPromptDriver(driver, prompter=prompter) - result = prompt_driver.resolve_mode(args=[]) - assert result == 'off' - class TestAutoPrompter(unittest.TestCase): def setUp(self): diff --git a/tests/unit/customizations/configure/test_sso.py b/tests/unit/customizations/configure/test_sso.py index a40a73ee3b6b..27cd34cf3165 100644 --- a/tests/unit/customizations/configure/test_sso.py +++ b/tests/unit/customizations/configure/test_sso.py @@ -29,13 +29,15 @@ ) from awscli.customizations.configure.sso import ( - ConfigureSSOCommand, - ConfigureSSOSessionCommand, PTKPrompt, RequiredInputValidator, ScopesValidator, SSOSessionConfigurationPrompter, StartUrlValidator, +) +from awscli.customizations.configure.sso_commands import ( + ConfigureSSOCommand, + ConfigureSSOSessionCommand, display_account, get_account_sorting_key, ) @@ -363,12 +365,16 @@ def account_id_select(account_id): expected_choices=sorted( [ selected_account, - {"accountId": "1234567890", "emailAddress": "account2@site.com"}, + { + "accountId": "1234567890", + "emailAddress": "account2@site.com", + }, ], key=get_account_sorting_key, - ) + ), ) + @pytest.fixture def role_name_select(role_name): return SelectMenu(answer=role_name, expected_choices=[role_name, "roleB"]) @@ -636,7 +642,7 @@ class UserInput: @dataclasses.dataclass class Prompt(UserInput): expected_validator_cls: typing.Optional[Validator] = None - expected_completions: typing.Optional[typing.List[str]] = None + expected_completions: typing.Optional[list[str]] = None _expected_message: typing.Optional[str] = dataclasses.field( init=False, repr=False, default=None ) @@ -739,7 +745,7 @@ class ProfilePrompt(PromptWithDefault): @dataclasses.dataclass class SelectMenu(UserInput): - expected_choices: typing.Optional[typing.List[typing.Any]] = None + expected_choices: typing.Optional[list[typing.Any]] = None @dataclasses.dataclass @@ -2239,8 +2245,6 @@ def passes_validator(validator, text): (ScopesValidator, "value-1, value-2 value3", None, False), ], ) - - def test_validators(validator_cls, input_value, default, is_valid): validator = validator_cls(default) assert passes_validator(validator, input_value) == is_valid diff --git a/tests/unit/customizations/logs/test_startlivetail.py b/tests/unit/customizations/logs/test_ui.py similarity index 98% rename from tests/unit/customizations/logs/test_startlivetail.py rename to tests/unit/customizations/logs/test_ui.py index 6ff407eb2055..1ce28f567219 100644 --- a/tests/unit/customizations/logs/test_startlivetail.py +++ b/tests/unit/customizations/logs/test_ui.py @@ -1,4 +1,4 @@ -# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of @@ -15,12 +15,12 @@ import colorama from prompt_toolkit.application import Application from prompt_toolkit.buffer import Buffer +from prompt_toolkit.input import create_pipe_input from prompt_toolkit.key_binding import KeyPressEvent from prompt_toolkit.output import DummyOutput -from prompt_toolkit.input import create_pipe_input from awscli.compat import StringIO -from awscli.customizations.logs.startlivetail import ( +from awscli.customizations.logs.ui import ( COLOR_LIST, InputState, InteractivePrinter, @@ -247,9 +247,6 @@ def test_remove_color_from_string(self): log_event = "This is an INFO log" self.keyword = Keyword(text) - colored_log_event = self.keyword._add_color_to_string( - log_event, 11, 15 - ) uncolored_log_event = self.keyword._remove_color_from_string( log_event, 15, 19 ) @@ -608,8 +605,10 @@ def setUp(self) -> None: self.log_events = [] self.session_metadata = LiveTailSessionMetadata() self.ui = InteractiveUI( - self.log_events, self.session_metadata, app_output=DummyOutput(), - app_input=create_pipe_input() + self.log_events, + self.session_metadata, + app_output=DummyOutput(), + app_input=create_pipe_input(), ) def test_update_toolbar(self): diff --git a/tests/unit/test_clidriver.py b/tests/unit/test_clidriver.py index 645a27ef5c1b..cd5315212728 100644 --- a/tests/unit/test_clidriver.py +++ b/tests/unit/test_clidriver.py @@ -17,6 +17,7 @@ import platform import re import sys +from collections import namedtuple import awscrt.io import botocore.model @@ -43,9 +44,12 @@ ServiceOperation, construct_cli_error_handlers_chain, create_clidriver, + resolve_auto_prompt_mode, + validate_auto_prompt_args_are_mutually_exclusive, ) from awscli.compat import StringIO from awscli.customizations.commands import BasicCommand +from awscli.customizations.exceptions import ParamValidationError from awscli.paramfile import URIArgumentHandler from awscli.testutils import BaseAWSCommandParamsTest, mock, unittest @@ -177,6 +181,43 @@ } +def _generate_auto_prompt_resolve_cases(): + # Each case is a 5-namedtuple with the following meaning: + # "args" is a list of arguments that command got as input from + # command line + # "config_variable" is the result from get_config_variable + # This takes a value of either 'on' , 'off' or 'on-partial' + # "expected_result" is a boolean indicating whether auto-prompt + # should be used or not. + # + # Note: This set of tests assumes that only one of --no-cli-auto-prompt + # or --cli-auto-prompt overrides can be specified. + # TestCLIAutoPrompt.test_throw_error_if_both_args_specified tests + # that these command line overrides are mutually exclusive. + Case = namedtuple( + 'Case', + [ + 'args', + 'config_variable', + 'expected_result', + ], + ) + return [ + Case([], 'off', 'off'), + Case([], 'on', 'on'), + Case(['--cli-auto-prompt'], 'off', 'on'), + Case(['--cli-auto-prompt'], 'on', 'on'), + Case(['--no-cli-auto-prompt'], 'off', 'off'), + Case(['--no-cli-auto-prompt'], 'on', 'off'), + Case([], 'on', 'on'), + Case([], 'on-partial', 'on-partial'), + Case(['--cli-auto-prompt'], 'on-partial', 'on'), + Case(['--no-cli-auto-prompt'], 'on-partial', 'off'), + Case(['--version'], 'on', 'off'), + Case(['help'], 'on', 'off'), + ] + + class FakeSession: def __init__(self, emitter=None): self.operation = None @@ -413,6 +454,26 @@ def test_no_debug_disables_crt_logging(self, mock_init_logging): awscrt.io.LogLevel.NoLogs, ) + def test_throw_error_if_both_args_specified(self): + args = ['--cli-auto-prompt', '--no-cli-auto-prompt'] + with pytest.raises(ParamValidationError): + validate_auto_prompt_args_are_mutually_exclusive(args) + + @pytest.mark.parametrize('case', _generate_auto_prompt_resolve_cases()) + def test_auto_prompt_resolve_mode(self, case): + driver = create_clidriver() + driver.session.set_config_variable( + 'cli_auto_prompt', case.config_variable + ) + result = resolve_auto_prompt_mode(case.args, driver.session) + assert result == case.expected_result + + def test_auto_prompt_resolve_mode_on_non_existing_profile(self): + driver = create_clidriver() + driver.session.set_config_variable('profile', 'not_exist') + result = resolve_auto_prompt_mode([], driver.session) + assert result == 'off' + class TestCliDriverHooks(unittest.TestCase): # These tests verify the proper hooks are emitted in clidriver. @@ -1085,12 +1146,18 @@ def _create_fake_cli_driver(*args): self.driver.session.user_agent_extra = '' return self.driver - self.prompt_patch = mock.patch('awscli.clidriver.AutoPromptDriver') + self.prompt_patch = mock.patch( + 'awscli.autoprompt.core.AutoPromptDriver' + ) self.crete_driver_patch = mock.patch( 'awscli.clidriver.create_clidriver' ) + self.resolve_mode_patch = mock.patch( + 'awscli.clidriver.resolve_auto_prompt_mode' + ) prompt_driver_class = self.prompt_patch.start() self.create_clidriver = self.crete_driver_patch.start() + self.resolve_auto_prompt_mode = self.resolve_mode_patch.start() self.create_clidriver.side_effect = _create_fake_cli_driver self.prompt_driver = mock.Mock() prompt_driver_class.return_value = self.prompt_driver @@ -1098,9 +1165,10 @@ def _create_fake_cli_driver(*args): def tearDown(self): self.prompt_patch.stop() self.crete_driver_patch.stop() + self.resolve_mode_patch.stop() def test_recreate_driver_in_partial_mode_on_param_err(self): - self.prompt_driver.resolve_mode.return_value = 'on-partial' + self.resolve_auto_prompt_mode.return_value = 'on-partial' self.driver.main.return_value = 252 entry_point = awscli.clidriver.AWSCLIEntryPoint() rc = entry_point.main([]) @@ -1108,7 +1176,7 @@ def test_recreate_driver_in_partial_mode_on_param_err(self): self.assertEqual(rc, 252) def test_not_recreate_driver_in_partial_mode_on_success(self): - self.prompt_driver.resolve_mode.return_value = 'on-partial' + self.resolve_auto_prompt_mode.return_value = 'on-partial' self.driver.main.return_value = 0 entry_point = awscli.clidriver.AWSCLIEntryPoint() rc = entry_point.main([]) @@ -1116,7 +1184,7 @@ def test_not_recreate_driver_in_partial_mode_on_success(self): self.assertEqual(rc, 0) def test_not_recreate_driver_in_on_mode(self): - self.prompt_driver.resolve_mode.return_value = 'on' + self.resolve_auto_prompt_mode.return_value = 'on' self.driver.main.return_value = 252 entry_point = awscli.clidriver.AWSCLIEntryPoint() rc = entry_point.main([]) @@ -1124,7 +1192,7 @@ def test_not_recreate_driver_in_on_mode(self): self.assertEqual(rc, 252) def test_not_recreate_driver_in_off_mode(self): - self.prompt_driver.resolve_mode.return_value = 'off' + self.resolve_auto_prompt_mode.return_value = 'off' self.driver.main.return_value = 252 entry_point = awscli.clidriver.AWSCLIEntryPoint() rc = entry_point.main([]) @@ -1132,7 +1200,7 @@ def test_not_recreate_driver_in_off_mode(self): self.assertEqual(rc, 252) def test_handle_exception_in_main(self): - self.prompt_driver.resolve_mode.return_value = 'on' + self.resolve_auto_prompt_mode.return_value = 'on' self.prompt_driver.prompt_for_args.side_effect = Exception('error') entry_point = awscli.clidriver.AWSCLIEntryPoint() fake_stderr = io.StringIO() @@ -1142,21 +1210,21 @@ def test_handle_exception_in_main(self): self.assertIn('error', fake_stderr.getvalue()) def test_update_user_agent_in_on_mode(self): - self.prompt_driver.resolve_mode.return_value = 'on' + self.resolve_auto_prompt_mode.return_value = 'on' self.driver.main.return_value = 252 entry_point = awscli.clidriver.AWSCLIEntryPoint() entry_point.main([]) self.assertEqual(self.driver.session.user_agent_extra, 'md/prompt#on') def test_not_update_user_agent_in_off_mode(self): - self.prompt_driver.resolve_mode.return_value = 'off' + self.resolve_auto_prompt_mode.return_value = 'off' self.driver.main.return_value = 252 entry_point = awscli.clidriver.AWSCLIEntryPoint() entry_point.main([]) self.assertEqual(self.driver.session.user_agent_extra, 'md/prompt#off') def test_update_user_agent_in_partial_mode_on_param_err(self): - self.prompt_driver.resolve_mode.return_value = 'on-partial' + self.resolve_auto_prompt_mode.return_value = 'on-partial' self.driver.main.return_value = 252 entry_point = awscli.clidriver.AWSCLIEntryPoint() entry_point.main([]) @@ -1165,7 +1233,7 @@ def test_update_user_agent_in_partial_mode_on_param_err(self): ) def test_not_update_user_agent_in_partial_mode_on_success(self): - self.prompt_driver.resolve_mode.return_value = 'on-partial' + self.resolve_auto_prompt_mode.return_value = 'on-partial' self.driver.main.return_value = 0 entry_point = awscli.clidriver.AWSCLIEntryPoint() entry_point.main([])