Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions internal/engine/postgresql/catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"strings"
"testing"

"github.com/sqlc-dev/sqlc/internal/sql/ast"
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -168,3 +170,119 @@ func TestUpdateErrors(t *testing.T) {
})
}
}

func TestDropTableCascadeViewRecreate(t *testing.T) {
// Regression test for https://github.com/sqlc-dev/sqlc/issues/4416
// DROP TABLE CASCADE should remove dependent views from the catalog,
// allowing a subsequent CREATE VIEW with the same name to succeed.
p := NewParser()

// First: create the table
stmts1, err := p.Parse(strings.NewReader(`
CREATE TABLE reference_rates (id BIGSERIAL PRIMARY KEY);
`))
if err != nil {
t.Fatalf("parse error: %v", err)
}

c := NewCatalog()
if err := c.Build(stmts1); err != nil {
t.Fatalf("create table error: %v", err)
}

// Manually add a view that depends on reference_rates to the catalog
var schema *catalog.Schema
for _, s := range c.Schemas {
if s.Name == "public" {
schema = s
}
}
schema.Tables = append(schema.Tables, &catalog.Table{
Rel: &ast.TableName{Schema: "public", Name: "vw_reference_rates"},
Columns: []*catalog.Column{{Name: "id"}},
DependsOnTables: []*ast.TableName{
{Schema: "public", Name: "reference_rates"},
},
})

// Verify the view exists
if !viewExists(schema, "vw_reference_rates") {
t.Fatal("view not found in catalog before drop")
}

// Second: DROP TABLE CASCADE
stmts2, err := p.Parse(strings.NewReader(`
DROP TABLE reference_rates CASCADE;
`))
if err != nil {
t.Fatalf("parse error: %v", err)
}
if err := c.Build(stmts2); err != nil {
t.Fatalf("DROP TABLE CASCADE error: %v", err)
}

// Verify the view was removed
if viewExists(schema, "vw_reference_rates") {
t.Fatal("expected view to be removed by CASCADE, but it still exists")
}
}

func TestDropTableCascadeWithoutCascadeFails(t *testing.T) {
// Without CASCADE, dropping a table that has a dependent view leaves the view
// in the catalog (matching current sqlc behavior, though real PostgreSQL would
// reject DROP TABLE without CASCADE when views depend on it).
p := NewParser()

// Create the table
stmts1, err := p.Parse(strings.NewReader(`
CREATE TABLE reference_rates (id BIGSERIAL PRIMARY KEY);
`))
if err != nil {
t.Fatalf("parse error: %v", err)
}

c := NewCatalog()
if err := c.Build(stmts1); err != nil {
t.Fatalf("create table error: %v", err)
}

// Manually add a view that depends on reference_rates
schema := c.Schemas[0]
for _, s := range c.Schemas {
if s.Name == "public" {
schema = s
}
}
schema.Tables = append(schema.Tables, &catalog.Table{
Rel: &ast.TableName{Schema: "public", Name: "vw_reference_rates"},
Columns: []*catalog.Column{{Name: "id"}},
DependsOnTables: []*ast.TableName{
{Schema: "public", Name: "reference_rates"},
},
})

// DROP TABLE without CASCADE
stmts2, err := p.Parse(strings.NewReader(`
DROP TABLE reference_rates;
`))
if err != nil {
t.Fatalf("parse error: %v", err)
}
if err := c.Build(stmts2); err != nil {
t.Fatalf("DROP TABLE error: %v", err)
}

// Without CASCADE, the view should still exist in the catalog
if !viewExists(schema, "vw_reference_rates") {
t.Fatal("expected view to still exist without CASCADE, but it was removed")
}
}

