From 765e641d4af075cb6e3517a4628220cea7c309f4 Mon Sep 17 00:00:00 2001 From: wucm667 Date: Wed, 29 Apr 2026 13:26:01 +0800 Subject: [PATCH] fix: handle DROP TABLE CASCADE for dependent views Signed-off-by: wucm667 --- internal/engine/postgresql/catalog_test.go | 118 +++++++++++++++++++++ internal/engine/postgresql/parse.go | 1 + internal/sql/ast/drop_table_stmt.go | 5 +- internal/sql/catalog/table.go | 35 ++++++ internal/sql/catalog/view.go | 38 +++++++ 5 files changed, 195 insertions(+), 2 deletions(-) diff --git a/internal/engine/postgresql/catalog_test.go b/internal/engine/postgresql/catalog_test.go index 875ea7e458..1092b7dcde 100644 --- a/internal/engine/postgresql/catalog_test.go +++ b/internal/engine/postgresql/catalog_test.go @@ -6,6 +6,8 @@ import ( "strings" "testing" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" "github.com/google/go-cmp/cmp" @@ -168,3 +170,119 @@ func TestUpdateErrors(t *testing.T) { }) } } + +func TestDropTableCascadeViewRecreate(t *testing.T) { + // Regression test for https://github.com/sqlc-dev/sqlc/issues/4416 + // DROP TABLE CASCADE should remove dependent views from the catalog, + // allowing a subsequent CREATE VIEW with the same name to succeed. + p := NewParser() + + // First: create the table + stmts1, err := p.Parse(strings.NewReader(` + CREATE TABLE reference_rates (id BIGSERIAL PRIMARY KEY); + `)) + if err != nil { + t.Fatalf("parse error: %v", err) + } + + c := NewCatalog() + if err := c.Build(stmts1); err != nil { + t.Fatalf("create table error: %v", err) + } + + // Manually add a view that depends on reference_rates to the catalog + var schema *catalog.Schema + for _, s := range c.Schemas { + if s.Name == "public" { + schema = s + } + } + schema.Tables = append(schema.Tables, &catalog.Table{ + Rel: &ast.TableName{Schema: "public", Name: "vw_reference_rates"}, + Columns: []*catalog.Column{{Name: "id"}}, + DependsOnTables: []*ast.TableName{ + {Schema: "public", Name: "reference_rates"}, + }, + }) + + // Verify the view exists + if !viewExists(schema, "vw_reference_rates") { + t.Fatal("view not found in catalog before drop") + } + + // Second: DROP TABLE CASCADE + stmts2, err := p.Parse(strings.NewReader(` + DROP TABLE reference_rates CASCADE; + `)) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if err := c.Build(stmts2); err != nil { + t.Fatalf("DROP TABLE CASCADE error: %v", err) + } + + // Verify the view was removed + if viewExists(schema, "vw_reference_rates") { + t.Fatal("expected view to be removed by CASCADE, but it still exists") + } +} + +func TestDropTableCascadeWithoutCascadeFails(t *testing.T) { + // Without CASCADE, dropping a table that has a dependent view leaves the view + // in the catalog (matching current sqlc behavior, though real PostgreSQL would + // reject DROP TABLE without CASCADE when views depend on it). + p := NewParser() + + // Create the table + stmts1, err := p.Parse(strings.NewReader(` + CREATE TABLE reference_rates (id BIGSERIAL PRIMARY KEY); + `)) + if err != nil { + t.Fatalf("parse error: %v", err) + } + + c := NewCatalog() + if err := c.Build(stmts1); err != nil { + t.Fatalf("create table error: %v", err) + } + + // Manually add a view that depends on reference_rates + schema := c.Schemas[0] + for _, s := range c.Schemas { + if s.Name == "public" { + schema = s + } + } + schema.Tables = append(schema.Tables, &catalog.Table{ + Rel: &ast.TableName{Schema: "public", Name: "vw_reference_rates"}, + Columns: []*catalog.Column{{Name: "id"}}, + DependsOnTables: []*ast.TableName{ + {Schema: "public", Name: "reference_rates"}, + }, + }) + + // DROP TABLE without CASCADE + stmts2, err := p.Parse(strings.NewReader(` + DROP TABLE reference_rates; + `)) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if err := c.Build(stmts2); err != nil { + t.Fatalf("DROP TABLE error: %v", err) + } + + // Without CASCADE, the view should still exist in the catalog + if !viewExists(schema, "vw_reference_rates") { + t.Fatal("expected view to still exist without CASCADE, but it was removed") + } +} + +func viewExists(schema *catalog.Schema, name string) bool { + for _, tbl := range schema.Tables { + if tbl.Rel.Name == name { + return true + } + } + return false +} diff --git a/internal/engine/postgresql/parse.go b/internal/engine/postgresql/parse.go index 0c6b3a0fc2..649055d2bd 100644 --- a/internal/engine/postgresql/parse.go +++ b/internal/engine/postgresql/parse.go @@ -578,6 +578,7 @@ func translate(node *nodes.Node) (ast.Node, error) { case nodes.ObjectType_OBJECT_TABLE, nodes.ObjectType_OBJECT_VIEW, nodes.ObjectType_OBJECT_MATVIEW: drop := &ast.DropTableStmt{ + Behavior: ast.DropBehavior(n.Behavior), IfExists: n.MissingOk, } for _, obj := range n.Objects { diff --git a/internal/sql/ast/drop_table_stmt.go b/internal/sql/ast/drop_table_stmt.go index 7485ceb887..82eb48b5ff 100644 --- a/internal/sql/ast/drop_table_stmt.go +++ b/internal/sql/ast/drop_table_stmt.go @@ -1,8 +1,9 @@ package ast type DropTableStmt struct { - IfExists bool - Tables []*TableName + Behavior DropBehavior + IfExists bool + Tables []*TableName } func (n *DropTableStmt) Pos() int { diff --git a/internal/sql/catalog/table.go b/internal/sql/catalog/table.go index dc30acfa1e..a876db2dd0 100644 --- a/internal/sql/catalog/table.go +++ b/internal/sql/catalog/table.go @@ -16,6 +16,10 @@ type Table struct { Rel *ast.TableName Columns []*Column Comment string + + // If non-nil, this Table represents a view and depends on the listed tables. + // Only set when the Table is created via CREATE VIEW. + DependsOnTables []*ast.TableName } func checkMissing(err error, missingOK bool) error { @@ -373,6 +377,24 @@ func (c *Catalog) dropTable(stmt *ast.DropTableStmt) error { return err } + // When CASCADE, drop dependent views first + // DROP_CASCADE = 2 in pg_query_go protobuf enum + if stmt.Behavior == 2 { + for i := len(schema.Tables) - 1; i >= 0; i-- { + view := schema.Tables[i] + if len(view.DependsOnTables) == 0 { + continue + } + for _, dep := range view.DependsOnTables { + if dep.Name == name.Name && tablesSameSchema(dep, name, c.DefaultSchema) { + // This view depends on the table being dropped + schema.Tables = append(schema.Tables[:i], schema.Tables[i+1:]...) + break + } + } + } + } + drop := &ast.DropTypeStmt{} for _, col := range tbl.Columns { if !col.linkedType { @@ -389,6 +411,19 @@ func (c *Catalog) dropTable(stmt *ast.DropTableStmt) error { return nil } +// tablesSameSchema checks if two table references point to the same schema. +func tablesSameSchema(a, b *ast.TableName, defaultSchema string) bool { + aSchema := a.Schema + bSchema := b.Schema + if aSchema == "" { + aSchema = defaultSchema + } + if bSchema == "" { + bSchema = defaultSchema + } + return aSchema == bSchema +} + func (c *Catalog) renameColumn(stmt *ast.RenameColumnStmt) error { _, tbl, err := c.getTable(stmt.Table) if err != nil { diff --git a/internal/sql/catalog/view.go b/internal/sql/catalog/view.go index d5222c4d03..ae4c6d6c3e 100644 --- a/internal/sql/catalog/view.go +++ b/internal/sql/catalog/view.go @@ -2,6 +2,7 @@ package catalog import ( "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/astutils" "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" ) @@ -29,6 +30,9 @@ func (c *Catalog) createView(stmt *ast.ViewStmt, colGen columnGenerator) error { Columns: cols, } + // Extract table dependencies from the view's SELECT query + tbl.DependsOnTables = extractTableDeps(stmt.Query) + ns := tbl.Rel.Schema if ns == "" { ns = c.DefaultSchema @@ -50,3 +54,37 @@ func (c *Catalog) createView(stmt *ast.ViewStmt, colGen columnGenerator) error { return nil } + +// extractTableDeps walks the SELECT query AST and returns all table references (RangeVar nodes). +func extractTableDeps(node ast.Node) []*ast.TableName { + var deps []*ast.TableName + seen := make(map[string]bool) + + astutils.Walk(astutils.VisitorFunc(func(n ast.Node) { + rv, ok := n.(*ast.RangeVar) + if !ok || rv.Relname == nil { + return + } + schema := "" + if rv.Schemaname != nil { + schema = *rv.Schemaname + } + key := schema + "." + *rv.Relname + if seen[key] { + return + } + seen[key] = true + + // Skip system catalogs and information schema + if schema == "pg_catalog" || schema == "information_schema" { + return + } + + deps = append(deps, &ast.TableName{ + Schema: schema, + Name: *rv.Relname, + }) + }), node) + + return deps +}