Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions alembic/autogenerate/compare/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import TYPE_CHECKING

from . import check_constraints
from . import comments
from . import constraints
from . import schema
Expand Down Expand Up @@ -60,3 +61,6 @@ def _produce_net_changes(
server_defaults, "alembic.autogenerate.defaults"
)
Plugin.setup_plugin_from_module(comments, "alembic.autogenerate.comments")
Plugin.setup_plugin_from_module(
check_constraints, "alembic.ext.checkconstraint"
)
186 changes: 186 additions & 0 deletions alembic/autogenerate/compare/check_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# mypy: allow-untyped-defs, allow-untyped-calls, allow-incomplete-defs

from __future__ import annotations

import logging
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union

from sqlalchemy import schema as sa_schema

from .util import _InspectorConv
from ...operations import ops
from ...util import PriorityDispatchResult
from ...util import sqla_compat

if TYPE_CHECKING:
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import CheckConstraint
from sqlalchemy.sql.schema import Table

from ...autogenerate.api import AutogenContext
from ...ddl.impl import DefaultImpl
from ...operations.ops import ModifyTableOps
from ...runtime.plugins import Plugin


log = logging.getLogger(__name__)


def _make_check_constraint(
impl: DefaultImpl,
params: dict,
conn_table: Table,
) -> CheckConstraint:
const = sa_schema.CheckConstraint(
params["sqltext"],
name=params["name"],
**impl.adjust_reflected_dialect_options(params, "check_constraint"),
)
conn_table.append_constraint(const)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Federico Caselli (CaselIT) wrote:

I don't think this is needed.

View this in Gerrit at https://gerrit.sqlalchemy.org/c/sqlalchemy/alembic/+/6672

return const


def _compare_check_constraints(
autogen_context: AutogenContext,
modify_table_ops: ModifyTableOps,
schema: Optional[str],
tname: Union[quoted_name, str],
conn_table: Optional[Table],
metadata_table: Optional[Table],
) -> PriorityDispatchResult:
if conn_table is None or metadata_table is None:
return PriorityDispatchResult.CONTINUE

inspector = autogen_context.inspector
impl = autogen_context.migration_context.impl

metadata_ck_constraints = {
ck
for ck in metadata_table.constraints
if isinstance(ck, sa_schema.CheckConstraint)
and not sqla_compat._is_type_bound(ck)
}

try:
conn_ck_list = _InspectorConv(inspector).get_check_constraints(
tname, schema=schema
)
except NotImplementedError:
return PriorityDispatchResult.CONTINUE

conn_ck_list = [
ck
for ck in conn_ck_list
if ck.get("name") is not None
and autogen_context.run_name_filters(
ck["name"],
"check_constraint",
{"table_name": tname, "schema_name": schema},
)
]

conn_ck_objs = {
_make_check_constraint(impl, ck_def, conn_table)
for ck_def in conn_ck_list
}

metadata_ck_sig = {
impl._create_metadata_constraint_sig(ck)
for ck in metadata_ck_constraints
if sqla_compat._constraint_is_named(ck, autogen_context.dialect)
}

conn_ck_sig = {
impl._create_reflected_constraint_sig(ck) for ck in conn_ck_objs
}

metadata_ck_by_name = {c.name: c for c in metadata_ck_sig if c.name}
conn_ck_by_name = {c.name: c for c in conn_ck_sig if c.name}

for removed_name in sorted(
set(conn_ck_by_name).difference(metadata_ck_by_name)
):
conn_obj = conn_ck_by_name[removed_name]
if autogen_context.run_object_filters(
conn_obj.const,
conn_obj.name,
"check_constraint",
True,
None,
):
modify_table_ops.ops.append(
ops.DropConstraintOp.from_constraint(conn_obj.const)
)
log.info(
"Detected removed check constraint %r on table %r",
conn_obj.name,
tname,
)

