diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothAggregatorTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothAggregatorTest.scala index 60bb5fc2c..efe04221f 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothAggregatorTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothAggregatorTest.scala @@ -174,6 +174,30 @@ class SawtoothAggregatorTest extends TestCase { timer.publish("comparison") } + def testElementWiseAverageWithMerge(): Unit = { + val columns = Seq( + Column("ts", LongType, 1), + Column("embeddings", ListType(DoubleType), 1000, chunkSize = 10, nullRate = -1) + ) + val events = CStream.gen(columns, 100).rows + val schema = columns.map(_.schema) + + val queries = Array(System.currentTimeMillis()) + val aggregations = Seq( + Builders.Aggregation( + Operation.AVERAGE, + "embeddings", + Seq(new Window(1, TimeUnit.DAYS)), + elementWise = Some(true) + ) + ) + + // Check that the sawtooth aggregator passes through the element wise aggregation correctly + val sawtoothIrs = sawtoothAggregate(events, queries, aggregations, schema) + assertNotNull(sawtoothIrs) + assertEquals(1, sawtoothIrs.length) + } + } object SawtoothAggregatorTest { diff --git a/api/src/main/scala/ai/chronon/api/Builders.scala b/api/src/main/scala/ai/chronon/api/Builders.scala index 97c353f0a..6d1722144 100644 --- a/api/src/main/scala/ai/chronon/api/Builders.scala +++ b/api/src/main/scala/ai/chronon/api/Builders.scala @@ -90,7 +90,8 @@ object Builders { inputColumn: String, windows: Seq[Window] = null, argMap: Map[String, String] = null, - buckets: Seq[String] = null): Aggregation = { + buckets: Seq[String] = null, + elementWise: Option[Boolean] = None): Aggregation = { val result = new Aggregation() result.setOperation(operation) result.setInputColumn(inputColumn) @@ -100,6 +101,7 @@ object Builders { result.setWindows(windows.toJava) if (buckets != null) result.setBuckets(buckets.toJava) + elementWise.foreach(result.setElementWise) result } } diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index d795bfe5d..12c25ea39 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -259,7 +259,8 @@ object Extensions { _.toScala.toMap ) .orNull, - bucket + bucket, + Option(agg.elementWise).getOrElse(false) ) for (window <- windows) { perWindow += WindowMapping( @@ -271,7 +272,8 @@ object Extensions { _.toScala.toMap ) .orNull, - bucket), + bucket, + Option(agg.elementWise).getOrElse(false)), counter ) }