diff --git a/api/py/ai/chronon/group_by.py b/api/py/ai/chronon/group_by.py index de76fb6ed..17ae3f257 100644 --- a/api/py/ai/chronon/group_by.py +++ b/api/py/ai/chronon/group_by.py @@ -257,6 +257,14 @@ def validate_group_by(group_by: ttypes.GroupBy): Keys {unselected_keys}, are unselected in source """ + # For global aggregations (empty keys), aggregations must be specified + if not keys: + assert aggregations is not None and len(aggregations) > 0, ( + "Global aggregations (empty keys) require at least one aggregation to be specified. " + "To compute global aggregates, provide aggregations like " + "[Aggregation(input_column='col', operation=Operation.SUM)]." + ) + # Aggregations=None is only valid if group_by is Entities if aggregations is None: is_events = any([s.events for s in sources]) @@ -359,7 +367,7 @@ def get_output_col_names(aggregation): def GroupBy( sources: Union[List[_ANY_SOURCE_TYPE], _ANY_SOURCE_TYPE], - keys: List[str], + keys: Optional[List[str]], aggregations: Optional[List[ttypes.Aggregation]], online: Optional[bool] = DEFAULT_ONLINE, production: Optional[bool] = DEFAULT_PRODUCTION, @@ -408,8 +416,9 @@ def GroupBy( :type sources: List[ai.chronon.api.ttypes.Events|ai.chronon.api.ttypes.Entities] :param keys: List of primary keys that defines the data that needs to be collected in the result table. Similar to the - GroupBy in the SQL context. - :type keys: List[String] + GroupBy in the SQL context. For global aggregations (computing a single aggregate value across all data), + pass either None or an empty list. In this case, aggregations will be computed without grouping by any keys. + :type keys: Optional[List[String]] :param aggregations: List of aggregations that needs to be computed for the data following the grouping defined by the keys:: @@ -500,11 +509,12 @@ def GroupBy( """ assert sources, "Sources are not specified" + key_columns = keys or [] agg_inputs = [] if aggregations is not None: agg_inputs = [agg.inputColumn for agg in aggregations] - required_columns = keys + agg_inputs + required_columns = key_columns + agg_inputs def _sanitize_columns(source: ttypes.Source): query = ( @@ -577,7 +587,7 @@ def _normalize_source(source): group_by = ttypes.GroupBy( sources=sources, - keyColumns=keys, + keyColumns=key_columns, aggregations=aggregations, metaData=metadata, backfillStartDate=backfill_start_date, diff --git a/api/py/test/test_group_by.py b/api/py/test/test_group_by.py index 3464a86a8..95b89aefd 100644 --- a/api/py/test/test_group_by.py +++ b/api/py/test/test_group_by.py @@ -1,4 +1,3 @@ - # Copyright (C) 2023 The Chronon Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest, json +import json +import pytest from ai.chronon import group_by, query -from ai.chronon.group_by import GroupBy, Derivation, TimeUnit, Window, Aggregation, Accuracy from ai.chronon.api import ttypes -from ai.chronon.api.ttypes import EventSource, EntitySource, Operation +from ai.chronon.group_by import Accuracy, Derivation @pytest.fixture @@ -50,11 +49,7 @@ def event_source(table, topic=None): topic=topic, query=ttypes.Query( startPartition="2020-04-09", - selects={ - "subject": "subject_sql", - "event_id": "event_sql", - "cnt": 1 - }, + selects={"subject": "subject_sql", "event_id": "event_sql", "cnt": 1}, timeColumn="CAST(ts AS DOUBLE)", ), ) @@ -69,30 +64,21 @@ def entity_source(snapshotTable, mutationTable): mutationTable=mutationTable, query=ttypes.Query( startPartition="2020-04-09", - selects={ - "subject": "subject_sql", - "event_id": "event_sql", - "cnt": 1 - }, + selects={"subject": "subject_sql", "event_id": "event_sql", "cnt": 1}, timeColumn="CAST(ts AS DOUBLE)", mutationTimeColumn="__mutationTs", reversalColumn="is_reverse", ), ) + def test_pretty_window_str(days_unit, hours_unit): """ Test pretty window utils. """ - window = ttypes.Window( - length=7, - timeUnit=days_unit - ) + window = ttypes.Window(length=7, timeUnit=days_unit) assert group_by.window_to_str_pretty(window) == "7 days" - window = ttypes.Window( - length=2, - timeUnit=hours_unit - ) + window = ttypes.Window(length=2, timeUnit=hours_unit) assert group_by.window_to_str_pretty(window) == "2 hours" @@ -108,7 +94,7 @@ def test_select(): """ Test select builder """ - assert query.select('subject', event="event_expr") == {"subject": "subject", "event": "event_expr"} + assert query.select("subject", event="event_expr") == {"subject": "subject", "event": "event_expr"} def test_contains_windowed_aggregation(sum_op, min_op, days_unit): @@ -117,16 +103,12 @@ def test_contains_windowed_aggregation(sum_op, min_op, days_unit): """ assert not group_by.contains_windowed_aggregation([]) aggregations = [ - ttypes.Aggregation(inputColumn='event', operation=sum_op), - ttypes.Aggregation(inputColumn='event', operation=min_op), + ttypes.Aggregation(inputColumn="event", operation=sum_op), + ttypes.Aggregation(inputColumn="event", operation=min_op), ] assert not group_by.contains_windowed_aggregation(aggregations) aggregations.append( - ttypes.Aggregation( - inputColumn='event', - operation=sum_op, - windows=[ttypes.Window(length=7, timeUnit=days_unit)] - ) + ttypes.Aggregation(inputColumn="event", operation=sum_op, windows=[ttypes.Window(length=7, timeUnit=days_unit)]) ) assert group_by.contains_windowed_aggregation(aggregations) @@ -174,6 +156,7 @@ def test_validator_ok(): aggregations=None, ) + def test_validator_accuracy(): with pytest.raises(AssertionError, match="SNAPSHOT accuracy should not be specified for streaming sources"): gb = group_by.GroupBy( @@ -192,9 +175,9 @@ def test_validator_accuracy(): assert all([agg.inputColumn for agg in gb.aggregations if agg.operation != ttypes.Operation.COUNT]) group_by.validate_group_by(gb) + def test_generic_collector(): - aggregation = group_by.Aggregation( - input_column="test", operation=group_by.Operation.APPROX_PERCENTILE([0.4, 0.2])) + aggregation = group_by.Aggregation(input_column="test", operation=group_by.Operation.APPROX_PERCENTILE([0.4, 0.2])) assert aggregation.argMap == {"k": "128", "percentiles": "[0.4, 0.2]"} @@ -202,21 +185,11 @@ def test_select_sanitization(): gb = group_by.GroupBy( sources=[ ttypes.EventSource( # No selects are spcified - table="event_table1", - query=query.Query( - selects=None, - time_column="ts" - ) + table="event_table1", query=query.Query(selects=None, time_column="ts") ), ttypes.EntitySource( # Some selects are specified - snapshotTable="entity_table1", - query=query.Query( - selects={ - "key1": "key1_sql", - "event_id": "event_sql" - } - ) - ) + snapshotTable="entity_table1", query=query.Query(selects={"key1": "key1_sql", "event_id": "event_sql"}) + ), ], keys=["key1", "key2"], aggregations=group_by.Aggregations( @@ -239,19 +212,20 @@ def test_snapshot_with_hour_aggregation(): ttypes.EntitySource( # Some selects are specified snapshotTable="entity_table1", query=query.Query( - selects={ - "key1": "key1_sql", - "event_id": "event_sql" - }, + selects={"key1": "key1_sql", "event_id": "event_sql"}, time_column="ts", - ) + ), ) ], keys=["key1"], aggregations=group_by.Aggregations( - random=ttypes.Aggregation(inputColumn="event_id", operation=ttypes.Operation.SUM, windows=[ - ttypes.Window(1, ttypes.TimeUnit.HOURS), - ]), + random=ttypes.Aggregation( + inputColumn="event_id", + operation=ttypes.Operation.SUM, + windows=[ + ttypes.Window(1, ttypes.TimeUnit.HOURS), + ], + ), ), backfill_start_date="2021-01-04", ) @@ -259,56 +233,79 @@ def test_snapshot_with_hour_aggregation(): def test_additional_metadata(): gb = group_by.GroupBy( - sources=[ - ttypes.EventSource( - table="event_table1", - query=query.Query( - selects=None, - time_column="ts" - ) - ) - ], + sources=[ttypes.EventSource(table="event_table1", query=query.Query(selects=None, time_column="ts"))], keys=["key1", "key2"], aggregations=[group_by.Aggregation(input_column="event_id", operation=ttypes.Operation.SUM)], - tags={"to_deprecate": True} + tags={"to_deprecate": True}, ) - assert json.loads(gb.metaData.customJson)['groupby_tags']['to_deprecate'] - + assert json.loads(gb.metaData.customJson)["groupby_tags"]["to_deprecate"] def test_group_by_with_description(): gb = group_by.GroupBy( - sources=[ - ttypes.EventSource( - table="event_table1", - query=query.Query( - selects=None, - time_column="ts" - ) - ) - ], + sources=[ttypes.EventSource(table="event_table1", query=query.Query(selects=None, time_column="ts"))], keys=["key1", "key2"], aggregations=[group_by.Aggregation(input_column="event_id", operation=ttypes.Operation.SUM)], name="test.additional_metadata_gb", - description="GroupBy description" + description="GroupBy description", ) assert gb.metaData.description == "GroupBy description" def test_derivation(): derivation = Derivation(name="derivation_name", expression="derivation_expression") - expected_derivation = ttypes.Derivation( - name="derivation_name", - expression="derivation_expression") + expected_derivation = ttypes.Derivation(name="derivation_name", expression="derivation_expression") assert derivation == expected_derivation def test_derivation_with_description(): - derivation = Derivation(name="derivation_name", expression="derivation_expression", description="Derivation description") + derivation = Derivation( + name="derivation_name", expression="derivation_expression", description="Derivation description" + ) expected_derivation = ttypes.Derivation( name="derivation_name", expression="derivation_expression", - metaData=ttypes.MetaData(description="Derivation description")) + metaData=ttypes.MetaData(description="Derivation description"), + ) - assert derivation == expected_derivation \ No newline at end of file + assert derivation == expected_derivation + + +def test_global_aggregation(): + """ + Test global aggregations with empty keys + """ + # Test with keys=[] + gb = group_by.GroupBy( + sources=event_source("table"), + keys=[], + aggregations=group_by.Aggregations( + total_count=ttypes.Aggregation(inputColumn="cnt", operation=ttypes.Operation.COUNT), + total_sum=ttypes.Aggregation(inputColumn="cnt", operation=ttypes.Operation.SUM), + ), + ) + assert gb.keyColumns == [] + assert len(gb.aggregations) == 2 + group_by.validate_group_by(gb) + + # Test with keys=None + gb = group_by.GroupBy( + sources=event_source("table"), + keys=None, + aggregations=group_by.Aggregations( + total_count=ttypes.Aggregation(inputColumn="cnt", operation=ttypes.Operation.COUNT), + total_sum=ttypes.Aggregation(inputColumn="cnt", operation=ttypes.Operation.SUM), + ), + ) + assert gb.keyColumns == [] + assert len(gb.aggregations) == 2 + group_by.validate_group_by(gb) + + # Test that global aggregations require aggregations + with pytest.raises(AssertionError, match="Global aggregations"): + fail_gb = group_by.GroupBy( + sources=event_source("table"), + keys=[], + aggregations=None, + ) diff --git a/api/py/test/test_join.py b/api/py/test/test_join.py index edec40d0c..f34da9095 100644 --- a/api/py/test/test_join.py +++ b/api/py/test/test_join.py @@ -163,3 +163,52 @@ def test_derivation_with_description(): ) assert derivation == expected_derivation + + +def test_join_with_global_aggregation(): + """ + Test that joins work with global aggregations (GroupBys with empty keys). + Global aggregations should join on system keys (partition/timestamp) only. + """ + # Create a global aggregation GroupBy (no keys) + global_gb = GroupBy( + sources=[event_source("global_stats_table")], + keys=[], # Empty keys = global aggregation + aggregations=[ + api.Aggregation(inputColumn="event_id", operation=api.Operation.COUNT), + api.Aggregation(inputColumn="event_id", operation=api.Operation.SUM), + ], + name="global_stats", + ) + + # Create a normal GroupBy with keys for comparison + regular_gb = GroupBy( + sources=[event_source("user_stats_table")], + keys=["subject"], + aggregations=[ + api.Aggregation(inputColumn="event_id", operation=api.Operation.LAST), + ], + name="user_stats", + ) + + # Create a join with both global and regular aggregations + join = Join( + left=event_source("events_table"), + right_parts=[ + api.JoinPart(groupBy=global_gb, prefix="global"), + api.JoinPart(groupBy=regular_gb), # Uses default key mapping + ], + name="events_with_global_and_user_stats", + ) + + # Verify the join was created successfully + assert join is not None + assert len(join.joinParts) == 2 + + # Verify global aggregation has empty keyColumns + assert join.joinParts[0].groupBy.keyColumns == [] + assert len(join.joinParts[0].groupBy.aggregations) == 2 + + # Verify regular aggregation has keys + assert join.joinParts[1].groupBy.keyColumns == ["subject"] + assert len(join.joinParts[1].groupBy.aggregations) == 1 diff --git a/api/src/main/scala/ai/chronon/api/Constants.scala b/api/src/main/scala/ai/chronon/api/Constants.scala index d5c09c5f2..3a090dcac 100644 --- a/api/src/main/scala/ai/chronon/api/Constants.scala +++ b/api/src/main/scala/ai/chronon/api/Constants.scala @@ -67,4 +67,5 @@ object Constants { val chrononArchiveFlag: String = "chronon_archived" val ChainingRequestTs: String = "chaining_request_ts" val ChainingFetchTs: String = "chaining_fetch_ts" + val GlobalAggregationKVStoreKey: String = "__global_aggregation_dummy_key__" } diff --git a/online/src/main/scala/ai/chronon/online/Api.scala b/online/src/main/scala/ai/chronon/online/Api.scala index 258900465..a01832462 100644 --- a/online/src/main/scala/ai/chronon/online/Api.scala +++ b/online/src/main/scala/ai/chronon/online/Api.scala @@ -98,7 +98,12 @@ trait KVStore { def createKeyBytes(keys: Map[String, AnyRef], groupByServingInfo: GroupByServingInfoParsed, dataset: String): Array[Byte] = { - groupByServingInfo.keyCodec.encode(keys) + // For global aggregations (empty key schema), use plain UTF-8 bytes for dummy key + if (groupByServingInfo.keyChrononSchema.fields.isEmpty) { + Constants.GlobalAggregationKVStoreKey.getBytes(Constants.UTF8) + } else { + groupByServingInfo.keyCodec.encode(keys) + } } } diff --git a/spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala b/spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala index f5365d499..a73437836 100644 --- a/spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala +++ b/spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala @@ -25,6 +25,7 @@ import com.yahoo.sketches.ArrayOfItemsSerDe import com.yahoo.sketches.cpc.CpcSketch import com.yahoo.sketches.frequencies.ItemsSketch import org.apache.spark.serializer.KryoRegistrator +import java.util.regex.Pattern class CpcSketchKryoSerializer extends Serializer[CpcSketch] { override def write(kryo: Kryo, output: Output, sketch: CpcSketch): Unit = { @@ -67,6 +68,25 @@ class ItemsSketchKryoSerializer[T] extends Serializer[ItemsSketchIR[T]] { } } +/** + * Custom Kryo serializer for java.util.regex.Pattern. + * + * Required for serializing Join configurations (JoinOps.identifierRegex) when Spark broadcasts metadata. + * Kryo's default FieldSerializer fails due to Java 9+ module restrictions on java.util.regex reflection. + */ +class PatternKryoSerializer extends Serializer[Pattern] { + override def write(kryo: Kryo, output: Output, pattern: Pattern): Unit = { + output.writeString(pattern.pattern()) + output.writeInt(pattern.flags()) + } + + override def read(kryo: Kryo, input: Input, `type`: Class[Pattern]): Pattern = { + val patternString = input.readString() + val flags = input.readInt() + Pattern.compile(patternString, flags) + } +} + class ChrononKryoRegistrator extends KryoRegistrator { // registering classes tells kryo to not send schema on the wire // helps shuffles and spilling to disk @@ -156,6 +176,7 @@ class ChrononKryoRegistrator extends KryoRegistrator { kryo.register(classOf[CpcSketch], new CpcSketchKryoSerializer()) kryo.register(classOf[Array[ItemSketchSerializable]]) kryo.register(classOf[ItemsSketchIR[AnyRef]], new ItemsSketchKryoSerializer[AnyRef]) + kryo.register(classOf[Pattern], new PatternKryoSerializer) } def doRegister(name: String, kryo: Kryo): Unit = { diff --git a/spark/src/main/scala/ai/chronon/spark/Extensions.scala b/spark/src/main/scala/ai/chronon/spark/Extensions.scala index 76695b015..b3b7d0f6a 100644 --- a/spark/src/main/scala/ai/chronon/spark/Extensions.scala +++ b/spark/src/main/scala/ai/chronon/spark/Extensions.scala @@ -43,7 +43,7 @@ object Extensions { } // pad the first column so that the second column is aligned vertically - val padding = schemaTuples.map(_._1.length).max + val padding = if (schemaTuples.isEmpty) 0 else schemaTuples.map(_._1.length).max schemaTuples .map { case (typ, name) => s" ${typ.padTo(padding, ' ')} : $name" @@ -60,8 +60,8 @@ object Extensions { case class DfStats(count: Long, partitionRange: PartitionRange) // helper class to maintain datafram stats that are necessary for downstream operations case class DfWithStats(df: DataFrame, partitionCounts: Map[String, Long])(implicit val tableUtils: TableUtils) { - private val minPartition: String = partitionCounts.keys.min - private val maxPartition: String = partitionCounts.keys.max + private val minPartition: String = if (partitionCounts.isEmpty) null else partitionCounts.keys.min + private val maxPartition: String = if (partitionCounts.isEmpty) null else partitionCounts.keys.max val partitionRange: PartitionRange = PartitionRange(minPartition, maxPartition) val count: Long = partitionCounts.values.sum @@ -226,9 +226,13 @@ object Extensions { } def removeNulls(cols: Seq[String]): DataFrame = { - logger.info(s"filtering nulls from columns: [${cols.mkString(", ")}]") - // do not use != or <> operator with null, it doesn't return false ever! - df.filter(cols.map(_ + " IS NOT NULL").mkString(" AND ")) + if (cols.isEmpty) { + df // Return unfiltered DataFrame for empty columns + } else { + logger.info(s"filtering nulls from columns: [${cols.mkString(", ")}]") + // do not use != or <> operator with null, it doesn't return false ever! + df.filter(cols.map(_ + " IS NOT NULL").mkString(" AND ")) + } } def nullSafeJoin(right: DataFrame, keys: Seq[String], joinType: String): DataFrame = { diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index beac142ed..003ca7850 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -510,8 +510,13 @@ object GroupBy { val processedInputDf = bloomMapOpt.map { skewFilteredDf.filterBloom }.getOrElse { skewFilteredDf } // at-least one of the keys should be present in the row. - val nullFilterClause = groupByConf.keyColumns.toScala.map(key => s"($key IS NOT NULL)").mkString(" OR ") - val nullFiltered = processedInputDf.filter(nullFilterClause) + // For global aggregations (empty keys), skip null filtering + val nullFiltered = if (groupByConf.keyColumns.isEmpty) { + processedInputDf + } else { + val nullFilterClause = groupByConf.keyColumns.toScala.map(key => s"($key IS NOT NULL)").mkString(" OR ") + processedInputDf.filter(nullFilterClause) + } if (showDf) { logger.info(s"printing input date for groupBy: ${groupByConf.metaData.name}") nullFiltered.prettyPrint() diff --git a/spark/src/main/scala/ai/chronon/spark/KvRdd.scala b/spark/src/main/scala/ai/chronon/spark/KvRdd.scala index 81457bf39..ead7d2bfa 100644 --- a/spark/src/main/scala/ai/chronon/spark/KvRdd.scala +++ b/spark/src/main/scala/ai/chronon/spark/KvRdd.scala @@ -51,9 +51,18 @@ sealed trait BaseKvRdd { val baseFlatSchema: StructType = StructType(keySchema ++ valueSchema) def flatSchema: StructType = if (withTime) StructType(baseFlatSchema :+ timeField) else baseFlatSchema def flatZSchema: api.StructType = flatSchema.toChrononSchema("Flat") - lazy val keyToBytes: Any => Array[Byte] = AvroConversions.encodeBytes(keyZSchema, GenericRowHandler.func) + // For global aggregations, use plain UTF-8 bytes for a dummy key (no need for an avro schema) + lazy val keyToBytes: Any => Array[Byte] = if (keySchema.fields.isEmpty) { + _ => api.Constants.GlobalAggregationKVStoreKey.getBytes(api.Constants.UTF8) + } else { + AvroConversions.encodeBytes(keyZSchema, GenericRowHandler.func) + } + lazy val keyToJson: Any => String = if (keySchema.fields.isEmpty) { + _ => api.Constants.GlobalAggregationKVStoreKey + } else { + AvroConversions.encodeJson(keyZSchema, GenericRowHandler.func) + } lazy val valueToBytes: Any => Array[Byte] = AvroConversions.encodeBytes(valueZSchema, GenericRowHandler.func) - lazy val keyToJson: Any => String = AvroConversions.encodeJson(keyZSchema, GenericRowHandler.func) lazy val valueToJson: Any => String = AvroConversions.encodeJson(valueZSchema, GenericRowHandler.func) private val baseRowSchema = StructType( Seq( @@ -75,17 +84,25 @@ case class KvRdd(data: RDD[(Array[Any], Array[Any])], keySchema: StructType, val val withTime = false def toAvroDf(jsonPercent: Int = 1): DataFrame = { - val avroRdd: RDD[Row] = data.map { - case (keys: Array[Any], values: Array[Any]) => - // json encoding is very expensive (50% of entire job). - // We only do it for a specified fraction to retain debuggability. - val (keyJson, valueJson) = if (math.random < jsonPercent.toDouble / 100) { - (keyToJson(keys), valueToJson(values)) - } else { - (null, null) - } - val result: Array[Any] = Array(keyToBytes(keys), valueToBytes(values), keyJson, valueJson) - new GenericRow(result) + val avroRdd: RDD[Row] = data.mapPartitions { iterator => + // Create reusable objects ONCE per partition to reduce allocations + val jsonThreshold = jsonPercent.toDouble / 100 + var isFirstRow = true + + iterator.map { + case (keys: Array[Any], values: Array[Any]) => + // json encoding is very expensive (50% of entire job). + // We only do it for a specified fraction to retain debuggability. + // We also always encode the first row of each partition to enable better debugging for small datasets + val (keyJson, valueJson) = if (isFirstRow || math.random < jsonPercent.toDouble / 100) { + isFirstRow = false + (keyToJson(keys), valueToJson(values)) + } else { + (null, null) + } + val result: Array[Any] = Array(keyToBytes(keys), valueToBytes(values), keyJson, valueJson) + new GenericRow(result) + } } logger.info(s""" |key schema: @@ -118,15 +135,22 @@ case class TimedKvRdd(data: RDD[(Array[Any], Array[Any], Long)], // TODO make json percent configurable def toAvroDf: DataFrame = { - val avroRdd: RDD[Row] = data.map { - case (keys, values, ts) => - val (keyJson, valueJson) = if (math.random < 0.01) { - (keyToJson(keys), valueToJson(values)) - } else { - (null, null) - } - val result: Array[Any] = Array(keyToBytes(keys), valueToBytes(values), keyJson, valueJson, ts) - new GenericRow(result) + val avroRdd: RDD[Row] = data.mapPartitions { iterator => + // Create reusable objects ONCE per partition to reduce allocations + val jsonThreshold = 0.01 + var isFirstRow = true + + iterator.map { + case (keys, values, ts) => + val (keyJson, valueJson) = if (isFirstRow || math.random < jsonThreshold) { + isFirstRow = false + (keyToJson(keys), valueToJson(values)) + } else { + (null, null) + } + val result: Array[Any] = Array(keyToBytes(keys), valueToBytes(values), keyJson, valueJson, ts) + new GenericRow(result) + } } val schemasStr = Seq(keyZSchema, valueZSchema).map(AvroConversions.fromChrononSchema(_).toString(true)) diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala index 4adf2583a..92cf69f75 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala @@ -952,10 +952,77 @@ class GroupByTest { // Verify we can query the results successfully with partition pushdown val filteredResult = tableUtils.sql(s""" - SELECT user, time_spent_ms_count - FROM $groupByOutputTable + SELECT user, time_spent_ms_count + FROM $groupByOutputTable WHERE ds >= '${tableUtils.partitionSpec.minus(today, new Window(7, TimeUnit.DAYS))}' """) assertTrue("Should be able to filter GroupBy results", filteredResult.count() >= 0) } + + @Test + def testGlobalAggregationSnapshot(): Unit = { + lazy val spark: SparkSession = + SparkSessionBuilder.build("GroupByTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) + implicit val tableUtils = TableUtils(spark) + + val schema = List( + Column("event_type", StringType, 5), + Column("value", IntType, 10000) + ) + + val df = DataFrameGen.events(spark, schema, 10000, 10) + val viewName = "test_global_aggregation" + df.createOrReplaceTempView(viewName) + + // Global aggregation: no keys, compute aggregates across all data + val aggregations: Seq[Aggregation] = Seq( + Builders.Aggregation(Operation.COUNT, "value", Seq(WindowUtils.Unbounded)), + Builders.Aggregation(Operation.SUM, "value", Seq(WindowUtils.Unbounded)) + ) + + + val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) + val monthAgo = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS)) + + // Empty keys array = global aggregation + val groupBy = new GroupBy(aggregations, Seq(), df) + val actualDf = groupBy.snapshotEvents(PartitionRange(monthAgo, today)) + + // Expected: one row per partition with global aggregates + val expectedDf = df.sqlContext.sql(s""" + |WITH + | counts AS ( + | SELECT ds, + | COUNT(value) AS value_count, + | SUM(value) AS value_sum + | FROM $viewName + | GROUP BY ds) + | + |SELECT counts.ds, + | SUM(counts.value_count) OVER (ORDER BY counts.ds ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS value_count, + | SUM(counts.value_sum) OVER (ORDER BY counts.ds ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS value_sum + |FROM counts + |ORDER BY counts.ds + |""".stripMargin) + + val diff = Comparison.sideBySide(actualDf, expectedDf, List(tableUtils.partitionColumn)) + if (diff.count() > 0) { + println("Global aggregation test failed - showing differences:") + diff.show() + } + assertEquals(0, diff.count()) + } + + @Test + def testGlobalAggregationDummyKeyInjection(): Unit = { + lazy val spark: SparkSession = SparkSessionBuilder.build("GlobalAggDummyKeyTest", local = true) + val emptyKeySchema = StructType(Seq.empty) + val valueSchema = StructType(Seq(StructField("count", SparkLongType))) + val data = spark.sparkContext.parallelize(Seq((Array.empty[Any], Array[Any](100L)))) + val kvRdd = KvRdd(data, emptyKeySchema, valueSchema)(spark) + val df = kvRdd.toAvroDf() + + val keyBytes = df.select("key_bytes").head().getAs[Array[Byte]](0) + assertEquals(Constants.GlobalAggregationKVStoreKey, new String(keyBytes, Constants.UTF8)) + } } diff --git a/spark/src/test/scala/ai/chronon/spark/test/JoinBasicTest.scala b/spark/src/test/scala/ai/chronon/spark/test/JoinBasicTest.scala index 7636a26c9..c9a96d154 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/JoinBasicTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/JoinBasicTest.scala @@ -506,7 +506,7 @@ class JoinBasicTests { val end = tableUtils.partitionSpec.minus(today, new Window(15, TimeUnit.DAYS)) val joinConf = Builders.Join( left = Builders.Source.entities(Builders.Query(selects = Map("user" -> "user"), startPartition = start), - snapshotTable = usersTable), + snapshotTable = usersTable), joinParts = Seq(Builders.JoinPart(groupBy = namesGroupBy)), metaData = Builders.MetaData(name = "test.user_features", namespace = namespace, team = "chronon") ) @@ -550,4 +550,84 @@ class JoinBasicTests { assertEquals(diffCount, 0) } + @Test + def testGlobalAggregation(): Unit = { + val spark: SparkSession = + SparkSessionBuilder.build("JoinBasicTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) + val tableUtils = TableUtils(spark) + val namespace = "test_namespace_jointest" + "_" + Random.alphanumeric.take(6).mkString + tableUtils.createDatabase(namespace) + + // Create a table for global aggregation (no keys) + val statsSchema = List( + Column("metric_value", api.LongType, 1000) + ) + val statsTable = s"$namespace.stats" + DataFrameGen.entities(spark, statsSchema, 1000, partitions = 200).save(statsTable) + + val statsSource = Builders.Source.entities( + query = Builders.Query(selects = Builders.Selects("metric_value"), startPartition = yearAgo, endPartition = dayAndMonthBefore), + snapshotTable = statsTable + ) + + val globalGroupBy = Builders.GroupBy( + sources = Seq(statsSource), + keyColumns = Seq(), // Empty keys = global aggregation + aggregations = Seq(Builders.Aggregation(operation = Operation.SUM, inputColumn = "metric_value")), + metaData = Builders.MetaData(name = "unit_test.global_stats", team = "chronon") + ) + + // left side + val userSchema = List(Column("user", api.StringType, 100)) + val usersTable = s"$namespace.users" + DataFrameGen.entities(spark, userSchema, 1000, partitions = 200).dropDuplicates().save(usersTable) + + val start = tableUtils.partitionSpec.minus(today, new Window(60, TimeUnit.DAYS)) + val end = tableUtils.partitionSpec.minus(today, new Window(15, TimeUnit.DAYS)) + val joinConf = Builders.Join( + left = Builders.Source.entities(Builders.Query(selects = Map("user" -> "user"), startPartition = start), + snapshotTable = usersTable), + joinParts = Seq(Builders.JoinPart(groupBy = globalGroupBy)), + metaData = Builders.MetaData(name = "test.user_with_global_stats", namespace = namespace, team = "chronon") + ) + + val runner = new Join(joinConf, end, tableUtils) + val computed = runner.computeJoin(Some(7)) + logger.debug(s"join start = $start") + + val expected = tableUtils.sql(s""" + |WITH + | users AS (SELECT user, ds from $usersTable where ds >= '$start' and ds <= '$end'), + | global_agg AS ( + | SELECT SUM(metric_value) as unit_test_global_stats_metric_value_sum, + | ds + | FROM $statsTable + | WHERE ds >= '$yearAgo' and ds <= '$dayAndMonthBefore' + | GROUP BY ds) + | SELECT users.user, + | global_agg.unit_test_global_stats_metric_value_sum, + | users.ds + | FROM users left outer join global_agg + | ON users.ds = global_agg.ds + """.stripMargin) + if (logger.isDebugEnabled) { + logger.debug("showing join result") + computed.show() + logger.debug("showing query result") + expected.show() + logger.debug( + s"Left side count: ${spark.sql(s"SELECT user, ds from $usersTable where ds >= '$start' and ds <= '$end'").count()}") + logger.debug(s"Actual count: ${computed.count()}") + logger.debug(s"Expected count: ${expected.count()}") + } + val diff = Comparison.sideBySide(computed, expected, List("user", "ds")) + val diffCount = diff.count() + if (diffCount > 0) { + logger.warn(s"Diff count: $diffCount") + logger.warn(s"diff result rows") + diff.show() + } + assertEquals(diffCount, 0) + } + }