11import abc
22import functools
3- from dataclasses import field
3+ import random
44from datetime import datetime
55import math
66import sys
1414import decimal
1515import contextvars
1616
17- from runtype import dataclass
17+ import attrs
1818from typing_extensions import Self
1919
2020from data_diff .abcs .compiler import AbstractCompiler
@@ -90,12 +90,7 @@ class CompileError(Exception):
9090 pass
9191
9292
93- # TODO: remove once switched to attrs, where ForwardRef[]/strings are resolved.
94- class _RuntypeHackToFixCicularRefrencedDatabase :
95- dialect : "BaseDialect"
96-
97-
98- @dataclass
93+ @attrs .define (frozen = True )
9994class Compiler (AbstractCompiler ):
10095 """
10196 Compiler bears the context for a single compilation.
@@ -107,16 +102,16 @@ class Compiler(AbstractCompiler):
107102 # Database is needed to normalize tables. Dialect is needed for recursive compilations.
108103 # In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects.
109104 # In practice, we currently bind the dialects to the specific database classes.
110- database : _RuntypeHackToFixCicularRefrencedDatabase
105+ database : "Database"
111106
112107 in_select : bool = False # Compilation runtime flag
113108 in_join : bool = False # Compilation runtime flag
114109
115- _table_context : List = field (default_factory = list ) # List[ITable]
116- _subqueries : Dict [str , Any ] = field (default_factory = dict ) # XXX not thread-safe
110+ _table_context : List = attrs . field (factory = list ) # List[ITable]
111+ _subqueries : Dict [str , Any ] = attrs . field (factory = dict ) # XXX not thread-safe
117112 root : bool = True
118113
119- _counter : List = field (default_factory = lambda : [0 ])
114+ _counter : List = attrs . field (factory = lambda : [0 ])
120115
121116 @property
122117 def dialect (self ) -> "BaseDialect" :
@@ -136,7 +131,7 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath:
136131 return self .database .dialect .parse_table_name (table_name )
137132
138133 def add_table_context (self , * tables : Sequence , ** kw ) -> Self :
139- return self . replace ( _table_context = self ._table_context + list (tables ), ** kw )
134+ return attrs . evolve ( self , table_context = self ._table_context + list (tables ), ** kw )
140135
141136
142137def parse_table_name (t ):
@@ -271,7 +266,7 @@ def _compile(self, compiler: Compiler, elem) -> str:
271266 if elem is None :
272267 return "NULL"
273268 elif isinstance (elem , Compilable ):
274- return self .render_compilable (compiler . replace ( root = False ), elem )
269+ return self .render_compilable (attrs . evolve ( compiler , root = False ), elem )
275270 elif isinstance (elem , str ):
276271 return f"'{ elem } '"
277272 elif isinstance (elem , (int , float )):
@@ -381,7 +376,7 @@ def render_column(self, c: Compiler, elem: Column) -> str:
381376 return self .quote (elem .name )
382377
383378 def render_cte (self , parent_c : Compiler , elem : Cte ) -> str :
384- c : Compiler = parent_c . replace ( _table_context = [], in_select = False )
379+ c : Compiler = attrs . evolve ( parent_c , table_context = [], in_select = False )
385380 compiled = self .compile (c , elem .source_table )
386381
387382 name = elem .name or parent_c .new_unique_name ()
@@ -494,7 +489,7 @@ def render_tablealias(self, c: Compiler, elem: TableAlias) -> str:
494489 return f"{ self .compile (c , elem .source_table )} { self .quote (elem .name )} "
495490
496491 def render_tableop (self , parent_c : Compiler , elem : TableOp ) -> str :
497- c : Compiler = parent_c . replace ( in_select = False )
492+ c : Compiler = attrs . evolve ( parent_c , in_select = False )
498493 table_expr = f"{ self .compile (c , elem .table1 )} { elem .op } { self .compile (c , elem .table2 )} "
499494 if parent_c .in_select :
500495 table_expr = f"({ table_expr } ) { c .new_unique_name ()} "
@@ -506,7 +501,7 @@ def render__resolvecolumn(self, c: Compiler, elem: _ResolveColumn) -> str:
506501 return self .compile (c , elem ._get_resolved ())
507502
508503 def render_select (self , parent_c : Compiler , elem : Select ) -> str :
509- c : Compiler = parent_c . replace ( in_select = True ) # .add_table_context(self.table)
504+ c : Compiler = attrs . evolve ( parent_c , in_select = True ) # .add_table_context(self.table)
510505 compile_fn = functools .partial (self .compile , c )
511506
512507 columns = ", " .join (map (compile_fn , elem .columns )) if elem .columns else "*"
@@ -544,7 +539,8 @@ def render_select(self, parent_c: Compiler, elem: Select) -> str:
544539
545540 def render_join (self , parent_c : Compiler , elem : Join ) -> str :
546541 tables = [
547- t if isinstance (t , TableAlias ) else TableAlias (t , parent_c .new_unique_name ()) for t in elem .source_tables
542+ t if isinstance (t , TableAlias ) else TableAlias (source_table = t , name = parent_c .new_unique_name ())
543+ for t in elem .source_tables
548544 ]
549545 c = parent_c .add_table_context (* tables , in_join = True , in_select = False )
550546 op = " JOIN " if elem .op is None else f" { elem .op } JOIN "
@@ -577,7 +573,8 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
577573 if isinstance (elem .table , Select ) and elem .table .columns is None and elem .table .group_by_exprs is None :
578574 return self .compile (
579575 c ,
580- elem .table .replace (
576+ attrs .evolve (
577+ elem .table ,
581578 columns = columns ,
582579 group_by_exprs = [Code (k ) for k in keys ],
583580 having_exprs = elem .having_exprs ,
@@ -589,7 +586,7 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
589586 having_str = (
590587 " HAVING " + " AND " .join (map (compile_fn , elem .having_exprs )) if elem .having_exprs is not None else ""
591588 )
592- select = f"SELECT { columns_str } FROM { self .compile (c . replace ( in_select = True ), elem .table )} GROUP BY { keys_str } { having_str } "
589+ select = f"SELECT { columns_str } FROM { self .compile (attrs . evolve ( c , in_select = True ), elem .table )} GROUP BY { keys_str } { having_str } "
593590
594591 if c .in_select :
595592 select = f"({ select } ) { c .new_unique_name ()} "
@@ -815,7 +812,7 @@ def set_timezone_to_utc(self) -> str:
815812T = TypeVar ("T" , bound = BaseDialect )
816813
817814
818- @dataclass
815+ @attrs . define ( frozen = True )
819816class QueryResult :
820817 rows : list
821818 columns : Optional [list ] = None
@@ -830,7 +827,7 @@ def __getitem__(self, i):
830827 return self .rows [i ]
831828
832829
833- class Database (abc .ABC , _RuntypeHackToFixCicularRefrencedDatabase ):
830+ class Database (abc .ABC ):
834831 """Base abstract class for databases.
835832
836833 Used for providing connection code and implementation specific SQL utilities.
0 commit comments