func viewExists(schema *catalog.Schema, name string) bool {
for _, tbl := range schema.Tables {
if tbl.Rel.Name == name {
return true
}
}
return false
}
1 change: 1 addition & 0 deletions internal/engine/postgresql/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ func translate(node *nodes.Node) (ast.Node, error) {

case nodes.ObjectType_OBJECT_TABLE, nodes.ObjectType_OBJECT_VIEW, nodes.ObjectType_OBJECT_MATVIEW:
drop := &ast.DropTableStmt{
Behavior: ast.DropBehavior(n.Behavior),
IfExists: n.MissingOk,
}
for _, obj := range n.Objects {
Expand Down
5 changes: 3 additions & 2 deletions internal/sql/ast/drop_table_stmt.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package ast

type DropTableStmt struct {
IfExists bool
Tables []*TableName
Behavior DropBehavior
IfExists bool
Tables []*TableName
}

func (n *DropTableStmt) Pos() int {
Expand Down
35 changes: 35 additions & 0 deletions internal/sql/catalog/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ type Table struct {
Rel *ast.TableName
Columns []*Column
Comment string

// If non-nil, this Table represents a view and depends on the listed tables.
// Only set when the Table is created via CREATE VIEW.
DependsOnTables []*ast.TableName
}

func checkMissing(err error, missingOK bool) error {
Expand Down Expand Up @@ -373,6 +377,24 @@ func (c *Catalog) dropTable(stmt *ast.DropTableStmt) error {
return err
}

// When CASCADE, drop dependent views first
// DROP_CASCADE = 2 in pg_query_go protobuf enum
if stmt.Behavior == 2 {
for i := len(schema.Tables) - 1; i >= 0; i-- {
view := schema.Tables[i]
if len(view.DependsOnTables) == 0 {
continue
}
for _, dep := range view.DependsOnTables {
if dep.Name == name.Name && tablesSameSchema(dep, name, c.DefaultSchema) {
// This view depends on the table being dropped
schema.Tables = append(schema.Tables[:i], schema.Tables[i+1:]...)
break
}
}
}
}

drop := &ast.DropTypeStmt{}
for _, col := range tbl.Columns {
if !col.linkedType {
Expand All @@ -389,6 +411,19 @@ func (c *Catalog) dropTable(stmt *ast.DropTableStmt) error {
return nil
}

// tablesSameSchema checks if two table references point to the same schema.
func tablesSameSchema(a, b *ast.TableName, defaultSchema string) bool {
aSchema := a.Schema
bSchema := b.Schema
if aSchema == "" {
aSchema = defaultSchema
}
if bSchema == "" {
bSchema = defaultSchema
}
return aSchema == bSchema
}

func (c *Catalog) renameColumn(stmt *ast.RenameColumnStmt) error {
_, tbl, err := c.getTable(stmt.Table)
if err != nil {
Expand Down
38 changes: 38 additions & 0 deletions internal/sql/catalog/view.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package catalog

import (
"github.com/sqlc-dev/sqlc/internal/sql/ast"
"github.com/sqlc-dev/sqlc/internal/sql/astutils"
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
)

Expand Down Expand Up @@ -29,6 +30,9 @@ func (c *Catalog) createView(stmt *ast.ViewStmt, colGen columnGenerator) error {
Columns: cols,
}

// Extract table dependencies from the view's SELECT query
tbl.DependsOnTables = extractTableDeps(stmt.Query)

ns := tbl.Rel.Schema
if ns == "" {
ns = c.DefaultSchema
Expand All @@ -50,3 +54,37 @@ func (c *Catalog) createView(stmt *ast.ViewStmt, colGen columnGenerator) error {

return nil
}

// extractTableDeps walks the SELECT query AST and returns all table references (RangeVar nodes).
func extractTableDeps(node ast.Node) []*ast.TableName {
var deps []*ast.TableName
seen := make(map[string]bool)

astutils.Walk(astutils.VisitorFunc(func(n ast.Node) {
rv, ok := n.(*ast.RangeVar)
if !ok || rv.Relname == nil {
return
}
schema := ""
if rv.Schemaname != nil {
schema = *rv.Schemaname
}
key := schema + "." + *rv.Relname
if seen[key] {
return
}
seen[key] = true

// Skip system catalogs and information schema
if schema == "pg_catalog" || schema == "information_schema" {
return
}

deps = append(deps, &ast.TableName{
Schema: schema,
Name: *rv.Relname,
})
}), node)

return deps
}
Loading