diff --git a/go/connection.go b/go/connection.go index 1a7e050..668b9de 100644 --- a/go/connection.go +++ b/go/connection.go @@ -514,9 +514,21 @@ func (c *connectionImpl) toArrowField(columnInfo driverbase.ColumnInfo) arrow.Fi field.Type = arrow.FixedWidthTypes.Timestamp_ns } case "GEOGRAPHY": - fallthrough + // With GEOGRAPHY_OUTPUT_FORMAT=WKB, data arrives as binary WKB. + // GEOGRAPHY is always WGS84 (SRID 4326). + field.Type = arrow.BinaryTypes.Binary + field.Metadata = arrow.MetadataFrom(map[string]string{ + "ARROW:extension:name": "geoarrow.wkb", + "ARROW:extension:metadata": `{"crs":"EPSG:4326"}`, + }) case "GEOMETRY": - field.Type = arrow.BinaryTypes.String + // With GEOMETRY_OUTPUT_FORMAT=WKB, data arrives as binary WKB. + // TODO: SRID for GEOMETRY requires inspecting data or a separate query. + // Same cross-driver issue as adbc-drivers/redshift#2 and adbc-drivers/databricks#350. + field.Type = arrow.BinaryTypes.Binary + field.Metadata = arrow.MetadataFrom(map[string]string{ + "ARROW:extension:name": "geoarrow.wkb", + }) case "VECTOR": // despite the fact that Snowflake *does* support returning data // for VECTOR typed columns as Arrow FixedSizeLists, there's no way @@ -559,9 +571,16 @@ func descToField(name, typ, isnull, primary string, comment sql.NullString, maxT case "VARIANT": field.Type = arrow.BinaryTypes.String case "GEOGRAPHY": - fallthrough + field.Type = arrow.BinaryTypes.Binary + field.Metadata = arrow.MetadataFrom(map[string]string{ + "ARROW:extension:name": "geoarrow.wkb", + "ARROW:extension:metadata": `{"crs":"EPSG:4326"}`, + }) case "GEOMETRY": - field.Type = arrow.BinaryTypes.String + field.Type = arrow.BinaryTypes.Binary + field.Metadata = arrow.MetadataFrom(map[string]string{ + "ARROW:extension:name": "geoarrow.wkb", + }) case "BOOLEAN": field.Type = arrow.FixedWidthTypes.Boolean default: @@ -623,6 +642,68 @@ func descToField(name, typ, isnull, primary string, comment sql.NullString, maxT return } +// detectGeoColumnsFromQuery attempts to extract a table name from a SQL query +// and runs DESCRIBE TABLE to identify GEOGRAPHY/GEOMETRY columns. +// Returns nil if the table name can't be determined or no geo columns exist. +// This works for table scans (SELECT ... FROM schema.table) which is the common +// case for adbc_scan. Arbitrary queries return nil — data is correct but without +// geoarrow metadata. TODO: Support arbitrary queries. +func (c *connectionImpl) detectGeoColumnsFromQuery(ctx context.Context, query string) map[string]geoColumnType { + // Simple extraction: find "FROM " in the query. + // Handles: SELECT ... FROM schema.table, SELECT ... FROM "schema"."table", etc. + upper := strings.ToUpper(strings.TrimSpace(query)) + fromIdx := strings.Index(upper, "FROM ") + if fromIdx == -1 { + return nil + } + + // Extract table reference after FROM + rest := strings.TrimSpace(query[fromIdx+5:]) + // Take until next SQL keyword or end + endIdx := len(rest) + for _, kw := range []string{" WHERE ", " ORDER ", " GROUP ", " HAVING ", " LIMIT ", " UNION ", " JOIN ", " LEFT ", " RIGHT ", " INNER ", " OUTER ", " CROSS "} { + if idx := strings.Index(strings.ToUpper(rest), kw); idx != -1 && idx < endIdx { + endIdx = idx + } + } + tableName := strings.TrimSpace(rest[:endIdx]) + if tableName == "" { + return nil + } + + // Run DESCRIBE TABLE to get original column types + rows, err := c.cn.QueryContext(ctx, "DESC TABLE "+tableName, nil) + if err != nil { + return nil + } + defer func() { _ = rows.Close() }() + + geoCols := make(map[string]geoColumnType) + dest := make([]driver.Value, len(rows.Columns())) + for { + if err := rows.Next(dest); err != nil { + break + } + if len(dest) < 2 { + continue + } + name, _ := dest[0].(string) + typ, _ := dest[1].(string) + typ = strings.ToUpper(typ) + + if strings.HasPrefix(typ, "GEOGRAPHY") { + geoCols[name] = geoColumnGeography + } else if strings.HasPrefix(typ, "GEOMETRY") { + geoCols[name] = geoColumnGeometry + } + } + + if len(geoCols) == 0 { + return nil + } + return geoCols +} + func (c *connectionImpl) getStringQuery(query string) (value string, err error) { result, err := c.cn.QueryContext(context.Background(), query, nil) if err != nil { diff --git a/go/database.go b/go/database.go index fb0e9d2..fdb38da 100644 --- a/go/database.go +++ b/go/database.go @@ -532,6 +532,22 @@ func (d *databaseImpl) Open(ctx context.Context) (adbcConnection adbc.Connection ctx, span := driverbase.StartSpan(ctx, "databaseImpl.Open", d) defer driverbase.EndSpan(span, err) + // Set WKB output for geospatial columns so they arrive as binary WKB + // instead of GeoJSON strings. Geo column detection is done separately + // via DESCRIBE TABLE (catalog metadata is unaffected by output format). + // Note: Snowflake's REST API rowtype metadata reports "binary" instead of + // "geography"/"geometry" when WKB format is set — we've reported this to Snowflake. + if d.cfg.Params == nil { + d.cfg.Params = make(map[string]*string) + } + wkb := "WKB" + if _, ok := d.cfg.Params["GEOGRAPHY_OUTPUT_FORMAT"]; !ok { + d.cfg.Params["GEOGRAPHY_OUTPUT_FORMAT"] = &wkb + } + if _, ok := d.cfg.Params["GEOMETRY_OUTPUT_FORMAT"]; !ok { + d.cfg.Params["GEOMETRY_OUTPUT_FORMAT"] = &wkb + } + connector := gosnowflake.NewConnector(drv, *d.cfg) ctx = gosnowflake.WithArrowAllocator( diff --git a/go/record_reader.go b/go/record_reader.go index 41aeede..6bd83cd 100644 --- a/go/record_reader.go +++ b/go/record_reader.go @@ -48,6 +48,15 @@ import ( const MetadataKeySnowflakeType = "SNOWFLAKE_TYPE" +// geoColumnType identifies the Snowflake geospatial type of a column. +type geoColumnType int + +const ( + geoColumnNone geoColumnType = iota + geoColumnGeography // GEOGRAPHY — always WGS84/SRID 4326 + geoColumnGeometry // GEOMETRY — SRID unknown without data inspection +) + func identCol(_ context.Context, a arrow.Array) (arrow.Array, error) { a.Retain() return a, nil @@ -80,7 +89,7 @@ func getRecTransformer(sc *arrow.Schema, tr []colTransformer) recordTransformer } } -func getTransformer(sc *arrow.Schema, ld gosnowflake.ArrowStreamLoader, useHighPrecision bool, maxTimestampPrecision MaxTimestampPrecision) (*arrow.Schema, recordTransformer) { +func getTransformer(sc *arrow.Schema, ld gosnowflake.ArrowStreamLoader, useHighPrecision bool, maxTimestampPrecision MaxTimestampPrecision, geoCols map[string]geoColumnType) (*arrow.Schema, recordTransformer) { loc, types := ld.Location(), ld.RowTypes() fields := make([]arrow.Field, len(sc.Fields())) @@ -88,6 +97,31 @@ func getTransformer(sc *arrow.Schema, ld gosnowflake.ArrowStreamLoader, useHighP for i, f := range sc.Fields() { srcMeta := types[i] originalArrowUnit := arrow.TimeUnit(srcMeta.Scale / 3) + + // With GEOGRAPHY/GEOMETRY_OUTPUT_FORMAT=WKB, geo columns arrive as binary WKB + // but srcMeta.Type is "binary" (Snowflake REST API limitation). Use the geoCols + // map (from DESCRIBE TABLE) to identify them and tag with geoarrow.wkb metadata. + // Data is already WKB binary — no conversion needed, just pass through. + if geoType, ok := geoCols[f.Name]; ok && geoType != geoColumnNone { + f.Type = arrow.BinaryTypes.Binary + if geoType == geoColumnGeography { + f.Metadata = arrow.MetadataFrom(map[string]string{ + "ARROW:extension:name": "geoarrow.wkb", + "ARROW:extension:metadata": `{"crs":"EPSG:4326"}`, + }) + } else { + // TODO: GEOMETRY SRID requires inspecting data or a separate query. + // Same cross-driver issue as adbc-drivers/redshift#2 and + // adbc-drivers/databricks#350. + f.Metadata = arrow.MetadataFrom(map[string]string{ + "ARROW:extension:name": "geoarrow.wkb", + }) + } + transformers[i] = identCol + fields[i] = f + continue + } + switch strings.ToUpper(srcMeta.Type) { case "FIXED": switch f.Type.ID() { @@ -551,7 +585,7 @@ type reader struct { done chan struct{} // signals all producer goroutines have finished } -func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake.ArrowStreamLoader, bufferSize, prefetchConcurrency int, useHighPrecision bool, maxTimestampPrecision MaxTimestampPrecision) (array.RecordReader, error) { +func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake.ArrowStreamLoader, bufferSize, prefetchConcurrency int, useHighPrecision bool, maxTimestampPrecision MaxTimestampPrecision, geoCols map[string]geoColumnType) (array.RecordReader, error) { batches, err := ld.GetBatches() if err != nil { return nil, errToAdbcErr(adbc.StatusInternal, err) @@ -671,7 +705,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake done: make(chan struct{}), } close(rdr.done) // No goroutines to wait for - rdr.schema, _ = getTransformer(schema, ld, useHighPrecision, maxTimestampPrecision) + rdr.schema, _ = getTransformer(schema, ld, useHighPrecision, maxTimestampPrecision, nil) return rdr, nil } @@ -710,7 +744,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake } var recTransform recordTransformer - rdr.schema, recTransform = getTransformer(rr.Schema(), ld, useHighPrecision, maxTimestampPrecision) + rdr.schema, recTransform = getTransformer(rr.Schema(), ld, useHighPrecision, maxTimestampPrecision, geoCols) group.Go(func() (err error) { defer rr.Release() diff --git a/go/statement.go b/go/statement.go index 7769447..13f8a35 100644 --- a/go/statement.go +++ b/go/statement.go @@ -548,7 +548,7 @@ func (st *statement) ExecuteQuery(ctx context.Context) (reader array.RecordReade return nil, err } - reader, err = newRecordReader(ctx, st.alloc, loader, st.queueSize, st.prefetchConcurrency, st.useHighPrecision, st.maxTimestampPrecision) + reader, err = newRecordReader(ctx, st.alloc, loader, st.queueSize, st.prefetchConcurrency, st.useHighPrecision, st.maxTimestampPrecision, nil) return reader, err }, currentBatch: st.bound, @@ -566,6 +566,12 @@ func (st *statement) ExecuteQuery(ctx context.Context) (reader array.RecordReade return } + // Detect geo columns before executing the query. For table scans, + // try to extract the table name and run DESCRIBE TABLE to identify + // GEOGRAPHY/GEOMETRY columns (catalog metadata is unaffected by WKB output format). + // TODO: Support arbitrary queries — currently only table scans get geoarrow metadata. + geoCols := st.cnxn.detectGeoColumnsFromQuery(ctx, st.query) + var loader gosnowflake.ArrowStreamLoader loader, err = st.cnxn.cn.QueryArrowStream(ctx, st.query) if err != nil { @@ -573,7 +579,7 @@ func (st *statement) ExecuteQuery(ctx context.Context) (reader array.RecordReade return } - reader, err = newRecordReader(ctx, st.alloc, loader, st.queueSize, st.prefetchConcurrency, st.useHighPrecision, st.maxTimestampPrecision) + reader, err = newRecordReader(ctx, st.alloc, loader, st.queueSize, st.prefetchConcurrency, st.useHighPrecision, st.maxTimestampPrecision, geoCols) nRows = loader.TotalRows() return }