diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/row/RowAggregator.scala b/aggregator/src/main/scala/ai/chronon/aggregator/row/RowAggregator.scala index c8bc1da08c..e9d0608d25 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/row/RowAggregator.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/row/RowAggregator.scala @@ -70,6 +70,11 @@ class RowAggregator(val inputSchema: Seq[(String, DataType)], val aggregationPar .toArray .zip(columnAggregators.map(_.irType)) + val incrementalOutputSchema = aggregationParts + .map(_.incrementalOutputColumnName) + .toArray + .zip(columnAggregators.map(_.irType)) + val outputSchema: Array[(String, DataType)] = aggregationParts .map(_.outputColumnName) .toArray diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothOnlineAggregatorTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothOnlineAggregatorTest.scala index 86546e3669..4b2a6e31b8 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothOnlineAggregatorTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothOnlineAggregatorTest.scala @@ -143,7 +143,7 @@ class SawtoothOnlineAggregatorTest extends TestCase { operation = Operation.HISTOGRAM, inputColumn = "action", windows = Seq( - new Window(3, TimeUnit.DAYS), + new Window(3, TimeUnit.DAYS) ) ) ) @@ -162,15 +162,15 @@ class SawtoothOnlineAggregatorTest extends TestCase { val finalBatchIr = FinalBatchIr( Array[Any]( - null, // collapsed (T-1 -> T) + null // collapsed (T-1 -> T) ), Array( - Array.empty, // 1‑day hops (not used) - Array( // 1-hour hops - hop(1, 1746745200000L), // 2025-05-08 23:00:00 UTC - hop(1, 1746766800000L), // 2025-05-09 05:00:00 UTC + Array.empty, // 1‑day hops (not used) + Array( // 1-hour hops + hop(1, 1746745200000L), // 2025-05-08 23:00:00 UTC + hop(1, 1746766800000L) // 2025-05-09 05:00:00 UTC ), - Array.empty // 5‑minute hops (not used) + Array.empty // 5‑minute hops (not used) ) ) val queryTs = batchEndTs + 100 diff --git a/api/py/ai/chronon/group_by.py b/api/py/ai/chronon/group_by.py index 6ddb816d98..f8aea48b28 100644 --- a/api/py/ai/chronon/group_by.py +++ b/api/py/ai/chronon/group_by.py @@ -362,7 +362,7 @@ def GroupBy( tags: Optional[Dict[str, str]] = None, derivations: Optional[List[ttypes.Derivation]] = None, deprecation_date: Optional[str] = None, - description: Optional[str] = None, + is_incremental: Optional[bool] = False, **kwargs, ) -> ttypes.GroupBy: """ @@ -570,6 +570,7 @@ def _normalize_source(source): backfillStartDate=backfill_start_date, accuracy=accuracy, derivations=derivations, + isIncremental=is_incremental, ) validate_group_by(group_by) return group_by diff --git a/api/py/test/sample/scripts/spark_submit.sh b/api/py/test/sample/scripts/spark_submit.sh index 45102e8843..ef048a5532 100644 --- a/api/py/test/sample/scripts/spark_submit.sh +++ b/api/py/test/sample/scripts/spark_submit.sh @@ -28,13 +28,14 @@ set -euxo pipefail CHRONON_WORKING_DIR=${CHRONON_TMPDIR:-/tmp}/${USER} +echo $CHRONON_WORKING_DIR mkdir -p ${CHRONON_WORKING_DIR} export TEST_NAME="${APP_NAME}_${USER}_test" unset PYSPARK_DRIVER_PYTHON unset PYSPARK_PYTHON unset SPARK_HOME unset SPARK_CONF_DIR -export LOG4J_FILE="${CHRONON_WORKING_DIR}/log4j_file" +export LOG4J_FILE="${CHRONON_WORKING_DIR}/log4j.properties" cat > ${LOG4J_FILE} << EOF log4j.rootLogger=INFO, stdout log4j.appender.stdout=org.apache.log4j.ConsoleAppender @@ -47,6 +48,9 @@ EOF $SPARK_SUBMIT_PATH \ --driver-java-options " -Dlog4j.configuration=file:${LOG4J_FILE}" \ --conf "spark.executor.extraJavaOptions= -XX:ParallelGCThreads=4 -XX:+UseParallelGC -XX:+UseCompressedOops" \ +--conf "spark.driver.extraJavaOptions=-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=5005 -Dlog4j.configuration=file:${LOG4J_FILE}" \ +--conf "spark.sql.warehouse.dir=/home/chaitu/projects/chronon/spark-warehouse" \ +--conf "javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=/home/chaitu/projects/chronon/hive-metastore/metastore_db;create=true" \ --conf spark.sql.shuffle.partitions=${PARALLELISM:-4000} \ --conf spark.dynamicAllocation.maxExecutors=${MAX_EXECUTORS:-1000} \ --conf spark.default.parallelism=${PARALLELISM:-4000} \ @@ -77,3 +81,6 @@ tee ${CHRONON_WORKING_DIR}/${APP_NAME}_spark.log + +#--conf "spark.driver.extraJavaOptions=-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=5005 -Dlog4j.rootLogger=INFO,console" \ + diff --git a/api/py/test/sample/teams.json b/api/py/test/sample/teams.json index 39f7a25559..a60502b65d 100644 --- a/api/py/test/sample/teams.json +++ b/api/py/test/sample/teams.json @@ -5,7 +5,7 @@ }, "common_env": { "VERSION": "latest", - "SPARK_SUBMIT_PATH": "[TODO]/path/to/spark-submit", + "SPARK_SUBMIT_PATH": "spark-submit", "JOB_MODE": "local[*]", "HADOOP_DIR": "[STREAMING-TODO]/path/to/folder/containing", "CHRONON_ONLINE_CLASS": "[ONLINE-TODO]your.online.class", diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index 86ead89c07..6a4cb6de6f 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -98,7 +98,7 @@ object Extensions { def cleanName: String = metaData.name.sanitize def outputTable = s"${metaData.outputNamespace}.${metaData.cleanName}" - + def incrementalOutputTable = s"${metaData.outputNamespace}.${metaData.cleanName}_inc" def preModelTransformsTable = s"${metaData.outputNamespace}.${metaData.cleanName}_pre_mt" def outputLabelTable = s"${metaData.outputNamespace}.${metaData.cleanName}_labels" def outputFinalView = s"${metaData.outputNamespace}.${metaData.cleanName}_labeled" @@ -179,6 +179,10 @@ object Extensions { def outputColumnName = s"${aggregationPart.inputColumn}_$opSuffix${aggregationPart.window.suffix}${bucketSuffix}" + + def incrementalOutputColumnName = + s"${aggregationPart.inputColumn}_$opSuffix${bucketSuffix}" + } implicit class AggregationOps(aggregation: Aggregation) { diff --git a/api/thrift/api.thrift b/api/thrift/api.thrift index dc8f162318..8a7734b43e 100644 --- a/api/thrift/api.thrift +++ b/api/thrift/api.thrift @@ -287,6 +287,7 @@ struct GroupBy { 6: optional string backfillStartDate // Optional derivation list 7: optional list derivations + 8: optional bool isIncremental } struct JoinPart { diff --git a/online/src/main/scala/ai/chronon/online/SparkConversions.scala b/online/src/main/scala/ai/chronon/online/SparkConversions.scala index c669f29a1d..22d2605a04 100644 --- a/online/src/main/scala/ai/chronon/online/SparkConversions.scala +++ b/online/src/main/scala/ai/chronon/online/SparkConversions.scala @@ -163,4 +163,78 @@ object SparkConversions { extraneousRecord ) } + + /** + * Converts a single Spark column value to Chronon normalized IR format. + * + * This is the inverse of toSparkRow() - used when reading pre-computed IR values + * from Spark DataFrames. Each IR column in the DataFrame is converted based on its + * Chronon IR type. + * + * Examples: + * - Count IR: Long → Long (pass-through, primitives stay primitives) + * - Sum IR: Double → Double (pass-through) + * - Average IR: Spark Row(sum, count) → Array[Any](sum, count) + * - UniqueCount IR: Spark Array[T] → java.util.ArrayList[T] + * - Histogram IR: Spark Map[K,V] → java.util.HashMap[K,V] + * - ApproxPercentile IR: Array[Byte] → Array[Byte] (pass-through for binary) + * + * @param sparkValue The value from a Spark DataFrame column + * @param irType The Chronon IR type for this column (from RowAggregator.incrementalOutputSchema) + * @return Normalized IR value ready for denormalize() + */ + def fromSparkValue(sparkValue: Any, irType: api.DataType): Any = { + if (sparkValue == null) return null + + (sparkValue, irType) match { + // Primitives - pass through (Count, Sum, Min, Max, Binary sketches) + case (v, + api.IntType | api.LongType | api.ShortType | api.ByteType | api.FloatType | api.DoubleType | + api.StringType | api.BooleanType | api.BinaryType) => + v + + // Spark Row → Array[Any] (Average, Variance, Skew, Kurtosis, FirstK/LastK) + case (row: Row, api.StructType(_, fields)) => + val arr = new Array[Any](fields.length) + fields.zipWithIndex.foreach { + case (field, idx) => + arr(idx) = fromSparkValue(row.get(idx), field.fieldType) + } + arr + + // Spark mutable.WrappedArray → util.ArrayList (UniqueCount, TopK, BottomK) + case (arr: mutable.WrappedArray[_], api.ListType(elementType)) => + val result = new util.ArrayList[Any](arr.length) + arr.foreach { elem => + result.add(fromSparkValue(elem, elementType)) + } + result + + // Spark native Array → util.ArrayList (alternative array representation) + case (arr: Array[_], api.ListType(elementType)) => + val result = new util.ArrayList[Any](arr.length) + arr.foreach { elem => + result.add(fromSparkValue(elem, elementType)) + } + result + + // Spark scala.collection.Map → util.HashMap (Histogram) + case (map: scala.collection.Map[_, _], api.MapType(keyType, valueType)) => + val result = new util.HashMap[Any, Any]() + map.foreach { + case (k, v) => + result.put( + fromSparkValue(k, keyType), + fromSparkValue(v, valueType) + ) + } + result + + case (value, tpe) => + throw new IllegalArgumentException( + s"Cannot convert Spark value $value (${value.getClass.getSimpleName}) " + + s"to Chronon IR type $tpe" + ) + } + } } diff --git a/spark/src/main/scala/ai/chronon/spark/DataRange.scala b/spark/src/main/scala/ai/chronon/spark/DataRange.scala index b0af96e05d..e6f13e0ea0 100644 --- a/spark/src/main/scala/ai/chronon/spark/DataRange.scala +++ b/spark/src/main/scala/ai/chronon/spark/DataRange.scala @@ -53,6 +53,11 @@ case class PartitionRange(start: String, end: String)(implicit tableUtils: Table } } + def daysBetween: Int = { + if (start == null || end == null) 0 + else Stream.iterate(start)(tableUtils.partitionSpec.after).takeWhile(_ <= end).size + } + def isSingleDay: Boolean = { start == end } diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index dbf2d24f25..a7a898f0f3 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -466,7 +466,8 @@ object Driver { tableUtils, args.stepDays.toOption, args.startPartitionOverride.toOption, - !args.runFirstHole() + !args.runFirstHole(), + args.groupByConf.isIncremental ) if (args.shouldExport()) { diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index ac4053932a..8b6a5fe69f 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -18,11 +18,12 @@ package ai.chronon.spark import ai.chronon.aggregator.base.TimeTuple import ai.chronon.aggregator.row.RowAggregator +import ai.chronon.aggregator.windowing.HopsAggregator.HopIr import ai.chronon.aggregator.windowing._ import ai.chronon.api import ai.chronon.api.DataModel.{Entities, Events} import ai.chronon.api.Extensions._ -import ai.chronon.api.{Accuracy, Constants, DataModel, ParametricMacro} +import ai.chronon.api.{Accuracy, Constants, DataModel, ParametricMacro, Source, TimeUnit, Window} import ai.chronon.online.{RowWrapper, SparkConversions} import ai.chronon.spark.Extensions._ import org.apache.spark.rdd.RDD @@ -35,13 +36,15 @@ import org.slf4j.LoggerFactory import java.util import scala.collection.{Seq, mutable} import scala.util.ScalaJavaConversions.{JListOps, ListOps, MapOps} +import _root_.com.google.common.collect.Table class GroupBy(val aggregations: Seq[api.Aggregation], val keyColumns: Seq[String], val inputDf: DataFrame, val mutationDfFn: () => DataFrame = null, skewFilter: Option[String] = None, - finalize: Boolean = true) + finalize: Boolean = true, + incrementalMode: Boolean = false) extends Serializable { @transient lazy val logger = LoggerFactory.getLogger(getClass) @@ -88,10 +91,17 @@ class GroupBy(val aggregations: Seq[api.Aggregation], lazy val aggPartWithSchema = aggregationParts.zip(columnAggregators.map(_.outputType)) lazy val postAggSchema: StructType = { - val valueChrononSchema = if (finalize) windowAggregator.outputSchema else windowAggregator.irSchema + val valueChrononSchema = if (finalize) { + windowAggregator.outputSchema + } else { + windowAggregator.irSchema + } SparkConversions.fromChrononSchema(valueChrononSchema) } + lazy val flattenedAgg: RowAggregator = new RowAggregator(selectedSchema, aggregations.flatMap(_.unWindowed)) + lazy val incrementalSchema: Array[(String, api.DataType)] = flattenedAgg.incrementalOutputSchema + @transient protected[spark] lazy val windowAggregator: RowAggregator = new RowAggregator(selectedSchema, aggregations.flatMap(_.unpack)) @@ -356,6 +366,29 @@ class GroupBy(val aggregations: Seq[api.Aggregation], toDf(outputRdd, Seq(Constants.TimeColumn -> LongType, tableUtils.partitionColumn -> StringType)) } + def flattenOutputArrayType( + hopsArrays: RDD[(KeyWithHash, HopsAggregator.OutputArrayType)]): RDD[(Array[Any], Array[Any])] = { + hopsArrays.flatMap { + case (keyWithHash: KeyWithHash, hopsArray: HopsAggregator.OutputArrayType) => + val hopsArrayHead: Array[HopIr] = hopsArray.headOption.get + hopsArrayHead.map { array: HopIr => + val timestamp = array.last.asInstanceOf[Long] + val withoutTimestamp = array.dropRight(1) + // Convert IR to Spark-native format for DataFrame storage + val sparkIr = SerializeIR.toSparkRow(withoutTimestamp, flattenedAgg) + ((keyWithHash.data :+ tableUtils.partitionSpec.at(timestamp) :+ timestamp), sparkIr) + } + } + } + + def convertHopsToDf(hops: RDD[(KeyWithHash, HopsAggregator.OutputArrayType)], + schema: Array[(String, ai.chronon.api.DataType)]): DataFrame = { + val hopsDf = flattenOutputArrayType(hops) + toDf(hopsDf, + Seq(tableUtils.partitionColumn -> StringType, Constants.TimeColumn -> LongType), + Some(SparkConversions.fromChrononSchema(schema))) + } + // convert raw data into IRs, collected by hopSizes // TODO cache this into a table: interface below // Class HopsCacher(keySchema, irSchema, resolution) extends RddCacher[(KeyWithHash, HopsOutput)] @@ -379,9 +412,10 @@ class GroupBy(val aggregations: Seq[api.Aggregation], } protected[spark] def toDf(aggregateRdd: RDD[(Array[Any], Array[Any])], - additionalFields: Seq[(String, DataType)]): DataFrame = { + additionalFields: Seq[(String, DataType)], + schema: Option[StructType] = None): DataFrame = { val finalKeySchema = StructType(keySchema ++ additionalFields.map { case (name, typ) => StructField(name, typ) }) - KvRdd(aggregateRdd, finalKeySchema, postAggSchema).toFlatDf + KvRdd(aggregateRdd, finalKeySchema, schema.getOrElse(postAggSchema)).toFlatDf } private def normalizeOrFinalize(ir: Array[Any]): Array[Any] = @@ -391,6 +425,12 @@ class GroupBy(val aggregations: Seq[api.Aggregation], windowAggregator.normalize(ir) } + def computeIncrementalDf(incrementalOutputTable: String, range: PartitionRange, tableProps: Map[String, String]) = { + + val hops = hopsAggregate(range.toTimePoints.min, DailyResolution) + val hopsDf: DataFrame = convertHopsToDf(hops, incrementalSchema) + hopsDf.save(incrementalOutputTable, tableProps) + } } // TODO: truncate queryRange for caching @@ -462,9 +502,14 @@ object GroupBy { bloomMapOpt: Option[util.Map[String, BloomFilter]] = None, skewFilter: Option[String] = None, finalize: Boolean = true, - showDf: Boolean = false): GroupBy = { + showDf: Boolean = false, + incrementalMode: Boolean = false): GroupBy = { logger.info(s"\n----[Processing GroupBy: ${groupByConfOld.metaData.name}]----") val groupByConf = replaceJoinSource(groupByConfOld, queryRange, tableUtils, computeDependency, showDf) + val sourceQueryWindow: Option[Window] = + if (incrementalMode) Some(new Window(queryRange.daysBetween, TimeUnit.DAYS)) else groupByConf.maxWindow + val backfillQueryRange: PartitionRange = + if (incrementalMode) PartitionRange(queryRange.end, queryRange.end)(tableUtils) else queryRange val inputDf = groupByConf.sources.toScala .map { source => val partitionColumn = tableUtils.getPartitionColumn(source.query) @@ -473,9 +518,9 @@ object GroupBy { groupByConf, source, groupByConf.getKeyColumns.toScala, - queryRange, + backfillQueryRange, tableUtils, - groupByConf.maxWindow, + sourceQueryWindow, groupByConf.inferredAccuracy, partitionColumn = partitionColumn ), @@ -552,15 +597,23 @@ object GroupBy { logger.info(s"printing mutation data for groupBy: ${groupByConf.metaData.name}") df.prettyPrint() } - df } + //if incrementalMode is enabled, we do not compute finalize values + //IR values are stored in the table + val finalizeValue = if (incrementalMode) { + false + } else { + finalize + } + new GroupBy(Option(groupByConf.getAggregations).map(_.toScala).orNull, keyColumns, nullFiltered, mutationDfFn, - finalize = finalize) + finalize = finalizeValue, + incrementalMode = incrementalMode) } def getIntersectedRange(source: api.Source, @@ -680,12 +733,121 @@ object GroupBy { query } + /** + * Computes and saves the output of hopsAggregation. + * HopsAggregate computes event level data to daily aggregates and saves the output in IR format + * + * @param groupByConf + * @param range + * @param tableUtils + */ + def computeIncrementalDf( + groupByConf: api.GroupBy, + range: PartitionRange, + tableUtils: TableUtils, + incrementalOutputTable: String + ): (PartitionRange, Seq[api.AggregationPart]) = { + + val tableProps: Map[String, String] = Option(groupByConf.metaData.tableProperties) + .map(_.toScala) + .orNull + + val incrementalQueryableRange = PartitionRange( + tableUtils.partitionSpec.minus(range.start, groupByConf.maxWindow.get), + range.end + )(tableUtils) + + logger.info(s"Writing incremental df to $incrementalOutputTable") + + val partitionRangeHoles: Option[Seq[PartitionRange]] = tableUtils.unfilledRanges( + incrementalOutputTable, + incrementalQueryableRange + ) + + val incrementalGroupByAggParts = partitionRangeHoles + .map { holes => + val incrementalAggregationParts = holes.map { hole => + logger.info(s"Filling hole in incremental table: $hole") + val incrementalGroupByBackfill = + from(groupByConf, hole, tableUtils, computeDependency = true, incrementalMode = true) + incrementalGroupByBackfill.computeIncrementalDf(incrementalOutputTable, hole, tableProps) + incrementalGroupByBackfill.flattenedAgg.aggregationParts + } + + incrementalAggregationParts.headOption.getOrElse(Seq.empty) + } + .getOrElse(Seq.empty) + + (incrementalQueryableRange, incrementalGroupByAggParts) + } + + def fromIncrementalDf( + groupByConf: api.GroupBy, + range: PartitionRange, + tableUtils: TableUtils + ): GroupBy = { + + val incrementalOutputTable = groupByConf.metaData.incrementalOutputTable + + val (incrementalQueryableRange, aggregationParts) = + computeIncrementalDf(groupByConf, range, tableUtils, incrementalOutputTable) + + val (_, incrementalDf: DataFrame) = incrementalQueryableRange.scanQueryStringAndDf(null, incrementalOutputTable) + + // Create RowAggregator for deserializing and merging IRs + val selectedSchema = SparkConversions.toChrononSchema(incrementalDf.schema) + val flattenedAggregations = groupByConf.getAggregations.toScala.flatMap(_.unWindowed) + val flattenedAgg = new RowAggregator(selectedSchema, flattenedAggregations) + + // Convert Spark DataFrame to RDD, deserialize IRs, and merge by key + val keyColumns = groupByConf.getKeyColumns.toScala + val keySchema = StructType(keyColumns.map(incrementalDf.schema.apply).toArray) + + val irRdd = incrementalDf.rdd.map { sparkRow => + // Extract keys + val keys = keyColumns.map(sparkRow.getAs[Any]).toArray + + // Deserialize IR columns from Spark Row + val ir = SerializeIR.fromSparkRow(sparkRow, flattenedAgg) + + (keys.toSeq, ir) + } + + // Merge IRs by key + val mergedRdd = irRdd.reduceByKey { (ir1, ir2) => + flattenedAgg.merge(ir1, ir2) + } + + // Finalize IRs to get final feature values + val finalRdd = mergedRdd.map { + case (keys, ir) => + (keys.toArray, flattenedAgg.finalize(ir)) + } + + // Convert back to DataFrame + val outputChrononSchema = flattenedAgg.outputSchema + val outputSparkSchema = SparkConversions.fromChrononSchema(outputChrononSchema) + implicit val session: SparkSession = incrementalDf.sparkSession + val finalDf = KvRdd(finalRdd, keySchema, outputSparkSchema).toFlatDf + + new GroupBy( + groupByConf.getAggregations.toScala, + keyColumns, + finalDf, + () => null, + finalize = true, + incrementalMode = false + ) + + } + def computeBackfill(groupByConf: api.GroupBy, endPartition: String, tableUtils: TableUtils, stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None, - skipFirstHole: Boolean = true): Unit = { + skipFirstHole: Boolean = true, + incrementalMode: Boolean = false): Unit = { assert( groupByConf.backfillStartDate != null, s"GroupBy:${groupByConf.metaData.name} has null backfillStartDate. This needs to be set for offline backfilling.") @@ -734,7 +896,11 @@ object GroupBy { stepRanges.zipWithIndex.foreach { case (range, index) => logger.info(s"Computing group by for range: $range [${index + 1}/${stepRanges.size}]") - val groupByBackfill = from(groupByConf, range, tableUtils, computeDependency = true) + val groupByBackfill: GroupBy = if (incrementalMode) { + fromIncrementalDf(groupByConf, range, tableUtils) + } else { + from(groupByConf, range, tableUtils, computeDependency = true) + } val outputDf = groupByConf.dataModel match { // group by backfills have to be snapshot only case Entities => groupByBackfill.snapshotEntities diff --git a/spark/src/main/scala/ai/chronon/spark/SerializeIR.scala b/spark/src/main/scala/ai/chronon/spark/SerializeIR.scala new file mode 100644 index 0000000000..e887529d58 --- /dev/null +++ b/spark/src/main/scala/ai/chronon/spark/SerializeIR.scala @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2023 The Chronon Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.chronon.spark + +import ai.chronon.aggregator.row.RowAggregator +import ai.chronon.online.SparkConversions +import org.apache.spark.sql.Row + +/** + * Utilities for serializing/deserializing Chronon IR values to/from Spark DataFrames. + * + * Key concepts: + * - IR tables have MULTIPLE COLUMNS, each with a different IR type + * - Example: [count:Long, sum:Double, avg:Struct(sum, count)] + * - Each column is converted independently based on its IR type + * + * This bridges between: + * - Java IR types (internal aggregator state: HashSet, CpcSketch, etc.) + * - Normalized IR types (serializable format: Array, ArrayList, bytes) + * - Spark native types (DataFrame columns: primitives, Row, WrappedArray, Map) + * + * The conversion pipeline: + * Writing: Java IR → normalize() → toSparkRow() → Spark Row → DataFrame + * Reading: DataFrame → Spark Row → fromSparkRow() → denormalize() → Java IR + */ +object SerializeIR { + + /** + * Converts a Chronon IR array to Spark-native format for DataFrame writing. + * + * Processes each column independently based on its IR type from + * RowAggregator.incrementalOutputSchema. + * + * @param ir The IR array from RowAggregator (Java types) + * @param rowAgg The RowAggregator with column schemas + * @return Array where each element is in Spark-native format + */ + def toSparkRow(ir: Array[Any], rowAgg: RowAggregator): Array[Any] = { + // Step 1: Normalize (Java types → serializable types) + val normalized = rowAgg.normalize(ir) + + // Step 2: Convert each column to Spark-native type + val sparkColumns = new Array[Any](normalized.length) + rowAgg.incrementalOutputSchema.zipWithIndex.foreach { + case ((_, irType), idx) => + sparkColumns(idx) = SparkConversions.toSparkRow(normalized(idx), irType) + } + sparkColumns + } + + /** + * Converts Spark DataFrame Row to Chronon IR format for aggregation. + * + * Reads each IR column from the Spark Row by name and converts based on IR type. + * Uses RowAggregator.incrementalOutputSchema to get both column names and types. + * + * @param sparkRow The Spark Row from DataFrame.read() + * @param rowAgg The RowAggregator with IR schemas + * @return Denormalized IR array ready for merge() (Java types) + */ + def fromSparkRow(sparkRow: Row, rowAgg: RowAggregator): Array[Any] = { + val normalized = new Array[Any](rowAgg.incrementalOutputSchema.length) + + // Step 1: Extract each IR column from Spark Row by name + rowAgg.incrementalOutputSchema.zipWithIndex.foreach { + case ((colName, irType), idx) => + // Get column from Spark Row by NAME + val sparkValue = sparkRow.getAs[Any](colName) + // Convert using IR type + normalized(idx) = SparkConversions.fromSparkValue(sparkValue, irType) + } + + // Step 2: Denormalize (serializable types → Java types) + rowAgg.denormalize(normalized) + } + + /** + * Alternative: Extract IR columns by position instead of name. + * Faster but requires column order to match exactly. + * + * @param sparkRow The Spark Row from DataFrame + * @param rowAgg The RowAggregator with IR schemas + * @param startIndex The starting index of IR columns in the Row + * @return Denormalized IR array (Java types) + */ + def fromSparkRowByPosition( + sparkRow: Row, + rowAgg: RowAggregator, + startIndex: Int + ): Array[Any] = { + val normalized = new Array[Any](rowAgg.incrementalOutputSchema.length) + + // Step 1: Extract each IR column by position + rowAgg.incrementalOutputSchema.zipWithIndex.foreach { + case ((_, irType), idx) => + val sparkValue = sparkRow.get(startIndex + idx) + normalized(idx) = SparkConversions.fromSparkValue(sparkValue, irType) + } + + // Step 2: Denormalize + rowAgg.denormalize(normalized) + } +} diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index a83943407d..57e6c80435 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -873,6 +873,7 @@ case class TableUtils(sparkSession: SparkSession) { outputPartitionRange } val outputExisting = partitions(outputTable) + logger.info(s"outputExisting : ${outputExisting}") // To avoid recomputing partitions removed by retention mechanisms we will not fill holes in the very beginning of the range // If a user fills a new partition in the newer end of the range, then we will never fill any partitions before that range. // We instead log a message saying why we won't fill the earliest hole. @@ -881,13 +882,17 @@ case class TableUtils(sparkSession: SparkSession) { } else { validPartitionRange.start } + + logger.info(s"Cutoff partition for skipping holes is set to $cutoffPartition") val fillablePartitions = if (skipFirstHole) { validPartitionRange.partitions.toSet.filter(_ >= cutoffPartition) } else { validPartitionRange.partitions.toSet } + val outputMissing = fillablePartitions -- outputExisting + val allInputExisting = inputTables .map { tables => tables diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala index 0324c199f9..8739f1bbb3 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala @@ -36,6 +36,7 @@ import ai.chronon.api.{ import ai.chronon.online.{RowWrapper, SparkConversions} import ai.chronon.spark.Extensions._ import ai.chronon.spark._ +import ai.chronon.spark.test.TestUtils.makeDf import com.google.gson.Gson import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.{StructField, StructType, LongType => SparkLongType, StringType => SparkStringType} @@ -115,6 +116,7 @@ class GroupByTest { val groupBy = new GroupBy(aggregations, Seq("user"), df) val actualDf = groupBy.snapshotEvents(PartitionRange(outputDates.min, outputDates.max)) + val outputDatesRdd: RDD[Row] = spark.sparkContext.parallelize(outputDates.map(Row(_))) val outputDatesDf = spark.createDataFrame(outputDatesRdd, StructType(Seq(StructField("ds", SparkStringType)))) val datesViewName = "test_group_by_snapshot_events_output_range" @@ -863,10 +865,10 @@ class GroupByTest { val testDatabase = s"staging_query_view_test_${Random.alphanumeric.take(6).mkString}" tableUtils.createDatabase(testDatabase) - // Create source data table with partitions + // Create source data table with partitions val sourceSchema = List( Column("user", StringType, 20), - Column("item", StringType, 50), + Column("item", StringType, 50), Column("time_spent_ms", LongType, 5000), Column("price", DoubleType, 100) ) @@ -889,7 +891,7 @@ class GroupByTest { stagingQueryJob.createStagingQueryView() val viewTable = s"$testDatabase.test_staging_view" - + // Now create a GroupBy that uses the staging query view as its source val source = Builders.Source.events( table = viewTable, @@ -898,12 +900,12 @@ class GroupByTest { val aggregations = Seq( Builders.Aggregation(operation = Operation.COUNT, inputColumn = "time_spent_ms"), - Builders.Aggregation(operation = Operation.AVERAGE, - inputColumn = "price", - windows = Seq(new Window(7, TimeUnit.DAYS))), + Builders.Aggregation(operation = Operation.AVERAGE, + inputColumn = "price", + windows = Seq(new Window(7, TimeUnit.DAYS))), Builders.Aggregation(operation = Operation.MAX, - inputColumn = "time_spent_ms", - windows = Seq(new Window(30, TimeUnit.DAYS))) + inputColumn = "time_spent_ms", + windows = Seq(new Window(30, TimeUnit.DAYS))) ) val groupByConf = Builders.GroupBy( @@ -957,4 +959,118 @@ class GroupByTest { """) assertTrue("Should be able to filter GroupBy results", filteredResult.count() >= 0) } + + @Test + def testIncrementalMode(): Unit = { + lazy val spark: SparkSession = + SparkSessionBuilder.build("GroupByTestIncremental" + "_" + Random.alphanumeric.take(6).mkString, local = true) + implicit val tableUtils = TableUtils(spark) + val namespace = "incremental" + tableUtils.createDatabase(namespace) + val schema = List( + Column("user", StringType, 10), // ts = last 10 days + Column("session_length", IntType, 2), + Column("rating", DoubleType, 2000) + ) + + val df = DataFrameGen.events(spark, schema, count = 100000, partitions = 100) + + println(s"Input DataFrame: ${df.count()}") + + val aggregations: Seq[Aggregation] = Seq( + //Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "session_length", Seq(new Window(10, TimeUnit.DAYS), WindowUtils.Unbounded)), + //Builders.Aggregation(Operation.UNIQUE_COUNT, "session_length", Seq(new Window(10, TimeUnit.DAYS), WindowUtils.Unbounded)), + Builders.Aggregation(Operation.SUM, "session_length", Seq(new Window(10, TimeUnit.DAYS))) + ) + + val tableProps: Map[String, String] = Map( + "source" -> "chronon" + ) + + val groupBy = new GroupBy(aggregations, Seq("user"), df) + groupBy.computeIncrementalDf("incremental.testIncrementalOutput", + PartitionRange("2025-05-01", "2025-06-01"), + tableProps) + + val actualIncrementalDf = spark.sql(s"select * from incremental.testIncrementalOutput where ds='2025-05-11'") + df.createOrReplaceTempView("test_incremental_input") + //spark.sql(s"select * from test_incremental_input where user='user7' and ds='2025-05-11'").show(numRows=100) + + spark.sql(s"select * from incremental.testIncrementalOutput where ds='2025-05-11'").show() + + val query = + s""" + |select user, ds, UNIX_TIMESTAMP(ds, 'yyyy-MM-dd')*1000 as ts, sum(session_length) as session_length_sum + |from test_incremental_input + |where ds='2025-05-11' + |group by user, ds, UNIX_TIMESTAMP(ds, 'yyyy-MM-dd')*1000 + |""".stripMargin + + val expectedDf = spark.sql(query) + + val diff = Comparison.sideBySide(actualIncrementalDf, expectedDf, List("user", tableUtils.partitionColumn)) + if (diff.count() > 0) { + diff.show() + println("diff result rows") + } + assertEquals(0, diff.count()) + } + + @Test + def testSnapshotIncrementalEvents(): Unit = { + lazy val spark: SparkSession = + SparkSessionBuilder.build("GroupByTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) + implicit val tableUtils = TableUtils(spark) + val schema = List( + Column("user", StringType, 10), // ts = last 10 days + Column("session_length", IntType, 2), + Column("rating", DoubleType, 2000) + ) + + val outputDates = CStream.genPartitions(10, tableUtils.partitionSpec) + + val df = DataFrameGen.events(spark, schema, count = 100000, partitions = 100) + df.drop("ts") // snapshots don't need ts. + val viewName = "test_group_by_snapshot_events" + df.createOrReplaceTempView(viewName) + val aggregations: Seq[Aggregation] = Seq( + Builders.Aggregation(Operation.SUM, "session_length", Seq(new Window(10, TimeUnit.DAYS))), + Builders.Aggregation(Operation.SUM, "rating", Seq(new Window(10, TimeUnit.DAYS))) + ) + + val groupBy = new GroupBy(aggregations, Seq("user"), df) + val actualDf = groupBy.snapshotEvents(PartitionRange(outputDates.min, outputDates.max)) + + val groupByIncremental = new GroupBy(aggregations, Seq("user"), df, incrementalMode = true) + val actualDfIncremental = groupByIncremental.snapshotEvents(PartitionRange(outputDates.min, outputDates.max)) + + val outputDatesRdd: RDD[Row] = spark.sparkContext.parallelize(outputDates.map(Row(_))) + val outputDatesDf = spark.createDataFrame(outputDatesRdd, StructType(Seq(StructField("ds", SparkStringType)))) + val datesViewName = "test_group_by_snapshot_events_output_range" + outputDatesDf.createOrReplaceTempView(datesViewName) + val expectedDf = df.sqlContext.sql(s""" + |select user, + | $datesViewName.ds, + | SUM(IF(ts >= (unix_timestamp($datesViewName.ds, 'yyyy-MM-dd') - 86400*(10-1)) * 1000, session_length, null)) AS session_length_sum_10d, + | SUM(IF(ts >= (unix_timestamp($datesViewName.ds, 'yyyy-MM-dd') - 86400*(10-1)) * 1000, rating, null)) AS rating_sum_10d + |FROM $viewName CROSS JOIN $datesViewName + |WHERE ts < unix_timestamp($datesViewName.ds, 'yyyy-MM-dd') * 1000 + ${tableUtils.partitionSpec.spanMillis} + |group by user, $datesViewName.ds + |""".stripMargin) + + val diff = Comparison.sideBySide(actualDf, expectedDf, List("user", tableUtils.partitionColumn)) + if (diff.count() > 0) { + diff.show() + println("diff result rows") + } + assertEquals(0, diff.count()) + + val diffIncremental = + Comparison.sideBySide(actualDfIncremental, expectedDf, List("user", tableUtils.partitionColumn)) + if (diffIncremental.count() > 0) { + diffIncremental.show() + println("diff result rows incremental") + } + assertEquals(0, diffIncremental.count()) + } } diff --git a/spark/src/test/scala/ai/chronon/spark/test/StagingQueryTest.scala b/spark/src/test/scala/ai/chronon/spark/test/StagingQueryTest.scala index cbf84be04e..193aa575e9 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/StagingQueryTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/StagingQueryTest.scala @@ -323,28 +323,35 @@ class StagingQueryTest { val outputView = stagingQueryConfView.metaData.outputTable val isView = tableUtils.tableReadFormat(outputView) match { case View => true - case _ => false + case _ => false } - + assert(isView, s"Expected $outputView to be a view when createView=true") // Verify virtual partition metadata was written for the view - val virtualPartitionExists = try { - val metadataCount = tableUtils.sql(s"SELECT COUNT(*) as count FROM ${stagingQueryView.signalPartitionsTable} WHERE table_name = '$outputView'").collect()(0).getAs[Long]("count") - metadataCount > 0 - } catch { - case _: Exception => false - } + val virtualPartitionExists = + try { + val metadataCount = tableUtils + .sql( + s"SELECT COUNT(*) as count FROM ${stagingQueryView.signalPartitionsTable} WHERE table_name = '$outputView'") + .collect()(0) + .getAs[Long]("count") + metadataCount > 0 + } catch { + case _: Exception => false + } assert(virtualPartitionExists, s"Expected virtual partition metadata to exist for view $outputView") // Verify the structure of virtual partition metadata if (virtualPartitionExists) { - val metadataRows = tableUtils.sql(s"SELECT * FROM ${stagingQueryView.signalPartitionsTable} WHERE table_name = '$outputView'").collect() + val metadataRows = tableUtils + .sql(s"SELECT * FROM ${stagingQueryView.signalPartitionsTable} WHERE table_name = '$outputView'") + .collect() assert(metadataRows.length > 0, "Should have at least one partition metadata entry") - + val firstRow = metadataRows(0) val tableName = firstRow.getAs[String]("table_name") - + assertEquals(s"Virtual partition metadata should have correct table name", outputView, tableName) } @@ -362,19 +369,25 @@ class StagingQueryTest { val outputTable = stagingQueryConfTable.metaData.outputTable val isTable = tableUtils.tableReadFormat(outputTable) match { case View => false - case _ => true + case _ => true } - + assert(isTable, s"Expected $outputTable to be a table when createView=false") // Verify virtual partition metadata was NOT written for the table - val virtualPartitionExistsForTable = try { - val metadataCountForTable = tableUtils.sql(s"SELECT COUNT(*) as count FROM ${stagingQueryTable.signalPartitionsTable} WHERE table_name = '$outputTable'").collect()(0).getAs[Long]("count") - metadataCountForTable > 0 - } catch { - case _: Exception => false - } - assert(!virtualPartitionExistsForTable, s"Expected NO virtual partition metadata for table $outputTable when createView=false") + val virtualPartitionExistsForTable = + try { + val metadataCountForTable = tableUtils + .sql( + s"SELECT COUNT(*) as count FROM ${stagingQueryTable.signalPartitionsTable} WHERE table_name = '$outputTable'") + .collect()(0) + .getAs[Long]("count") + metadataCountForTable > 0 + } catch { + case _: Exception => false + } + assert(!virtualPartitionExistsForTable, + s"Expected NO virtual partition metadata for table $outputTable when createView=false") // Test Case 3: createView unset (should default to false and create table) val stagingQueryConfUnset = Builders.StagingQuery( @@ -389,9 +402,9 @@ class StagingQueryTest { val outputUnset = stagingQueryConfUnset.metaData.outputTable val isTableUnset = tableUtils.tableReadFormat(outputUnset) match { case View => false - case _ => true + case _ => true } - + assert(isTableUnset, s"Expected $outputUnset to be a table when createView is unset") } }