diff --git a/plugins/outputs/all/snowpipe_streaming.go b/plugins/outputs/all/snowpipe_streaming.go new file mode 100644 index 0000000000000..401baa6ee18e3 --- /dev/null +++ b/plugins/outputs/all/snowpipe_streaming.go @@ -0,0 +1,5 @@ +//go:build !custom || outputs || outputs.snowpipe_streaming + +package all + +import _ "github.com/influxdata/telegraf/plugins/outputs/snowpipe_streaming" // register plugin diff --git a/plugins/outputs/snowpipe_streaming/README.md b/plugins/outputs/snowpipe_streaming/README.md new file mode 100644 index 0000000000000..08fa8c0e72d6a --- /dev/null +++ b/plugins/outputs/snowpipe_streaming/README.md @@ -0,0 +1,201 @@ +# Snowpipe Streaming Output Plugin + +This plugin writes metrics to [Snowflake][snowflake] using efficient batch +inserts via the [gosnowflake][gosnowflake] driver with array binding, which +leverages Snowpipe Streaming internally for low-latency, high-throughput +ingest without staging files. + +[snowflake]: https://www.snowflake.com/ +[gosnowflake]: https://github.com/snowflakedb/gosnowflake + +⭐ Telegraf v1.35.0 +🏷️ cloud, datastore +💻 all + +## Prerequisites + +1. A Snowflake account with a database and schema already created. +2. Key-pair authentication configured for the Snowflake user: + - Generate an RSA key pair: + + ```bash + openssl genrsa 2048 | openssl pkcs8 -topk8 -inform PEM -out rsa_key.p8 -nocrypt + openssl rsa -in rsa_key.p8 -pubout -out rsa_key.pub + ``` + + - Assign the public key to the user: + + ```sql + ALTER USER my_user SET RSA_PUBLIC_KEY=''; + ``` + +3. The user must have INSERT privileges on the target table(s). +4. If `create_table = true`, the user must also have CREATE TABLE privileges. + +## Global configuration options + +Plugins support additional global and plugin configuration settings for tasks +such as modifying metrics, tags, and fields, creating aliases, and configuring +plugin ordering. See [CONFIGURATION.md][CONFIGURATION.md] for more details. + +[CONFIGURATION.md]: ../../../docs/CONFIGURATION.md#plugins + +## Configuration + +```toml @sample.conf +# Stream metrics to Snowflake via Snowpipe Streaming +[[outputs.snowpipe_streaming]] + ## Snowflake account identifier (e.g. "xy12345.us-east-1") + account = "" + + ## Snowflake username for key-pair authentication + user = "" + + ## Path to RSA private key file (PEM format) for key-pair auth + private_key_path = "" + + ## Optional passphrase for the RSA private key + # private_key_passphrase = "" + + ## Snowflake role to use + # role = "" + + ## Target database name + database = "" + + ## Target schema name + schema = "" + + ## Target table name + ## Supports Go templates with access to metric properties: + ## {{.Name}} - metric name + ## {{.Tag "key"}} - tag value + ## Example: "metrics_{{.Name}}" routes each metric name to a separate table + table = "" + + ## Number of rows per insert batch + # batch_size = 1000 + + ## Maximum number of retries on transient errors + # retry_max = 3 + + ## Delay between retries (exponential backoff base) + # retry_delay = "1s" + + ## Column name to store the metric timestamp + # timestamp_column = "timestamp" + + ## Restrict which tags to include as columns (empty = all tags) + # tag_columns = [] + + ## Restrict which fields to include as columns (empty = all fields) + # field_columns = [] + + ## Automatically create the target table if it does not exist + # create_table = false + + ## How long to cache table schema information + # table_schema_cache_ttl = "5m" +``` + +## Table Schema + +Each metric is stored as a row with the following column mapping: + +| Column | Type | Source | +|--------------------|---------------|-----------------------| +| `timestamp` | TIMESTAMP_NTZ | Metric timestamp | +| `name` | VARCHAR | Metric name | +| *(each tag key)* | VARCHAR | Tag value | +| *(each field key)* | varies | Field value | + +Field type mapping: + +| Go Type | Snowflake Type | +|-----------------|----------------| +| int64, uint64 | NUMBER | +| float64 | DOUBLE | +| bool | BOOLEAN | +| string | VARCHAR | + +When `create_table = true`, the plugin will create the table with appropriate +types. When new tags or fields appear, columns are automatically added via +`ALTER TABLE ADD COLUMN`. + +## Example Configurations + +### Basic — single table + +```toml +[[outputs.snowpipe_streaming]] + account = "xy12345.us-east-1" + user = "TELEGRAF_USER" + private_key_path = "/etc/telegraf/snowflake_key.p8" + database = "TELEMETRY" + schema = "PUBLIC" + table = "METRICS" + create_table = true +``` + +### Template-based table routing + +```toml +[[outputs.snowpipe_streaming]] + account = "xy12345.us-east-1" + user = "TELEGRAF_USER" + private_key_path = "/etc/telegraf/snowflake_key.p8" + database = "TELEMETRY" + schema = "RAW" + table = "metrics_{{.Name}}" + create_table = true +``` + +### Specific columns only + +```toml +[[outputs.snowpipe_streaming]] + account = "xy12345.us-east-1" + user = "TELEGRAF_USER" + private_key_path = "/etc/telegraf/snowflake_key.p8" + database = "TELEMETRY" + schema = "PUBLIC" + table = "CPU_METRICS" + tag_columns = ["host", "cpu"] + field_columns = ["usage_idle", "usage_user", "usage_system"] + batch_size = 5000 +``` + +## Troubleshooting + +### Authentication errors + +Ensure your RSA key pair is correctly configured: + +```sql +DESC USER my_user; +``` + +Check that `RSA_PUBLIC_KEY_FP` is set and matches your key. + +### Permission errors + +The user/role must have the required grants: + +```sql +GRANT USAGE ON DATABASE telemetry TO ROLE my_role; +GRANT USAGE ON SCHEMA telemetry.public TO ROLE my_role; +GRANT INSERT ON TABLE telemetry.public.metrics TO ROLE my_role; +-- If using create_table = true: +GRANT CREATE TABLE ON SCHEMA telemetry.public TO ROLE my_role; +``` + +### Transient errors and retries + +The plugin automatically retries on transient errors (connection resets, +timeouts, service unavailable) with exponential backoff. Increase `retry_max` +and `retry_delay` for unreliable networks. + +### NaN/Inf field values + +Fields containing NaN or Inf float values are inserted as NULL to avoid +Snowflake errors. diff --git a/plugins/outputs/snowpipe_streaming/sample.conf b/plugins/outputs/snowpipe_streaming/sample.conf new file mode 100644 index 0000000000000..f37af952217b8 --- /dev/null +++ b/plugins/outputs/snowpipe_streaming/sample.conf @@ -0,0 +1,53 @@ +# Stream metrics to Snowflake via Snowpipe Streaming +[[outputs.snowpipe_streaming]] + ## Snowflake account identifier (e.g. "xy12345.us-east-1") + account = "" + + ## Snowflake username for key-pair authentication + user = "" + + ## Path to RSA private key file (PEM format) for key-pair auth + private_key_path = "" + + ## Optional passphrase for the RSA private key + # private_key_passphrase = "" + + ## Snowflake role to use + # role = "" + + ## Target database name + database = "" + + ## Target schema name + schema = "" + + ## Target table name + ## Supports Go templates with access to metric properties: + ## {{.Name}} - metric name + ## {{.Tag "key"}} - tag value + ## Example: "metrics_{{.Name}}" routes each metric name to a separate table + table = "" + + ## Number of rows per insert batch + # batch_size = 1000 + + ## Maximum number of retries on transient errors + # retry_max = 3 + + ## Delay between retries (exponential backoff base) + # retry_delay = "1s" + + ## Column name to store the metric timestamp + # timestamp_column = "timestamp" + + ## Restrict which tags to include as columns (empty = all tags) + # tag_columns = [] + + ## Restrict which fields to include as columns (empty = all fields) + # field_columns = [] + + ## Automatically create the target table if it does not exist + # create_table = false + + ## How long to cache table schema information + # table_schema_cache_ttl = "5m" diff --git a/plugins/outputs/snowpipe_streaming/snowpipe_streaming.go b/plugins/outputs/snowpipe_streaming/snowpipe_streaming.go new file mode 100644 index 0000000000000..cf95153d952c9 --- /dev/null +++ b/plugins/outputs/snowpipe_streaming/snowpipe_streaming.go @@ -0,0 +1,588 @@ +//go:generate ../../../tools/readme_config_includer/generator +package snowpipe_streaming + +import ( + "crypto/rsa" + "crypto/x509" + gosql "database/sql" + _ "embed" + "encoding/pem" + "errors" + "fmt" + "math" + "os" + "strings" + "sync" + "text/template" + "time" + + "github.com/snowflakedb/gosnowflake" + + "github.com/influxdata/telegraf" + "github.com/influxdata/telegraf/config" + "github.com/influxdata/telegraf/plugins/outputs" +) + +//go:embed sample.conf +var sampleConfig string + +type SnowpipeStreaming struct { + Account string `toml:"account"` + User string `toml:"user"` + PrivateKeyPath string `toml:"private_key_path"` + PrivateKeyPass string `toml:"private_key_passphrase"` + Role string `toml:"role"` + Database string `toml:"database"` + Schema string `toml:"schema"` + Table string `toml:"table"` + BatchSize int `toml:"batch_size"` + RetryMax int `toml:"retry_max"` + RetryDelay config.Duration `toml:"retry_delay"` + TimestampColumn string `toml:"timestamp_column"` + TagColumns []string `toml:"tag_columns"` + FieldColumns []string `toml:"field_columns"` + CreateTable bool `toml:"create_table"` + TableSchemaCacheTTL config.Duration `toml:"table_schema_cache_ttl"` + + Log telegraf.Logger `toml:"-"` + + db *gosql.DB + tableTmpl *template.Template + tableHasTmpl bool + tagSet map[string]bool + fieldSet map[string]bool + + schemaMu sync.RWMutex + schemaCache map[string]*tableSchema + + // For testing: allow overriding the connection opener + openDB func() (*gosql.DB, error) +} + +type tableSchema struct { + columns map[string]bool + fetchedAt time.Time +} + +func (*SnowpipeStreaming) SampleConfig() string { + return sampleConfig +} + +func (s *SnowpipeStreaming) Init() error { + if s.Account == "" { + return errors.New(`"account" is required`) + } + if s.User == "" { + return errors.New(`"user" is required`) + } + if s.Database == "" { + return errors.New(`"database" is required`) + } + if s.Schema == "" { + return errors.New(`"schema" is required`) + } + if s.Table == "" { + return errors.New(`"table" is required`) + } + + if strings.Contains(s.Table, "{{") { + tmpl, err := template.New("table").Parse(s.Table) + if err != nil { + return fmt.Errorf("parsing table template: %w", err) + } + s.tableTmpl = tmpl + s.tableHasTmpl = true + } + + if len(s.TagColumns) > 0 { + s.tagSet = make(map[string]bool, len(s.TagColumns)) + for _, t := range s.TagColumns { + s.tagSet[t] = true + } + } + if len(s.FieldColumns) > 0 { + s.fieldSet = make(map[string]bool, len(s.FieldColumns)) + for _, f := range s.FieldColumns { + s.fieldSet[f] = true + } + } + + s.schemaCache = make(map[string]*tableSchema) + + return nil +} + +func (s *SnowpipeStreaming) Connect() error { + var db *gosql.DB + var err error + + if s.openDB != nil { + db, err = s.openDB() + } else { + var dsn string + dsn, err = s.buildDSN() + if err != nil { + return fmt.Errorf("building DSN: %w", err) + } + db, err = gosql.Open("snowflake", dsn) + } + if err != nil { + return fmt.Errorf("opening snowflake connection: %w", err) + } + + if err := db.Ping(); err != nil { + db.Close() + return fmt.Errorf("pinging snowflake: %w", err) + } + + s.db = db + return nil +} + +func (s *SnowpipeStreaming) Close() error { + if s.db != nil { + return s.db.Close() + } + return nil +} + +func (s *SnowpipeStreaming) Write(metrics []telegraf.Metric) error { + grouped := s.groupByTable(metrics) + for tableName, rows := range grouped { + if err := s.writeTable(tableName, rows); err != nil { + return fmt.Errorf("writing to table %q: %w", tableName, err) + } + } + return nil +} + +func (s *SnowpipeStreaming) writeTable(tableName string, metrics []telegraf.Metric) error { + if s.CreateTable { + if err := s.ensureTable(tableName, metrics[0]); err != nil { + return err + } + } + + for start := 0; start < len(metrics); start += s.BatchSize { + end := start + s.BatchSize + if end > len(metrics) { + end = len(metrics) + } + if err := s.insertBatch(tableName, metrics[start:end]); err != nil { + return err + } + } + return nil +} + +func (s *SnowpipeStreaming) insertBatch(tableName string, metrics []telegraf.Metric) error { + if len(metrics) == 0 { + return nil + } + + columns, allValues := s.metricsToRows(metrics) + if len(columns) == 0 { + return nil + } + + query := s.buildInsertQuery(tableName, columns, len(metrics)) + + flat := make([]interface{}, 0, len(columns)*len(metrics)) + for _, row := range allValues { + flat = append(flat, row...) + } + + var lastErr error + for attempt := 0; attempt <= s.RetryMax; attempt++ { + if attempt > 0 { + delay := time.Duration(s.RetryDelay) * (1 << (attempt - 1)) + time.Sleep(delay) + } + + _, err := s.db.Exec(query, flat...) + if err == nil { + return nil + } + lastErr = err + + if !isTransientError(err) { + return fmt.Errorf("insert failed: %w", err) + } + s.Log.Warnf("Transient error on attempt %d/%d for table %q: %v", attempt+1, s.RetryMax+1, tableName, err) + } + + return fmt.Errorf("insert failed after %d retries: %w", s.RetryMax, lastErr) +} + +func (s *SnowpipeStreaming) metricsToRows(metrics []telegraf.Metric) ([]string, [][]interface{}) { + columnOrder := s.buildColumnOrder(metrics[0]) + columnSet := make(map[string]bool, len(columnOrder)) + for _, c := range columnOrder { + columnSet[c] = true + } + + rows := make([][]interface{}, 0, len(metrics)) + for _, m := range metrics { + row := s.metricToRow(m, columnOrder, columnSet) + rows = append(rows, row) + } + + return columnOrder, rows +} + +func (s *SnowpipeStreaming) buildColumnOrder(m telegraf.Metric) []string { + columns := make([]string, 0, 1+len(m.TagList())+len(m.FieldList())) + + if s.TimestampColumn != "" { + columns = append(columns, s.TimestampColumn) + } + + columns = append(columns, "name") + + for _, tag := range m.TagList() { + if s.tagSet != nil && !s.tagSet[tag.Key] { + continue + } + columns = append(columns, tag.Key) + } + + for _, field := range m.FieldList() { + if s.fieldSet != nil && !s.fieldSet[field.Key] { + continue + } + columns = append(columns, field.Key) + } + + return columns +} + +func (s *SnowpipeStreaming) metricToRow(m telegraf.Metric, columns []string, columnSet map[string]bool) []interface{} { + vals := make(map[string]interface{}, len(columns)) + + if s.TimestampColumn != "" { + vals[s.TimestampColumn] = m.Time() + } + vals["name"] = m.Name() + + for _, tag := range m.TagList() { + if s.tagSet != nil && !s.tagSet[tag.Key] { + continue + } + if columnSet[tag.Key] { + vals[tag.Key] = tag.Value + } + } + + for _, field := range m.FieldList() { + if s.fieldSet != nil && !s.fieldSet[field.Key] { + continue + } + if columnSet[field.Key] { + vals[field.Key] = sanitizeFieldValue(field.Value) + } + } + + row := make([]interface{}, len(columns)) + for i, col := range columns { + row[i] = vals[col] + } + return row +} + +func sanitizeFieldValue(v interface{}) interface{} { + if f, ok := v.(float64); ok { + if math.IsNaN(f) || math.IsInf(f, 0) { + return nil + } + } + return v +} + +func (s *SnowpipeStreaming) buildInsertQuery(tableName string, columns []string, numRows int) string { + quoted := make([]string, len(columns)) + for i, c := range columns { + quoted[i] = quoteIdent(c) + } + + placeholders := make([]string, len(columns)) + for i := range columns { + placeholders[i] = "?" + } + rowPlaceholder := "(" + strings.Join(placeholders, ", ") + ")" + + rowPlaceholders := make([]string, numRows) + for i := range numRows { + rowPlaceholders[i] = rowPlaceholder + } + + return fmt.Sprintf("INSERT INTO %s.%s.%s (%s) VALUES %s", + quoteIdent(s.Database), + quoteIdent(s.Schema), + quoteIdent(tableName), + strings.Join(quoted, ", "), + strings.Join(rowPlaceholders, ", "), + ) +} + +func (s *SnowpipeStreaming) groupByTable(metrics []telegraf.Metric) map[string][]telegraf.Metric { + groups := make(map[string][]telegraf.Metric) + for _, m := range metrics { + name := s.resolveTableName(m) + groups[name] = append(groups[name], m) + } + return groups +} + +func (s *SnowpipeStreaming) resolveTableName(m telegraf.Metric) string { + if !s.tableHasTmpl { + return s.Table + } + var b strings.Builder + if err := s.tableTmpl.Execute(&b, m); err != nil { + s.Log.Errorf("Executing table template: %v, falling back to literal table name", err) + return s.Table + } + return b.String() +} + +func (s *SnowpipeStreaming) ensureTable(tableName string, sample telegraf.Metric) error { + s.schemaMu.RLock() + cached, exists := s.schemaCache[tableName] + s.schemaMu.RUnlock() + + ttl := time.Duration(s.TableSchemaCacheTTL) + if exists && time.Since(cached.fetchedAt) < ttl { + return s.evolveSchema(tableName, sample, cached) + } + + s.schemaMu.Lock() + defer s.schemaMu.Unlock() + + if err := s.createTableIfNotExists(tableName, sample); err != nil { + return err + } + + schema, err := s.fetchTableSchema(tableName) + if err != nil { + return err + } + s.schemaCache[tableName] = schema + + s.evolveSchemaLocked(tableName, sample, schema) + return nil +} + +func (s *SnowpipeStreaming) createTableIfNotExists(tableName string, sample telegraf.Metric) error { + columns := s.buildColumnOrder(sample) + colDefs := make([]string, len(columns)) + for i, col := range columns { + colDefs[i] = fmt.Sprintf("%s %s", quoteIdent(col), s.sqlTypeFor(col, sample)) + } + + query := fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %s.%s.%s (%s)", + quoteIdent(s.Database), + quoteIdent(s.Schema), + quoteIdent(tableName), + strings.Join(colDefs, ", "), + ) + + _, err := s.db.Exec(query) + if err != nil { + return fmt.Errorf("creating table %q: %w", tableName, err) + } + return nil +} + +func (s *SnowpipeStreaming) sqlTypeFor(col string, m telegraf.Metric) string { + if col == s.TimestampColumn { + return "TIMESTAMP_NTZ" + } + if col == "name" { + return "VARCHAR" + } + for _, tag := range m.TagList() { + if tag.Key == col { + return "VARCHAR" + } + } + for _, field := range m.FieldList() { + if field.Key == col { + return goTypeToSnowflake(field.Value) + } + } + return "VARCHAR" +} + +func goTypeToSnowflake(v interface{}) string { + switch v.(type) { + case int, int8, int16, int32, int64: + return "NUMBER" + case uint, uint8, uint16, uint32, uint64: + return "NUMBER" + case float32, float64: + return "DOUBLE" + case bool: + return "BOOLEAN" + default: + return "VARCHAR" + } +} + +func (s *SnowpipeStreaming) fetchTableSchema(tableName string) (*tableSchema, error) { + query := fmt.Sprintf( //nolint:gosec // G201: quoteIdent sanitises identifier + "SELECT COLUMN_NAME FROM %s.INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?", + quoteIdent(s.Database), + ) + + rows, err := s.db.Query(query, strings.ToUpper(s.Schema), strings.ToUpper(tableName)) + if err != nil { + return nil, fmt.Errorf("fetching schema for %q: %w", tableName, err) + } + defer rows.Close() + + cols := make(map[string]bool) + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + cols[strings.ToUpper(name)] = true + } + + return &tableSchema{columns: cols, fetchedAt: time.Now()}, rows.Err() +} + +func (s *SnowpipeStreaming) evolveSchema(tableName string, sample telegraf.Metric, schema *tableSchema) error { + s.schemaMu.Lock() + defer s.schemaMu.Unlock() + s.evolveSchemaLocked(tableName, sample, schema) + return nil +} + +func (s *SnowpipeStreaming) evolveSchemaLocked(tableName string, sample telegraf.Metric, schema *tableSchema) { + needed := s.buildColumnOrder(sample) + for _, col := range needed { + if schema.columns[strings.ToUpper(col)] { + continue + } + sqlType := s.sqlTypeFor(col, sample) + alter := fmt.Sprintf("ALTER TABLE %s.%s.%s ADD COLUMN %s %s", + quoteIdent(s.Database), + quoteIdent(s.Schema), + quoteIdent(tableName), + quoteIdent(col), + sqlType, + ) + if _, err := s.db.Exec(alter); err != nil { + s.Log.Warnf("Failed to add column %q to %q: %v", col, tableName, err) + continue + } + schema.columns[strings.ToUpper(col)] = true + s.Log.Infof("Added column %q (%s) to table %q", col, sqlType, tableName) + } +} + +func (s *SnowpipeStreaming) buildDSN() (string, error) { + cfg := &gosnowflake.Config{ + Account: s.Account, + User: s.User, + Database: s.Database, + Schema: s.Schema, + Role: s.Role, + } + + if s.PrivateKeyPath != "" { + key, err := loadPrivateKey(s.PrivateKeyPath, s.PrivateKeyPass) + if err != nil { + return "", fmt.Errorf("loading private key: %w", err) + } + cfg.Authenticator = gosnowflake.AuthTypeJwt + cfg.PrivateKey = key + } + + dsn, err := gosnowflake.DSN(cfg) + if err != nil { + return "", fmt.Errorf("building snowflake DSN: %w", err) + } + return dsn, nil +} + +func loadPrivateKey(path, passphrase string) (*rsa.PrivateKey, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("reading key file: %w", err) + } + + block, _ := pem.Decode(data) + if block == nil { + return nil, errors.New("failed to decode PEM block") + } + + var keyBytes []byte + if passphrase != "" { + keyBytes, err = x509.DecryptPEMBlock(block, []byte(passphrase)) //nolint:staticcheck // SA1019: required for PKCS#5 keys + if err != nil { + return nil, fmt.Errorf("decrypting private key: %w", err) + } + } else { + keyBytes = block.Bytes + } + + parsed, err := x509.ParsePKCS8PrivateKey(keyBytes) + if err != nil { + // Fall back to PKCS1 + key, err2 := x509.ParsePKCS1PrivateKey(keyBytes) + if err2 != nil { + return nil, fmt.Errorf("parsing private key (PKCS8: %w, PKCS1: %w)", err, err2) + } + return key, nil + } + + rsaKey, ok := parsed.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("private key is not RSA") + } + return rsaKey, nil +} + +func isTransientError(err error) bool { + if err == nil { + return false + } + var sfErr *gosnowflake.SnowflakeError + if errors.As(err, &sfErr) { + // HTTP 429, 503, and internal server errors are transient + switch sfErr.Number { + case gosnowflake.ErrCodeServiceUnavailable: + return true + case gosnowflake.ErrCodeFailedToConnect: + return true + } + } + + msg := err.Error() + return strings.Contains(msg, "connection refused") || + strings.Contains(msg, "connection reset") || + strings.Contains(msg, "i/o timeout") || + strings.Contains(msg, "service unavailable") +} + +func quoteIdent(name string) string { + return `"` + strings.ReplaceAll(name, `"`, `""`) + `"` +} + +func init() { + outputs.Add("snowpipe_streaming", func() telegraf.Output { + return &SnowpipeStreaming{ + BatchSize: 1000, + RetryMax: 3, + RetryDelay: config.Duration(1 * time.Second), + TimestampColumn: "timestamp", + TableSchemaCacheTTL: config.Duration(5 * time.Minute), + } + }) +} + +// Compile-time interface check +var _ telegraf.Output = (*SnowpipeStreaming)(nil) diff --git a/plugins/outputs/snowpipe_streaming/snowpipe_streaming_test.go b/plugins/outputs/snowpipe_streaming/snowpipe_streaming_test.go new file mode 100644 index 0000000000000..ffbaef352d44c --- /dev/null +++ b/plugins/outputs/snowpipe_streaming/snowpipe_streaming_test.go @@ -0,0 +1,785 @@ +package snowpipe_streaming + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "io" + "math" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/influxdata/telegraf" + "github.com/influxdata/telegraf/config" + "github.com/influxdata/telegraf/metric" + "github.com/influxdata/telegraf/testutil" +) + +// --------------------------------------------------------------------------- +// Mock SQL driver +// --------------------------------------------------------------------------- + +// mockDriver records every query executed against it. +type mockDriver struct{} + +type mockConn struct { + mu sync.Mutex + queries []executedQuery + closed bool + + // If set, Exec returns this error for the first N calls + execErr error + execErrCount int32 // atomic: how many execs should fail +} + +type executedQuery struct { + query string + args []driver.Value +} + +type mockStmt struct { + conn *mockConn + query string +} + +type mockTx struct { + conn *mockConn +} + +type mockRows struct { + columns []string + data [][]driver.Value + pos int +} + +var ( + globalMockConn *mockConn + globalMockMu sync.Mutex +) + +func resetGlobalMock() *mockConn { + globalMockMu.Lock() + defer globalMockMu.Unlock() + globalMockConn = &mockConn{} + return globalMockConn +} + +func getGlobalMock() *mockConn { + globalMockMu.Lock() + defer globalMockMu.Unlock() + return globalMockConn +} + +func init() { + sql.Register("snowflake_mock", &mockDriver{}) +} + +func (*mockDriver) Open(_ string) (driver.Conn, error) { + c := getGlobalMock() + if c == nil { + return nil, errors.New("no mock conn configured") + } + return c, nil +} + +func (c *mockConn) Prepare(query string) (driver.Stmt, error) { + return &mockStmt{conn: c, query: query}, nil +} + +func (c *mockConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +func (c *mockConn) Begin() (driver.Tx, error) { + return &mockTx{conn: c}, nil +} + +func (c *mockConn) Exec(query string, args []driver.Value) (driver.Result, error) { + c.mu.Lock() + c.queries = append(c.queries, executedQuery{query: query, args: args}) + c.mu.Unlock() + + if c.execErr != nil { + remaining := atomic.AddInt32(&c.execErrCount, -1) + if remaining >= 0 { + return nil, c.execErr + } + } + + return mockResult{}, nil +} + +func (c *mockConn) Query(query string, args []driver.Value) (driver.Rows, error) { + c.mu.Lock() + c.queries = append(c.queries, executedQuery{query: query, args: args}) + c.mu.Unlock() + + if strings.Contains(strings.ToUpper(query), "INFORMATION_SCHEMA.COLUMNS") { + return &mockRows{ + columns: []string{"COLUMN_NAME"}, + data: make([][]driver.Value, 0), + }, nil + } + + return &mockRows{columns: make([]string, 0), data: make([][]driver.Value, 0)}, nil +} + +func (c *mockConn) getQueries() []executedQuery { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]executedQuery, len(c.queries)) + copy(out, c.queries) + return out +} + +func (*mockStmt) Close() error { return nil } +func (*mockStmt) NumInput() int { return -1 } + +func (s *mockStmt) Exec(args []driver.Value) (driver.Result, error) { + return s.conn.Exec(s.query, args) +} + +func (s *mockStmt) Query(args []driver.Value) (driver.Rows, error) { + return s.conn.Query(s.query, args) +} + +func (*mockTx) Commit() error { return nil } +func (*mockTx) Rollback() error { return nil } + +type mockResult struct{} + +func (mockResult) LastInsertId() (int64, error) { return 0, nil } +func (mockResult) RowsAffected() (int64, error) { return 1, nil } + +func (r *mockRows) Columns() []string { return r.columns } +func (*mockRows) Close() error { return nil } +func (r *mockRows) Next(dest []driver.Value) error { + if r.pos >= len(r.data) { + return io.EOF + } + copy(dest, r.data[r.pos]) + r.pos++ + return nil +} + +// --------------------------------------------------------------------------- +// Helper: create a plugin wired to the mock driver +// --------------------------------------------------------------------------- + +func newTestPlugin(t *testing.T) *SnowpipeStreaming { + t.Helper() + s := &SnowpipeStreaming{ + Account: "test_account", + User: "test_user", + Database: "TEST_DB", + Schema: "PUBLIC", + Table: "METRICS", + BatchSize: 1000, + RetryMax: 3, + RetryDelay: config.Duration(10 * time.Millisecond), + TimestampColumn: "timestamp", + TableSchemaCacheTTL: config.Duration(5 * time.Minute), + Log: testutil.Logger{}, + } + require.NoError(t, s.Init()) + return s +} + +func connectTestPlugin(t *testing.T, s *SnowpipeStreaming) { + t.Helper() + mc := resetGlobalMock() + _ = mc + + s.openDB = func() (*sql.DB, error) { + return sql.Open("snowflake_mock", "mock") + } + require.NoError(t, s.Connect()) + t.Cleanup(func() { s.Close() }) +} + +func testMetric(name string, tags map[string]string, fields map[string]interface{}, ts time.Time) telegraf.Metric { + m := metric.New(name, tags, fields, ts) + return m +} + +// --------------------------------------------------------------------------- +// Unit Tests +// --------------------------------------------------------------------------- + +func TestInit(t *testing.T) { + tests := []struct { + name string + plugin *SnowpipeStreaming + wantErr string + }{ + { + name: "missing account", + plugin: &SnowpipeStreaming{User: "u", Database: "d", Schema: "s", Table: "t"}, + wantErr: `"account" is required`, + }, + { + name: "missing user", + plugin: &SnowpipeStreaming{Account: "a", Database: "d", Schema: "s", Table: "t"}, + wantErr: `"user" is required`, + }, + { + name: "missing database", + plugin: &SnowpipeStreaming{Account: "a", User: "u", Schema: "s", Table: "t"}, + wantErr: `"database" is required`, + }, + { + name: "missing schema", + plugin: &SnowpipeStreaming{Account: "a", User: "u", Database: "d", Table: "t"}, + wantErr: `"schema" is required`, + }, + { + name: "missing table", + plugin: &SnowpipeStreaming{Account: "a", User: "u", Database: "d", Schema: "s"}, + wantErr: `"table" is required`, + }, + { + name: "valid minimal config", + plugin: &SnowpipeStreaming{Account: "a", User: "u", Database: "d", Schema: "s", Table: "t"}, + }, + { + name: "invalid table template", + plugin: &SnowpipeStreaming{Account: "a", User: "u", Database: "d", Schema: "s", Table: "{{.Invalid"}, + wantErr: "parsing table template", + }, + { + name: "valid template table", + plugin: &SnowpipeStreaming{Account: "a", User: "u", Database: "d", Schema: "s", Table: "metrics_{{.Name}}"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.plugin.Init() + if tc.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestMetricToRow(t *testing.T) { + s := newTestPlugin(t) + + ts := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + m := testMetric("cpu", map[string]string{"host": "server1"}, map[string]interface{}{ + "usage_idle": float64(95.5), + "count": int64(42), + }, ts) + + columns := s.buildColumnOrder(m) + columnSet := make(map[string]bool, len(columns)) + for _, c := range columns { + columnSet[c] = true + } + row := s.metricToRow(m, columns, columnSet) + + require.Contains(t, columns, "timestamp") + require.Contains(t, columns, "name") + require.Contains(t, columns, "host") + require.Contains(t, columns, "usage_idle") + require.Contains(t, columns, "count") + + // Check that row values match + valMap := make(map[string]interface{}, len(columns)) + for i, col := range columns { + valMap[col] = row[i] + } + + require.Equal(t, ts, valMap["timestamp"]) + require.Equal(t, "cpu", valMap["name"]) + require.Equal(t, "server1", valMap["host"]) + require.InDelta(t, float64(95.5), valMap["usage_idle"], 1e-9) + require.Equal(t, int64(42), valMap["count"]) +} + +func TestMetricToRowNaNInf(t *testing.T) { + s := newTestPlugin(t) + + ts := time.Now() + m := testMetric("test", nil, map[string]interface{}{ + "nan_val": math.NaN(), + "ok_val": float64(1.0), + }, ts) + + columns := s.buildColumnOrder(m) + columnSet := make(map[string]bool, len(columns)) + for _, c := range columns { + columnSet[c] = true + } + row := s.metricToRow(m, columns, columnSet) + + valMap := make(map[string]interface{}, len(columns)) + for i, col := range columns { + valMap[col] = row[i] + } + + require.Nil(t, valMap["nan_val"]) + require.InDelta(t, float64(1.0), valMap["ok_val"], 1e-9) +} + +func TestTableNameTemplate(t *testing.T) { + s := &SnowpipeStreaming{ + Account: "a", + User: "u", + Database: "d", + Schema: "s", + Table: "metrics_{{.Name}}", + } + require.NoError(t, s.Init()) + + m := testMetric("cpu", map[string]string{"host": "h1"}, map[string]interface{}{"val": 1.0}, time.Now()) + name := s.resolveTableName(m) + require.Equal(t, "metrics_cpu", name) + + m2 := testMetric("mem", nil, map[string]interface{}{"val": 1.0}, time.Now()) + name2 := s.resolveTableName(m2) + require.Equal(t, "metrics_mem", name2) +} + +func TestTableNameNoTemplate(t *testing.T) { + s := &SnowpipeStreaming{ + Account: "a", + User: "u", + Database: "d", + Schema: "s", + Table: "fixed_table", + } + require.NoError(t, s.Init()) + + m := testMetric("cpu", nil, map[string]interface{}{"val": 1.0}, time.Now()) + require.Equal(t, "fixed_table", s.resolveTableName(m)) +} + +func TestBatching(t *testing.T) { + s := newTestPlugin(t) + s.BatchSize = 3 + connectTestPlugin(t, s) + + ts := time.Now() + metrics := make([]telegraf.Metric, 7) + for i := range 7 { + metrics[i] = testMetric("cpu", nil, map[string]interface{}{ + "val": float64(i), + }, ts) + } + + require.NoError(t, s.Write(metrics)) + + mc := getGlobalMock() + queries := mc.getQueries() + + // With batch_size=3 and 7 metrics, expect 3 INSERT queries (3+3+1) + insertCount := 0 + for _, q := range queries { + if strings.HasPrefix(q.query, "INSERT INTO") { + insertCount++ + } + } + require.Equal(t, 3, insertCount, "expected 3 batch inserts for 7 rows with batch_size=3") + + // Verify the last batch has only 1 row's worth of placeholders + lastInsert := queries[len(queries)-1] + require.Contains(t, lastInsert.query, "VALUES (?") + // Count the number of VALUES groups + valuesCount := strings.Count(lastInsert.query, "(?,") + // Last batch should be 1 row only + require.Equal(t, 1, valuesCount, "last batch should contain 1 row") +} + +func TestRetryLogic(t *testing.T) { + s := newTestPlugin(t) + s.RetryMax = 2 + s.RetryDelay = config.Duration(1 * time.Millisecond) + + mc := resetGlobalMock() + mc.execErr = errors.New("connection refused") + atomic.StoreInt32(&mc.execErrCount, 2) // fail first 2, succeed on 3rd + + s.openDB = func() (*sql.DB, error) { + return sql.Open("snowflake_mock", "mock") + } + require.NoError(t, s.Connect()) + t.Cleanup(func() { s.Close() }) + + ts := time.Now() + m := testMetric("test", nil, map[string]interface{}{"val": 1.0}, ts) + require.NoError(t, s.Write([]telegraf.Metric{m})) + + queries := mc.getQueries() + insertCount := 0 + for _, q := range queries { + if strings.HasPrefix(q.query, "INSERT INTO") { + insertCount++ + } + } + require.Equal(t, 3, insertCount, "expected 3 attempts (1 initial + 2 retries)") +} + +func TestRetryExhausted(t *testing.T) { + s := newTestPlugin(t) + s.RetryMax = 1 + s.RetryDelay = config.Duration(1 * time.Millisecond) + + mc := resetGlobalMock() + mc.execErr = errors.New("connection refused") + atomic.StoreInt32(&mc.execErrCount, 100) // always fail + + s.openDB = func() (*sql.DB, error) { + return sql.Open("snowflake_mock", "mock") + } + require.NoError(t, s.Connect()) + t.Cleanup(func() { s.Close() }) + + ts := time.Now() + m := testMetric("test", nil, map[string]interface{}{"val": 1.0}, ts) + err := s.Write([]telegraf.Metric{m}) + require.Error(t, err) + require.Contains(t, err.Error(), "after 1 retries") +} + +func TestNonTransientErrorNoRetry(t *testing.T) { + s := newTestPlugin(t) + s.RetryMax = 3 + s.RetryDelay = config.Duration(1 * time.Millisecond) + + mc := resetGlobalMock() + mc.execErr = errors.New("sql compilation error: invalid identifier") + atomic.StoreInt32(&mc.execErrCount, 100) + + s.openDB = func() (*sql.DB, error) { + return sql.Open("snowflake_mock", "mock") + } + require.NoError(t, s.Connect()) + t.Cleanup(func() { s.Close() }) + + ts := time.Now() + m := testMetric("test", nil, map[string]interface{}{"val": 1.0}, ts) + err := s.Write([]telegraf.Metric{m}) + require.Error(t, err) + require.Contains(t, err.Error(), "insert failed") + + queries := mc.getQueries() + insertCount := 0 + for _, q := range queries { + if strings.HasPrefix(q.query, "INSERT INTO") { + insertCount++ + } + } + require.Equal(t, 1, insertCount, "non-transient error should not be retried") +} + +func TestTagFieldFiltering(t *testing.T) { + t.Run("filter tags", func(t *testing.T) { + s := &SnowpipeStreaming{ + Account: "a", + User: "u", + Database: "d", + Schema: "s", + Table: "t", + BatchSize: 1000, + TimestampColumn: "timestamp", + TagColumns: []string{"host"}, + } + require.NoError(t, s.Init()) + + m := testMetric("cpu", map[string]string{"host": "h1", "region": "us"}, map[string]interface{}{"val": 1.0}, time.Now()) + columns := s.buildColumnOrder(m) + + require.Contains(t, columns, "host") + require.NotContains(t, columns, "region") + }) + + t.Run("filter fields", func(t *testing.T) { + s := &SnowpipeStreaming{ + Account: "a", + User: "u", + Database: "d", + Schema: "s", + Table: "t", + BatchSize: 1000, + TimestampColumn: "timestamp", + FieldColumns: []string{"usage_idle"}, + } + require.NoError(t, s.Init()) + + m := testMetric("cpu", nil, map[string]interface{}{ + "usage_idle": 95.5, + "usage_system": 4.5, + }, time.Now()) + columns := s.buildColumnOrder(m) + + require.Contains(t, columns, "usage_idle") + require.NotContains(t, columns, "usage_system") + }) + + t.Run("no filter includes all", func(t *testing.T) { + s := newTestPlugin(t) + + m := testMetric("cpu", map[string]string{"host": "h1", "region": "us"}, map[string]interface{}{ + "usage_idle": 95.5, + "usage_system": 4.5, + }, time.Now()) + columns := s.buildColumnOrder(m) + + require.Contains(t, columns, "host") + require.Contains(t, columns, "region") + require.Contains(t, columns, "usage_idle") + require.Contains(t, columns, "usage_system") + }) +} + +// --------------------------------------------------------------------------- +// Integration-style tests (with mocked Snowflake) +// --------------------------------------------------------------------------- + +func TestConnectAndWriteIntegration(t *testing.T) { + s := newTestPlugin(t) + connectTestPlugin(t, s) + + ts := time.Date(2024, 6, 1, 12, 0, 0, 0, time.UTC) + metrics := []telegraf.Metric{ + testMetric("cpu", map[string]string{"host": "srv1"}, map[string]interface{}{ + "usage_idle": float64(95.5), + }, ts), + testMetric("cpu", map[string]string{"host": "srv2"}, map[string]interface{}{ + "usage_idle": float64(80.0), + }, ts), + } + + require.NoError(t, s.Write(metrics)) + + mc := getGlobalMock() + queries := mc.getQueries() + + // Should have exactly one INSERT + var insertQ executedQuery + found := false + for _, q := range queries { + if strings.HasPrefix(q.query, "INSERT INTO") { + insertQ = q + found = true + break + } + } + require.True(t, found, "expected an INSERT query") + + // Verify fully qualified table + require.Contains(t, insertQ.query, `"TEST_DB"."PUBLIC"."METRICS"`) + + // Verify columns + require.Contains(t, insertQ.query, `"timestamp"`) + require.Contains(t, insertQ.query, `"name"`) + require.Contains(t, insertQ.query, `"host"`) + require.Contains(t, insertQ.query, `"usage_idle"`) + + // 2 rows, each with 4 columns = 8 args + require.Len(t, insertQ.args, 8) +} + +func TestCreateTable(t *testing.T) { + s := newTestPlugin(t) + s.CreateTable = true + connectTestPlugin(t, s) + + ts := time.Now() + m := testMetric("new_metric", map[string]string{"host": "h1"}, map[string]interface{}{ + "val": int64(42), + "active": true, + }, ts) + + require.NoError(t, s.Write([]telegraf.Metric{m})) + + mc := getGlobalMock() + queries := mc.getQueries() + + // Find CREATE TABLE query + var createQ string + for _, q := range queries { + if strings.Contains(q.query, "CREATE TABLE IF NOT EXISTS") { + createQ = q.query + break + } + } + require.NotEmpty(t, createQ, "expected a CREATE TABLE query") + require.Contains(t, createQ, `"TEST_DB"."PUBLIC"."METRICS"`) + require.Contains(t, createQ, `"timestamp" TIMESTAMP_NTZ`) + require.Contains(t, createQ, `"name" VARCHAR`) + require.Contains(t, createQ, `"host" VARCHAR`) + require.Contains(t, createQ, `"val" NUMBER`) + require.Contains(t, createQ, `"active" BOOLEAN`) +} + +func TestSchemaEvolution(t *testing.T) { + s := newTestPlugin(t) + s.CreateTable = true + connectTestPlugin(t, s) + + mc := getGlobalMock() + + ts := time.Now() + m1 := testMetric("evolve", nil, map[string]interface{}{"val": int64(1)}, ts) + require.NoError(t, s.Write([]telegraf.Metric{m1})) + + // Now write a metric with an extra field — should trigger ALTER TABLE + m2 := testMetric("evolve", nil, map[string]interface{}{ + "val": int64(2), + "new_col": "hello", + }, ts) + require.NoError(t, s.Write([]telegraf.Metric{m2})) + + queries := mc.getQueries() + var alterFound bool + for _, q := range queries { + if strings.Contains(q.query, "ALTER TABLE") && strings.Contains(q.query, `"new_col"`) { + alterFound = true + break + } + } + require.True(t, alterFound, "expected ALTER TABLE to add new_col") +} + +func TestBatchRetry(t *testing.T) { + s := newTestPlugin(t) + s.BatchSize = 2 + s.RetryMax = 2 + s.RetryDelay = config.Duration(1 * time.Millisecond) + + mc := resetGlobalMock() + mc.execErr = errors.New("i/o timeout") + atomic.StoreInt32(&mc.execErrCount, 1) // fail first attempt, succeed second + + s.openDB = func() (*sql.DB, error) { + return sql.Open("snowflake_mock", "mock") + } + require.NoError(t, s.Connect()) + t.Cleanup(func() { s.Close() }) + + ts := time.Now() + metrics := []telegraf.Metric{ + testMetric("cpu", nil, map[string]interface{}{"val": 1.0}, ts), + testMetric("cpu", nil, map[string]interface{}{"val": 2.0}, ts), + } + + require.NoError(t, s.Write(metrics)) +} + +func TestConcurrentWrites(t *testing.T) { + s := newTestPlugin(t) + s.Table = "metrics_{{.Name}}" + require.NoError(t, s.Init()) + connectTestPlugin(t, s) + + ts := time.Now() + var wg sync.WaitGroup + + for i := range 10 { + wg.Add(1) + go func(i int) { + defer wg.Done() + name := fmt.Sprintf("metric_%d", i) + m := testMetric(name, nil, map[string]interface{}{"val": float64(i)}, ts) + assert.NoError(t, s.Write([]telegraf.Metric{m})) + }(i) + } + + wg.Wait() + + mc := getGlobalMock() + queries := mc.getQueries() + insertCount := 0 + for _, q := range queries { + if strings.HasPrefix(q.query, "INSERT INTO") { + insertCount++ + } + } + require.Equal(t, 10, insertCount, "expected 10 inserts from 10 concurrent writers") +} + +func TestBuildInsertQuery(t *testing.T) { + s := newTestPlugin(t) + + query := s.buildInsertQuery("my_table", []string{"timestamp", "name", "host", "value"}, 2) + require.Contains(t, query, `"TEST_DB"."PUBLIC"."my_table"`) + require.Contains(t, query, `"timestamp", "name", "host", "value"`) + require.Contains(t, query, "(?, ?, ?, ?), (?, ?, ?, ?)") +} + +func TestQuoteIdent(t *testing.T) { + require.Equal(t, `"simple"`, quoteIdent("simple")) + require.Equal(t, `"has""quote"`, quoteIdent(`has"quote`)) + require.Equal(t, `"has space"`, quoteIdent("has space")) +} + +func TestGroupByTable(t *testing.T) { + s := &SnowpipeStreaming{ + Account: "a", + User: "u", + Database: "d", + Schema: "s", + Table: "metrics_{{.Name}}", + } + require.NoError(t, s.Init()) + + ts := time.Now() + metrics := []telegraf.Metric{ + testMetric("cpu", nil, map[string]interface{}{"val": 1.0}, ts), + testMetric("mem", nil, map[string]interface{}{"val": 2.0}, ts), + testMetric("cpu", nil, map[string]interface{}{"val": 3.0}, ts), + } + + groups := s.groupByTable(metrics) + require.Len(t, groups, 2) + require.Len(t, groups["metrics_cpu"], 2) + require.Len(t, groups["metrics_mem"], 1) +} + +func TestGoTypeToSnowflake(t *testing.T) { + require.Equal(t, "NUMBER", goTypeToSnowflake(int64(1))) + require.Equal(t, "NUMBER", goTypeToSnowflake(uint64(1))) + require.Equal(t, "DOUBLE", goTypeToSnowflake(float64(1.0))) + require.Equal(t, "BOOLEAN", goTypeToSnowflake(true)) + require.Equal(t, "VARCHAR", goTypeToSnowflake("text")) +} + +func TestIsTransientError(t *testing.T) { + require.False(t, isTransientError(nil)) + require.True(t, isTransientError(errors.New("connection refused"))) + require.True(t, isTransientError(errors.New("i/o timeout"))) + require.True(t, isTransientError(errors.New("service unavailable"))) + require.True(t, isTransientError(errors.New("connection reset by peer"))) + require.False(t, isTransientError(errors.New("sql compilation error"))) +} + +func TestSanitizeFieldValue(t *testing.T) { + require.InDelta(t, float64(1.0), sanitizeFieldValue(float64(1.0)), 1e-9) + require.Equal(t, "hello", sanitizeFieldValue("hello")) + require.Nil(t, sanitizeFieldValue(math.NaN())) + require.Nil(t, sanitizeFieldValue(math.Inf(1))) + require.Equal(t, int64(42), sanitizeFieldValue(int64(42))) +} + +func TestSampleConfig(t *testing.T) { + s := &SnowpipeStreaming{} + conf := s.SampleConfig() + require.NotEmpty(t, conf) + require.Contains(t, conf, "snowpipe_streaming") + require.Contains(t, conf, "account") +}