Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 24 additions & 53 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,11 @@ def __init__(
# Load config.
config_files: list[str | IO[str]] = self.system_config_files + [myclirc] + [self.pwd_config_file]
c = self.config = read_config_files(config_files)
# this parallel config exists to
# * compare with my.cnf
# * support the --checkup feature
# this parallel config exists only to compare with my.cnf and can be removed with my.cnf support
self.config_without_package_defaults = read_config_files(config_files, ignore_package_defaults=True)
for toplevel in ['main', 'connection']:
if not self.config_without_package_defaults.get(toplevel):
self.config_without_package_defaults[toplevel] = {}
self.multi_line = c["main"].as_bool("multi_line")
self.key_bindings = c["main"]["key_bindings"]
special.set_timing_enabled(c["main"].as_bool("timing"))
Expand All @@ -168,6 +169,7 @@ def __init__(
self.post_redirect_command = c['main'].get('post_redirect_command')
self.null_string = c['main'].get('null_string')
self.numeric_alignment = c['main'].get('numeric_alignment', 'right')
self.binary_display = c['main'].get('binary_display')

# set ssl_mode if a valid option is provided in a config file, otherwise None
ssl_mode = c["main"].get("ssl_mode", None)
Expand Down Expand Up @@ -519,15 +521,14 @@ def connect(
host = host or cnf["host"]
port = port or cnf["port"]
ssl_config: dict[str, Any] = ssl or {}
user_connection_config = self.config_without_package_defaults.get('connection', {})

int_port = port and int(port)
if not int_port:
int_port = 3306
if not host or host == "localhost":
socket = (
socket
or user_connection_config.get("default_socket")
or self.config_without_package_defaults["connection"].get("default_socket")
or cnf["socket"]
or cnf["default_socket"]
or guess_socket_location()
Expand All @@ -554,7 +555,7 @@ def connect(
use_local_infile = False
for local_infile_option in (
local_infile,
user_connection_config.get('default_local_infile'),
self.config_without_package_defaults['connection'].get('default_local_infile'),
cnf['local_infile'],
cnf['local-infile'],
cnf['loose_local_infile'],
Expand All @@ -568,16 +569,16 @@ def connect(
pass

# temporary my.cnf override mappings
if 'default_ssl_ca' in user_connection_config:
cnf['ssl-ca'] = user_connection_config.get('default_ssl_ca') or None
if 'default_ssl_cert' in user_connection_config:
cnf['ssl-cert'] = user_connection_config.get('default_ssl_cert') or None
if 'default_ssl_key' in user_connection_config:
cnf['ssl-key'] = user_connection_config.get('default_ssl_key') or None
if 'default_ssl_cipher' in user_connection_config:
cnf['ssl-cipher'] = user_connection_config.get('default_ssl_cipher') or None
if 'default_ssl_verify_server_cert' in user_connection_config:
cnf['ssl-verify-server-cert'] = user_connection_config.get('default_ssl_verify_server_cert') or None
if 'default_ssl_ca' in self.config_without_package_defaults['connection']:
cnf['ssl-ca'] = self.config_without_package_defaults['connection']['default_ssl_ca'] or None
if 'default_ssl_cert' in self.config_without_package_defaults['connection']:
cnf['ssl-cert'] = self.config_without_package_defaults['connection']['default_ssl_cert'] or None
if 'default_ssl_key' in self.config_without_package_defaults['connection']:
cnf['ssl-key'] = self.config_without_package_defaults['connection']['default_ssl_key'] or None
if 'default_ssl_cipher' in self.config_without_package_defaults['connection']:
cnf['ssl-cipher'] = self.config_without_package_defaults['connection']['default_ssl_cipher'] or None
if 'default_ssl_verify_server_cert' in self.config_without_package_defaults['connection']:
cnf['ssl-verify-server-cert'] = self.config_without_package_defaults['connection']['default_ssl_verify_server_cert'] or None

# todo: rewrite the merge method using self.config['connection'] instead of cnf, after removing my.cnf support
ssl_config_or_none: dict[str, Any] | None = self.merge_ssl_with_cnf(ssl_config, cnf)
Expand Down Expand Up @@ -888,6 +889,7 @@ def output_res(results: Generator[SQLResult], start: float) -> None:
special.is_redirected(),
self.null_string,
self.numeric_alignment,
self.binary_display,
max_width,
)

Expand Down Expand Up @@ -926,6 +928,7 @@ def output_res(results: Generator[SQLResult], start: float) -> None:
special.is_redirected(),
self.null_string,
self.numeric_alignment,
self.binary_display,
max_width,
)
self.echo("")
Expand Down Expand Up @@ -1404,6 +1407,7 @@ def run_query(
special.is_redirected(),
self.null_string,
self.numeric_alignment,
self.binary_display,
)
for line in output:
self.log_output(line)
Expand All @@ -1424,6 +1428,7 @@ def run_query(
special.is_redirected(),
self.null_string,
self.numeric_alignment,
self.binary_display,
)
for line in output:
click.echo(line, nl=new_line)
Expand All @@ -1440,6 +1445,7 @@ def format_output(
is_redirected: bool = False,
null_string: str | None = None,
numeric_alignment: str = 'right',
binary_display: str | None = None,
max_width: int | None = None,
) -> itertools.chain[str]:
if is_redirected:
Expand All @@ -1461,7 +1467,7 @@ def format_output(
if null_string is not None and default_kwargs.get('missing_value') == DEFAULT_MISSING_VALUE:
output_kwargs['missing_value'] = null_string

if use_formatter.format_name not in sql_format.supported_formats:
if use_formatter.format_name not in sql_format.supported_formats and binary_display != 'utf8':
# will run before preprocessors defined as part of the format in cli_helpers
output_kwargs["preprocessors"] = (preprocessors.convert_to_undecoded_string,)

Expand Down Expand Up @@ -1624,7 +1630,6 @@ def get_last_query(self) -> str | None:
default=None,
help='Store and retrieve passwords from the system keyring: true/false/reset.',
)
@click.option("--checkup", is_flag=True, help="Run a checkup on your config file.")
@click.pass_context
def cli(
ctx: click.Context,
Expand Down Expand Up @@ -1678,7 +1683,6 @@ def cli(
batch_format: str | None,
throttle: float,
use_keyring_cli_opt: str | None,
checkup: bool,
) -> None:
"""A MySQL terminal client with auto-completion and syntax highlighting.

Expand Down Expand Up @@ -1745,10 +1749,6 @@ def get_password_from_file(password_file: str | None) -> str | None:
myclirc=myclirc,
)

if checkup:
do_config_checkup(mycli)
sys.exit(0)

if csv and batch_format not in [None, 'csv']:
click.secho("Conflicting --csv and --format arguments.", err=True, fg="red")
sys.exit(1)
Expand Down Expand Up @@ -1999,8 +1999,7 @@ def get_password_from_file(password_file: str | None) -> str | None:
and mycli.my_cnf[mycnf_section_name].get(mycnf_item_name.replace('-', '_')) is None
):
continue
user_section = mycli.config_without_package_defaults.get(myclirc_section_name, {})
if user_section.get(myclirc_item_name) is None:
if mycli.config_without_package_defaults[myclirc_section_name].get(myclirc_item_name) is None:
cnf_value = mycli.my_cnf[mycnf_section_name].get(mycnf_item_name)
if cnf_value is None:
cnf_value = mycli.my_cnf[mycnf_section_name].get(mycnf_item_name.replace('-', '_'))
Expand Down Expand Up @@ -2224,33 +2223,5 @@ def read_ssh_config(ssh_config_path: str):
return ssh_config


def do_config_checkup(mycli: MyCli) -> None:
did_output = False

if not list(mycli.config.keys()):
print('\nThe local ~/,myclirc is missing or empty.\n')
did_output = True
else:
for section_name in mycli.config.keys():
if section_name not in mycli.config_without_package_defaults:
if not did_output:
print('\nMissing in user ~/.myclirc:\n')
print(f'The entire section:\n\n [{section_name}]\n')
did_output = True
continue
for item_name in mycli.config[section_name]:
if item_name not in mycli.config_without_package_defaults[section_name]:
if not did_output:
print('\nMissing in user ~/.myclirc:\n')
print(f'The item:\n\n [{section_name}]\n {item_name} =\n')
did_output = True
if did_output:
print(
'For more info on new features, see the commentary and defaults at:\n\n * https://github.com/dbcli/mycli/blob/main/mycli/myclirc\n'
)
else:
print('User configuration all up to date!')


if __name__ == "__main__":
cli()
5 changes: 5 additions & 0 deletions mycli/myclirc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ null_string = <null>
# How to align numeric data in tabular output: right or left.
numeric_alignment = right

# How to display binary values in tabular output: "hex", or "utf8". "utf8"
# means attempt to render valid UTF-8 sequences as strings, then fall back
# to hex rendering if not possible.
binary_display = hex

# A command to run after a successful output redirect, with {} to be replaced
# with the escaped filename. Mac example: echo {} | pbcopy. Escaping is not
# reliable/safe on Windows.
Expand Down
5 changes: 5 additions & 0 deletions test/myclirc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ null_string = <nope>
# How to align numeric data in tabular output: right or left.
numeric_alignment = right

# How to display binary values in tabular output: "hex", or "utf8". "utf8"
# means attempt to render valid UTF-8 sequences as strings, then fall back
# to hex rendering if not possible.
binary_display = hex

# A command to run after a successful output redirect, with {} to be replaced
# with the escaped filename. Mac example: echo {} | pbcopy. Escaping is not
# reliable/safe on Windows.
Expand Down
78 changes: 78 additions & 0 deletions test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,84 @@
]


@dbtest
def test_binary_display_hex(executor, capsys):
m = MyCli()
m.sqlexecute = SQLExecute(
None,
USER,
PASSWORD,
HOST,
PORT,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
m.explicit_pager = False
sqlresult = next(m.sqlexecute.run("select b'01101010' AS binary_test"))
formatted = m.format_output(
sqlresult.title,
sqlresult.results,
sqlresult.headers,
False,
False,
"<nope>",
"right",
"hex",
None,
)
m.output(formatted, sqlresult.status)
expected = "+-------------+\n| binary_test |\n+-------------+\n| 0x6a |\n+-------------+\n1 row in set\n"
stdout = capsys.readouterr().out
assert expected in stdout


@dbtest
def test_binary_display_utf8(executor, capsys):
m = MyCli()
m.sqlexecute = SQLExecute(
None,
USER,
PASSWORD,
HOST,
PORT,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
m.explicit_pager = False
sqlresult = next(m.sqlexecute.run("select b'01101010' AS binary_test"))
formatted = m.format_output(
sqlresult.title,
sqlresult.results,
sqlresult.headers,
False,
False,
"<nope>",
"right",
"utf8",
None,
)
m.output(formatted, sqlresult.status)
expected = "+-------------+\n| binary_test |\n+-------------+\n| j |\n+-------------+\n1 row in set\n"
stdout = capsys.readouterr().out
assert expected in stdout


@dbtest
def test_select_from_empty_table(executor):
run(executor, """create table t1(id int)""")
Expand Down