for existing_name in sorted(
set(metadata_ck_by_name).intersection(conn_ck_by_name)
):
metadata_obj = metadata_ck_by_name[existing_name]
conn_obj = conn_ck_by_name[existing_name]

comparison = metadata_obj.compare_to_reflected(conn_obj)

if comparison.is_different:
if autogen_context.run_object_filters(
metadata_obj.const,
metadata_obj.name,
"check_constraint",
False,
conn_obj.const,
):
log.info(
"Detected changed check constraint %r on table %r: %s",
existing_name,
tname,
comparison.message,
)
modify_table_ops.ops.append(
ops.DropConstraintOp.from_constraint(conn_obj.const)
)
modify_table_ops.ops.append(
ops.AddConstraintOp.from_constraint(metadata_obj.const)
)
elif comparison.is_skip:
log.info(
"Cannot compare check constraint %r, "
"assuming equal and skipping. %s",
existing_name,
comparison.message,
)

for added_name in sorted(
set(metadata_ck_by_name).difference(conn_ck_by_name)
):
metadata_obj = metadata_ck_by_name[added_name]
if autogen_context.run_object_filters(
metadata_obj.const,
metadata_obj.name,
"check_constraint",
False,
None,
):
modify_table_ops.ops.append(
ops.AddConstraintOp.from_constraint(metadata_obj.const)
)
log.info(
"Detected added check constraint %r on table %r",
metadata_obj.name,
tname,
)

return PriorityDispatchResult.CONTINUE


def setup(plugin: Plugin) -> None:
plugin.add_autogenerate_comparator(
_compare_check_constraints,
"table",
"checkconstraints",
)
25 changes: 25 additions & 0 deletions alembic/autogenerate/compare/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
if TYPE_CHECKING:
from sqlalchemy import Table
from sqlalchemy.engine import Inspector
from sqlalchemy.engine.interfaces import ReflectedCheckConstraint
from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint
from sqlalchemy.engine.interfaces import ReflectedIndex
from sqlalchemy.engine.interfaces import ReflectedUniqueConstraint
Expand Down Expand Up @@ -78,6 +79,11 @@ def get_foreign_keys(
) -> list[ReflectedForeignKeyConstraint]:
raise NotImplementedError()

def get_check_constraints(
self, tname: str, schema: str | None
) -> list[ReflectedCheckConstraint]:
raise NotImplementedError()

def reflect_table(self, table: Table) -> None:
raise NotImplementedError()

Expand Down Expand Up @@ -123,6 +129,13 @@ def get_foreign_keys(
self.inspector.get_foreign_keys(tname, schema=schema)
)

def get_check_constraints(
self, tname: str, schema: str | None
) -> list[ReflectedCheckConstraint]:
return self._apply_reflectinfo_conv(
self.inspector.get_check_constraints(tname, schema=schema)
)

def reflect_table(self, table: Table) -> None:
self.inspector.reflect_table(table, include_columns=None)

Expand Down Expand Up @@ -252,6 +265,18 @@ def get_foreign_keys(
apply_constraint_conv=True,
)

def get_check_constraints(
self, tname: str, schema: str | None
) -> list[ReflectedCheckConstraint]:
return self._return_from_cache(
tname,
schema,
"alembic_check_constraints",
self.inspector.get_check_constraints,
apply_constraint_conv=True,
optional=False,
)

def _apply_reflectinfo_conv(self, consts):
if not consts:
return consts
Expand Down
20 changes: 18 additions & 2 deletions alembic/autogenerate/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,24 @@ def _add_pk_constraint(constraint, autogen_context):


@renderers.dispatch_for(ops.CreateCheckConstraintOp)
def _add_check_constraint(constraint, autogen_context):
raise NotImplementedError()
def _add_check_constraint(
autogen_context: AutogenContext, op: ops.CreateCheckConstraintOp
) -> str:
constraint = op.to_constraint()
args = [repr(_render_gen_name(autogen_context, op.constraint_name))]
if not autogen_context._has_batch:
args.append(repr(_ident(op.table_name)))
args.append(
_render_potential_expr(
constraint.sqltext, autogen_context, wrap_in_element=False
)
)
if not autogen_context._has_batch and op.schema:
args.append("schema=%r" % _ident(op.schema))
return "%(prefix)screate_check_constraint(%(args)s)" % {
"prefix": _alembic_autogenerate_prefix(autogen_context),
"args": ", ".join(args),
}


