diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala index ff80f02253..94ca8332e7 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala @@ -160,6 +160,82 @@ class Average extends SimpleAggregator[Double, Array[Any], Double] { override def isDeletable: Boolean = true } +class RunningAverage extends SimpleAggregator[Double, Array[Any], Double] { + override def outputType: DataType = DoubleType + + override def irType: DataType = + StructType( + "RunningAvgIr", + Array(StructField("running_average", DoubleType), StructField("weight", DoubleType)) + ) + + override def prepare(input: Double): Array[Any] = Array(input, 1.0) + + /** + * When combining averages, if the counts sizes are too close we should use a different algorithm. This + * constant defines how close the ratio of the smaller to the total count can be: + */ + private[this] val STABILITY_CONSTANT = 0.1 + + /** + * Given two streams of doubles (left, leftWeight) and (right, rightWeight) of form (mean, weighted count), calculates + * the mean of the combined stream. + * + * Uses a more stable online algorithm which should be suitable for large numbers of records similar to: + * http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + */ + private def computeRunningAverage(ir: Array[Any], right: Double, rightWeight: Double): Array[Any] = { + val left = ir(0).asInstanceOf[Double] + val leftWeight = ir(1).asInstanceOf[Double] + if (leftWeight < rightWeight) { + computeRunningAverage(Array(right, rightWeight), left, leftWeight) + } else { + val newCount = leftWeight + rightWeight + val newAverage = newCount match { + case 0.0 => 0.0 + case newCount if newCount == leftWeight => left + case newCount => + val scaling = rightWeight / newCount + if (scaling < STABILITY_CONSTANT) { + left + (right - left) * scaling + } else { + (leftWeight * left + rightWeight * right) / newCount + } + } + + ir.update(0, newAverage) + ir.update(1, newCount) + ir + } + } + + // mutating + override def update(ir: Array[Any], input: Double): Array[Any] = { + computeRunningAverage(ir, input, 1.0) + } + + // mutating + override def merge(ir1: Array[Any], ir2: Array[Any]): Array[Any] = { + computeRunningAverage(ir1, ir2(0).asInstanceOf[Double], ir2(1).asInstanceOf[Double]) + } + + override def finalize(ir: Array[Any]): Double = + ir(0).asInstanceOf[Double] + + // mutating + override def delete(ir: Array[Any], input: Double): Array[Any] = { + computeRunningAverage(ir, input, -1.0) + } + + override def clone(ir: Array[Any]): Array[Any] = { + val arr = new Array[Any](ir.length) + ir.copyToArray(arr) + arr + } + + override def isDeletable: Boolean = true +} + // Welford algo for computing variance // Traditional sum of squares based formula has serious numerical stability problems class WelfordState(ir: Array[Any]) { diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala b/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala index dc335537ee..43575122f1 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala @@ -355,6 +355,16 @@ object ColumnAggregator { case _ => mismatchException } + case Operation.RUNNING_AVERAGE => + inputType match { + case IntType => simple(new RunningAverage, toDouble[Int]) + case LongType => simple(new RunningAverage, toDouble[Long]) + case ShortType => simple(new RunningAverage, toDouble[Short]) + case DoubleType => simple(new RunningAverage) + case FloatType => simple(new RunningAverage, toDouble[Float]) + case _ => mismatchException + } + case Operation.VARIANCE => inputType match { case IntType => simple(new Variance, toDouble[Int]) diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/RowAggregatorTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/RowAggregatorTest.scala index ebc67839fb..7afa9a51ff 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/RowAggregatorTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/RowAggregatorTest.scala @@ -109,7 +109,12 @@ class RowAggregatorTest extends TestCase { Builders.AggregationPart(Operation.HISTOGRAM, "hist_input", argMap = Map("k" -> "2")) -> histogram, Builders.AggregationPart(Operation.AVERAGE, "hist_map") -> mapAvg, Builders.AggregationPart(Operation.SUM, "embeddings", elementWise = true) -> List(12.0, 14.0).toJava, - Builders.AggregationPart(Operation.AVERAGE, "embeddings", elementWise = true) -> List(6.0, 7.0).toJava + Builders.AggregationPart(Operation.AVERAGE, "embeddings", elementWise = true) -> List(6.0, 7.0).toJava, + Builders.AggregationPart(Operation.RUNNING_AVERAGE, "views") -> 19.0 / 3, + Builders.AggregationPart(Operation.RUNNING_AVERAGE, "session_lengths") -> 8.0, + Builders.AggregationPart(Operation.RUNNING_AVERAGE, "session_lengths", bucket = "title") -> sessionLengthAvgByTitle, + Builders.AggregationPart(Operation.RUNNING_AVERAGE, "hist_map") -> mapAvg, + Builders.AggregationPart(Operation.RUNNING_AVERAGE, "embeddings", elementWise = true) -> List(6.0, 7.0).toJava ) val (specs, expectedVals) = specsAndExpected.unzip @@ -142,7 +147,10 @@ class RowAggregatorTest extends TestCase { val finalized = rowAggregator.finalize(forDeletion) expectedVals.zip(finalized).zip(rowAggregator.outputSchema.map(_._1)).foreach { - case ((expected, actual), name) => assertEquals(expected, actual) + case ((expected: Double, actual: Double), _) => + assertEquals(expected, actual, 1e-9) + case ((expected, actual), name) => + assertEquals(expected, actual) } } } diff --git a/api/py/ai/chronon/group_by.py b/api/py/ai/chronon/group_by.py index de76fb6edc..63e2679248 100644 --- a/api/py/ai/chronon/group_by.py +++ b/api/py/ai/chronon/group_by.py @@ -69,6 +69,7 @@ class Operation: COUNT = ttypes.Operation.COUNT SUM = ttypes.Operation.SUM AVERAGE = ttypes.Operation.AVERAGE + RUNNING_AVERAGE = ttypes.Operation.RUNNING_AVERAGE VARIANCE = ttypes.Operation.VARIANCE SKEW = ttypes.Operation.SKEW KURTOSIS = ttypes.Operation.KURTOSIS diff --git a/api/thrift/api.thrift b/api/thrift/api.thrift index 1e48f624fc..fbd554ab60 100644 --- a/api/thrift/api.thrift +++ b/api/thrift/api.thrift @@ -182,6 +182,7 @@ enum Operation { HISTOGRAM = 17, // use this only if you know the set of inputs is bounded APPROX_HISTOGRAM_K = 18, BOUNDED_UNIQUE_COUNT = 19 + RUNNING_AVERAGE = 20 } // integers map to milliseconds in the timeunit