diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/row/RowAggregator.scala b/aggregator/src/main/scala/ai/chronon/aggregator/row/RowAggregator.scala index c8bc1da08..151051619 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/row/RowAggregator.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/row/RowAggregator.scala @@ -70,6 +70,11 @@ class RowAggregator(val inputSchema: Seq[(String, DataType)], val aggregationPar .toArray .zip(columnAggregators.map(_.irType)) + val incrementalOutputSchema: Array[(String, DataType)] = aggregationParts + .map(_.incrementalOutputColumnName) + .toArray + .zip(columnAggregators.map(_.irType)) + val outputSchema: Array[(String, DataType)] = aggregationParts .map(_.outputColumnName) .toArray diff --git a/api/py/ai/chronon/group_by.py b/api/py/ai/chronon/group_by.py index 7601a8599..54237fe3f 100644 --- a/api/py/ai/chronon/group_by.py +++ b/api/py/ai/chronon/group_by.py @@ -390,6 +390,7 @@ def GroupBy( historical_backfill: Optional[bool] = None, deprecation_date: Optional[str] = None, description: Optional[str] = None, + is_incremental: Optional[bool] = None, **kwargs, ) -> ttypes.GroupBy: """ @@ -608,6 +609,7 @@ def _normalize_source(source): backfillStartDate=backfill_start_date, accuracy=accuracy, derivations=derivations, + isIncremental=is_incremental, ) validate_group_by(group_by) return group_by diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index f3e9b6821..44f653d3c 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -97,7 +97,7 @@ object Extensions { def cleanName: String = metaData.name.sanitize def outputTable = s"${metaData.outputNamespace}.${metaData.cleanName}" - + def incrementalOutputTable = s"${metaData.outputNamespace}.${metaData.cleanName}_daily_inc" def preModelTransformsTable = s"${metaData.outputNamespace}.${metaData.cleanName}_pre_mt" def outputLabelTable = s"${metaData.outputNamespace}.${metaData.cleanName}_labels" def outputFinalView = s"${metaData.outputNamespace}.${metaData.cleanName}_labeled" @@ -178,6 +178,10 @@ object Extensions { def outputColumnName = s"${aggregationPart.inputColumn}_$opSuffix${aggregationPart.window.suffix}${bucketSuffix}" + + def incrementalOutputColumnName = + s"${aggregationPart.inputColumn}_$opSuffix${bucketSuffix}" + } implicit class AggregationOps(aggregation: Aggregation) { diff --git a/api/thrift/api.thrift b/api/thrift/api.thrift index 1e48f624f..b9241129f 100644 --- a/api/thrift/api.thrift +++ b/api/thrift/api.thrift @@ -309,6 +309,7 @@ struct GroupBy { 6: optional string backfillStartDate // Optional derivation list 7: optional list derivations + 8: optional bool isIncremental } struct JoinPart { diff --git a/spark/src/main/scala/ai/chronon/spark/Comparison.scala b/spark/src/main/scala/ai/chronon/spark/Comparison.scala index 83c0db33d..2e5032dea 100644 --- a/spark/src/main/scala/ai/chronon/spark/Comparison.scala +++ b/spark/src/main/scala/ai/chronon/spark/Comparison.scala @@ -20,13 +20,32 @@ import org.slf4j.LoggerFactory import ai.chronon.online.Extensions.StructTypeOps import com.google.gson.{Gson, GsonBuilder} import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.{DecimalType, DoubleType, FloatType, MapType} +import org.apache.spark.sql.types.{ArrayType, DecimalType, DoubleType, FloatType, MapType, StructType} +import org.apache.spark.sql.functions.col import java.util +import scala.collection.mutable object Comparison { @transient lazy val logger = LoggerFactory.getLogger(getClass) + // Flatten struct columns into individual columns so nested double fields can be compared with tolerance + private def flattenStructs(df: DataFrame): DataFrame = { + val flattenedSelects = df.schema.fields.toSeq.flatMap { field => + field.dataType match { + case structType: StructType => + // Flatten struct fields: struct_name.field_name -> struct_name_field_name + structType.fields.map { subField => + col(s"${field.name}.${subField.name}").alias(s"${field.name}_${subField.name}") + }.toSeq + case _ => + // Keep non-struct fields as-is + Seq(col(field.name)) + } + } + df.select(flattenedSelects: _*) + } + // used for comparison def sortedJson(m: Map[String, Any]): String = { if (m == null) return null @@ -69,8 +88,12 @@ object Comparison { |""".stripMargin ) - val prefixedExpectedDf = prefixColumnName(stringifyMaps(a), s"${aName}_") - val prefixedOutputDf = prefixColumnName(stringifyMaps(b), s"${bName}_") + // Flatten structs so nested double fields can be compared with tolerance + val aFlattened = flattenStructs(stringifyMaps(a)) + val bFlattened = flattenStructs(stringifyMaps(b)) + + val prefixedExpectedDf = prefixColumnName(aFlattened, s"${aName}_") + val prefixedOutputDf = prefixColumnName(bFlattened, s"${bName}_") val joinExpr = keys .map(key => prefixedExpectedDf(s"${aName}_$key") <=> prefixedOutputDf(s"${bName}_$key")) @@ -82,15 +105,16 @@ object Comparison { ) var finalDf = joined + // Use flattened schema for comparison val comparisonColumns = - a.schema.fieldNames.toSet.diff(keys.toSet).toList.sorted + aFlattened.schema.fieldNames.toSet.diff(keys.toSet).toList.sorted val colOrder = keys.map(key => { finalDf(s"${aName}_$key").as(key) }) ++ comparisonColumns.flatMap { col => List(finalDf(s"${aName}_$col"), finalDf(s"${bName}_$col")) } - // double columns need to be compared approximately - val doubleCols = a.schema.fields + // double columns need to be compared approximately (now includes flattened struct fields) + val doubleCols = aFlattened.schema.fields .filter(field => field.dataType == DoubleType || field.dataType == FloatType || field.dataType.isInstanceOf[DecimalType]) .map(_.name) diff --git a/spark/src/main/scala/ai/chronon/spark/DataRange.scala b/spark/src/main/scala/ai/chronon/spark/DataRange.scala index 7728f69e4..3f8d135ab 100644 --- a/spark/src/main/scala/ai/chronon/spark/DataRange.scala +++ b/spark/src/main/scala/ai/chronon/spark/DataRange.scala @@ -54,6 +54,11 @@ case class PartitionRange(start: String, end: String)(implicit tableUtils: Table } } + def daysBetween: Int = { + if (start == null || end == null) 0 + else Stream.iterate(start)(tableUtils.partitionSpec.after).takeWhile(_ <= end).size + } + def isSingleDay: Boolean = { start == end } diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index a5f0c8ab2..bd81e02e3 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -467,7 +467,8 @@ object Driver { tableUtils, args.stepDays.toOption, args.startPartitionOverride.toOption, - !args.runFirstHole() + !args.runFirstHole(), + Option(args.groupByConf.isIncremental).getOrElse(false) ) if (args.shouldExport()) { diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index 566a19cd1..139cc000b 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -18,11 +18,12 @@ package ai.chronon.spark import ai.chronon.aggregator.base.TimeTuple import ai.chronon.aggregator.row.RowAggregator +import ai.chronon.aggregator.windowing.HopsAggregator.HopIr import ai.chronon.aggregator.windowing._ import ai.chronon.api import ai.chronon.api.DataModel.{Entities, Events} import ai.chronon.api.Extensions._ -import ai.chronon.api.{Accuracy, Constants, DataModel, ParametricMacro} +import ai.chronon.api.{Accuracy, Constants, DataModel, ParametricMacro, TimeUnit, Window} import ai.chronon.online.serde.{RowWrapper, SparkConversions} import ai.chronon.spark.Extensions._ import ai.chronon.spark.catalog.TableUtils @@ -33,6 +34,8 @@ import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.util.sketch.BloomFilter import org.slf4j.LoggerFactory +import scala.jdk.CollectionConverters._ + import java.util import scala.collection.{Seq, mutable} import scala.util.ScalaJavaConversions.{JListOps, ListOps, MapOps} @@ -89,10 +92,18 @@ class GroupBy(val aggregations: Seq[api.Aggregation], lazy val aggPartWithSchema = aggregationParts.zip(columnAggregators.map(_.outputType)) lazy val postAggSchema: StructType = { - val valueChrononSchema = if (finalize) windowAggregator.outputSchema else windowAggregator.irSchema + val valueChrononSchema = if (finalize) { + windowAggregator.outputSchema + } else { + windowAggregator.irSchema + } SparkConversions.fromChrononSchema(valueChrononSchema) } + @transient lazy val flattenedAgg: RowAggregator = + new RowAggregator(selectedSchema, aggregations.flatMap(_.unWindowed)) + lazy val incrementalSchema: Array[(String, api.DataType)] = flattenedAgg.incrementalOutputSchema + @transient protected[spark] lazy val windowAggregator: RowAggregator = new RowAggregator(selectedSchema, aggregations.flatMap(_.unpack)) @@ -357,6 +368,30 @@ class GroupBy(val aggregations: Seq[api.Aggregation], toDf(outputRdd, Seq(Constants.TimeColumn -> LongType, tableUtils.partitionColumn -> StringType)) } + def flattenOutputArrayType( + hopsArrays: RDD[(KeyWithHash, HopsAggregator.OutputArrayType)]): RDD[(Array[Any], Array[Any])] = { + hopsArrays.flatMap { + case (keyWithHash: KeyWithHash, hopsArray: HopsAggregator.OutputArrayType) => + hopsArray.headOption match { + case Some(dailyHops) => + dailyHops.map { hopIr => + val timestamp = hopIr.last.asInstanceOf[Long] + val normalizedIR = flattenedAgg.normalize(hopIr.dropRight(1)) + ((keyWithHash.data :+ tableUtils.partitionSpec.at(timestamp) :+ timestamp), normalizedIR) + } + case None => Iterator.empty + } + } + } + + def convertHopsToDf(hops: RDD[(KeyWithHash, HopsAggregator.OutputArrayType)], + schema: Array[(String, ai.chronon.api.DataType)]): DataFrame = { + val hopsDf = flattenOutputArrayType(hops) + toDf(hopsDf, + Seq(tableUtils.partitionColumn -> StringType, Constants.TimeColumn -> LongType), + Some(SparkConversions.fromChrononSchema(schema))) + } + // convert raw data into IRs, collected by hopSizes // TODO cache this into a table: interface below // Class HopsCacher(keySchema, irSchema, resolution) extends RddCacher[(KeyWithHash, HopsOutput)] @@ -380,9 +415,10 @@ class GroupBy(val aggregations: Seq[api.Aggregation], } protected[spark] def toDf(aggregateRdd: RDD[(Array[Any], Array[Any])], - additionalFields: Seq[(String, DataType)]): DataFrame = { + additionalFields: Seq[(String, DataType)], + schema: Option[StructType] = None): DataFrame = { val finalKeySchema = StructType(keySchema ++ additionalFields.map { case (name, typ) => StructField(name, typ) }) - KvRdd(aggregateRdd, finalKeySchema, postAggSchema).toFlatDf + KvRdd(aggregateRdd, finalKeySchema, schema.getOrElse(postAggSchema)).toFlatDf } private def normalizeOrFinalize(ir: Array[Any]): Array[Any] = @@ -392,6 +428,18 @@ class GroupBy(val aggregations: Seq[api.Aggregation], windowAggregator.normalize(ir) } + /** + * computes incremental daily table + * @param incrementalOutputTable output of the incremental data stored here + * @param range date range to calculate daily aggregations + * @param tableProps + */ + def computeIncrementalDf(incrementalOutputTable: String, range: PartitionRange, tableProps: Map[String, String]) = { + + val hops = hopsAggregate(range.toTimePoints.min, DailyResolution) + val hopsDf: DataFrame = convertHopsToDf(hops, incrementalSchema) + hopsDf.save(incrementalOutputTable, tableProps) + } } // TODO: truncate queryRange for caching @@ -463,31 +511,16 @@ object GroupBy { bloomMapOpt: Option[util.Map[String, BloomFilter]] = None, skewFilter: Option[String] = None, finalize: Boolean = true, - showDf: Boolean = false): GroupBy = { + showDf: Boolean = false, + incrementalMode: Boolean = false): GroupBy = { logger.info(s"\n----[Processing GroupBy: ${groupByConfOld.metaData.name}]----") val groupByConf = replaceJoinSource(groupByConfOld, queryRange, tableUtils, computeDependency, showDf) - val inputDf = groupByConf.sources.toScala - .map { source => - val partitionColumn = tableUtils.getPartitionColumn(source.query) - tableUtils.sqlWithDefaultPartitionColumn( - renderDataSourceQuery( - groupByConf, - source, - groupByConf.getKeyColumns.toScala, - queryRange, - tableUtils, - groupByConf.maxWindow, - groupByConf.inferredAccuracy, - partitionColumn = partitionColumn - ), - existingPartitionColumn = partitionColumn - ) - } - .reduce { (df1, df2) => - // align the columns by name - when one source has select * the ordering might not be aligned - val columns1 = df1.schema.fields.map(_.name) - df1.union(df2.selectExpr(columns1: _*)) - } + val sourceQueryWindow: Option[Window] = + if (incrementalMode) Some(new Window(queryRange.daysBetween, TimeUnit.DAYS)) else groupByConf.maxWindow + val backfillQueryRange: PartitionRange = + if (incrementalMode) PartitionRange(queryRange.end, queryRange.end)(tableUtils) else queryRange + val inputDf = + buildSourceDataFrame(groupByConf, backfillQueryRange, sourceQueryWindow, tableUtils, schemaOnly = false) def doesNotNeedTime = !Option(groupByConf.getAggregations).exists(_.toScala.needsTimestamp) def hasValidTimeColumn = inputDf.schema.find(_.name == Constants.TimeColumn).exists(_.dataType == LongType) @@ -553,15 +586,22 @@ object GroupBy { logger.info(s"printing mutation data for groupBy: ${groupByConf.metaData.name}") df.prettyPrint() } - df } + //if incrementalMode is enabled, we do not compute finalize values + //IR values are stored in the table + val finalizeValue = if (incrementalMode) { + false + } else { + finalize + } + new GroupBy(Option(groupByConf.getAggregations).map(_.toScala).orNull, keyColumns, nullFiltered, mutationDfFn, - finalize = finalize) + finalize = finalizeValue) } def getIntersectedRange(source: api.Source, @@ -688,12 +728,220 @@ object GroupBy { query } + /** + * Builds a unified DataFrame from all sources in the GroupBy configuration. + * Used to create input DataFrames with proper schema alignment. + * + * @param groupByConf the GroupBy configuration + * @param range the partition range to query + * @param window the window size for querying (None uses maxWindow from config) + * @param tableUtils table utilities for partition handling + * @param schemaOnly if true, returns empty DataFrame with just schema (uses .limit(0)) + * @return unified DataFrame from all sources + */ + private def buildSourceDataFrame( + groupByConf: api.GroupBy, + range: PartitionRange, + window: Option[Window], + tableUtils: TableUtils, + schemaOnly: Boolean = false + ): DataFrame = { + groupByConf.sources.toScala + .map { source => + val partitionColumn = tableUtils.getPartitionColumn(source.query) + val df = tableUtils.sqlWithDefaultPartitionColumn( + renderDataSourceQuery( + groupByConf, + source, + groupByConf.getKeyColumns.toScala, + range, + tableUtils, + window.orElse(groupByConf.maxWindow), + groupByConf.inferredAccuracy, + partitionColumn = partitionColumn + ), + existingPartitionColumn = partitionColumn + ) + if (schemaOnly) df.limit(0) else df + } + .reduce { (df1, df2) => + // align the columns by name - when one source has select * the ordering might not be aligned + val columns1 = df1.schema.fields.map(_.name) + df1.union(df2.selectExpr(columns1: _*)) + } + } + + /** + * Computes and saves the output of hopsAggregation. + * HopsAggregate computes event level data to daily aggregates and saves the output in IR format + * + * @param groupByConf + * @param range + * @param tableUtils + */ + def computeIncrementalDf( + groupByConf: api.GroupBy, + range: PartitionRange, + tableUtils: TableUtils, + incrementalOutputTable: String + ): (PartitionRange, Seq[api.AggregationPart]) = { + + val tableProps: Map[String, String] = Option(groupByConf.metaData.tableProperties) + .map(_.toScala) + .orNull + + val maxWindow = groupByConf.maxWindow.getOrElse( + throw new IllegalArgumentException( + s"GroupBy ${groupByConf.metaData.name} has no windowed aggregations. " + + "Incremental mode requires at least one windowed aggregation.")) + + val incrementalQueryableRange = PartitionRange( + tableUtils.partitionSpec.minus(range.start, maxWindow), + range.end + )(tableUtils) + + logger.info(s"Writing incremental df to $incrementalOutputTable") + + val partitionRangeHoles: Option[Seq[PartitionRange]] = tableUtils.unfilledRanges( + incrementalOutputTable, + incrementalQueryableRange, + skipFirstHole = false + ) + + val aggregationParts = groupByConf.getAggregations.toScala.flatMap(_.unWindowed) + + partitionRangeHoles.foreach { holes => + holes.foreach { hole => + logger.info(s"Filling hole in incremental table: $hole") + val incrementalGroupByBackfill = + from(groupByConf, hole, tableUtils, computeDependency = true, incrementalMode = true) + incrementalGroupByBackfill.computeIncrementalDf(incrementalOutputTable, hole, tableProps) + } + } + + (incrementalQueryableRange, aggregationParts) + } + + private def convertIncrementalDfToHops( + incrementalDf: DataFrame, + aggregationParts: Seq[api.AggregationPart], + groupByConf: api.GroupBy, + tableUtils: TableUtils, + inputSchema: Seq[(String, api.DataType)]): RDD[(KeyWithHash, HopsAggregator.OutputArrayType)] = { + + val keyColumns = groupByConf.getKeyColumns.toScala + val keyBuilder: Row => KeyWithHash = + FastHashing.generateKeyBuilder(keyColumns.toArray, incrementalDf.schema) + + incrementalDf.rdd + .mapPartitions { rows => + // Reconstruct RowAggregator once per partition on executor + // Avoids serializing TimedDispatcher and reduces network overhead + val rowAggregator = new RowAggregator(inputSchema, aggregationParts) + + rows.map { row => + //Extract timestamp from partition column + val ds = row.getAs[String](tableUtils.partitionColumn) + val ts = tableUtils.partitionSpec.epochMillis(ds) + + // Extract normalized IRs from the row + // Recursively convert Spark types to Java types that denormalize() expects + // This handles nested structures (e.g., FIRST_K: Array → ArrayList[Array]) + def convertSparkToJava(value: Any): Any = + value match { + case null => null + case r: Row => + // Struct → Array[Any], recursively convert nested values + r.toSeq.map(convertSparkToJava).toArray + case arr: scala.collection.mutable.WrappedArray[_] => + // Array → ArrayList, recursively convert elements + val converted = arr.map(convertSparkToJava) + new java.util.ArrayList[Any](converted.toSeq.asJava) + case map: scala.collection.Map[_, _] => + // Map → HashMap, recursively convert values + val javaMap = new java.util.HashMap[Any, Any]() + map.foreach { + case (k, v) => + javaMap.put(k, convertSparkToJava(v)) + } + javaMap + case other => + // Scalars (Long, Double, String, byte arrays, etc.) pass through + other + } + + val normalizedIrs = new Array[Any](aggregationParts.length) + aggregationParts.zipWithIndex.foreach { + case (part, idx) => + val value = row.get(row.fieldIndex(part.incrementalOutputColumnName)) + normalizedIrs(idx) = convertSparkToJava(value) + } + + // Denormalize IRs to in-memory format (e.g., ArrayList -> HashSet) + val irs = rowAggregator.denormalize(normalizedIrs) + + // Build HopIR : [IR1, IR2, ..., IRn, ts] + val hopIr: HopIr = irs :+ ts + (keyBuilder(row), hopIr) + } + } + .groupByKey() + .mapValues { hopIrs => + //Convert to HopsAggregator.OutputArrayType: Array[Array[HopIr]] + val sortedHops = hopIrs.toArray.sortBy(_.last.asInstanceOf[Long]) + Array(sortedHops) + } + } + + def fromIncrementalDf( + groupByConf: api.GroupBy, + range: PartitionRange, + tableUtils: TableUtils + ): GroupBy = { + + val incrementalOutputTable = groupByConf.metaData.incrementalOutputTable + assert(incrementalOutputTable != null, + s"incrementalOutputTable is not set for GroupBy: ${groupByConf.metaData.name}") + + val (incrementalQueryableRange, aggregationParts) = + computeIncrementalDf(groupByConf, range, tableUtils, incrementalOutputTable) + + val (_, incrementalDf: DataFrame) = incrementalQueryableRange.scanQueryStringAndDf(null, incrementalOutputTable) + + // Create a DataFrame with the source schema (raw data schema) to match aggregations + // We need this because GroupBy class variables expect inputDf schema to match aggregation input columns + // We create an empty DataFrame with the correct schema - it won't be used for computation + val sourceDf = buildSourceDataFrame(groupByConf, range, None, tableUtils, schemaOnly = true) + + // Extract input schema for RowAggregator reconstruction on executors + // Pass lightweight schema instead of heavy RowAggregator to avoid serializing TimedDispatcher + val chrononSchema = SparkConversions.toChrononSchema(sourceDf.schema) + + val incrementalHops = + convertIncrementalDfToHops(incrementalDf, aggregationParts, groupByConf, tableUtils, chrononSchema) + + new GroupBy( + groupByConf.getAggregations.toScala, + groupByConf.getKeyColumns.toScala, + sourceDf, // Use source schema, not incremental schema + () => null + ) { + // Override hopsAggregate to return precomputed hops instead of computing from raw data + override def hopsAggregate(minQueryTs: Long, + resolution: Resolution): RDD[(KeyWithHash, HopsAggregator.OutputArrayType)] = { + incrementalHops + } + } + + } + def computeBackfill(groupByConf: api.GroupBy, endPartition: String, tableUtils: TableUtils, stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None, - skipFirstHole: Boolean = true): Unit = { + skipFirstHole: Boolean = true, + incrementalMode: Boolean = false): Unit = { assert( groupByConf.backfillStartDate != null, s"GroupBy:${groupByConf.metaData.name} has null backfillStartDate. This needs to be set for offline backfilling.") @@ -751,7 +999,11 @@ object GroupBy { stepRanges.zipWithIndex.foreach { case (range, index) => logger.info(s"Computing group by for range: $range [${index + 1}/${stepRanges.size}]") - val groupByBackfill = from(groupByConf, range, tableUtils, computeDependency = true) + val groupByBackfill: GroupBy = if (incrementalMode) { + fromIncrementalDf(groupByConf, range, tableUtils) + } else { + from(groupByConf, range, tableUtils, computeDependency = true) + } val outputDf = groupByConf.dataModel match { // group by backfills have to be snapshot only case Entities => groupByBackfill.snapshotEntities diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByIncrementalTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByIncrementalTest.scala new file mode 100644 index 000000000..28ccde4a1 --- /dev/null +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByIncrementalTest.scala @@ -0,0 +1,469 @@ +/* + * Copyright (C) 2023 The Chronon Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.chronon.spark.test + +import ai.chronon.aggregator.test.{CStream, Column} +import ai.chronon.api.Extensions._ +import ai.chronon.api.{ + Aggregation, + Builders, + DoubleType, + IntType, + LongType, + Operation, + Source, + StringType, + TimeUnit, + Window +} +import ai.chronon.spark.Extensions._ +import ai.chronon.spark._ +import ai.chronon.spark.catalog.TableUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.{StructField, StructType, LongType => SparkLongType, StringType => SparkStringType} +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.functions.col +import org.junit.Assert._ +import org.junit.Test + +import scala.util.Random + +class GroupByIncrementalTest { + + private def createTestSourceIncremental(windowSize: Int = 365, + suffix: String = "", + partitionColOpt: Option[String] = None): (Source, String) = { + lazy val spark: SparkSession = + SparkSessionBuilder.build("GroupByIncrementalTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) + implicit val tableUtils = TableUtils(spark) + val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) + val startPartition = tableUtils.partitionSpec.minus(today, new Window(windowSize, TimeUnit.DAYS)) + val endPartition = tableUtils.partitionSpec.at(System.currentTimeMillis()) + val sourceSchema = List( + Column("user", StringType, 10000), + Column("item", StringType, 100), + Column("time_spent_ms", LongType, 5000), + Column("price", DoubleType, 100) + ) + val namespace = "chronon_incremental_test" + val sourceTable = s"$namespace.test_group_by_steps$suffix" + + tableUtils.createDatabase(namespace) + val genDf = + DataFrameGen.events(spark, sourceSchema, count = 1000, partitions = 200, partitionColOpt = partitionColOpt) + partitionColOpt match { + case Some(partitionCol) => genDf.save(sourceTable, partitionColumns = Seq(partitionCol)) + case None => genDf.save(sourceTable) + } + + val source = Builders.Source.events( + query = Builders.Query(selects = Builders.Selects("ts", "user", "time_spent_ms", "price", "item"), + startPartition = startPartition, + partitionColumn = partitionColOpt.orNull), + table = sourceTable + ) + (source, endPartition) + } + + /** + * Tests basic aggregations in incremental mode by comparing Chronon's output against SQL. + * + * Operations: SUM, COUNT, AVERAGE, MIN, MAX, VARIANCE, UNIQUE_COUNT, HISTOGRAM, BOUNDED_UNIQUE_COUNT + * + * Actual: Chronon computes daily IRs using computeIncrementalDf, storing intermediate results + * Expected: SQL query computes the same aggregations directly on the input data for the same date + */ + @Test + def testIncrementalBasicAggregations(): Unit = { + lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByTestIncrementalBasic" + "_" + Random.alphanumeric.take(6).mkString, local = true) + implicit val tableUtils = TableUtils(spark) + val namespace = s"incremental_basic_aggs_${Random.alphanumeric.take(6).mkString}" + tableUtils.createDatabase(namespace) + + val schema = List( + Column("user", StringType, 10), + Column("price", DoubleType, 100), + Column("quantity", IntType, 50), + Column("product_id", StringType, 20), // Low cardinality for UNIQUE_COUNT, HISTOGRAM, BOUNDED_UNIQUE_COUNT + Column("rating", DoubleType, 2000) + ) + + val df = DataFrameGen.events(spark, schema, count = 100000, partitions = 100) + + val aggregations: Seq[Aggregation] = Seq( + // Simple aggregations + Builders.Aggregation(Operation.SUM, "price", Seq(new Window(7, TimeUnit.DAYS))), + Builders.Aggregation(Operation.COUNT, "quantity", Seq(new Window(7, TimeUnit.DAYS))), + + // Complex aggregation - AVERAGE (struct IR with sum/count) + Builders.Aggregation(Operation.AVERAGE, "rating", Seq(new Window(7, TimeUnit.DAYS))), + + // Min/Max + Builders.Aggregation(Operation.MIN, "price", Seq(new Window(7, TimeUnit.DAYS))), + Builders.Aggregation(Operation.MAX, "quantity", Seq(new Window(7, TimeUnit.DAYS))), + + // Variance (struct IR with count/mean/m2) + Builders.Aggregation(Operation.VARIANCE, "price", Seq(new Window(7, TimeUnit.DAYS))), + + // UNIQUE_COUNT (array IR): IR = array of distinct values + Builders.Aggregation(Operation.UNIQUE_COUNT, "price", Seq(new Window(7, TimeUnit.DAYS))), + + // HISTOGRAM (map IR) + Builders.Aggregation(Operation.HISTOGRAM, "product_id", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "0")), + + // BOUNDED_UNIQUE_COUNT (array IR with bound): IR = array of bounded distinct values (MD5-hashed) + Builders.Aggregation(Operation.BOUNDED_UNIQUE_COUNT, "product_id", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "100")) + ) + + val tableProps: Map[String, String] = Map("source" -> "chronon") + + val today_date = tableUtils.partitionSpec.at(System.currentTimeMillis()) + val today_minus_7_date = tableUtils.partitionSpec.minus(today_date, new Window(7, TimeUnit.DAYS)) + val today_minus_20_date = tableUtils.partitionSpec.minus(today_date, new Window(20, TimeUnit.DAYS)) + + val partitionRange = PartitionRange(today_minus_20_date, today_date) + + val groupBy = new GroupBy(aggregations, Seq("user"), df) + groupBy.computeIncrementalDf(s"${namespace}.testIncrementalBasicAggsOutput", partitionRange, tableProps) + + val actualIncrementalDf = spark.sql(s"select * from ${namespace}.testIncrementalBasicAggsOutput where ds='$today_minus_7_date'") + df.createOrReplaceTempView("test_basic_aggs_input") + + // Compare against SQL computation + val query = + s""" + |WITH base_aggs AS ( + | SELECT user, ds, UNIX_TIMESTAMP(ds, 'yyyy-MM-dd')*1000 as ts, + | sum(price) as price_sum, + | count(quantity) as quantity_count, + | struct(sum(rating) as sum, count(rating) as count) as rating_average, + | min(price) as price_min, + | max(quantity) as quantity_max, + | struct( + | cast(count(price) as int) as count, + | avg(price) as mean, + | sum(price * price) - count(price) * avg(price) * avg(price) as m2 + | ) as price_variance, + | count(distinct price) as price_unique_count, + | least(count(distinct product_id), 100) as product_id_bounded_unique_count + | FROM test_basic_aggs_input + | WHERE ds='$today_minus_7_date' + | GROUP BY user, ds + |), + |histogram_agg AS ( + | SELECT user, ds, + | map_from_entries(collect_list(struct(product_id, cast(cnt as int)))) as product_id_histogram + | FROM ( + | SELECT user, ds, product_id, count(*) as cnt + | FROM test_basic_aggs_input + | WHERE ds='$today_minus_7_date' AND product_id IS NOT NULL + | GROUP BY user, ds, product_id + | ) + | GROUP BY user, ds + |) + |SELECT b.*, h.product_id_histogram + |FROM base_aggs b + |LEFT JOIN histogram_agg h ON b.user <=> h.user AND b.ds <=> h.ds + |""".stripMargin + + val expectedDf = spark.sql(query) + + // Replace UNIQUE_COUNT and BOUNDED_UNIQUE_COUNT array columns with their sizes for comparison. + // SQL produces Long counts; size() returns Int — cast both to Long for type consistency. + import org.apache.spark.sql.functions.size + val actualForComparison = actualIncrementalDf + .withColumn("price_unique_count", size(col("price_unique_count")).cast("long")) + .withColumn("product_id_bounded_unique_count", size(col("product_id_bounded_unique_count")).cast("long")) + + val diff = Comparison.sideBySide(actualForComparison, expectedDf, List("user", tableUtils.partitionColumn)) + + val irRowCount = actualIncrementalDf.count() + if (diff.count() > 0) { + println(s"=== Diff Details for All Aggregations ===") + println(s"Actual count: ${irRowCount}") + println(s"Expected count: ${expectedDf.count()}") + println(s"Diff count: ${diff.count()}") + actualForComparison.show(10, truncate = false) + diff.show(100, truncate = false) + } + + assertEquals(0, diff.count()) + } + + /** + * This test verifies that the incremental snapshotEvents output matches the non-incremental output. + * + * 1. Computes snapshotEvents using the standard GroupBy on the full input data. + * 2. Computes snapshotEvents using GroupBy in incremental mode over the same date range. + * 3. Compares the two outputs to ensure they are identical. + */ + @Test + def testSnapshotIncrementalEvents(): Unit = { + lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) + implicit val tableUtils = TableUtils(spark) + val namespace = s"incremental_groupBy_snapshot_${Random.alphanumeric.take(6).mkString}" + tableUtils.createDatabase(namespace) + + + val outputDates = CStream.genPartitions(10, tableUtils.partitionSpec) + + val aggregations: Seq[Aggregation] = Seq( + // Basic + Builders.Aggregation(Operation.SUM, "time_spent_ms", Seq(new Window(10, TimeUnit.DAYS), new Window(5, TimeUnit.DAYS))), + Builders.Aggregation(Operation.SUM, "price", Seq(new Window(10, TimeUnit.DAYS))), + Builders.Aggregation(Operation.COUNT, "user", Seq(new Window(10, TimeUnit.DAYS))), + Builders.Aggregation(Operation.AVERAGE, "price", Seq(new Window(10, TimeUnit.DAYS))), + Builders.Aggregation(Operation.MIN, "price", Seq(new Window(10, TimeUnit.DAYS))), + Builders.Aggregation(Operation.MAX, "price", Seq(new Window(10, TimeUnit.DAYS))), + // Statistical + Builders.Aggregation(Operation.VARIANCE, "price", Seq(new Window(10, TimeUnit.DAYS))), + Builders.Aggregation(Operation.SKEW, "price", Seq(new Window(10, TimeUnit.DAYS))), + Builders.Aggregation(Operation.KURTOSIS, "price", Seq(new Window(10, TimeUnit.DAYS))), + Builders.Aggregation(Operation.APPROX_PERCENTILE, "price", Seq(new Window(10, TimeUnit.DAYS)), + argMap = Map("percentiles" -> "[0.5, 0.25, 0.75]")), + // Temporal + Builders.Aggregation(Operation.FIRST, "price", Seq(new Window(10, TimeUnit.DAYS))), + Builders.Aggregation(Operation.LAST, "price", Seq(new Window(10, TimeUnit.DAYS))), + Builders.Aggregation(Operation.FIRST_K, "price", Seq(new Window(10, TimeUnit.DAYS)), argMap = Map("k" -> "3")), + Builders.Aggregation(Operation.LAST_K, "price", Seq(new Window(10, TimeUnit.DAYS)), argMap = Map("k" -> "3")), + Builders.Aggregation(Operation.TOP_K, "price", Seq(new Window(10, TimeUnit.DAYS)), argMap = Map("k" -> "3")), + Builders.Aggregation(Operation.BOTTOM_K, "price", Seq(new Window(10, TimeUnit.DAYS)), argMap = Map("k" -> "3")), + // Cardinality / Set + Builders.Aggregation(Operation.UNIQUE_COUNT, "user", Seq(new Window(10, TimeUnit.DAYS))), + Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "user", Seq(new Window(10, TimeUnit.DAYS))), + Builders.Aggregation(Operation.BOUNDED_UNIQUE_COUNT, "user", Seq(new Window(10, TimeUnit.DAYS))), + // Distribution + Builders.Aggregation(Operation.HISTOGRAM, "user", Seq(new Window(10, TimeUnit.DAYS))), + Builders.Aggregation(Operation.APPROX_HISTOGRAM_K, "user", Seq(new Window(10, TimeUnit.DAYS)), argMap = Map("k" -> "10")) + ) + + val (source, endPartition) = createTestSourceIncremental(windowSize = 30, suffix = "_snapshot_events", partitionColOpt = Some(tableUtils.partitionColumn)) + val groupByConf = Builders.GroupBy( + sources = Seq(source), + keyColumns = Seq("item"), + aggregations = aggregations, + metaData = Builders.MetaData(name = "testSnapshotIncremental", namespace = namespace, team = "chronon"), + backfillStartDate = tableUtils.partitionSpec.minus(tableUtils.partitionSpec.at(System.currentTimeMillis()), + new Window(20, TimeUnit.DAYS)) + ) + + val df = spark.read.table(source.table) + + val groupBy = new GroupBy(aggregations, Seq("item"), df.filter("item is not null")) + val actualDf = groupBy.snapshotEvents(PartitionRange(outputDates.min, outputDates.max)) + + val groupByIncremental = GroupBy.fromIncrementalDf(groupByConf, PartitionRange(outputDates.min, outputDates.max), tableUtils) + val incrementalExpectedDf = groupByIncremental.snapshotEvents(PartitionRange(outputDates.min, outputDates.max)) + + val outputDatesRdd: RDD[Row] = spark.sparkContext.parallelize(outputDates.map(Row(_))) + val outputDatesDf = spark.createDataFrame(outputDatesRdd, StructType(Seq(StructField("ds", SparkStringType)))) + val datesViewName = "test_group_by_snapshot_events_output_range" + outputDatesDf.createOrReplaceTempView(datesViewName) + + val diff = Comparison.sideBySide(actualDf, incrementalExpectedDf, List("item", tableUtils.partitionColumn)) + if (diff.count() > 0) { + diff.show() + println("=== Diff result rows ===") + } + assertEquals(0, diff.count()) + } + + /** + * Unit test for FIRST and LAST aggregations with incremental IR + * FIRST/LAST use TimeTuple IR: struct {epochMillis: Long, payload: Value} + * FIRST keeps the value with the earliest timestamp + * LAST keeps the value with the latest timestamp + */ + @Test + def testIncrementalFirstLast(): Unit = { + lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByTestFirstLast" + "_" + Random.alphanumeric.take(6).mkString, local = true) + implicit val tableUtils = TableUtils(spark) + val namespace = s"incremental_first_last_${Random.alphanumeric.take(6).mkString}" + tableUtils.createDatabase(namespace) + + val schema = List( + Column("user", StringType, 5), + Column("value", DoubleType, 100) + ) + + // Generate events and add random milliseconds to ts for unique timestamps + import org.apache.spark.sql.functions.{rand, col} + import org.apache.spark.sql.types.{LongType => SparkLongType} + + val dfWithRandom = DataFrameGen.events(spark, schema, count = 10000, partitions = 20) + .withColumn("ts", col("ts") + (rand() * 86400000).cast(SparkLongType)) // Add 0-24h random millis + .cache() // Mark for caching + + // Force materialization - computes and caches the random values + dfWithRandom.count() + + // Write the CACHED data to table - writes already-materialized values + dfWithRandom.write.mode("overwrite").saveAsTable(s"${namespace}.test_first_last_input") + + // Read back from table - guaranteed same data as what was written + val df = spark.table(s"${namespace}.test_first_last_input") + + val aggregations: Seq[Aggregation] = Seq( + Builders.Aggregation(Operation.FIRST, "value", Seq(new Window(7, TimeUnit.DAYS))), + Builders.Aggregation(Operation.LAST, "value", Seq(new Window(7, TimeUnit.DAYS))), + Builders.Aggregation(Operation.FIRST_K, "value", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "3")), + Builders.Aggregation(Operation.LAST_K, "value", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "3")), + Builders.Aggregation(Operation.TOP_K, "value", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "3")), + Builders.Aggregation(Operation.BOTTOM_K, "value", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "3")) + ) + + val tableProps: Map[String, String] = Map("source" -> "chronon") + val today_date = tableUtils.partitionSpec.at(System.currentTimeMillis()) + val today_minus_7_date = tableUtils.partitionSpec.minus(today_date, new Window(7, TimeUnit.DAYS)) + val today_minus_20_date = tableUtils.partitionSpec.minus(today_date, new Window(20, TimeUnit.DAYS)) + val partitionRange = PartitionRange(today_minus_20_date, today_date) + + val groupBy = new GroupBy(aggregations, Seq("user"), df) + groupBy.computeIncrementalDf(s"${namespace}.testIncrementalFirstLastOutput", partitionRange, tableProps) + + val rawIncrementalDf = spark.sql(s"select * from ${namespace}.testIncrementalFirstLastOutput where ds='$today_minus_7_date'") + + println("=== Incremental FIRST/LAST IR Schema ===") + rawIncrementalDf.printSchema() + + // Sort array columns in raw IRs to match SQL output ordering + // Raw IRs store unsorted arrays for mergeability, but we need to sort them for comparison + import org.apache.spark.sql.functions.{sort_array, col} + val actualIncrementalDf = rawIncrementalDf + .withColumn("value_first3", sort_array(col("value_first3"))) + .withColumn("value_last3", sort_array(col("value_last3"))) + .withColumn("value_top3", sort_array(col("value_top3"))) + .withColumn("value_bottom3", sort_array(col("value_bottom3"))) + + // Compare against SQL computation + // Note: ts column in IR table is the partition timestamp (derived from ds) + // But FIRST/LAST use the actual event timestamps (with random milliseconds) + val query = + s""" + |SELECT user, + | to_date(from_unixtime(ts / 1000, 'yyyy-MM-dd HH:mm:ss')) as ds, + | named_struct( + | 'epochMillis', min(ts), + | 'payload', sort_array(collect_list(struct(ts, value)))[0].value + | ) as value_first, + | named_struct( + | 'epochMillis', max(ts), + | 'payload', reverse(sort_array(collect_list(struct(ts, value))))[0].value + | ) as value_last, + | transform( + | slice(sort_array(filter(collect_list(struct(ts, value)), x -> x.value IS NOT NULL)), 1, 3), + | x -> named_struct('epochMillis', x.ts, 'payload', x.value) + | ) as value_first3, + | transform( + | slice(sort_array(filter(collect_list(struct(ts, value)), x -> x.value IS NOT NULL)), + | greatest(-size(sort_array(filter(collect_list(struct(ts, value)), x -> x.value IS NOT NULL))), -3), 3), + | x -> named_struct('epochMillis', x.ts, 'payload', x.value) + | ) as value_last3, + | transform( + | slice(sort_array(filter(collect_list(struct(value, ts)), x -> x.value IS NOT NULL), true), + | greatest(-size(sort_array(filter(collect_list(struct(value, ts)), x -> x.value IS NOT NULL))), -3), 3), + | x -> x.value + | ) as value_top3, + | transform( + | slice(sort_array(filter(collect_list(struct(value, ts)), x -> x.value IS NOT NULL), true), 1, 3), + | x -> x.value + | ) as value_bottom3 + |FROM ${namespace}.test_first_last_input + |WHERE to_date(from_unixtime(ts / 1000, 'yyyy-MM-dd HH:mm:ss'))='$today_minus_7_date' + |GROUP BY user, to_date(from_unixtime(ts / 1000, 'yyyy-MM-dd HH:mm:ss')) + |""".stripMargin + + val expectedDf = spark.sql(query) + + // Drop ts from comparison - it's just the partition timestamp, not part of the aggregation IR + val actualWithoutTs = actualIncrementalDf.drop("ts") + + val diff = Comparison.sideBySide(actualWithoutTs, expectedDf, List("user", tableUtils.partitionColumn)) + + if (diff.count() > 0) { + println(s"=== Diff Details for Time-based Aggregations ===") + println(s"Expected count: ${expectedDf.count()}") + println(s"Diff count: ${diff.count()}") + diff.show(100, truncate = false) + } + + assertEquals(0, diff.count()) + + println("=== Time-based Aggregations Incremental Test Passed ===") + println("✓ FIRST: TimeTuple IR {epochMillis, payload}") + println("✓ LAST: TimeTuple IR {epochMillis, payload}") + println("✓ FIRST_K: Array[TimeTuple] - stores timestamps") + println("✓ LAST_K: Array[TimeTuple] - stores timestamps") + println("✓ TOP_K: Array[Double] - stores only values") + println("✓ BOTTOM_K: Array[Double] - stores only values") + + // Cleanup + spark.stop() + } + + @Test + def testIncrementalStatisticalAggregations(): Unit = { + lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByTestStatistical" + "_" + Random.alphanumeric.take(6).mkString, local = true) + implicit val tableUtils = TableUtils(spark) + val namespace = s"incremental_stats_${Random.alphanumeric.take(6).mkString}" + tableUtils.createDatabase(namespace) + + val outputDates = CStream.genPartitions(10, tableUtils.partitionSpec) + + val aggregations: Seq[Aggregation] = Seq( + // Moment-based (IR = array [n, m1, m2, m3, m4]); finalized to Double + Builders.Aggregation(Operation.SKEW, "price", Seq(new Window(10, TimeUnit.DAYS))), + Builders.Aggregation(Operation.KURTOSIS, "price", Seq(new Window(10, TimeUnit.DAYS))), + // Sketch-based (IR = binary KLL sketch); finalized to Array[Float] + Builders.Aggregation(Operation.APPROX_PERCENTILE, "price", Seq(new Window(10, TimeUnit.DAYS)), + argMap = Map("percentiles" -> "[0.5, 0.25, 0.75]")), + // Sketch-based (IR = binary CPC sketch); finalized to Long + Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "user", Seq(new Window(10, TimeUnit.DAYS))), + // Sketch-based (IR = binary); finalized to Map[String, Long] + Builders.Aggregation(Operation.APPROX_HISTOGRAM_K, "user", Seq(new Window(10, TimeUnit.DAYS)), + argMap = Map("k" -> "10")) + ) + + val (source, _) = createTestSourceIncremental(windowSize = 30, suffix = "_stats_events", + partitionColOpt = Some(tableUtils.partitionColumn)) + val groupByConf = Builders.GroupBy( + sources = Seq(source), + keyColumns = Seq("item"), + aggregations = aggregations, + metaData = Builders.MetaData(name = "testIncrementalStats", namespace = namespace, team = "chronon"), + backfillStartDate = tableUtils.partitionSpec.minus(tableUtils.partitionSpec.at(System.currentTimeMillis()), + new Window(20, TimeUnit.DAYS)) + ) + + val df = spark.read.table(source.table) + val groupBy = new GroupBy(aggregations, Seq("item"), df.filter("item is not null")) + val nonIncrementalDf = groupBy.snapshotEvents(PartitionRange(outputDates.min, outputDates.max)) + + val groupByIncremental = GroupBy.fromIncrementalDf(groupByConf, PartitionRange(outputDates.min, outputDates.max), tableUtils) + val incrementalDf = groupByIncremental.snapshotEvents(PartitionRange(outputDates.min, outputDates.max)) + + val diff = Comparison.sideBySide(nonIncrementalDf, incrementalDf, List("item", tableUtils.partitionColumn)) + if (diff.count() > 0) { + println("=== Diff Details for Statistical Aggregations ===") + diff.show(100, truncate = false) + } + assertEquals(0, diff.count()) + + // Cleanup + spark.stop() + } +}