Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,82 @@ class Average extends SimpleAggregator[Double, Array[Any], Double] {
override def isDeletable: Boolean = true
}

class RunningAverage extends SimpleAggregator[Double, Array[Any], Double] {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of a new operator would it make sense to add an argument prevent_overflow or running_average to the Average operator that defaults to false?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i like that! defaults to None that gets interpreted as false - to keep the semantic hashes as they were

override def outputType: DataType = DoubleType

override def irType: DataType =
StructType(
"RunningAvgIr",
Array(StructField("running_average", DoubleType), StructField("weight", DoubleType))
)

override def prepare(input: Double): Array[Any] = Array(input, 1.0)

/**
* When combining averages, if the counts sizes are too close we should use a different algorithm. This
* constant defines how close the ratio of the smaller to the total count can be:
*/
private[this] val STABILITY_CONSTANT = 0.1

/**
* Given two streams of doubles (left, leftWeight) and (right, rightWeight) of form (mean, weighted count), calculates
* the mean of the combined stream.
*
* Uses a more stable online algorithm which should be suitable for large numbers of records similar to:
* http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
*/
private def computeRunningAverage(ir: Array[Any], right: Double, rightWeight: Double): Array[Any] = {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is apparently already a getCombinedMean below in the moments stuff

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know

val left = ir(0).asInstanceOf[Double]
val leftWeight = ir(1).asInstanceOf[Double]
if (leftWeight < rightWeight) {
computeRunningAverage(Array(right, rightWeight), left, leftWeight)
} else {
val newCount = leftWeight + rightWeight
val newAverage = newCount match {
case 0.0 => 0.0
case newCount if newCount == leftWeight => left
case newCount =>
val scaling = rightWeight / newCount
if (scaling < STABILITY_CONSTANT) {
left + (right - left) * scaling
} else {
(leftWeight * left + rightWeight * right) / newCount
}
Comment on lines +198 to +203
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to do this in-place (replace logic in average directly). and it actually makes the tests fail - the average operation stops being commutative due to slight errors in the double multiply and double division. I also tried the (lw*la + rw*ra) / (lw + rw) - without luck.

we ended up merging the following change instead: zipline-ai/chronon#1292

}

ir.update(0, newAverage)
ir.update(1, newCount)
ir
}
}

// mutating
override def update(ir: Array[Any], input: Double): Array[Any] = {
computeRunningAverage(ir, input, 1.0)
}

// mutating
override def merge(ir1: Array[Any], ir2: Array[Any]): Array[Any] = {
computeRunningAverage(ir1, ir2(0).asInstanceOf[Double], ir2(1).asInstanceOf[Double])
}

override def finalize(ir: Array[Any]): Double =
ir(0).asInstanceOf[Double]

// mutating
override def delete(ir: Array[Any], input: Double): Array[Any] = {
computeRunningAverage(ir, input, -1.0)
}

override def clone(ir: Array[Any]): Array[Any] = {
val arr = new Array[Any](ir.length)
ir.copyToArray(arr)
arr
}

override def isDeletable: Boolean = true
}

// Welford algo for computing variance
// Traditional sum of squares based formula has serious numerical stability problems
class WelfordState(ir: Array[Any]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,16 @@ object ColumnAggregator {
case _ => mismatchException
}

case Operation.RUNNING_AVERAGE =>
inputType match {
case IntType => simple(new RunningAverage, toDouble[Int])
case LongType => simple(new RunningAverage, toDouble[Long])
case ShortType => simple(new RunningAverage, toDouble[Short])
case DoubleType => simple(new RunningAverage)
case FloatType => simple(new RunningAverage, toDouble[Float])
case _ => mismatchException
}

case Operation.VARIANCE =>
inputType match {
case IntType => simple(new Variance, toDouble[Int])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,12 @@ class RowAggregatorTest extends TestCase {
Builders.AggregationPart(Operation.HISTOGRAM, "hist_input", argMap = Map("k" -> "2")) -> histogram,
Builders.AggregationPart(Operation.AVERAGE, "hist_map") -> mapAvg,
Builders.AggregationPart(Operation.SUM, "embeddings", elementWise = true) -> List(12.0, 14.0).toJava,
Builders.AggregationPart(Operation.AVERAGE, "embeddings", elementWise = true) -> List(6.0, 7.0).toJava
Builders.AggregationPart(Operation.AVERAGE, "embeddings", elementWise = true) -> List(6.0, 7.0).toJava,
Builders.AggregationPart(Operation.RUNNING_AVERAGE, "views") -> 19.0 / 3,
Builders.AggregationPart(Operation.RUNNING_AVERAGE, "session_lengths") -> 8.0,
Builders.AggregationPart(Operation.RUNNING_AVERAGE, "session_lengths", bucket = "title") -> sessionLengthAvgByTitle,
Builders.AggregationPart(Operation.RUNNING_AVERAGE, "hist_map") -> mapAvg,
Builders.AggregationPart(Operation.RUNNING_AVERAGE, "embeddings", elementWise = true) -> List(6.0, 7.0).toJava
)

val (specs, expectedVals) = specsAndExpected.unzip
Expand Down Expand Up @@ -142,7 +147,10 @@ class RowAggregatorTest extends TestCase {
val finalized = rowAggregator.finalize(forDeletion)

expectedVals.zip(finalized).zip(rowAggregator.outputSchema.map(_._1)).foreach {
case ((expected, actual), name) => assertEquals(expected, actual)
case ((expected: Double, actual: Double), _) =>
assertEquals(expected, actual, 1e-9)
case ((expected, actual), name) =>
assertEquals(expected, actual)
}
}
}
1 change: 1 addition & 0 deletions api/py/ai/chronon/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class Operation:
COUNT = ttypes.Operation.COUNT
SUM = ttypes.Operation.SUM
AVERAGE = ttypes.Operation.AVERAGE
RUNNING_AVERAGE = ttypes.Operation.RUNNING_AVERAGE
VARIANCE = ttypes.Operation.VARIANCE
SKEW = ttypes.Operation.SKEW
KURTOSIS = ttypes.Operation.KURTOSIS
Expand Down
1 change: 1 addition & 0 deletions api/thrift/api.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ enum Operation {
HISTOGRAM = 17, // use this only if you know the set of inputs is bounded
APPROX_HISTOGRAM_K = 18,
BOUNDED_UNIQUE_COUNT = 19
RUNNING_AVERAGE = 20
}

// integers map to milliseconds in the timeunit
Expand Down