diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index b5b860214564..336aadd73c48 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -111,4 +111,12 @@ private SparkSQLProperties() {} // Prefix for custom snapshot properties public static final String SNAPSHOT_PROPERTY_PREFIX = "spark.sql.iceberg.snapshot-property."; + + // Controls whether to shred variant columns during write operations + public static final String SHRED_VARIANTS = "spark.sql.iceberg.shred-variants"; + + // Controls the buffer size for variant schema inference during writes + // This determines how many rows are buffered before inferring shredded schema + public static final String VARIANT_INFERENCE_BUFFER_SIZE = + "spark.sql.iceberg.variant-inference-buffer-size"; } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index aba7e4dda082..add12e6040b0 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -33,6 +33,8 @@ import static org.apache.iceberg.TableProperties.ORC_COMPRESSION_STRATEGY; import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION; import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION_LEVEL; +import static org.apache.iceberg.TableProperties.PARQUET_SHRED_VARIANTS; +import static org.apache.iceberg.TableProperties.PARQUET_VARIANT_BUFFER_SIZE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; import java.util.Locale; @@ -529,6 +531,14 @@ private Map dataWriteProperties() { if (parquetCompressionLevel != null) { writeProperties.put(PARQUET_COMPRESSION_LEVEL, parquetCompressionLevel); } + boolean shouldShredVariants = shredVariants(); + writeProperties.put(PARQUET_SHRED_VARIANTS, String.valueOf(shouldShredVariants)); + + // Add variant shredding configuration properties + if (shouldShredVariants) { + writeProperties.put( + PARQUET_VARIANT_BUFFER_SIZE, String.valueOf(variantInferenceBufferSize())); + } break; case AVRO: @@ -749,4 +759,24 @@ public DeleteGranularity deleteGranularity() { .defaultValue(DeleteGranularity.FILE) .parse(); } + + public boolean shredVariants() { + return confParser + .booleanConf() + .option(SparkWriteOptions.SHRED_VARIANTS) + .sessionConf(SparkSQLProperties.SHRED_VARIANTS) + .tableProperty(TableProperties.PARQUET_SHRED_VARIANTS) + .defaultValue(TableProperties.PARQUET_SHRED_VARIANTS_DEFAULT) + .parse(); + } + + public int variantInferenceBufferSize() { + return confParser + .intConf() + .option(SparkWriteOptions.VARIANT_INFERENCE_BUFFER_SIZE) + .sessionConf(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE) + .tableProperty(TableProperties.PARQUET_VARIANT_BUFFER_SIZE) + .defaultValue(TableProperties.PARQUET_VARIANT_BUFFER_SIZE_DEFAULT) + .parse(); + } } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java index 1be02feaf0c0..6c76b5c873c5 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java @@ -86,4 +86,10 @@ private SparkWriteOptions() {} // Overrides the delete granularity public static final String DELETE_GRANULARITY = "delete-granularity"; + + // Controls whether to shred variant columns during write operations + public static final String SHRED_VARIANTS = "shred-variants"; + + // Controls the buffer size for variant schema inference during writes + public static final String VARIANT_INFERENCE_BUFFER_SIZE = "variant-inference-buffer-size"; } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java index 23fbe54a4be3..5b7862116aea 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java @@ -51,7 +51,9 @@ public static void register() { StructType.class, SparkParquetWriters::buildWriter, (icebergSchema, fileSchema, engineSchema, idToConstant) -> - SparkParquetReaders.buildReader(icebergSchema, fileSchema, idToConstant))); + SparkParquetReaders.buildReader(icebergSchema, fileSchema, idToConstant), + new SparkVariantShreddingAnalyzer(), + InternalRow::copy)); FormatModelRegistry.register( ParquetFormatModel.create( diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java new file mode 100644 index 000000000000..2c08c662c9da --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.List; +import org.apache.iceberg.parquet.VariantShreddingAnalyzer; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.VariantValue; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.VariantVal; + +/** + * Spark-specific implementation that extracts variant values from {@link InternalRow} instances. + */ +class SparkVariantShreddingAnalyzer extends VariantShreddingAnalyzer { + + SparkVariantShreddingAnalyzer() {} + + @Override + protected int resolveColumnIndex(StructType sparkSchema, String columnName) { + try { + return sparkSchema.fieldIndex(columnName); + } catch (IllegalArgumentException e) { + return -1; + } + } + + @Override + protected List extractVariantValues( + List bufferedRows, int variantFieldIndex) { + List values = Lists.newArrayList(); + + for (InternalRow row : bufferedRows) { + if (!row.isNullAt(variantFieldIndex)) { + VariantVal variantVal = row.getVariant(variantFieldIndex); + if (variantVal != null) { + VariantValue variantValue = + VariantValue.from( + VariantMetadata.from( + ByteBuffer.wrap(variantVal.getMetadata()).order(ByteOrder.LITTLE_ENDIAN)), + ByteBuffer.wrap(variantVal.getValue()).order(ByteOrder.LITTLE_ENDIAN)); + values.add(variantValue); + } + } + } + + return values; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java index c83b1b6e26ac..c5cfbe62b1be 100644 --- a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java @@ -34,6 +34,7 @@ import static org.apache.iceberg.TableProperties.ORC_COMPRESSION_STRATEGY; import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION; import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION_LEVEL; +import static org.apache.iceberg.TableProperties.PARQUET_SHRED_VARIANTS; import static org.apache.iceberg.TableProperties.UPDATE_DISTRIBUTION_MODE; import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE; import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_HASH; @@ -61,6 +62,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.TestTemplate; @@ -340,6 +342,8 @@ public void testSparkConfOverride() { TableProperties.DELETE_PARQUET_COMPRESSION, "snappy"), ImmutableMap.of( + PARQUET_SHRED_VARIANTS, + "false", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -461,6 +465,8 @@ public void testDataPropsDefaultsAsDeleteProps() { PARQUET_COMPRESSION_LEVEL, "5"), ImmutableMap.of( + PARQUET_SHRED_VARIANTS, + "false", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -532,6 +538,8 @@ public void testDeleteFileWriteConf() { DELETE_PARQUET_COMPRESSION_LEVEL, "6"), ImmutableMap.of( + PARQUET_SHRED_VARIANTS, + "false", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -686,4 +694,81 @@ private void checkMode(DistributionMode expectedMode, SparkWriteConf writeConf) assertThat(writeConf.copyOnWriteDistributionMode(MERGE)).isEqualTo(expectedMode); assertThat(writeConf.positionDeltaDistributionMode(MERGE)).isEqualTo(expectedMode); } + + @TestTemplate + public void testShredVariantsDefault() { + Table table = validationCatalog.loadTable(tableIdent); + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + assertThat(writeConf.shredVariants()).isFalse(); + } + + @TestTemplate + public void testVariantInferenceBufferSizeDefault() { + Table table = validationCatalog.loadTable(tableIdent); + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + assertThat(writeConf.variantInferenceBufferSize()) + .isEqualTo(TableProperties.PARQUET_VARIANT_BUFFER_SIZE_DEFAULT); + } + + @TestTemplate + public void testVariantInferenceBufferSizeTableProperty() { + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(TableProperties.PARQUET_VARIANT_BUFFER_SIZE, "500").commit(); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + assertThat(writeConf.variantInferenceBufferSize()).isEqualTo(500); + } + + @TestTemplate + public void testShredVariantsSessionOverridesTableProperty() { + Table table = validationCatalog.loadTable(tableIdent); + table.updateProperties().set(TableProperties.PARQUET_SHRED_VARIANTS, "false").commit(); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.SHRED_VARIANTS, "true"), + () -> { + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + assertThat(writeConf.shredVariants()).isTrue(); + }); + } + + @TestTemplate + public void testShredVariantsWriteOptionOverridesSessionConf() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.SHRED_VARIANTS, "false"), + () -> { + Table table = validationCatalog.loadTable(tableIdent); + SparkWriteConf writeConf = + new SparkWriteConf( + spark, + table, + new CaseInsensitiveStringMap( + ImmutableMap.of(SparkWriteOptions.SHRED_VARIANTS, "true"))); + assertThat(writeConf.shredVariants()).isTrue(); + }); + } + + @TestTemplate + public void testVariantInferenceBufferSizeSessionConf() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "250"), + () -> { + Table table = validationCatalog.loadTable(tableIdent); + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + assertThat(writeConf.variantInferenceBufferSize()).isEqualTo(250); + }); + } + + @TestTemplate + public void testWritePropertiesIncludeVariantShredding() { + Table table = validationCatalog.loadTable(tableIdent); + table.updateProperties().set(TableProperties.PARQUET_SHRED_VARIANTS, "true").commit(); + table.updateProperties().set(TableProperties.PARQUET_VARIANT_BUFFER_SIZE, "200").commit(); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + Map writeProperties = writeConf.writeProperties(); + assertThat(writeProperties).containsEntry(PARQUET_SHRED_VARIANTS, "true"); + assertThat(writeProperties).containsEntry(TableProperties.PARQUET_VARIANT_BUFFER_SIZE, "200"); + } } diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java new file mode 100644 index 000000000000..8cdcf22e5817 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java @@ -0,0 +1,1101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.variant; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.parquet.schema.Types.optional; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.math.BigDecimal; +import java.net.InetAddress; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.variants.Variant; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.util.HadoopInputFile; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestVariantShredding extends CatalogTestBase { + + private static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.optional(2, "address", Types.VariantType.get())); + + private static final Schema SCHEMA2 = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.optional(2, "address", Types.VariantType.get()), + Types.NestedField.optional(3, "metadata", Types.VariantType.get())); + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + protected static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties() + }, + }; + } + + @BeforeAll + public static void startMetastoreAndSpark() { + // First call parent to initialize metastore and spark with local[2] + CatalogTestBase.startMetastoreAndSpark(); + + // Now stop and recreate spark with local[1] to write all rows to a single file + if (spark != null) { + spark.stop(); + } + + spark = + SparkSession.builder() + .master("local[1]") // Use one thread to write the rows to a single parquet file + .config("spark.driver.host", InetAddress.getLoopbackAddress().getHostAddress()) + .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic") + .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname)) + .config("spark.sql.legacy.respectNullabilityInTextDatasetConversion", "true") + .config(DISABLE_UI) + .enableHiveSupport() + .getOrCreate(); + + sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + @BeforeEach + public void before() { + super.before(); + validationCatalog.createTable( + tableIdent, SCHEMA, null, Map.of(TableProperties.FORMAT_VERSION, "3")); + } + + @AfterEach + public void after() { + spark.conf().unset(SparkSQLProperties.SHRED_VARIANTS); + spark.conf().unset(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE); + validationCatalog.dropTable(tableIdent, true); + } + + @TestTemplate + public void testVariantShreddingDisabled() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "false"); + + String values = "(1, parse_json('{\"city\": \"NYC\", \"zip\": 10001}')), (2, null)"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType address = variant("address", 2, Type.Repetition.OPTIONAL); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testExcludingNullValue() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"name": "Alice", "age": 30, "dummy": null}')),\ + (2, parse_json('{"name": "Bob", "age": 25}')),\ + (3, parse_json('{"name": "Charlie", "age": 35}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testInconsistentType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"age": "25"}')),\ + (2, parse_json('{"age": 30}')),\ + (3, parse_json('{"age": "35"}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + List rows = + sql("SELECT variant_get(address, '$.age', 'int') FROM %s WHERE id = 2", tableName); + assertThat(rows).hasSize(1); + assertThat(rows.get(0)[0]).isEqualTo(30); + } + + @TestTemplate + public void testPrimitiveType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = "(1, parse_json('123')), (2, parse_json('456')), (3, parse_json('789'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType address = + variant( + "address", + 2, + Type.Repetition.REQUIRED, + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16, true))); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testPrimitiveDecimalType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + "(1, parse_json('123.56')), (2, parse_json('\"abc\"')), (3, parse_json('12.56'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType address = + variant( + "address", + 2, + Type.Repetition.REQUIRED, + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.decimalType(2, 5))); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testBooleanType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"active": true}')),\ + (2, parse_json('{"active": false}')),\ + (3, parse_json('{"active": true}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType active = field("active", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(active)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testDecimalTypeWithInconsistentScales() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"price": 123.456789}')),\ + (2, parse_json('{"price": 678.90}')),\ + (3, parse_json('{"price": 999.99}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType price = + field( + "price", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.decimalType(6, 9))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(price)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testDecimalTypeWithConsistentScales() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"price": 123.45}')),\ + (2, parse_json('{"price": 678.90}')),\ + (3, parse_json('{"price": 999.99}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType price = + field( + "price", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.decimalType(2, 5))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(price)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testArrayType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('["java", "scala", "python"]')),\ + (2, parse_json('["rust", "go"]')),\ + (3, parse_json('["javascript"]'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType arr = + list( + element( + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType()))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, arr); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testNestedArrayType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"tags": ["java", "scala", "python"]}')),\ + (2, parse_json('{"tags": ["rust", "go"]}')),\ + (3, parse_json('{"tags": ["javascript"]}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType tags = + field( + "tags", + list( + element( + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, + LogicalTypeAnnotation.stringType())))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(tags)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testNestedObjectType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"location": {"city": "Seattle", "zip": 98101}, "tags": ["java", "scala", "python"]}')),\ + (2, parse_json('{"location": {"city": "Portland", "zip": 97201}}')),\ + (3, parse_json('{"location": {"city": "NYC", "zip": 10001}}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType city = + field( + "city", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType zip = + field( + "zip", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(32, true))); + GroupType location = field("location", objectFields(city, zip)); + GroupType tags = + field( + "tags", + list( + element( + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, + LogicalTypeAnnotation.stringType())))); + + GroupType address = + variant("address", 2, Type.Repetition.REQUIRED, objectFields(location, tags)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testLazyInitializationWithBufferedRows() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "5"); + + String values = + """ + (1, parse_json('{"name": "Alice", "age": 30}')),\ + (2, parse_json('{"name": "Bob", "age": 25}')),\ + (3, parse_json('{"name": "Charlie", "age": 35}')),\ + (4, parse_json('{"name": "David", "age": 28}')),\ + (5, parse_json('{"name": "Eve", "age": 32}')),\ + (6, parse_json('{"name": "Frank", "age": 40}')),\ + (7, parse_json('{"name": "Grace", "age": 27}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + long rowCount = spark.read().format("iceberg").load(tableName).count(); + assertThat(rowCount).isEqualTo(7); + } + + @TestTemplate + public void testMultipleRowGroups() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + int numRows = 1000; + StringBuilder valuesBuilder = new StringBuilder(); + for (int i = 1; i <= numRows; i++) { + if (i > 1) { + valuesBuilder.append(", "); + } + valuesBuilder.append( + String.format("(%d, parse_json('{\"name\": \"User%d\", \"age\": %d}'))", i, i, 20 + i)); + } + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", + tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 1024); + sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString()); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + long rowCount = spark.read().format("iceberg").load(tableName).count(); + assertThat(rowCount).isEqualTo(numRows); + } + + @TestTemplate + public void testColumnIndexTruncateLength() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + int customTruncateLength = 10; + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", + tableName, "parquet.columnindex.truncate.length", customTruncateLength); + + StringBuilder valuesBuilder = new StringBuilder(); + for (int i = 1; i <= 10; i++) { + if (i > 1) { + valuesBuilder.append(", "); + } + String longValue = "A".repeat(20); + valuesBuilder.append( + String.format( + "(%d, parse_json('{\"description\": \"%s\", \"id\": %d}'))", i, longValue, i)); + } + sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString()); + + GroupType description = + field( + "description", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType id = + field( + "id", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = + variant("address", 2, Type.Repetition.REQUIRED, objectFields(description, id)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + long rowCount = spark.read().format("iceberg").load(tableName).count(); + assertThat(rowCount).isEqualTo(10); + } + + @TestTemplate + public void testIntegerFamilyPromotion() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Mix of INT8, INT16, INT32, INT64 - should promote to INT64 + String values = + """ + (1, parse_json('{"value": 10}')),\ + (2, parse_json('{"value": 1000}')),\ + (3, parse_json('{"value": 100000}')),\ + (4, parse_json('{"value": 10000000000}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType value = + field( + "value", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT64, LogicalTypeAnnotation.intType(64, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(value)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testDecimalFamilyPromotion() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Test that they get promoted to the most capable decimal type observed + String values = + """ + (1, parse_json('{"value": 1.5}')),\ + (2, parse_json('{"value": 123.456789}')),\ + (3, parse_json('{"value": 123456789123456.789}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType value = + field( + "value", + optional(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + .length(16) + .as(LogicalTypeAnnotation.decimalType(6, 21)) + .named("typed_value")); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(value)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testDataRoundTripWithShredding() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"name": "Alice", "age": 30}')),\ + (2, parse_json('{"name": "Bob", "age": 25}')),\ + (3, parse_json('{"name": "Charlie", "age": 35}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + // Verify that we can read the data back correctly + List rows = + sql( + "SELECT id, variant_get(address, '$.name', 'string')," + + " variant_get(address, '$.age', 'int')" + + " FROM %s ORDER BY id", + tableName); + assertThat(rows).hasSize(3); + assertThat(rows.get(0)[0]).isEqualTo(1); + assertThat(rows.get(0)[1]).isEqualTo("Alice"); + assertThat(rows.get(0)[2]).isEqualTo(30); + assertThat(rows.get(1)[0]).isEqualTo(2); + assertThat(rows.get(1)[1]).isEqualTo("Bob"); + assertThat(rows.get(1)[2]).isEqualTo(25); + assertThat(rows.get(2)[0]).isEqualTo(3); + assertThat(rows.get(2)[1]).isEqualTo("Charlie"); + assertThat(rows.get(2)[2]).isEqualTo(35); + } + + @TestTemplate + public void testMultipleVariantsWithShredding() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Recreate table with SCHEMA2 (address + metadata variant columns) + validationCatalog.dropTable(tableIdent, true); + validationCatalog.createTable( + tableIdent, SCHEMA2, null, Map.of(TableProperties.FORMAT_VERSION, "3")); + + String values = + """ + (1, parse_json('{"city": "NYC"}'), parse_json('{"source": "web"}')),\ + (2, parse_json('{"city": "LA"}'), parse_json('{"source": "app"}')),\ + (3, parse_json('{"city": "SF"}'), parse_json('{"source": "api"}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType city = + field( + "city", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(city)); + + GroupType source = + field( + "source", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType metadata = variant("metadata", 3, Type.Repetition.REQUIRED, objectFields(source)); + MessageType expectedSchema = parquetSchema(address, metadata); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testVariantWithNullValues() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('null')),\ + (2, parse_json('null')),\ + (3, parse_json('null'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType address = variant("address", 2, Type.Repetition.REQUIRED); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testArrayOfNullElementsWithShredding() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + sql( + "INSERT INTO %s VALUES (1, parse_json('[null, null, null]')), " + + "(2, parse_json('[null]'))", + tableName); + + // Array elements are all null, element type is null, falls back to unshredded + GroupType address = variant("address", 2, Type.Repetition.REQUIRED); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testMixedNullAndNonNullVariantValues() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"name": "Alice", "age": 30}')),\ + (2, null),\ + (3, parse_json('{"name": "Charlie", "age": 35}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + long rowCount = spark.read().format("iceberg").load(tableName).count(); + assertThat(rowCount).isEqualTo(3); + } + + @TestTemplate + public void testWriteOptionOverridesSessionConfig() throws IOException, NoSuchTableException { + // Disable shredding at session level + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "false"); + + // Enable shredding via per-write option + String query = + "SELECT 1 as id, parse_json('{\"name\": \"Alice\", \"age\": 30}') as address" + + " UNION ALL SELECT 2, parse_json('{\"name\": \"Bob\", \"age\": 25}')" + + " UNION ALL SELECT 3, parse_json('{\"name\": \"Charlie\", \"age\": 35}')"; + spark.sql(query).writeTo(tableName).option("shred-variants", "true").append(); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testInfrequentFieldPruning() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "11"); + + StringBuilder valuesBuilder = new StringBuilder(); + for (int i = 1; i <= 11; i++) { + if (i > 1) { + valuesBuilder.append(", "); + } + if (i == 1) { + // Only the first row has rare_field + valuesBuilder.append( + String.format( + "(%d, parse_json('{\"name\": \"User%d\", \"rare_field\": \"rare\"}'))", i, i)); + } else { + valuesBuilder.append(String.format("(%d, parse_json('{\"name\": \"User%d\"}'))", i, i)); + } + } + sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString()); + + // rare_field appears in 1/11 rows, should be pruned + // name appears in 11/11 rows and should be kept + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testMixedTypeTieBreaking() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "10"); + + StringBuilder valuesBuilder = new StringBuilder(); + for (int i = 1; i <= 10; i++) { + if (i > 1) { + valuesBuilder.append(", "); + } + if (i <= 5) { + valuesBuilder.append(String.format("(%d, parse_json('{\"val\": %d}'))", i, i)); + } else { + valuesBuilder.append(String.format("(%d, parse_json('{\"val\": \"text%d\"}'))", i, i)); + } + } + sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString()); + + // 5 ints + 5 strings is a tie so STRING wins (higher TIE_BREAK_PRIORITY) + GroupType val = + field( + "val", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(val)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + // Verify data round-trips correctly + List rows = + sql("SELECT id, variant_get(address, '$.val', 'string') FROM %s ORDER BY id", tableName); + assertThat(rows).hasSize(10); + assertThat(rows.get(0)[1]).isEqualTo("1"); + assertThat(rows.get(5)[1]).isEqualTo("text6"); + } + + @TestTemplate + public void testFieldOnlyAfterBuffer() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + String values = + """ + (1, parse_json('{"name": "Alice"}')),\ + (2, parse_json('{"name": "Bob"}')),\ + (3, parse_json('{"name": "Charlie"}')),\ + (4, parse_json('{"name": "David", "score": 95}')),\ + (5, parse_json('{"name": "Eve", "score": 88}')),\ + (6, parse_json('{"name": "Frank", "score": 72}')),\ + (7, parse_json('{"name": "Grace", "score": 91}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // Schema is determined from buffer (rows 1-3) which only has "name". + // "score" is not shredded + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + // Verify all data round-trips despite "score" not being shredded + List rows = + sql( + "SELECT id, variant_get(address, '$.name', 'string')," + + " variant_get(address, '$.score', 'int')" + + " FROM %s ORDER BY id", + tableName); + assertThat(rows).hasSize(7); + assertThat(rows.get(0)[1]).isEqualTo("Alice"); + assertThat(rows.get(0)[2]).isNull(); + assertThat(rows.get(3)[1]).isEqualTo("David"); + assertThat(rows.get(3)[2]).isEqualTo(95); + assertThat(rows.get(6)[1]).isEqualTo("Grace"); + assertThat(rows.get(6)[2]).isEqualTo(91); + } + + @TestTemplate + public void testCrossFileDifferentShreddedType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + // File 1: "score" is always integer → shredded as INT8 + String batch1 = + """ + (1, parse_json('{"score": 95}')),\ + (2, parse_json('{"score": 88}')),\ + (3, parse_json('{"score": 72}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, batch1); + + // Verify file 1 schema: score shredded as INT8 + Table table = validationCatalog.loadTable(tableIdent); + GroupType scoreInt = + field( + "score", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + MessageType expectedSchema1 = + parquetSchema(variant("address", 2, Type.Repetition.REQUIRED, objectFields(scoreInt))); + verifyParquetSchema(table, expectedSchema1); + + // File 2: "score" is always string → shredded as STRING + String batch2 = + """ + (4, parse_json('{"score": "high"}')),\ + (5, parse_json('{"score": "medium"}')),\ + (6, parse_json('{"score": "low"}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, batch2); + + // Query across both files, reader must handle different shredded types + List rows = + sql("SELECT id, variant_get(address, '$.score', 'string') FROM %s ORDER BY id", tableName); + assertThat(rows).hasSize(6); + assertThat(rows.get(0)[1]).isEqualTo("95"); + assertThat(rows.get(1)[1]).isEqualTo("88"); + assertThat(rows.get(3)[1]).isEqualTo("high"); + assertThat(rows.get(5)[1]).isEqualTo("low"); + } + + @TestTemplate + public void testAllNullVariantColumn() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + sql("INSERT INTO %s VALUES (1, null), (2, null), (3, null)", tableName); + + // All variant values are SQL NULL, so no shredding should occur + Table table = validationCatalog.loadTable(tableIdent); + MessageType expectedSchema = parquetSchema(variant("address", 2, Type.Repetition.OPTIONAL)); + verifyParquetSchema(table, expectedSchema); + + List rows = sql("SELECT id, address FROM %s ORDER BY id", tableName); + assertThat(rows).hasSize(3); + assertThat(rows.get(0)[1]).isNull(); + assertThat(rows.get(1)[1]).isNull(); + assertThat(rows.get(2)[1]).isNull(); + } + + @TestTemplate + public void testBufferSizeOne() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "1"); + + sql( + """ + INSERT INTO %s VALUES + (1, parse_json('{"name": "Alice", "age": 30}')), + (2, parse_json('{"name": "Bob", "age": 25}')), + (3, parse_json('{"name": "Charlie", "age": 35}')) + """, + tableName); + + // Schema inferred from first row only, should still shred name and age + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + List rows = + sql("SELECT id, variant_get(address, '$.name', 'string') FROM %s ORDER BY id", tableName); + assertThat(rows).hasSize(3); + assertThat(rows.get(0)[1]).isEqualTo("Alice"); + assertThat(rows.get(2)[1]).isEqualTo("Charlie"); + } + + @TestTemplate + public void testDecimalFallbackAfterBuffer() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + // Buffer: scale=2, 3 integer digits -> DECIMAL(5,2) + // Row 4: precision overflow -> fallback to value field + // Row 5: scale overflow -> fallback to value field + // Row 6: fits typed column, scale widened from 1 to 2 via setScale + String values = + """ + (1, parse_json('{"val": 123.45}')),\ + (2, parse_json('{"val": 678.90}')),\ + (3, parse_json('{"val": 999.99}')),\ + (4, parse_json('{"val": 123456.78}')),\ + (5, parse_json('{"val": 1.2345}')),\ + (6, parse_json('{"val": 12.3}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + List rows = + sql( + "SELECT id, variant_get(address, '$.val', 'decimal(10,4)') FROM %s ORDER BY id", + tableName); + assertThat(rows).hasSize(6); + assertThat(rows.get(0)[1]).isEqualTo(new BigDecimal("123.4500")); + assertThat(rows.get(3)[1]).isEqualTo(new BigDecimal("123456.7800")); + assertThat(rows.get(4)[1]).isEqualTo(new BigDecimal("1.2345")); + assertThat(rows.get(5)[1]).isEqualTo(new BigDecimal("12.3000")); + } + + private void verifyParquetSchema(Table table, MessageType expectedSchema) throws IOException { + try (CloseableIterable tasks = table.newScan().planFiles()) { + assertThat(tasks).isNotEmpty(); + + for (FileScanTask task : tasks) { + String path = task.file().location(); + + HadoopInputFile inputFile = HadoopInputFile.fromPath(new Path(path), new Configuration()); + + try (ParquetFileReader reader = ParquetFileReader.open(inputFile)) { + MessageType actualSchema = reader.getFileMetaData().getSchema(); + assertThat(actualSchema).isEqualTo(expectedSchema); + } + } + } + } + + private static MessageType parquetSchema(Type... variantTypes) { + return org.apache.parquet.schema.Types.buildMessage() + .required(PrimitiveType.PrimitiveTypeName.INT32) + .id(1) + .named("id") + .addFields(variantTypes) + .named("table"); + } + + private static GroupType variant(String name, int fieldId, Type.Repetition repetition) { + return org.apache.parquet.schema.Types.buildGroup(repetition) + .id(fieldId) + .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) + .required(PrimitiveType.PrimitiveTypeName.BINARY) + .named("metadata") + .required(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .named(name); + } + + private static GroupType variant( + String name, int fieldId, Type.Repetition repetition, Type shreddedType) { + checkShreddedType(shreddedType); + return org.apache.parquet.schema.Types.buildGroup(repetition) + .id(fieldId) + .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) + .required(PrimitiveType.PrimitiveTypeName.BINARY) + .named("metadata") + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedType) + .named(name); + } + + private static Type shreddedPrimitive(PrimitiveType.PrimitiveTypeName primitive) { + return optional(primitive).named("typed_value"); + } + + private static Type shreddedPrimitive( + PrimitiveType.PrimitiveTypeName primitive, LogicalTypeAnnotation annotation) { + return optional(primitive).as(annotation).named("typed_value"); + } + + private static GroupType objectFields(GroupType... fields) { + for (GroupType fieldType : fields) { + checkField(fieldType); + } + + return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.OPTIONAL) + .addFields(fields) + .named("typed_value"); + } + + private static GroupType field(String name, Type shreddedType) { + checkShreddedType(shreddedType); + return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedType) + .named(name); + } + + private static GroupType element(Type shreddedType) { + return field("element", shreddedType); + } + + private static GroupType list(GroupType elementType) { + return org.apache.parquet.schema.Types.optionalList().element(elementType).named("typed_value"); + } + + private static void checkShreddedType(Type shreddedType) { + Preconditions.checkArgument( + shreddedType.getName().equals("typed_value"), + "Invalid shredded type name: %s should be typed_value", + shreddedType.getName()); + Preconditions.checkArgument( + shreddedType.isRepetition(Type.Repetition.OPTIONAL), + "Invalid shredded type repetition: %s should be OPTIONAL", + shreddedType.getRepetition()); + } + + private static void checkField(GroupType fieldType) { + Preconditions.checkArgument( + fieldType.isRepetition(Type.Repetition.REQUIRED), + "Invalid field type repetition: %s should be REQUIRED", + fieldType.getRepetition()); + } +}