diff --git a/metis/cmd/daemon.go b/metis/cmd/daemon.go index 686e10dc7..151430d39 100644 --- a/metis/cmd/daemon.go +++ b/metis/cmd/daemon.go @@ -17,12 +17,15 @@ limitations under the License. package main import ( + "context" "os" + "os/signal" + "syscall" "github.com/spf13/cobra" cliflag "k8s.io/component-base/cli/flag" "k8s.io/klog/v2" - "k8s.io/metis/daemon" + "k8s.io/metis/pkg/daemon" ) func newDaemonCommand() *cobra.Command { @@ -42,7 +45,11 @@ func newDaemonCommand() *cobra.Command { var cfg daemon.Config _ = opts.ApplyTo(&cfg) d := daemon.NewDaemon(cfg) - if err := d.Run(); err != nil { + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + if err := d.Run(ctx); err != nil { klog.ErrorS(err, "Daemon failed to run") os.Exit(1) } diff --git a/metis/cmd/daemonoptions.go b/metis/cmd/daemonoptions.go index 181c36fa8..62437a6e8 100644 --- a/metis/cmd/daemonoptions.go +++ b/metis/cmd/daemonoptions.go @@ -20,7 +20,8 @@ import ( "time" cliflag "k8s.io/component-base/cli/flag" - "k8s.io/metis/daemon" + "k8s.io/metis/pkg" + "k8s.io/metis/pkg/daemon" ) // DaemonOptions holds the metis daemon options. @@ -36,9 +37,10 @@ func (o *DaemonOptions) AddFlags() cliflag.NamedFlagSets { } fs := fss.FlagSet("daemon") - // We apply default values directly within Flags for now, or assume they are pre-initialized fs.DurationVar(&o.MonitorInterval, "monitor-interval", 5*time.Second, "Monitor interval (e.g., 5s, 1m)") fs.DurationVar(&o.ReleaseCooldown, "release-cooldown", 1*time.Minute, "Release cooldown duration (e.g., 5m)") + fs.StringVar(&o.DBPath, "db-path", pkg.DefaultDBPath, "Path to the SQLite database file") + fs.StringVar(&o.SocketPath, "socket-path", pkg.DefaultSockPath, "Path to the Unix domain socket") return fss } @@ -51,6 +53,8 @@ func (o *DaemonOptions) ApplyTo(cfg *daemon.Config) error { cfg.MonitorInterval = o.MonitorInterval cfg.ReleaseCooldown = o.ReleaseCooldown + cfg.DBPath = o.DBPath + cfg.SocketPath = o.SocketPath return nil } diff --git a/metis/daemon/daemon.go b/metis/daemon/daemon.go deleted file mode 100644 index 8797cb857..000000000 --- a/metis/daemon/daemon.go +++ /dev/null @@ -1,40 +0,0 @@ -/* -Copyright 2026 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -package daemon - -import ( - "fmt" - - "k8s.io/klog/v2" -) - -// Daemon represents the metis daemon process. -type Daemon struct { - Config Config -} - -// NewDaemon creates a new Daemon instance with the given configuration. -func NewDaemon(cfg Config) *Daemon { - return &Daemon{ - Config: cfg, - } -} - -// Run starts the daemon process. -func (d *Daemon) Run() error { - klog.InfoS("metis daemon has started successfully", "config", fmt.Sprintf("%+v", d.Config)) - return nil -} diff --git a/metis/daemon/daemon_configurations.go b/metis/daemon/daemon_configurations.go deleted file mode 100644 index 541f20b86..000000000 --- a/metis/daemon/daemon_configurations.go +++ /dev/null @@ -1,24 +0,0 @@ -/* -Copyright 2026 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -package daemon - -import "time" - -// Config contains the configuration parameters for the daemon. -type Config struct { - MonitorInterval time.Duration - ReleaseCooldown time.Duration -} diff --git a/metis/daemon/daemon_test.go b/metis/daemon/daemon_test.go deleted file mode 100644 index 20a5284ad..000000000 --- a/metis/daemon/daemon_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package daemon - -import ( - "testing" - "time" -) - -func TestDaemon_Run(t *testing.T) { - cfg := Config{ - MonitorInterval: 5 * time.Second, - ReleaseCooldown: 1 * time.Minute, - } - - d := NewDaemon(cfg) - - // Since Run() just logs and returns nil currently, it shouldn't error. - err := d.Run() - if err != nil { - t.Fatalf("expected no error from Run(), got: %v", err) - } -} diff --git a/metis/go.mod b/metis/go.mod index 2266e9ee4..2731b62c6 100644 --- a/metis/go.mod +++ b/metis/go.mod @@ -11,6 +11,7 @@ require ( github.com/spf13/pflag v1.0.10 google.golang.org/grpc v1.72.2 google.golang.org/protobuf v1.36.8 + k8s.io/apimachinery v0.35.1 k8s.io/component-base v0.35.1 k8s.io/klog/v2 v2.140.0 ) @@ -42,7 +43,6 @@ require ( golang.org/x/text v0.34.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250528174236-200df99c418a // indirect gopkg.in/inf.v0 v0.9.1 // indirect - k8s.io/apimachinery v0.35.1 // indirect k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 // indirect k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 // indirect sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect diff --git a/metis/ipam/store.go b/metis/ipam/store.go deleted file mode 100644 index c3f4316a1..000000000 --- a/metis/ipam/store.go +++ /dev/null @@ -1,174 +0,0 @@ -/* -Copyright 2026 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package ipam - -import ( - "database/sql" - _ "embed" - "fmt" - "os" - "path/filepath" - - "github.com/go-logr/logr" - _ "github.com/mattn/go-sqlite3" // SQLite driver -) - -//go:embed schema.sql -var schemaSQL string - -const ( - // dbSchemaVersion tracks the SQLite schema version to allow safe local - // migrations and prevent state corruption across daemon restarts. - dbSchemaVersion = 1 - maxOpenConns = 10 - maxIdleConns = 10 -) - -// Store manages database operations for IPAM. -type Store struct { - db *sql.DB - log logr.Logger -} - -// NewStore creates a new Store instance and initializes the database. -func NewStore(log logr.Logger, dbPath string) (*Store, error) { - if dbPath == "" { - return nil, fmt.Errorf("dbPath cannot be empty: an absolute path must be explicitly provided") - } - - log.Info("Opening or creating database", "path", dbPath) - - if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil { - return nil, fmt.Errorf("failed to create db directory: %w", err) - } - - // SQLite is configured directly through the DSN string. This approach - // guarantees every new connection spawned by the sql.DB pool inherits these - // exact configurations natively. - dsn := dbPath + - // Enables Write-Ahead Logging (WAL) mode. This significantly improves - // concurrency by allowing multiple readers to access the database - // simultaneously without blocking a writer, which is critical for burst - // requests. - // See: https://www.sqlite.org/pragma.html#pragma_journal_mode - "?_journal_mode=WAL" + - // Enforces foreign key constraints. SQLite ignores these by default. - // This is required to ensure ON DELETE CASCADE functions correctly on the - // ip_addresses table when a draining CIDR block is officially removed. - // See: https://www.sqlite.org/pragma.html#pragma_foreign_keys - "&_foreign_keys=on" + - // Sets the busy timeout to 5000 milliseconds. If the database is locked - // by another transaction, this tells the SQLite driver to wait for up - // to 5 seconds before giving up and returning a locked error. - // See: https://www.sqlite.org/pragma.html#pragma_busy_timeout - "&_busy_timeout=5000" + - // Instructs the Go driver to send "BEGIN IMMEDIATE" instead of standard - // "BEGIN" when starting a transaction. This grabs a write lock instantly, - // preventing deadlocks when concurrent requests try to upgrade their - // read locks to write locks simultaneously. Note: This is a go-sqlite3 - // driver feature, not a native SQLite PRAGMA. - // See: https://github.com/mattn/go-sqlite3#connection-string - "&_txlock=immediate" + - // Maps to PRAGMA synchronous = NORMAL. In WAL mode, this is the optimal - // setting for high-concurrency daemons. It prevents database corruption - // during power loss or hard crashes while offering much faster write - // performance than FULL mode, sacrificing only a few milliseconds of - // un-checkpointed durability. - // See: https://www.sqlite.org/pragma.html#pragma_synchronous - "&_synchronous=1" - - db, err := sql.Open("sqlite3", dsn) - if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) - } - - db.SetMaxOpenConns(maxOpenConns) - db.SetMaxIdleConns(maxIdleConns) - // Sets the maximum amount of time a connection may be reused to infinity - // (0). This guarantees the single connection never expires. - db.SetConnMaxLifetime(0) - - store := &Store{ - db: db, - log: log, - } - - // Only a single process enters this execution block at a time. - if err := store.initSchema(); err != nil { - db.Close() - return nil, fmt.Errorf("failed to initialize schema: %w", err) - } - - log.Info("Initialized or updated database schema", "path", dbPath) - - return store, nil -} - -// initSchema creates the necessary tables if they don't exist. -func (s *Store) initSchema() error { - var currentVersion int - err := s.db.QueryRow("PRAGMA user_version").Scan(¤tVersion) - if err != nil { - return fmt.Errorf("failed to check schema version: %w", err) - } - - if currentVersion == dbSchemaVersion { - s.log.V(4).Info("Database schema already initialized", "version", currentVersion) - return nil - } - - s.log.Info("Initializing DB schema", "currentVersion", currentVersion, "expectedVersion", dbSchemaVersion) - - // 1. Begin an atomic transaction - tx, err := s.db.Begin() - if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) - } - // Safe to defer; Rollback does nothing if Commit() is successful - defer tx.Rollback() - - // 2. Execute the embedded schema.sql file - if _, err := tx.Exec(schemaSQL); err != nil { - return fmt.Errorf("failed to execute schema.sql: %w", err) - } - - // 3. Set User Version - setVersion := fmt.Sprintf("PRAGMA user_version = %d;", dbSchemaVersion) - if _, err := tx.Exec(setVersion); err != nil { - return fmt.Errorf("failed to set user_version: %w", err) - } - - // 4. Commit everything atomically - if err := tx.Commit(); err != nil { - return fmt.Errorf("failed to commit schema transaction: %w", err) - } - - s.log.Info("Database schema initialized or updated successfully") - return nil -} - -// Close safely closes the database connection and releases any file locks. -// This should be called during the daemon's graceful shutdown sequence. -func (s *Store) Close() error { - s.log.Info("Closing IPAM database connection") - - if err := s.db.Close(); err != nil { - return fmt.Errorf("failed to close database connection: %w", err) - } - - return nil -} diff --git a/metis/ipam/store_test.go b/metis/ipam/store_test.go deleted file mode 100644 index c6382e432..000000000 --- a/metis/ipam/store_test.go +++ /dev/null @@ -1,285 +0,0 @@ -/* -Copyright 2026 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package ipam - -import ( - "context" - "database/sql" - "fmt" - "path/filepath" - "sync" - "testing" - "time" - - "github.com/go-logr/logr" - _ "github.com/mattn/go-sqlite3" // SQLite driver -) - -// TestNewStore_SuccessAndClose verifies that a new Store can be created, -// the schema is successfully initialized, and the database closes cleanly. -func TestNewStore_SuccessAndClose(t *testing.T) { - tmpDir := t.TempDir() - dbPath := filepath.Join(tmpDir, "test_ipam.db") - logger := logr.Discard() - - store, err := NewStore(logger, dbPath) - if err != nil { - t.Fatalf("NewStore failed: %v", err) - } - if store == nil { - t.Fatal("Expected a valid Store instance, got nil") - } - - // Verify tables were actually created. - var tableName string - err = store.db.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='cidr_blocks';").Scan(&tableName) - if err != nil || tableName != "cidr_blocks" { - t.Errorf("Expected table 'cidr_blocks' to exist, got error: %v", err) - } - - // Verify the database connection is alive. - if err := store.db.Ping(); err != nil { - t.Errorf("Database ping failed: %v", err) - } - - if err := store.Close(); err != nil { - t.Errorf("Store.Close() failed: %v", err) - } -} - -// TestNewStore_Idempotency verifies idempotency by ensuring a second -// initialization safely skips schema creation. By intentionally dropping an -// index after the first run, the test deterministically proves that the second -// execution short-circuits and leaves the index uncreated. -func TestNewStore_Idempotency(t *testing.T) { - tmpDir := t.TempDir() - dbPath := filepath.Join(tmpDir, "idempotency_test.db") - logger := logr.Discard() - - // Initial creation sets up the full schema and user_version = 1. - store1, err := NewStore(logger, dbPath) - if err != nil { - t.Fatalf("First NewStore call failed: %v", err) - } - if err := store1.Close(); err != nil { - t.Fatalf("Failed to close first store: %v", err) - } - - // Sabotage the schema slightly to prove the bypass works. Manually connect - // and drop an index that initSchema would normally create. - db, err := sql.Open("sqlite3", dbPath) - if err != nil { - t.Fatalf("Failed to open DB for manual intervention: %v", err) - } - if _, err := db.Exec("DROP INDEX idx_ip_idempotency;"); err != nil { - t.Fatalf("Failed to manually drop index: %v", err) - } - db.Close() - - // If the short-circuit works, it will see user_version=1 and return early, - // meaning it will NOT execute the CREATE statements to fix the missing index. - store2, err := NewStore(logger, dbPath) - if err != nil { - t.Fatalf("Second NewStore call failed: %v", err) - } - defer store2.Close() - - // Verify the index is still missing. - var name string - query := "SELECT name FROM sqlite_master WHERE type='index' AND name='idx_ip_idempotency';" - err = store2.db.QueryRow(query).Scan(&name) - - if err == nil { - // If err is nil, it means the query found the index, which means the - // schema block executed and recreated it. - t.Errorf("Expected index 'idx_ip_idempotency' to be missing, but it was recreated. Short-circuit failed.") - } else if err != sql.ErrNoRows { - // If an error other than ErrNoRows occurs, the query failed unexpectedly. - t.Fatalf("Unexpected error querying sqlite_master: %v", err) - } -} - -// TestNewStore_SchemaVerification rigorously checks the sqlite_master table -// to ensure all expected tables, indexes, and triggers were successfully -// created during initialization. -func TestNewStore_SchemaVerification(t *testing.T) { - tmpDir := t.TempDir() - dbPath := filepath.Join(tmpDir, "schema_test.db") - logger := logr.Discard() - - store, err := NewStore(logger, dbPath) - if err != nil { - t.Fatalf("NewStore failed: %v", err) - } - defer store.Close() - - // Define the expected schema components. - expectedTables := []string{"cidr_blocks", "ip_addresses"} - expectedIndexes := []string{"idx_available_ips", "idx_ip_idempotency"} - expectedTriggers := []string{"update_cidr_blocks_updated_at", "update_ip_addresses_updated_at"} - - // Verify Tables. - for _, table := range expectedTables { - var name string - query := "SELECT name FROM sqlite_master WHERE type='table' AND name=?;" - if err := store.db.QueryRow(query, table).Scan(&name); err != nil { - t.Errorf("Schema verification failed: expected table '%s' not found: %v", table, err) - } - } - - // Verify Indexes. - for _, index := range expectedIndexes { - var name string - query := "SELECT name FROM sqlite_master WHERE type='index' AND name=?;" - if err := store.db.QueryRow(query, index).Scan(&name); err != nil { - t.Errorf("Schema verification failed: expected index '%s' not found: %v", index, err) - } - } - - // Verify Triggers. - for _, trigger := range expectedTriggers { - var name string - query := "SELECT name FROM sqlite_master WHERE type='trigger' AND name=?;" - if err := store.db.QueryRow(query, trigger).Scan(&name); err != nil { - t.Errorf("Schema verification failed: expected trigger '%s' not found: %v", trigger, err) - } - } -} - -// TestStore_Concurrency verifies that the SQLite connection pool and -// transaction locks (_txlock=immediate, MaxOpenConns=10) can successfully -// handle high bursts of concurrent requests without throwing SQLITE_BUSY -// errors. It achieves maximum contention by spawning multiple goroutines and -// using a broadcast channel as a "starting gun" to release them at the exact -// same logical moment. -func TestStore_Concurrency(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "ipam-concurrency.db") - s, err := NewStore(logr.Discard(), dbPath) - if err != nil { - t.Fatalf("Failed to initialize store: %v", err) - } - defer s.Close() - - var wg sync.WaitGroup - numGoroutines := 10 - - // Create the Starting Line channel. - startLine := make(chan struct{}) - - // Simulate 10 concurrent gRPC requests hitting the database. - for i := 0; i < numGoroutines; i++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - - // Block this goroutine until the starting gun fires. - <-startLine - - // Simulate an Insert. - cidr := fmt.Sprintf("10.0.%d.0/24", id) - insertQuery := `INSERT INTO cidr_blocks (cidr, network, ip_family, total_ips, allocated_ips, state) - VALUES (?, 'test-network', 'ipv4', 256, 0, 'Ready')` - - _, err := s.db.Exec(insertQuery, cidr) - if err != nil { - t.Errorf("Goroutine %d failed to insert: %v", id, err) - return - } - - // Simulate a Read. - var state string - readQuery := `SELECT state FROM cidr_blocks WHERE cidr = ?` - err = s.db.QueryRow(readQuery, cidr).Scan(&state) - if err != nil { - t.Errorf("Goroutine %d failed to read: %v", id, err) - return - } - }(i) - } - - // Fire the starting gun! - // Closing the channel instantly releases all 10 blocked goroutines at the exact same time. - close(startLine) - - // Wait for all goroutines to finish. - wg.Wait() -} - -// TestStore_MaxOpenConns_Limit verifies that the connection pool can fan out -// to the configured maxOpenConns limit. It proves this by holding -// (maxOpenConns - 1) read connections hostage and ensuring the final allowed -// concurrent query can still execute successfully without hitting a pool bottleneck. -func TestStore_MaxOpenConns_Limit(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "ipam-pool-limit.db") - s, err := NewStore(logr.Discard(), dbPath) - if err != nil { - t.Fatalf("Failed to initialize store: %v", err) - } - defer s.Close() - - // Dynamically scale the test based on the Store's configuration - hostageCount := s.db.Stats().MaxOpenConnections - 1 - connAcquired := make(chan struct{}, hostageCount) - releaseHostages := make(chan struct{}) - var wg sync.WaitGroup - - // 1. Take (maxOpenConns - 1) connections hostage using unclosed read queries - for i := range hostageCount { - wg.Add(1) - go func(id int) { - defer wg.Done() - - rows, err := s.db.Query(`SELECT state FROM cidr_blocks`) - if err != nil { - t.Errorf("Goroutine %d failed to execute read query: %v", id, err) - return - } - defer rows.Close() - - connAcquired <- struct{}{} - <-releaseHostages - }(i) - } - - // Wait for all hostage connections to be checked out of the pool. - // We use a 1-second timeout to prevent the test from hanging indefinitely - // if the connection pool is configured smaller than hostageCount. - timeout := time.After(1 * time.Second) - for i := 0; i < hostageCount; i++ { - select { - case <-connAcquired: - // Connection successfully acquired - case <-timeout: - t.Fatalf("Test timed out waiting to acquire %d connections. MaxOpenConns is likely configured lower than the expected limit.", hostageCount) - } - } - - // 2. Execute the final allowed query to reach maxOpenConns - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - var state string - err = s.db.QueryRowContext(ctx, `SELECT state FROM cidr_blocks LIMIT 1`).Scan(&state) - - if err != nil && err != sql.ErrNoRows { - t.Fatalf("Final concurrent query failed (pool limit reached prematurely): %v", err) - } - - // 3. Clean up - close(releaseHostages) - wg.Wait() -} diff --git a/metis/ipam/consts.go b/metis/pkg/consts.go similarity index 80% rename from metis/ipam/consts.go rename to metis/pkg/consts.go index 64a7c53f4..c643d0ff2 100644 --- a/metis/ipam/consts.go +++ b/metis/pkg/consts.go @@ -14,6 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. */ -package ipam +package pkg -const DefaultDBPath = "/var/lib/cni/metis/metis.sqlite" +const ( + DefaultDBPath = "/var/lib/cni/metis/metis.sqlite" + DefaultSockPath = "/var/lib/cni/metis/metis-adaptive-ipam.sock" +) diff --git a/metis/pkg/daemon/daemon.go b/metis/pkg/daemon/daemon.go new file mode 100644 index 000000000..f12920cbe --- /dev/null +++ b/metis/pkg/daemon/daemon.go @@ -0,0 +1,84 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package daemon + +import ( + "context" + "fmt" + "time" + + "k8s.io/klog/v2" + "k8s.io/metis/pkg" + "k8s.io/metis/pkg/store" +) + +// Config contains the configuration parameters for the daemon. +type Config struct { + MonitorInterval time.Duration + ReleaseCooldown time.Duration + DBPath string + SocketPath string +} + +// Daemon represents the metis daemon process. +type Daemon struct { + Config Config +} + +// NewDaemon creates a new Daemon instance with the given configuration. +func NewDaemon(cfg Config) *Daemon { + return &Daemon{ + Config: cfg, + } +} + +// Run starts the daemon process and listens for gRPC requests on a domain socket. +func (d *Daemon) Run(ctx context.Context) error { + klog.InfoS("metis daemon is starting", "config", fmt.Sprintf("%+v", d.Config)) + + dbPath := d.Config.DBPath + if dbPath == "" { + dbPath = pkg.DefaultDBPath + } + + logger := klog.Background().WithName("metis").WithName("daemon") // klog/v2 provides a logr.Logger + + storeInstance, err := store.NewStore(ctx, logger, dbPath) + if err != nil { + return fmt.Errorf("failed to initialize sqlite store: %w", err) + } + defer storeInstance.Close() + + server := newAdaptiveIpamServer(logger, storeInstance, d.Config.SocketPath, d.Config.ReleaseCooldown, store.DefaultBusyTimeout) + + errCh := make(chan error, 1) + go func() { + errCh <- server.start() + }() + + select { + case err := <-errCh: + if err != nil { + return fmt.Errorf("server failed: %w", err) + } + case <-ctx.Done(): + klog.InfoS("Context cancelled, shutting down daemon") + server.stop() + } + + return nil +} diff --git a/metis/pkg/daemon/daemon_server.go b/metis/pkg/daemon/daemon_server.go new file mode 100644 index 000000000..35c40da44 --- /dev/null +++ b/metis/pkg/daemon/daemon_server.go @@ -0,0 +1,189 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package daemon + +import ( + "context" + "errors" + "fmt" + "net" + "os" + "strings" + "time" + + "github.com/go-logr/logr" + "google.golang.org/grpc" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/metis/api/adaptiveipam/v1" + "k8s.io/metis/pkg" + "k8s.io/metis/pkg/store" +) + +type adaptiveIpamServer struct { + adaptiveipam.UnimplementedAdaptiveIpamServer + store *store.Store + sockPath string + releaseCooldown time.Duration + busyTimeout time.Duration + grpcServer *grpc.Server + logger logr.Logger +} + +func newAdaptiveIpamServer(logger logr.Logger, storeInstance *store.Store, socketPath string, releaseCooldown time.Duration, busyTimeout time.Duration) *adaptiveIpamServer { + server := &adaptiveIpamServer{ + store: storeInstance, + sockPath: socketPath, + releaseCooldown: releaseCooldown, + busyTimeout: busyTimeout, + logger: logger, + } + + return server +} + +func (s *adaptiveIpamServer) AllocatePodIP(ctx context.Context, req *adaptiveipam.AllocatePodIPRequest) (*adaptiveipam.AllocatePodIPResponse, error) { + s.logger.Info("AllocatePodIP request received", + "network", req.Network, + "podName", req.PodName, + "podNamespace", req.PodNamespace, + "ipv4Config", fmt.Sprintf("%+v", req.Ipv4Config), + "ipv6Config", fmt.Sprintf("%+v", req.Ipv6Config)) + + if req.Ipv4Config == nil && req.Ipv6Config == nil { + err := fmt.Errorf("both ipv4_config and ipv6_config are missing for pod %s/%s", req.PodNamespace, req.PodName) + s.logger.Error(err, "AllocatePodIP validation failed", "podName", req.PodName, "podNamespace", req.PodNamespace) + return nil, err + } + + var ipv4Alloc *adaptiveipam.PodIP + if req.Ipv4Config != nil { + if req.Ipv4Config.InitialPodCidr != "" { + exists, err := s.store.GetCIDRBlockByCIDR(ctx, req.Ipv4Config.InitialPodCidr) + if err != nil { + s.logger.Error(err, "failed to check if initial cidr block exists", "network", req.Network, "cidr", req.Ipv4Config.InitialPodCidr) + return nil, fmt.Errorf("failed to check if initial cidr block %s exists for network %s: %w", req.Ipv4Config.InitialPodCidr, req.Network, err) + } + if !exists { + if err := s.store.AddCIDR(ctx, req.Network, req.Ipv4Config.InitialPodCidr); err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") { + s.logger.Info("Initial CIDR block already added by another thread", "network", req.Network, "cidr", req.Ipv4Config.InitialPodCidr) + } else { + s.logger.Error(err, "failed to add initial cidr block", "network", req.Network, "cidr", req.Ipv4Config.InitialPodCidr) + return nil, fmt.Errorf("failed to add initial cidr block %s for network %s: %w", req.Ipv4Config.InitialPodCidr, req.Network, err) + } + } + } + } + + var ip, cidr string + var lastErr error + timeout := s.busyTimeout + if timeout == 0 { + timeout = store.DefaultBusyTimeout + } + // The total timeout is set to timeout to align with the SQLite busy_timeout + // configured in the DSN in store.go. + // TODO: Measure the store allocation query time and update the interval appropriately. + err := wait.PollUntilContextTimeout(ctx, 50*time.Millisecond, timeout, true, func(ctx context.Context) (bool, error) { + ip, cidr, lastErr = s.store.AllocateIPv4(ctx, req.Network, req.Ipv4Config.InterfaceName, req.Ipv4Config.ContainerId) + if lastErr == nil { + return true, nil // Success + } + if errors.Is(lastErr, store.ErrNoAvailableIPs) { + return true, lastErr // Stop immediately on non-retryable error + } + if ctx.Err() != nil { + return true, ctx.Err() // Stop immediately if context is done + } + s.logger.V(4).Info("Retrying AllocateIPv4 due to transient error", "err", lastErr, "network", req.Network) + return false, nil // Retry + }) + + if err != nil { + if (errors.Is(err, wait.ErrWaitTimeout) || errors.Is(err, context.DeadlineExceeded)) && lastErr != nil { + err = lastErr // Use last error if timed out + } + s.logger.Error(err, "failed to allocate ipv4", "network", req.Network, "podName", req.PodName, "podNamespace", req.PodNamespace) + return nil, fmt.Errorf("failed to allocate ipv4 for pod %s/%s: %w", req.PodNamespace, req.PodName, err) + } + ipv4Alloc = &adaptiveipam.PodIP{ + IpAddress: ip, + Cidr: cidr, + } + } + + if req.Ipv6Config != nil { + // TODO: add ipv6 allocation + } + + return &adaptiveipam.AllocatePodIPResponse{ + Ipv4: ipv4Alloc, + }, nil +} + +func (s *adaptiveIpamServer) DeallocatePodIP(ctx context.Context, req *adaptiveipam.DeallocatePodIPRequest) (*adaptiveipam.DeallocatePodIPResponse, error) { + s.logger.Info("DeallocatePodIP request received", + "network", req.Network, + "containerID", req.ContainerId, + "interfaceName", req.InterfaceName, + "podName", req.PodName, + "podNamespace", req.PodNamespace) + + count, err := s.store.ReleaseIPByOwner(ctx, req.Network, req.ContainerId, req.InterfaceName, s.releaseCooldown) + if err != nil { + s.logger.Error(err, "failed to deallocate ips", "network", req.Network, "podName", req.PodName, "podNamespace", req.PodNamespace) + return nil, fmt.Errorf("failed to deallocate ips for pod %s/%s: %w", req.PodNamespace, req.PodName, err) + } + + if count == 0 { + s.logger.Info("No IP addresses were released (likely already deallocated or didn't exist)", "network", req.Network, "podName", req.PodName, "podNamespace", req.PodNamespace) + } else { + s.logger.Info("Successfully deallocated ips", "network", req.Network, "podName", req.PodName, "podNamespace", req.PodNamespace, "count", count) + } + + return &adaptiveipam.DeallocatePodIPResponse{}, nil +} + +func (s *adaptiveIpamServer) start() error { + sockPath := s.sockPath + if sockPath == "" { + sockPath = pkg.DefaultSockPath + } + + if err := os.Remove(sockPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove existing socket: %w", err) + } + + listener, err := net.Listen("unix", sockPath) + if err != nil { + return fmt.Errorf("failed to listen on uds %s: %w", sockPath, err) + } + defer listener.Close() + + s.grpcServer = grpc.NewServer() + adaptiveipam.RegisterAdaptiveIpamServer(s.grpcServer, s) + + s.logger.Info("gRPC server is listening", "socket", sockPath) + return s.grpcServer.Serve(listener) +} + +func (s *adaptiveIpamServer) stop() { + if s.grpcServer != nil { + s.logger.Info("Stopping gRPC server gracefully") + s.grpcServer.GracefulStop() + } +} diff --git a/metis/pkg/daemon/daemon_server_test.go b/metis/pkg/daemon/daemon_server_test.go new file mode 100644 index 000000000..c00f6a3d8 --- /dev/null +++ b/metis/pkg/daemon/daemon_server_test.go @@ -0,0 +1,404 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package daemon + +import ( + "context" + "errors" + "fmt" + "net" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/go-logr/logr" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "k8s.io/klog/v2" + "k8s.io/metis/api/adaptiveipam/v1" + "k8s.io/metis/pkg/store" +) + +func TestAdaptiveIpamServer_Start(t *testing.T) { + tempDir := t.TempDir() + sockPath := filepath.Join(tempDir, "metis_test_server.sock") + + server := &adaptiveIpamServer{sockPath: sockPath} + + errCh := make(chan error, 1) + go func() { + errCh <- server.start() + }() + + select { + case err := <-errCh: + t.Fatalf("Server failed on start: %v", err) + case <-time.After(5 * time.Second): + // Expect it to listen for 5s + } + + if _, err := os.Stat(sockPath); os.IsNotExist(err) { + t.Errorf("Expected socket to be created at %s, but doesn't exist", sockPath) + } +} + +func TestAdaptiveIpamServer_AllocatePodIP(t *testing.T) { + logger := klog.Background() + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "metis_server_test.sqlite") + + storeInstance, err := store.NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("Failed to create store: %v", err) + } + defer storeInstance.Close() + + server := &adaptiveIpamServer{store: storeInstance} + + network := "test-network" + cidr := "10.0.1.0/24" + + req := &adaptiveipam.AllocatePodIPRequest{ + Network: network, + PodName: "test-pod", + PodNamespace: "default", + Ipv4Config: &adaptiveipam.IPConfig{ + InterfaceName: "eth0", + ContainerId: "test-container", + InitialPodCidr: cidr, + }, + } + + resp, err := server.AllocatePodIP(context.Background(), req) + if err != nil { + t.Fatalf("AllocatePodIP failed: %v", err) + } + + if resp.Ipv4 == nil { + t.Fatal("Expected Ipv4 allocation, got nil") + } + + if resp.Ipv4.IpAddress == "" { + t.Fatal("Expected IP address, got empty string") + } +} + +func TestAdaptiveIpamServer_AllocatePodIP_Concurrency(t *testing.T) { + logger := klog.Background() + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "metis_server_concurrency_test.sqlite") + + storeInstance, err := store.NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("Failed to create store: %v", err) + } + defer storeInstance.Close() + + server := &adaptiveIpamServer{store: storeInstance} + + network := "test-network" + cidr := "10.0.1.0/24" + + numGoroutines := 10 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + ips := make([]string, numGoroutines) + errs := make([]error, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(index int) { + defer wg.Done() + req := &adaptiveipam.AllocatePodIPRequest{ + Network: network, + PodName: fmt.Sprintf("test-pod-%d", index), + PodNamespace: "default", + Ipv4Config: &adaptiveipam.IPConfig{ + InterfaceName: "eth0", + ContainerId: fmt.Sprintf("test-container-%d", index/2), + InitialPodCidr: cidr, + }, + } + resp, err := server.AllocatePodIP(context.Background(), req) + if err != nil { + errs[index] = err + return + } + if resp.Ipv4 != nil { + ips[index] = resp.Ipv4.IpAddress + } + }(i) + } + + wg.Wait() + + for i, err := range errs { + if err != nil { + t.Errorf("Goroutine %d failed: %v", i, err) + } + } + + for i, ip := range ips { + if ip == "" { + t.Errorf("Goroutine %d returned empty IP", i) + } + } + + for i := 0; i < numGoroutines; i += 2 { + if ips[i] != "" && ips[i] != ips[i+1] { + t.Errorf("Idempotency check failed for pair %d and %d: expected same IP, got %s and %s", i, i+1, ips[i], ips[i+1]) + } + } + + uniqueIpMap := make(map[string]bool) + for _, ip := range ips { + if ip != "" { + uniqueIpMap[ip] = true + } + } + if len(uniqueIpMap) != numGoroutines/2 { + t.Errorf("Expected %d unique IPs, got %d (ips: %v)", numGoroutines/2, len(uniqueIpMap), ips) + } + +} + +func TestAdaptiveIpamServer_DeallocatePodIP(t *testing.T) { + logger := logr.Discard() + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "metis_daemon_test_release.sqlite") + + s, err := store.NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("NewStore returned unexpected error: %v", err) + } + defer s.Close() + + server := &adaptiveIpamServer{store: s, sockPath: "", releaseCooldown: 1 * time.Minute} + + network := "gke-pod-network" + cidr := "10.0.1.0/24" + containerID := "test-container-release" + interfaceName := "eth0" + podName := "test-pod" + podNamespace := "default" + + // 1. Allocate first + reqAlloc := &adaptiveipam.AllocatePodIPRequest{ + Network: network, + PodName: podName, + PodNamespace: podNamespace, + Ipv4Config: &adaptiveipam.IPConfig{ + InterfaceName: interfaceName, + ContainerId: containerID, + InitialPodCidr: cidr, + }, + } + + allocResp, err := server.AllocatePodIP(context.Background(), reqAlloc) + if err != nil { + t.Fatalf("AllocatePodIP failed in deallocate test setup: %v", err) + } + + if allocResp.Ipv4 == nil || allocResp.Ipv4.IpAddress == "" { + t.Fatalf("AllocatePodIP response empty") + } + + // 2. Deallocate + reqDealloc := &adaptiveipam.DeallocatePodIPRequest{ + Network: network, + InterfaceName: interfaceName, + ContainerId: containerID, + PodName: podName, + PodNamespace: podNamespace, + } + + deallocResp, err := server.DeallocatePodIP(context.Background(), reqDealloc) + if err != nil { + t.Fatalf("DeallocatePodIP failed: %v", err) + } + + if deallocResp == nil { + t.Errorf("DeallocatePodIP returned nil response") + } + + // 3. Verify via store + var isAlloc bool + err = s.DB().QueryRow("SELECT is_allocated FROM ip_addresses WHERE address = ?", allocResp.Ipv4.IpAddress).Scan(&isAlloc) + if err != nil { + t.Fatalf("Failed to query DB for IP status: %v", err) + } + if isAlloc { + t.Errorf("Expected IP to be unallocated") + } +} + +func TestAdaptiveIpamServer_GRPCClientIntegration(t *testing.T) { + logger := logr.Discard() + tempDir := t.TempDir() + sockPath := filepath.Join(tempDir, "metis_test_client_integration.sock") + dbPath := filepath.Join(tempDir, "metis_client_integration.sqlite") + + s, err := store.NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer s.Close() + + server := &adaptiveIpamServer{store: s, sockPath: sockPath} + + // 1. Start server in background + errCh := make(chan error, 1) + go func() { + errCh <- server.start() + }() + + // Wait for socket to appear + time.Sleep(100 * time.Millisecond) + + // 2. Dial using gRPC client + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := grpc.DialContext(ctx, sockPath, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return net.Dial("unix", addr) + })) + if err != nil { + t.Fatalf("Failed to dial UDS %s: %v", sockPath, err) + } + defer conn.Close() + + client := adaptiveipam.NewAdaptiveIpamClient(conn) + + // 3. Prepare data and call + network := "integration-network" + cidr := "10.0.1.0/24" + req := &adaptiveipam.AllocatePodIPRequest{ + Network: network, + PodName: "test-pod", + PodNamespace: "default", + Ipv4Config: &adaptiveipam.IPConfig{ + InterfaceName: "eth0", + ContainerId: "test-container-integration", + InitialPodCidr: cidr, + }, + } + + resp, err := client.AllocatePodIP(ctx, req) + if err != nil { + t.Fatalf("gRPC Client AllocatePodIP failed: %v", err) + } + + if resp.Ipv4 == nil || resp.Ipv4.IpAddress == "" { + t.Errorf("Expected valid IP address from gRPC client, got response: %v", resp) + } +} + +func TestAdaptiveIpamServer_AllocatePodIP_RetryOnDBError(t *testing.T) { + logger := klog.Background() + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "metis_server_retry_test.sqlite") + + storeInstance, err := store.NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("Failed to create store: %v", err) + } + + server := &adaptiveIpamServer{store: storeInstance, busyTimeout: 500 * time.Millisecond} + + network := "test-network" + cidr := "10.0.1.0/24" + + if err := storeInstance.AddCIDR(context.Background(), network, cidr); err != nil { + t.Fatalf("Failed to add CIDR: %v", err) + } + + req := &adaptiveipam.AllocatePodIPRequest{ + Network: network, + PodName: "test-pod", + PodNamespace: "default", + Ipv4Config: &adaptiveipam.IPConfig{ + InterfaceName: "eth0", + ContainerId: "test-container", + }, + } + + // Close the DB to simulate transient error + storeInstance.Close() + + startTime := time.Now() + _, err = server.AllocatePodIP(context.Background(), req) + duration := time.Since(startTime) + + if err == nil { + t.Fatal("Expected error after closing DB, got nil") + } + + // Expect it to have retried, so duration should be at least 300ms + if duration < 300*time.Millisecond { + t.Errorf("Expected test to take at least 300ms due to retries, took %v", duration) + } + + if !strings.Contains(err.Error(), "database is closed") { + t.Errorf("Expected error to contain 'database is closed', got: %v", err) + } +} + +func TestAdaptiveIpamServer_AllocatePodIP_NoRetryOnExhaustion(t *testing.T) { + logger := klog.Background() + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "metis_server_exhaust_test.sqlite") + + storeInstance, err := store.NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("Failed to create store: %v", err) + } + defer storeInstance.Close() + + server := &adaptiveIpamServer{store: storeInstance} + + network := "test-network" + + req := &adaptiveipam.AllocatePodIPRequest{ + Network: network, + PodName: "test-pod", + PodNamespace: "default", + Ipv4Config: &adaptiveipam.IPConfig{ + InterfaceName: "eth0", + ContainerId: "test-container", + }, + } + + startTime := time.Now() + _, err = server.AllocatePodIP(context.Background(), req) + duration := time.Since(startTime) + + if err == nil { + t.Fatal("Expected error for exhausted store, got nil") + } + + // Expect it to fail fast, so duration should be small (much less than 100ms backoff) + if duration >= 100*time.Millisecond { + t.Errorf("Expected test to fail fast, but took %v", duration) + } + + if !errors.Is(err, store.ErrNoAvailableIPs) { + t.Errorf("Expected error to be ErrNoAvailableIPs, got: %v", err) + } +} diff --git a/metis/pkg/daemon/daemon_test.go b/metis/pkg/daemon/daemon_test.go new file mode 100644 index 000000000..ffe6844cf --- /dev/null +++ b/metis/pkg/daemon/daemon_test.go @@ -0,0 +1,73 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package daemon + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" +) + +func TestDaemon_Run(t *testing.T) { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "metis_daemon_test.sqlite") + sockPath := filepath.Join(tempDir, "metis_test.sock") + + cfg := Config{ + MonitorInterval: 5 * time.Second, + ReleaseCooldown: 1 * time.Minute, + DBPath: dbPath, + SocketPath: sockPath, + } + + d := NewDaemon(cfg) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // Clean up after test + + errCh := make(chan error, 1) + go func() { + errCh <- d.Run(ctx) + }() + + // Wait for server to start and create socket + time.Sleep(500 * time.Millisecond) + + if _, err := os.Stat(sockPath); os.IsNotExist(err) { + t.Errorf("Expected socket to be created at %s, but doesn't exist", sockPath) + } + + if _, err := os.Stat(dbPath); os.IsNotExist(err) { + t.Errorf("Expected database to be created at %s, but doesn't exist", dbPath) + } + + // Trigger exit path! + cancel() + + select { + case err := <-errCh: + if err != nil { + t.Errorf("Daemon exited with error: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatalf("Daemon failed to shut down within timeout") + } + + // If select completes without timing out, Run() exited, meaning `defer storeInstance.Close()` was executed! +} diff --git a/metis/ipam/schema.sql b/metis/pkg/store/schema.sql similarity index 100% rename from metis/ipam/schema.sql rename to metis/pkg/store/schema.sql diff --git a/metis/pkg/store/store.go b/metis/pkg/store/store.go new file mode 100644 index 000000000..87fb0c550 --- /dev/null +++ b/metis/pkg/store/store.go @@ -0,0 +1,511 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package store + +import ( + "context" + "database/sql" + _ "embed" + "errors" + "fmt" + "net" + "os" + "path/filepath" + "time" + + "github.com/go-logr/logr" + _ "github.com/mattn/go-sqlite3" // SQLite driver +) + +//go:embed schema.sql +var schemaSQL string + +const ( + // dbSchemaVersion tracks the SQLite schema version to allow safe local + // migrations and prevent state corruption across daemon restarts. + dbSchemaVersion = 1 + maxOpenConns = 10 + maxIdleConns = 10 + // DefaultBusyTimeout is the default timeout for SQLite busy handler. + DefaultBusyTimeout = 5000 * time.Millisecond +) + +// ErrNoAvailableIPs is returned when no available IPs can be found in any CIDR block. +var ErrNoAvailableIPs = errors.New("no available IPs in store") + +// Store manages database operations for IPAM. +type Store struct { + db *sql.DB + log logr.Logger +} + +// NewStore creates a new Store instance and initializes the database. +func NewStore(ctx context.Context, log logr.Logger, dbPath string) (*Store, error) { + if dbPath == "" { + return nil, fmt.Errorf("dbPath cannot be empty: an absolute path must be explicitly provided") + } + + log.Info("Opening or creating database", "path", dbPath) + + if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil { + return nil, fmt.Errorf("failed to create db directory: %w", err) + } + + // SQLite is configured directly through the DSN string. This approach + // guarantees every new connection spawned by the sql.DB pool inherits these + // exact configurations natively. + dsn := dbPath + + // Enables Write-Ahead Logging (WAL) mode. This significantly improves + // concurrency by allowing multiple readers to access the database + // simultaneously without blocking a writer, which is critical for burst + // requests. + // See: https://www.sqlite.org/pragma.html#pragma_journal_mode + "?_journal_mode=WAL" + + // Enforces foreign key constraints. SQLite ignores these by default. + // This is required to ensure ON DELETE CASCADE functions correctly on the + // ip_addresses table when a draining CIDR block is officially removed. + // See: https://www.sqlite.org/pragma.html#pragma_foreign_keys + "&_foreign_keys=on" + + // Sets the busy timeout. If the database is locked + // by another transaction, this tells the SQLite driver to wait for up + // to this duration before giving up and returning a locked error. + // See: https://www.sqlite.org/pragma.html#pragma_busy_timeout + fmt.Sprintf("&_busy_timeout=%d", DefaultBusyTimeout.Milliseconds()) + + // Instructs the Go driver to send "BEGIN IMMEDIATE" instead of standard + // "BEGIN" when starting a transaction. This grabs a write lock instantly, + // preventing deadlocks when concurrent requests try to upgrade their + // read locks to write locks simultaneously. Note: This is a go-sqlite3 + // driver feature, not a native SQLite PRAGMA. + // See: https://github.com/mattn/go-sqlite3#connection-string + "&_txlock=immediate" + + // Maps to PRAGMA synchronous = NORMAL. In WAL mode, this is the optimal + // setting for high-concurrency daemons. It prevents database corruption + // during power loss or hard crashes while offering much faster write + // performance than FULL mode, sacrificing only a few milliseconds of + // un-checkpointed durability. + // See: https://www.sqlite.org/pragma.html#pragma_synchronous + "&_synchronous=1" + + db, err := sql.Open("sqlite3", dsn) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + db.SetMaxOpenConns(maxOpenConns) + db.SetMaxIdleConns(maxIdleConns) + // Sets the maximum amount of time a connection may be reused to infinity + // (0). This guarantees the single connection never expires. + db.SetConnMaxLifetime(0) + + store := &Store{ + db: db, + log: log, + } + + // Only a single process enters this execution block at a time. + if err := store.initSchema(ctx); err != nil { + db.Close() + return nil, fmt.Errorf("failed to initialize schema: %w", err) + } + + log.Info("Initialized or updated database schema", "path", dbPath) + + return store, nil +} + +// initSchema creates the necessary tables if they don't exist. +func (s *Store) initSchema(ctx context.Context) error { + var currentVersion int + err := s.db.QueryRowContext(ctx, "PRAGMA user_version").Scan(¤tVersion) + if err != nil { + return fmt.Errorf("failed to check schema version: %w", err) + } + + if currentVersion == dbSchemaVersion { + s.log.V(4).Info("Database schema already initialized", "version", currentVersion) + return nil + } + + s.log.Info("Initializing DB schema", "currentVersion", currentVersion, "expectedVersion", dbSchemaVersion) + + // 1. Begin an atomic transaction + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + // Safe to defer; Rollback does nothing if Commit() is successful + defer tx.Rollback() + + // 2. Execute the embedded schema.sql file + if _, err := tx.ExecContext(ctx, schemaSQL); err != nil { + return fmt.Errorf("failed to execute schema.sql: %w", err) + } + + // 3. Set User Version + setVersion := fmt.Sprintf("PRAGMA user_version = %d;", dbSchemaVersion) + if _, err := tx.ExecContext(ctx, setVersion); err != nil { + return fmt.Errorf("failed to set user_version: %w", err) + } + + // 4. Commit everything atomically + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit schema transaction: %w", err) + } + + s.log.Info("Database schema initialized or updated successfully") + return nil +} + +// Close safely closes the database connection and releases any file locks. +// This should be called during the daemon's graceful shutdown sequence. +func (s *Store) Close() error { + s.log.Info("Closing IPAM database connection") + + if err := s.db.Close(); err != nil { + return fmt.Errorf("failed to close database connection: %w", err) + } + + return nil +} + +// DB returns the underlying sql.DB connection for direct queries. +func (s *Store) DB() *sql.DB { + return s.db +} + +// allocateIPv4Tx is a helper that executes the IP allocation within an existing transaction. +// It returns sql.ErrNoRows if the CIDR block is full or not found, allowing the caller to try another block. +func (s *Store) allocateIPv4Tx(ctx context.Context, tx *sql.Tx, cidrBlockID int64, interfaceName, containerID string) (string, string, error) { + // 1. Fetch CIDR range for the given ID and verify it is not full + var cidrRange string + err := tx.QueryRowContext(ctx, ` + SELECT cidr FROM cidr_blocks + WHERE id = ? AND ip_family = 'ipv4' AND total_ips > allocated_ips AND state = 'Ready' + `, cidrBlockID).Scan(&cidrRange) + + if err != nil { + if err == sql.ErrNoRows { + return "", "", sql.ErrNoRows + } + return "", "", fmt.Errorf("failed to query cidr_block: %w", err) + } + + // 2. Find the first available entry and mark it as allocated + var address string + err = tx.QueryRowContext(ctx, ` + UPDATE ip_addresses + SET is_allocated = TRUE, container_id = ?, interface_name = ?, allocated_at = CURRENT_TIMESTAMP + WHERE id = ( + SELECT id FROM ip_addresses + WHERE cidr_block_id = ? AND is_allocated = FALSE AND (release_at IS NULL OR release_at <= CURRENT_TIMESTAMP) + ORDER BY id ASC + LIMIT 1 + ) + RETURNING address + `, containerID, interfaceName, cidrBlockID).Scan(&address) + + if err != nil { + if err == sql.ErrNoRows { + return "", "", sql.ErrNoRows + } + return "", "", fmt.Errorf("failed to allocate ip: %w", err) + } + + // Also increment allocated_ips in cidr_blocks to keep it in sync + _, err = tx.ExecContext(ctx, ` + UPDATE cidr_blocks + SET allocated_ips = allocated_ips + 1 + WHERE id = ? + `, cidrBlockID) + + if err != nil { + return "", "", fmt.Errorf("failed to update allocated_ips in cidr_blocks: %w", err) + } + + return address, cidrRange, nil +} + +// AllocateIPv4 finds the first available IP from Ready CIDR blocks for a given network and allocates it. +// It also performs an idempotency check to see if the container already has an IP allocated. +func (s *Store) AllocateIPv4(ctx context.Context, network, interfaceName, containerID string) (string, string, error) { + // 1. Idempotency check (Fast Path - outside write transaction) + var address string + var cidrRange string + err := s.db.QueryRowContext(ctx, ` + SELECT i.address, c.cidr + FROM ip_addresses i + JOIN cidr_blocks c ON i.cidr_block_id = c.id + WHERE i.container_id = ? AND i.interface_name = ? AND i.is_allocated = TRUE AND c.ip_family = 'ipv4' + LIMIT 1 + `, containerID, interfaceName).Scan(&address, &cidrRange) + + if err == nil { + s.log.Info("Idempotency check hit (fast path), returning existing allocation", "containerID", containerID, "interfaceName", interfaceName, "address", address, "cidr", cidrRange) + return address, cidrRange, nil + } + if err != sql.ErrNoRows { + return "", "", fmt.Errorf("failed during fast-path idempotency check: %w", err) + } + + // 2. Query available CIDRs (Outside write transaction) + rows, err := s.db.QueryContext(ctx, ` + SELECT id FROM cidr_blocks + WHERE network = ? AND ip_family = 'ipv4' AND total_ips > allocated_ips AND state = 'Ready' + `, network) + if err != nil { + return "", "", fmt.Errorf("failed to query available cidr blocks: %w", err) + } + defer rows.Close() + + var cidrBlockIDs []int64 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return "", "", fmt.Errorf("failed to scan cidr block id: %w", err) + } + cidrBlockIDs = append(cidrBlockIDs, id) + } + + if len(cidrBlockIDs) == 0 { + return "", "", fmt.Errorf("%w: no available cidr blocks found for network %s", ErrNoAvailableIPs, network) + } + + // 3. Loop and try to allocate with short transactions + for _, cidrBlockID := range cidrBlockIDs { + // Start a short BEGIN IMMEDIATE transaction for this specific CIDR + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return "", "", fmt.Errorf("failed to begin transaction: %w", err) + } + + // Re-check idempotency inside the lock to be 100% safe against concurrent race to different CIDRs + err = tx.QueryRowContext(ctx, ` + SELECT i.address, c.cidr + FROM ip_addresses i + JOIN cidr_blocks c ON i.cidr_block_id = c.id + WHERE i.container_id = ? AND i.interface_name = ? AND i.is_allocated = TRUE AND c.ip_family = 'ipv4' + LIMIT 1 + `, containerID, interfaceName).Scan(&address, &cidrRange) + + if err == nil { + tx.Rollback() // Release lock! + s.log.Info("Idempotency check hit (slow path), returning existing allocation", "containerID", containerID, "interfaceName", interfaceName, "address", address, "cidr", cidrRange) + return address, cidrRange, nil + } + if err != sql.ErrNoRows { + tx.Rollback() + return "", "", fmt.Errorf("failed during slow-path idempotency check: %w", err) + } + + ip, cidr, err := s.allocateIPv4Tx(ctx, tx, cidrBlockID, interfaceName, containerID) + if err == nil { + if err := tx.Commit(); err != nil { + return "", "", fmt.Errorf("failed to commit transaction: %w", err) + } + return ip, cidr, nil + } + + tx.Rollback() // Rollback if failed (full or error) + + if err == sql.ErrNoRows { + s.log.V(4).Info("No available IPs in cidr block, tried next one", "cidrBlockID", cidrBlockID) + continue + } + return "", "", fmt.Errorf("failed to allocate ipv4 in cidr block %d: %w", cidrBlockID, err) + } + + return "", "", fmt.Errorf("%w: failed to allocate ipv4 in any cidr block for network %s", ErrNoAvailableIPs, network) +} + +// GetCIDRBlockByCIDR checks if a CIDR block already exists in the database. +func (s *Store) GetCIDRBlockByCIDR(ctx context.Context, cidr string) (bool, error) { + var id int64 + err := s.db.QueryRowContext(ctx, ` + SELECT id FROM cidr_blocks WHERE cidr = ? LIMIT 1 + `, cidr).Scan(&id) + + if err == nil { + return true, nil + } + if err == sql.ErrNoRows { + return false, nil + } + return false, fmt.Errorf("failed to query cidr_blocks: %w", err) +} + +// AddCIDR parses the CIDR, determines family, and inserts it + its constituent IP addresses into the store. +// The first two IPs and the last IP are automatically marked as allocated. +func (s *Store) AddCIDR(ctx context.Context, network, cidr string) error { + ip, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + return fmt.Errorf("failed to parse cidr %s: %w", cidr, err) + } + + ipFamily := "ipv4" + if ip.To4() == nil { + ipFamily = "ipv6" + } + + // Generate the list of all IPs in this CIDR range + var ips []string + for curr := ipnet.IP.Mask(ipnet.Mask); ipnet.Contains(curr); curr = incIP(curr) { + ips = append(ips, curr.String()) + } + + if len(ips) == 0 { + return fmt.Errorf("cidr range is empty: %s", cidr) + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() + + // 1. Insert into cidr_blocks + res, err := tx.ExecContext(ctx, ` + INSERT INTO cidr_blocks (cidr, network, ip_family, total_ips, allocated_ips, state) + VALUES (?, ?, ?, ?, ?, 'Ready') + `, cidr, network, ipFamily, len(ips), 0) // We will update allocated_ips later after insertions + + if err != nil { + return fmt.Errorf("failed to insert cidr_block: %w", err) + } + + cidrBlockID, err := res.LastInsertId() + if err != nil { + return fmt.Errorf("failed to get last inserted id: %w", err) + } + + // 2. Insert IP addresses and determine allocation status + var allocatedCount int + stmt, err := tx.PrepareContext(ctx, ` + INSERT INTO ip_addresses (cidr_block_id, address, is_allocated, container_id, interface_name) + VALUES (?, ?, ?, '', '') + `) + if err != nil { + return fmt.Errorf("failed to prepare insert statement: %w", err) + } + defer stmt.Close() + + for idx, addr := range ips { + isAllocated := false + // For small CIDRs (smaller than /30, i.e., /31 and /32), we do not reserve + // the first two and the last IPs. The IPs returned will still be routable + // by the underlying infrastructure. + if len(ips) >= 4 && (idx == 0 || idx == 1 || idx == len(ips)-1) { + isAllocated = true + allocatedCount++ + } + + _, err = stmt.ExecContext(ctx, cidrBlockID, addr, isAllocated) + if err != nil { + return fmt.Errorf("failed to insert ip_address %s: %w", addr, err) + } + } + + // 3. Update allocated_ips to reflect the defaults we just reserved + _, err = tx.ExecContext(ctx, ` + UPDATE cidr_blocks SET allocated_ips = ? WHERE id = ? + `, allocatedCount, cidrBlockID) + if err != nil { + return fmt.Errorf("failed to update allocated_ips: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + s.log.Info("Successfully added CIDR and its IPs to store", "cidr", cidr, "network", network, "totalIPs", len(ips), "reservedIPs", allocatedCount) + return nil +} + +// incIP increments an IP address. +func incIP(ip net.IP) net.IP { + newIP := make(net.IP, len(ip)) + copy(newIP, ip) + for i := len(newIP) - 1; i >= 0; i-- { + newIP[i]++ + if newIP[i] > 0 { + break + } + } + return newIP +} + +// ReleaseIPByOwner updates all IP addresses matching the network, container id and interface name to be is_allocated = FALSE, and sets release_at timestamp to be now + releaseCooldown. It also decrements allocated_ips count in cidr_blocks. +func (s *Store) ReleaseIPByOwner(ctx context.Context, network, containerID, interfaceName string, releaseCooldown time.Duration) (int, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return 0, fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() + + releaseAt := time.Now().Add(releaseCooldown) + + rows, err := tx.QueryContext(ctx, ` + SELECT i.id, i.cidr_block_id + FROM ip_addresses i + JOIN cidr_blocks c ON i.cidr_block_id = c.id + WHERE c.network = ? AND i.container_id = ? AND i.interface_name = ? AND i.is_allocated = TRUE + `, network, containerID, interfaceName) + + if err != nil { + return 0, fmt.Errorf("failed to query matching IP owners: %w", err) + } + defer rows.Close() + + type release struct { + id int64 + cidrBlockID int64 + } + var releases []release + + for rows.Next() { + var r release + if err := rows.Scan(&r.id, &r.cidrBlockID); err != nil { + return 0, fmt.Errorf("failed to scan affected IP details: %w", err) + } + releases = append(releases, r) + } + + for _, r := range releases { + _, err = tx.ExecContext(ctx, ` + UPDATE ip_addresses + SET is_allocated = FALSE, release_at = ? + WHERE id = ? + `, releaseAt, r.id) + if err != nil { + return 0, fmt.Errorf("failed to release IP %d: %w", r.id, err) + } + + _, err = tx.ExecContext(ctx, ` + UPDATE cidr_blocks + SET allocated_ips = allocated_ips - 1 + WHERE id = ? + `, r.cidrBlockID) + if err != nil { + return 0, fmt.Errorf("failed to update cidr_block %d count: %w", r.cidrBlockID, err) + } + } + + if err := tx.Commit(); err != nil { + return 0, fmt.Errorf("failed to commit release transaction: %w", err) + } + + return len(releases), nil +} diff --git a/metis/pkg/store/store_test.go b/metis/pkg/store/store_test.go new file mode 100644 index 000000000..6008828ac --- /dev/null +++ b/metis/pkg/store/store_test.go @@ -0,0 +1,799 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package store + +import ( + "context" + "database/sql" + "errors" + "fmt" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/go-logr/logr" + _ "github.com/mattn/go-sqlite3" // SQLite driver +) + +// TestNewStore_SuccessAndClose verifies that a new Store can be created, +// the schema is successfully initialized, and the database closes cleanly. +func TestNewStore_SuccessAndClose(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "test_ipam.db") + logger := logr.Discard() + + store, err := NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + if store == nil { + t.Fatal("Expected a valid Store instance, got nil") + } + + // Verify tables were actually created. + var tableName string + err = store.db.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='cidr_blocks';").Scan(&tableName) + if err != nil || tableName != "cidr_blocks" { + t.Errorf("Expected table 'cidr_blocks' to exist, got error: %v", err) + } + + // Verify the database connection is alive. + if err := store.db.Ping(); err != nil { + t.Errorf("Database ping failed: %v", err) + } + + if err := store.Close(); err != nil { + t.Errorf("Store.Close() failed: %v", err) + } +} + +// TestNewStore_Idempotency verifies idempotency by ensuring a second +// initialization safely skips schema creation. By intentionally dropping an +// index after the first run, the test deterministically proves that the second +// execution short-circuits and leaves the index uncreated. +func TestNewStore_Idempotency(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "idempotency_test.db") + logger := logr.Discard() + + // Initial creation sets up the full schema and user_version = 1. + store1, err := NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("First NewStore call failed: %v", err) + } + if err := store1.Close(); err != nil { + t.Fatalf("Failed to close first store: %v", err) + } + + // Sabotage the schema slightly to prove the bypass works. Manually connect + // and drop an index that initSchema would normally create. + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatalf("Failed to open DB for manual intervention: %v", err) + } + if _, err := db.Exec("DROP INDEX idx_ip_idempotency;"); err != nil { + t.Fatalf("Failed to manually drop index: %v", err) + } + db.Close() + + // If the short-circuit works, it will see user_version=1 and return early, + // meaning it will NOT execute the CREATE statements to fix the missing index. + store2, err := NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("Second NewStore call failed: %v", err) + } + defer store2.Close() + + // Verify the index is still missing. + var name string + query := "SELECT name FROM sqlite_master WHERE type='index' AND name='idx_ip_idempotency';" + err = store2.db.QueryRow(query).Scan(&name) + + if err == nil { + // If err is nil, it means the query found the index, which means the + // schema block executed and recreated it. + t.Errorf("Expected index 'idx_ip_idempotency' to be missing, but it was recreated. Short-circuit failed.") + } else if err != sql.ErrNoRows { + // If an error other than ErrNoRows occurs, the query failed unexpectedly. + t.Fatalf("Unexpected error querying sqlite_master: %v", err) + } +} + +// TestNewStore_SchemaVerification rigorously checks the sqlite_master table +// to ensure all expected tables, indexes, and triggers were successfully +// created during initialization. +func TestNewStore_SchemaVerification(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "schema_test.db") + logger := logr.Discard() + + store, err := NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer store.Close() + + // Define the expected schema components. + expectedTables := []string{"cidr_blocks", "ip_addresses"} + expectedIndexes := []string{"idx_available_ips", "idx_ip_idempotency"} + expectedTriggers := []string{"update_cidr_blocks_updated_at", "update_ip_addresses_updated_at"} + + // Verify Tables. + for _, table := range expectedTables { + var name string + query := "SELECT name FROM sqlite_master WHERE type='table' AND name=?;" + if err := store.db.QueryRow(query, table).Scan(&name); err != nil { + t.Errorf("Schema verification failed: expected table '%s' not found: %v", table, err) + } + } + + // Verify Indexes. + for _, index := range expectedIndexes { + var name string + query := "SELECT name FROM sqlite_master WHERE type='index' AND name=?;" + if err := store.db.QueryRow(query, index).Scan(&name); err != nil { + t.Errorf("Schema verification failed: expected index '%s' not found: %v", index, err) + } + } + + // Verify Triggers. + for _, trigger := range expectedTriggers { + var name string + query := "SELECT name FROM sqlite_master WHERE type='trigger' AND name=?;" + if err := store.db.QueryRow(query, trigger).Scan(&name); err != nil { + t.Errorf("Schema verification failed: expected trigger '%s' not found: %v", trigger, err) + } + } +} + +// TestStore_Concurrency verifies that the SQLite connection pool and +// transaction locks (_txlock=immediate, MaxOpenConns=10) can successfully +// handle high bursts of concurrent requests without throwing SQLITE_BUSY +// errors. It achieves maximum contention by spawning multiple goroutines and +// using a broadcast channel as a "starting gun" to release them at the exact +// same logical moment. +func TestStore_Concurrency(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "ipam-concurrency.db") + s, err := NewStore(context.Background(), logr.Discard(), dbPath) + if err != nil { + t.Fatalf("Failed to initialize store: %v", err) + } + defer s.Close() + + var wg sync.WaitGroup + numGoroutines := 10 + + // Create the Starting Line channel. + startLine := make(chan struct{}) + + // Simulate 10 concurrent gRPC requests hitting the database. + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + // Block this goroutine until the starting gun fires. + <-startLine + + // Simulate an Insert. + cidr := fmt.Sprintf("10.0.%d.0/24", id) + insertQuery := `INSERT INTO cidr_blocks (cidr, network, ip_family, total_ips, allocated_ips, state) + VALUES (?, 'test-network', 'ipv4', 256, 0, 'Ready')` + + _, err := s.db.Exec(insertQuery, cidr) + if err != nil { + t.Errorf("Goroutine %d failed to insert: %v", id, err) + return + } + + // Simulate a Read. + var state string + readQuery := `SELECT state FROM cidr_blocks WHERE cidr = ?` + err = s.db.QueryRow(readQuery, cidr).Scan(&state) + if err != nil { + t.Errorf("Goroutine %d failed to read: %v", id, err) + return + } + }(i) + } + + // Fire the starting gun! + // Closing the channel instantly releases all 10 blocked goroutines at the exact same time. + close(startLine) + + // Wait for all goroutines to finish. + wg.Wait() +} + +// TestStore_MaxOpenConns_Limit verifies that the connection pool can fan out +// to the configured maxOpenConns limit. It proves this by holding +// (maxOpenConns - 1) read connections hostage and ensuring the final allowed +// concurrent query can still execute successfully without hitting a pool bottleneck. +func TestStore_MaxOpenConns_Limit(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "ipam-pool-limit.db") + s, err := NewStore(context.Background(), logr.Discard(), dbPath) + if err != nil { + t.Fatalf("Failed to initialize store: %v", err) + } + defer s.Close() + + // Dynamically scale the test based on the Store's configuration + hostageCount := s.db.Stats().MaxOpenConnections - 1 + connAcquired := make(chan struct{}, hostageCount) + releaseHostages := make(chan struct{}) + var wg sync.WaitGroup + + // 1. Take (maxOpenConns - 1) connections hostage using unclosed read queries + for i := range hostageCount { + wg.Add(1) + go func(id int) { + defer wg.Done() + + rows, err := s.db.Query(`SELECT state FROM cidr_blocks`) + if err != nil { + t.Errorf("Goroutine %d failed to execute read query: %v", id, err) + return + } + defer rows.Close() + + connAcquired <- struct{}{} + <-releaseHostages + }(i) + } + + // Wait for all hostage connections to be checked out of the pool. + // We use a 1-second timeout to prevent the test from hanging indefinitely + // if the connection pool is configured smaller than hostageCount. + timeout := time.After(1 * time.Second) + for i := 0; i < hostageCount; i++ { + select { + case <-connAcquired: + // Connection successfully acquired + case <-timeout: + t.Fatalf("Test timed out waiting to acquire %d connections. MaxOpenConns is likely configured lower than the expected limit.", hostageCount) + } + } + + // 2. Execute the final allowed query to reach maxOpenConns + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + var state string + err = s.db.QueryRowContext(ctx, `SELECT state FROM cidr_blocks LIMIT 1`).Scan(&state) + + if err != nil && err != sql.ErrNoRows { + t.Fatalf("Final concurrent query failed (pool limit reached prematurely): %v", err) + } + + // 3. Clean up + close(releaseHostages) + wg.Wait() +} + +func TestStore_AddCIDR(t *testing.T) { + logger := logr.Discard() // Use discard logger to avoid klog dependency in tests + + // Use testing.T.TempDir() which is standard in modern Go and cleans up automatically! + tempDir := t.TempDir() + + dbPath := filepath.Join(tempDir, "metis.sqlite") + s, err := NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("NewStore returned unexpected error: %v", err) + } + defer s.Close() + + network := "gke-pod-network-addcidr" + cidr := "10.0.1.0/29" // 8 IPs: 10.0.1.0 to 10.0.1.7 + + err = s.AddCIDR(context.Background(), network, cidr) + if err != nil { + t.Fatalf("AddCIDR failed: %v", err) + } + + // 1. Verify cidr_block table insertion + var totalIPs, allocatedIPs int + var state string + err = s.db.QueryRow(`SELECT total_ips, allocated_ips, state FROM cidr_blocks WHERE cidr = ?`, cidr).Scan(&totalIPs, &allocatedIPs, &state) + if err != nil { + t.Fatalf("Failed to query inserted cidr_block: %v", err) + } + + if totalIPs != 8 { + t.Errorf("Expected total_ips 8, got %d", totalIPs) + } + if allocatedIPs != 3 { + t.Errorf("Expected allocated_ips 3 (first two and last one reserved), got %d", allocatedIPs) + } + if state != "Ready" { + t.Errorf("Expected state Ready, got %s", state) + } + + // 2. Verify ip_addresses table insertion and allocations + rows, err := s.db.Query(`SELECT address, is_allocated FROM ip_addresses WHERE cidr_block_id = (SELECT id FROM cidr_blocks WHERE cidr = ?) ORDER BY address`, cidr) + if err != nil { + t.Fatalf("Failed to query inserted ip_addresses: %v", err) + } + defer rows.Close() + + var addresses []string + var allocations []bool + for rows.Next() { + var addr string + var isAlloc bool + if err := rows.Scan(&addr, &isAlloc); err != nil { + t.Fatalf("Failed to scan ip_address: %v", err) + } + addresses = append(addresses, addr) + allocations = append(allocations, isAlloc) + } + + expectedAddrs := []string{ + "10.0.1.0", "10.0.1.1", "10.0.1.2", "10.0.1.3", + "10.0.1.4", "10.0.1.5", "10.0.1.6", "10.0.1.7", + } + expectedAllocs := []bool{ + true, true, false, false, + false, false, false, true, + } + + if len(addresses) != len(expectedAddrs) { + t.Fatalf("Expected %d addresses, got %d", len(expectedAddrs), len(addresses)) + } + + for i := range expectedAddrs { + if addresses[i] != expectedAddrs[i] { + t.Errorf("[%d] Expected address %s, got %s", i, expectedAddrs[i], addresses[i]) + } + if allocations[i] != expectedAllocs[i] { + t.Errorf("[%d] Expected allocation %v for address %s, got %v", i, expectedAllocs[i], addresses[i], allocations[i]) + } + } +} + +func TestStore_AddCIDR_Small(t *testing.T) { + logger := logr.Discard() + tempDir := t.TempDir() + + dbPath := filepath.Join(tempDir, "metis_small.sqlite") + s, err := NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("NewStore returned unexpected error: %v", err) + } + defer s.Close() + + // Test /31 (2 IPs) + network31 := "gke-pod-network-31" + cidr31 := "10.0.2.0/31" // 2 IPs: 10.0.2.0, 10.0.2.1 + + err = s.AddCIDR(context.Background(), network31, cidr31) + if err != nil { + t.Fatalf("AddCIDR failed for /31: %v", err) + } + + var allocatedIPs31 int + err = s.db.QueryRow(`SELECT allocated_ips FROM cidr_blocks WHERE cidr = ?`, cidr31).Scan(&allocatedIPs31) + if err != nil { + t.Fatalf("Failed to query inserted cidr_block for /31: %v", err) + } + if allocatedIPs31 != 0 { + t.Errorf("Expected allocated_ips 0 for /31, got %d", allocatedIPs31) + } + + // Test /32 (1 IP) + network32 := "gke-pod-network-32" + cidr32 := "10.0.3.0/32" // 1 IP: 10.0.3.0 + + err = s.AddCIDR(context.Background(), network32, cidr32) + if err != nil { + t.Fatalf("AddCIDR failed for /32: %v", err) + } + + var allocatedIPs32 int + err = s.db.QueryRow(`SELECT allocated_ips FROM cidr_blocks WHERE cidr = ?`, cidr32).Scan(&allocatedIPs32) + if err != nil { + t.Fatalf("Failed to query inserted cidr_block for /32: %v", err) + } + if allocatedIPs32 != 0 { + t.Errorf("Expected allocated_ips 0 for /32, got %d", allocatedIPs32) + } +} + +func TestStore_AllocateIPv4_SingleCIDR(t *testing.T) { + logger := logr.Discard() + tempDir := t.TempDir() + + dbPath := filepath.Join(tempDir, "metis_allocate_network_single.sqlite") + s, err := NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("NewStore returned unexpected error: %v", err) + } + defer s.Close() + + network := "gke-pod-network-allocate" + cidr := "10.0.2.0/29" // 8 IPs: .0 to .7. Reserved: .0, .1, .7. Available: .2, .3, .4, .5, .6. + + // Test Case 1: Error - No CIDR blocks found (DB empty) + _, _, err = s.AllocateIPv4(context.Background(), network, "eth0", "container-1") + if err == nil { + t.Error("Expected error for no CIDR blocks, got nil") + } else if !errors.Is(err, ErrNoAvailableIPs) { + t.Errorf("Expected ErrNoAvailableIPs, got %v", err) + } + + // Add the CIDR + if err := s.AddCIDR(context.Background(), network, cidr); err != nil { + t.Fatalf("AddCIDR failed: %v", err) + } + + // Test Case 2: Happy path - First allocation + ip1, cidrRange1, err := s.AllocateIPv4(context.Background(), network, "eth0", "container-1") + if err != nil { + t.Fatalf("First allocation failed: %v", err) + } + if ip1 != "10.0.2.2" { + t.Errorf("Expected IP 10.0.2.2, got %s", ip1) + } + if cidrRange1 != cidr { + t.Errorf("Expected CIDR range %s, got %s", cidr, cidrRange1) + } + + // Verify DB state for first allocation + var isAlloc bool + var containerID, interfaceName string + err = s.db.QueryRow(`SELECT is_allocated, container_id, interface_name FROM ip_addresses WHERE address = '10.0.2.2'`).Scan(&isAlloc, &containerID, &interfaceName) + if err != nil { + t.Fatalf("Failed to query DB for IP status: %v", err) + } + if !isAlloc { + t.Error("Expected IP 10.0.2.2 to be marked as allocated") + } + if containerID != "container-1" || interfaceName != "eth0" { + t.Errorf("Expected container-1/eth0, got %s/%s", containerID, interfaceName) + } + + // Test Case 3: Happy path - Second allocation + ip2, _, err := s.AllocateIPv4(context.Background(), network, "eth0", "container-2") + if err != nil { + t.Fatalf("Second allocation failed: %v", err) + } + if ip2 != "10.0.2.3" { + t.Errorf("Expected IP 10.0.2.3, got %s", ip2) + } + + // Test Case 4: Exhaust allocation + // We had 5 available IPs: .2, .3, .4, .5, .6. + // We already allocated .2 and .3. + // Let's allocate .4, .5, .6. + for i := 4; i <= 6; i++ { + expectedIP := fmt.Sprintf("10.0.2.%d", i) + ip, _, err := s.AllocateIPv4(context.Background(), network, "eth0", fmt.Sprintf("container-%d", i)) + if err != nil { + t.Fatalf("Allocation failed for %s: %v", expectedIP, err) + } + if ip != expectedIP { + t.Errorf("Expected IP %s, got %s", expectedIP, ip) + } + } + + // Now it should be exhausted. Next allocation should fail. + ipEx, _, err := s.AllocateIPv4(context.Background(), network, "eth0", "container-exhaust") + if err == nil { + t.Errorf("Expected error for exhausted CIDR, got nil. Returned IP: %s", ipEx) + } else if !errors.Is(err, ErrNoAvailableIPs) { + t.Errorf("Expected ErrNoAvailableIPs, got %v", err) + } + + // Test Case 5: Error - No available IP address found (desync) + newNetwork := "gke-pod-network-2" + newCIDR := "10.0.3.0/29" + if err := s.AddCIDR(context.Background(), newNetwork, newCIDR); err != nil { + t.Fatalf("Failed to add new CIDR: %v", err) + } + + var newCidrBlockID int64 + err = s.db.QueryRow("SELECT id FROM cidr_blocks WHERE cidr = ?", newCIDR).Scan(&newCidrBlockID) + if err != nil { + t.Fatalf("Failed to get cidr_block_id: %v", err) + } + + // Manually mark all IPs as allocated in `ip_addresses` to simulate desync + _, err = s.db.Exec(`UPDATE ip_addresses SET is_allocated = TRUE WHERE cidr_block_id = ?`, newCidrBlockID) + if err != nil { + t.Fatalf("Failed to manually corrupt DB: %v", err) + } + + ipDesync, _, err := s.AllocateIPv4(context.Background(), newNetwork, "eth0", "container-desync") + if err == nil { + t.Errorf("Expected error due to IP address desync, got nil. Returned IP: %s", ipDesync) + } else if !errors.Is(err, ErrNoAvailableIPs) { + t.Errorf("Expected ErrNoAvailableIPs, got %v", err) + } +} + +func TestStore_ReleaseIPByOwner(t *testing.T) { + logger := logr.Discard() + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "metis_release_test.sqlite") + + s, err := NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer s.Close() + + network := "test-network" + cidr := "10.0.1.0/24" + + if err := s.AddCIDR(context.Background(), network, cidr); err != nil { + t.Fatalf("AddCIDR failed: %v", err) + } + + var cidrBlockID int64 + err = s.db.QueryRow("SELECT id FROM cidr_blocks WHERE cidr = ?", cidr).Scan(&cidrBlockID) + if err != nil { + t.Fatalf("Failed to query cidr_block_id: %v", err) + } + + containerID := "test-container" + interfaceName := "eth0" + + ip, _, err := s.AllocateIPv4(context.Background(), network, interfaceName, containerID) + if err != nil { + t.Fatalf("AllocateIPv4 failed: %v", err) + } + + var allocatedIPs int + err = s.db.QueryRow("SELECT allocated_ips FROM cidr_blocks WHERE id = ?", cidrBlockID).Scan(&allocatedIPs) + if err != nil { + t.Fatalf("QueryRow failed: %v", err) + } + if allocatedIPs != 4 { // 3 reserved + 1 allocated + t.Errorf("Expected 4 allocated IPs, got %d", allocatedIPs) + } + + cooloff := 1 * time.Minute + count, err := s.ReleaseIPByOwner(context.Background(), network, containerID, interfaceName, cooloff) + if err != nil { + t.Fatalf("ReleaseIPByOwner failed: %v", err) + } + if count != 1 { + t.Errorf("Expected 1 IP to be released, got %d", count) + } + + err = s.db.QueryRow("SELECT allocated_ips FROM cidr_blocks WHERE id = ?", cidrBlockID).Scan(&allocatedIPs) + if err != nil { + t.Fatalf("QueryRow failed after release: %v", err) + } + if allocatedIPs != 3 { // Back to 3! + t.Errorf("Expected 3 allocated IPs after release, got %d", allocatedIPs) + } + + var isAlloc bool + var releaseAt sql.NullTime + err = s.db.QueryRow("SELECT is_allocated, release_at FROM ip_addresses WHERE address = ?", ip).Scan(&isAlloc, &releaseAt) + if err != nil { + t.Fatalf("QueryRow failed for IP status: %v", err) + } + if isAlloc { + t.Error("Expected IP to be unallocated") + } + if !releaseAt.Valid { + t.Error("Expected release_at to be valid") + } +} + +func TestStore_AllocateIPv4_FallbackAndCooldown(t *testing.T) { + logger := logr.Discard() + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "fallback_test.sqlite") + + s, err := NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer s.Close() + + network := "test-network" + cidr1 := "10.0.1.0/29" // 5 available addresses (.2 to .6) + + if err := s.AddCIDR(context.Background(), network, cidr1); err != nil { + t.Fatalf("AddCIDR failed: %v", err) + } + + // 1. Allocate 5 IPs to exhaust the first CIDR + for i := 1; i <= 5; i++ { + _, _, err := s.AllocateIPv4(context.Background(), network, "eth0", fmt.Sprintf("container-%d", i)) + if err != nil { + t.Fatalf("Failed to allocate container-%d: %v", i, err) + } + } + + // 2. Attempting to allocate another should FAIL because the first CIDR is full + _, _, err = s.AllocateIPv4(context.Background(), network, "eth0", "container-6") + if err == nil { + t.Error("Expected AllocateIPv4 to fail when first CIDR is full, got nil") + } + + // 3. Add a second CIDR block + cidr2 := "10.0.2.0/29" + if err := s.AddCIDR(context.Background(), network, cidr2); err != nil { + t.Fatalf("Failed to add second CIDR block: %v", err) + } + + // 4. Try allocating again, it should succeed by falling back to the second CIDR + ip, cidr, err := s.AllocateIPv4(context.Background(), network, "eth0", "container-7") + if err != nil { + t.Fatalf("AllocateIPv4 failed after adding second CIDR: %v", err) + } + + if ip != "10.0.2.2" { // First available in second CIDR + t.Errorf("Expected IP 10.0.2.2 from second CIDR, got %s", ip) + } + if cidr != cidr2 { + t.Errorf("Expected CIDR %s, got %s", cidr2, cidr) + } + + // 5. Release one IP with cooldown + _, err = s.ReleaseIPByOwner(context.Background(), network, "container-1", "eth0", 1*time.Hour) + if err != nil { + t.Fatalf("ReleaseIPByOwner failed: %v", err) + } + + // 6. Try to re-allocate for a NEW container. It should NOT pick the released IP (since it's in cooldown). + // It should pick the next available in the second CIDR (since first CIDR is full except for the cooled-down one). + ipNew, _, err := s.AllocateIPv4(context.Background(), network, "eth0", "container-new") + if err != nil { + t.Fatalf("AllocateIPv4 failed after release with cooldown: %v", err) + } + + if ipNew == "10.0.1.2" { + t.Errorf("Expected different IP from 10.0.1.2 which should be in release cooldown") + } +} + +func TestStore_AllocateIPv4_Idempotency_Concurrency(t *testing.T) { + logger := logr.Discard() + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "idempotency_concurrency_test.sqlite") + + s, err := NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer s.Close() + + network := "test-network" + cidr := "10.0.1.0/24" + + if err := s.AddCIDR(context.Background(), network, cidr); err != nil { + t.Fatalf("AddCIDR failed: %v", err) + } + + containerID := "test-concurrent-container" + interfaceName := "eth0" + + const numGoroutines = 10 + var wg sync.WaitGroup + ips := make([]string, numGoroutines) + errs := make([]error, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + addr, _, err := s.AllocateIPv4(context.Background(), network, interfaceName, containerID) + ips[idx] = addr + errs[idx] = err + }(i) + } + + wg.Wait() + + var firstIP string + for i := 0; i < numGoroutines; i++ { + if errs[i] != nil { + t.Errorf("Goroutine %d failed: %v", i, errs[i]) + } + if ips[i] == "" { + t.Errorf("Goroutine %d returned empty IP", i) + } else { + if firstIP == "" { + firstIP = ips[i] + } else if ips[i] != firstIP { + t.Errorf("Goroutine %d got different IP: %s, want %s", i, ips[i], firstIP) + } + } + } + + // Double check the DB to ensure only 1 row was created + var count int + err = s.DB().QueryRow("SELECT COUNT(*) FROM ip_addresses WHERE container_id = ? AND interface_name = ?", containerID, interfaceName).Scan(&count) + if err != nil { + t.Fatalf("Failed to query DB for count: %v", err) + } + if count != 1 { + t.Errorf("Expected exactly 1 row for container %s, got %d", containerID, count) + } +} + +func TestStore_AllocateIPv4_Concurrency_DifferentContainers(t *testing.T) { + logger := logr.Discard() + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "concurrency_diff_containers.sqlite") + + s, err := NewStore(context.Background(), logger, dbPath) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer s.Close() + + network := "test-network" + cidr := "10.0.1.0/24" // 253 available IPs + + if err := s.AddCIDR(context.Background(), network, cidr); err != nil { + t.Fatalf("AddCIDR failed: %v", err) + } + + const numGoroutines = 50 // High contention + var wg sync.WaitGroup + ips := make([]string, numGoroutines) + errs := make([]error, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + addr, _, err := s.AllocateIPv4(context.Background(), network, "eth0", fmt.Sprintf("container-%d", idx)) + ips[idx] = addr + errs[idx] = err + }(i) + } + + wg.Wait() + + // Verify all succeeded and IPs are unique + uniqueIPs := make(map[string]bool) + for i := 0; i < numGoroutines; i++ { + if errs[i] != nil { + t.Errorf("Goroutine %d failed: %v", i, errs[i]) + } + if ips[i] == "" { + t.Errorf("Goroutine %d returned empty IP", i) + } else { + if uniqueIPs[ips[i]] { + t.Errorf("Duplicate IP allocated: %s", ips[i]) + } + uniqueIPs[ips[i]] = true + } + } + + if len(uniqueIPs) != numGoroutines { + t.Errorf("Expected %d unique IPs, got %d", numGoroutines, len(uniqueIPs)) + } + + // Verify DB stats + var count int + err = s.DB().QueryRow("SELECT COUNT(*) FROM ip_addresses WHERE is_allocated = TRUE AND container_id LIKE 'container-%'").Scan(&count) + if err != nil { + t.Fatalf("Failed to query DB for count: %v", err) + } + if count != numGoroutines { + t.Errorf("Expected %d allocated rows in DB, got %d", numGoroutines, count) + } +}