@renderers.dispatch_for(ops.DropConstraintOp)
Expand Down
35 changes: 35 additions & 0 deletions alembic/ddl/_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import TypeVar
from typing import Union

from sqlalchemy.sql.schema import CheckConstraint
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
Expand Down Expand Up @@ -86,6 +87,7 @@ class _constraint_sig(Generic[_C]):
_is_index: ClassVar[bool] = False
_is_fk: ClassVar[bool] = False
_is_uq: ClassVar[bool] = False
_is_ck: ClassVar[bool] = False

_is_metadata: bool

Expand Down Expand Up @@ -325,5 +327,38 @@ def is_uq_sig(sig: _constraint_sig) -> TypeGuard[_uq_constraint_sig]:
return sig._is_uq


class _ck_constraint_sig(_constraint_sig[CheckConstraint]):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Federico Caselli (CaselIT) wrote:

nit: let's move this before the other typeguards

View this in Gerrit at https://gerrit.sqlalchemy.org/c/sqlalchemy/alembic/+/6672

_is_ck = True

@classmethod
def _register(cls) -> None:
_clsreg["check_constraint"] = cls
_clsreg["table_or_column_check_constraint"] = cls
_clsreg["column_check_constraint"] = cls

def __init__(
self,
is_metadata: bool,
impl: DefaultImpl,
const: CheckConstraint,
) -> None:
self._is_metadata = is_metadata
self.impl = impl
self.const = const
self.name = sqla_compat.constraint_name_or_none(const.name)
self._sig = (self.name,)

def _compare_to_reflected(
self, other: _constraint_sig[_C]
) -> ComparisonResult:
assert self._is_metadata
assert is_ck_sig(other)
return self.impl.compare_check_constraint(self.const, other.const)


def is_ck_sig(sig: _constraint_sig) -> TypeGuard[_ck_constraint_sig]:
return sig._is_ck


def is_fk_sig(sig: _constraint_sig) -> TypeGuard[_fk_constraint_sig]:
return sig._is_fk
12 changes: 11 additions & 1 deletion alembic/ddl/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.interfaces import ReflectedCheckConstraint
from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint
from sqlalchemy.engine.interfaces import ReflectedIndex
from sqlalchemy.engine.interfaces import ReflectedPrimaryKeyConstraint
Expand All @@ -51,6 +52,7 @@
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql import Executable
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import CheckConstraint
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
Expand All @@ -64,7 +66,8 @@
from ..operations.batch import BatchOperationsImpl

_ReflectedConstraint = (
ReflectedForeignKeyConstraint
ReflectedCheckConstraint
| ReflectedForeignKeyConstraint
| ReflectedPrimaryKeyConstraint
| ReflectedIndex
| ReflectedUniqueConstraint
Expand Down Expand Up @@ -840,6 +843,13 @@ def compare_unique_constraint(
else:
return ComparisonResult.Equal()

def compare_check_constraint(
self,
metadata_constraint: CheckConstraint,
reflected_constraint: CheckConstraint,
) -> ComparisonResult:
return ComparisonResult.Equal()

def _skip_functional_indexes(self, metadata_indexes, conn_indexes):
conn_indexes_by_name = {c.name: c for c in conn_indexes}

Expand Down
4 changes: 4 additions & 0 deletions alembic/testing/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def check_constraints_w_enforcement(self):

return exclusions.open()

@property
def check_constraint_reflection(self):
return exclusions.open()

@property
def reflects_pk_names(self):
return exclusions.closed()
Expand Down
Loading
Loading