Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Features
* Make `--progress` and `--checkpoint` strictly by statement.
* Allow more characters in passwords read from a file.
* Show sponsors and contributors separately in startup messages.
* Add support for expired password (sandbox) mode (#440)


Bug Fixes
Expand Down
4 changes: 4 additions & 0 deletions mycli/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@

DEFAULT_WIDTH = 80
DEFAULT_HEIGHT = 25

# MySQL error codes not available in pymysql.constants.ER
ER_MUST_CHANGE_PASSWORD_LOGIN = 1862
ER_MUST_CHANGE_PASSWORD = 1820
18 changes: 18 additions & 0 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
DEFAULT_HOST,
DEFAULT_PORT,
DEFAULT_WIDTH,
ER_MUST_CHANGE_PASSWORD_LOGIN,
ISSUES_URL,
REPO_URL,
)
Expand Down Expand Up @@ -152,6 +153,7 @@ def __init__(
self.prompt_session: PromptSession | None = None
self._keepalive_counter = 0
self.keepalive_ticks: int | None = 0
self.sandbox_mode: bool = False

# self.cnf_files is a class variable that stores the list of mysql
# config files to read in at launch.
Expand Down Expand Up @@ -750,6 +752,13 @@ def _connect(
keyring_retrieved_cleanly=keyring_retrieved_cleanly,
keyring_save_eligible=keyring_save_eligible,
)
elif e1.args[0] == ER_MUST_CHANGE_PASSWORD_LOGIN:
self.echo(
"Your password has expired and the server rejected the connection.",
err=True,
fg='red',
)
raise e1
elif e1.args[0] == CR_SERVER_LOST:
self.echo(
(
Expand Down Expand Up @@ -803,6 +812,15 @@ def _connect(
sys.exit(1)

_connect(keyring_retrieved_cleanly=keyring_retrieved_cleanly)

# Check if SQLExecute detected sandbox mode during connection
if self.sqlexecute and self.sqlexecute.sandbox_mode:
self.sandbox_mode = True
self.echo(
"Your password has expired. Use ALTER USER to set a new password, or quit.",
err=True,
fg='yellow',
)
except Exception as e: # Connecting to a database could fail.
self.logger.debug("Database connection failed: %r.", e)
self.logger.error("traceback: %r", traceback.format_exc())
Expand Down
99 changes: 89 additions & 10 deletions mycli/main_modes/repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from mycli.constants import (
DEFAULT_HOST,
DEFAULT_WIDTH,
ER_MUST_CHANGE_PASSWORD,
HOME_URL,
ISSUES_URL,
)
Expand Down Expand Up @@ -132,7 +133,8 @@ def _show_startup_banner(
if mycli.less_chatty:
return

print(sqlexecute.server_info)
if sqlexecute.server_info is not None:
print(sqlexecute.server_info)
print('mycli', mycli_package.__version__)
print(SUPPORT_INFO)
if random.random() <= 0.25:
Expand Down Expand Up @@ -232,8 +234,6 @@ def get_prompt(
) -> str:
sqlexecute = mycli.sqlexecute
assert sqlexecute is not None
assert sqlexecute.server_info is not None
assert sqlexecute.server_info.species is not None
if mycli.login_path and mycli.login_path_as_host:
prompt_host = mycli.login_path
elif sqlexecute.host is not None:
Expand All @@ -250,7 +250,8 @@ def get_prompt(
string = string.replace('\\h', prompt_host or '(none)')
string = string.replace('\\H', short_prompt_host or '(none)')
string = string.replace('\\d', sqlexecute.dbname or '(none)')
string = string.replace('\\t', sqlexecute.server_info.species.name)
species_name = sqlexecute.server_info.species.name if sqlexecute.server_info and sqlexecute.server_info.species else 'MySQL'
string = string.replace('\\t', species_name)
string = string.replace('\\n', '\n')
string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y'))
string = string.replace('\\m', now.strftime('%M'))
Expand Down Expand Up @@ -518,6 +519,52 @@ def _build_prompt_session(
mycli.prompt_session.app.ttimeoutlen = mycli.emacs_ttimeoutlen


_SANDBOX_ALLOWED_RE = re.compile(
r'^\s*(ALTER\s+USER|SET\s+PASSWORD|QUIT|EXIT|\\q)\b',
re.IGNORECASE,
)

_PASSWORD_CHANGE_RE = re.compile(
r'^\s*(ALTER\s+USER|SET\s+PASSWORD)\b',
re.IGNORECASE,
)


def _is_sandbox_allowed(text: str) -> bool:
"""Return True if the command is allowed in expired-password sandbox mode."""
stripped = text.strip()
if not stripped:
return True
return bool(_SANDBOX_ALLOWED_RE.match(stripped))


def _is_password_change(text: str) -> bool:
"""Return True if the command is a password change statement."""
return bool(_PASSWORD_CHANGE_RE.match(text.strip()))


_IDENTIFIED_BY_RE = re.compile(
r"IDENTIFIED\s+BY\s+'([^']*)'",
re.IGNORECASE,
)

_SET_PASSWORD_RE = re.compile(
r"SET\s+PASSWORD\s*=\s*'([^']*)'",
re.IGNORECASE,
)


def _extract_new_password(text: str) -> str | None:
"""Extract the new password from an ALTER USER or SET PASSWORD statement."""
m = _IDENTIFIED_BY_RE.search(text)
if m:
return m.group(1)
m = _SET_PASSWORD_RE.search(text)
if m:
return m.group(1)
return None


def _one_iteration(
mycli: 'MyCli',
state: ReplState,
Expand Down Expand Up @@ -615,6 +662,14 @@ def _one_iteration(
mycli.echo(str(e), err=True, fg='red')
return

if mycli.sandbox_mode and not _is_sandbox_allowed(text):
mycli.echo(
"ERROR 1820: You must reset your password using ALTER USER or SET PASSWORD before executing this statement.",
err=True,
fg='red',
)
return

if mycli.destructive_warning:
destroy = confirm_destructive_query(mycli.destructive_keywords, text)
if destroy is None:
Expand Down Expand Up @@ -674,20 +729,44 @@ def _one_iteration(
mycli.echo('Not Yet Implemented.', fg='yellow')
except pymysql.OperationalError as e1:
mycli.logger.debug('Exception: %r', e1)
if e1.args[0] in (2003, 2006, 2013):
if e1.args[0] == ER_MUST_CHANGE_PASSWORD:
mycli.sandbox_mode = True
mycli.echo(
"ERROR 1820: You must reset your password using ALTER USER or SET PASSWORD before executing this statement.",
err=True,
fg='red',
)
elif e1.args[0] in (2003, 2006, 2013):
if not mycli.reconnect():
return
_one_iteration(mycli, state, text)
return

mycli.logger.error('sql: %r, error: %r', text, e1)
mycli.logger.error('traceback: %r', traceback.format_exc())
mycli.echo(str(e1), err=True, fg='red')
else:
mycli.logger.error('sql: %r, error: %r', text, e1)
mycli.logger.error('traceback: %r', traceback.format_exc())
mycli.echo(str(e1), err=True, fg='red')
except Exception as e:
mycli.logger.error('sql: %r, error: %r', text, e)
mycli.logger.error('traceback: %r', traceback.format_exc())
mycli.echo(str(e), err=True, fg='red')
else:
if mycli.sandbox_mode and _is_password_change(text):
new_password = _extract_new_password(text)
if new_password is not None:
sqlexecute.password = new_password
try:
sqlexecute.connect()
mycli.sandbox_mode = False
mycli.echo("Password changed successfully. Reconnected.", err=True, fg='green')
mycli.refresh_completions()
except Exception as e:
mycli.sandbox_mode = False
mycli.echo(
f"Password changed but reconnection failed: {e}\nPlease restart mycli with your new password.",
err=True,
fg='yellow',
)

if is_dropping_database(text, sqlexecute.dbname):
sqlexecute.dbname = None
sqlexecute.connect()
Expand Down Expand Up @@ -756,7 +835,7 @@ def main_repl(mycli: 'MyCli') -> None:
state = ReplState()

mycli.configure_pager()
if mycli.smart_completion:
if mycli.smart_completion and not mycli.sandbox_mode:
mycli.refresh_completions()

history = _create_history(mycli)
Expand Down
84 changes: 61 additions & 23 deletions mycli/sqlexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pymysql.converters import conversions, convert_date, convert_datetime, convert_time, decoders
from pymysql.cursors import Cursor

from mycli.constants import ER_MUST_CHANGE_PASSWORD
from mycli.packages.special import iocommands
from mycli.packages.special.main import CommandNotFound, execute
from mycli.packages.sqlresult import SQLResult
Expand Down Expand Up @@ -280,32 +281,50 @@ def connect(
client_flag = pymysql.constants.CLIENT.INTERACTIVE
if init_command and len(list(iocommands.split_queries(init_command))) > 1:
client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS
client_flag |= pymysql.constants.CLIENT.HANDLE_EXPIRED_PASSWORDS

ssl_context = None
if ssl:
ssl_context = self._create_ssl_ctx(ssl)

conn = pymysql.connect(
database=db,
user=user,
password=password or '',
host=host,
port=port or 0,
unix_socket=socket,
use_unicode=True,
charset=character_set or '',
autocommit=True,
client_flag=client_flag,
local_infile=local_infile or False,
conv=conv,
ssl=ssl_context, # type: ignore[arg-type]
program_name="mycli",
defer_connect=defer_connect,
init_command=init_command or None,
cursorclass=pymysql.cursors.SSCursor if unbuffered else pymysql.cursors.Cursor,
) # type: ignore[misc]
connect_kwargs: dict[str, Any] = {
"database": db,
"user": user,
"password": password or '',
"host": host,
"port": port or 0,
"unix_socket": socket,
"use_unicode": True,
"charset": character_set or '',
"autocommit": True,
"client_flag": client_flag,
"local_infile": local_infile or False,
"conv": conv,
"ssl": ssl_context, # type: ignore[arg-type]
"program_name": "mycli",
"defer_connect": defer_connect,
"init_command": init_command or None,
"cursorclass": pymysql.cursors.SSCursor if unbuffered else pymysql.cursors.Cursor,
}

self.sandbox_mode = False
try:
conn = pymysql.connect(**connect_kwargs) # type: ignore[misc]
except pymysql.OperationalError as e:
if e.args[0] == ER_MUST_CHANGE_PASSWORD:
# Post-handshake queries (SET NAMES, SET AUTOCOMMIT, init_command)
# fail with ER_MUST_CHANGE_PASSWORD in sandbox mode.
# Reconnect with only the raw handshake.
connect_kwargs['defer_connect'] = True
connect_kwargs['autocommit'] = None
connect_kwargs['init_command'] = None
conn = pymysql.connect(**connect_kwargs) # type: ignore[misc]
self._connect_sandbox(conn)
self.sandbox_mode = True
else:
raise

if ssh_host:
if ssh_host and not self.sandbox_mode:
##### paramiko.Channel is a bad socket implementation overall if you want SSL through an SSH tunnel
#####
# instead let's open a tunnel and rewrite host:port to local bind
Expand Down Expand Up @@ -343,9 +362,10 @@ def connect(
self.ssl = ssl
self.init_command = init_command
self.unbuffered = unbuffered
# retrieve connection id
self.reset_connection_id()
self.server_info = ServerInfo.from_version_string(conn.server_version) # type: ignore[attr-defined]
# retrieve connection id (skip in sandbox mode as queries will fail)
if not self.sandbox_mode:
self.reset_connection_id()
self.server_info = ServerInfo.from_version_string(conn.server_version) # type: ignore[attr-defined]

def run(self, statement: str) -> Generator[SQLResult, None, None]:
"""Execute the sql in the database and return the results."""
Expand Down Expand Up @@ -576,6 +596,24 @@ def change_db(self, db: str) -> None:
self.conn.select_db(db)
self.dbname = db

@staticmethod
def _connect_sandbox(conn: Connection) -> None:
"""Connect in sandbox mode, performing only the handshake.

pymysql's normal connect() runs post-handshake queries (SET NAMES,
SET AUTOCOMMIT, init_command) that all fail with ER_MUST_CHANGE_PASSWORD
in sandbox mode. This method performs the raw socket connection and
authentication handshake only.
"""
# Reuse pymysql internals for the handshake + auth, but
# temporarily stub out set_character_set so it becomes a no-op.
original_set_charset = conn.set_character_set
conn.set_character_set = lambda *_args, **_kwargs: None # type: ignore[assignment]
try:
conn.connect()
finally:
conn.set_character_set = original_set_charset # type: ignore[assignment]

def _create_ssl_ctx(self, sslp: dict) -> ssl.SSLContext:
ca = sslp.get("ca")
capath = sslp.get("capath")
Expand Down
Loading
Loading