-
Notifications
You must be signed in to change notification settings - Fork 90
Switch average implementation to utilize running average to prevent overflow #1066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -160,6 +160,82 @@ class Average extends SimpleAggregator[Double, Array[Any], Double] { | |
| override def isDeletable: Boolean = true | ||
| } | ||
|
|
||
| class RunningAverage extends SimpleAggregator[Double, Array[Any], Double] { | ||
| override def outputType: DataType = DoubleType | ||
|
|
||
| override def irType: DataType = | ||
| StructType( | ||
| "RunningAvgIr", | ||
| Array(StructField("running_average", DoubleType), StructField("weight", DoubleType)) | ||
| ) | ||
|
|
||
| override def prepare(input: Double): Array[Any] = Array(input, 1.0) | ||
|
|
||
| /** | ||
| * When combining averages, if the counts sizes are too close we should use a different algorithm. This | ||
| * constant defines how close the ratio of the smaller to the total count can be: | ||
| */ | ||
| private[this] val STABILITY_CONSTANT = 0.1 | ||
|
|
||
| /** | ||
| * Given two streams of doubles (left, leftWeight) and (right, rightWeight) of form (mean, weighted count), calculates | ||
| * the mean of the combined stream. | ||
| * | ||
| * Uses a more stable online algorithm which should be suitable for large numbers of records similar to: | ||
| * http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm | ||
| */ | ||
| private def computeRunningAverage(ir: Array[Any], right: Double, rightWeight: Double): Array[Any] = { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is apparently already a getCombinedMean below in the moments stuff
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good to know |
||
| val left = ir(0).asInstanceOf[Double] | ||
| val leftWeight = ir(1).asInstanceOf[Double] | ||
| if (leftWeight < rightWeight) { | ||
| computeRunningAverage(Array(right, rightWeight), left, leftWeight) | ||
| } else { | ||
| val newCount = leftWeight + rightWeight | ||
| val newAverage = newCount match { | ||
| case 0.0 => 0.0 | ||
| case newCount if newCount == leftWeight => left | ||
| case newCount => | ||
| val scaling = rightWeight / newCount | ||
| if (scaling < STABILITY_CONSTANT) { | ||
| left + (right - left) * scaling | ||
| } else { | ||
| (leftWeight * left + rightWeight * right) / newCount | ||
| } | ||
|
Comment on lines
+198
to
+203
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried to do this in-place (replace logic in average directly). and it actually makes the tests fail - the average operation stops being commutative due to slight errors in the double multiply and double division. I also tried the we ended up merging the following change instead: zipline-ai/chronon#1292 |
||
| } | ||
|
|
||
| ir.update(0, newAverage) | ||
| ir.update(1, newCount) | ||
| ir | ||
| } | ||
| } | ||
|
|
||
| // mutating | ||
| override def update(ir: Array[Any], input: Double): Array[Any] = { | ||
| computeRunningAverage(ir, input, 1.0) | ||
| } | ||
|
|
||
| // mutating | ||
| override def merge(ir1: Array[Any], ir2: Array[Any]): Array[Any] = { | ||
| computeRunningAverage(ir1, ir2(0).asInstanceOf[Double], ir2(1).asInstanceOf[Double]) | ||
| } | ||
|
|
||
| override def finalize(ir: Array[Any]): Double = | ||
| ir(0).asInstanceOf[Double] | ||
|
|
||
| // mutating | ||
| override def delete(ir: Array[Any], input: Double): Array[Any] = { | ||
| computeRunningAverage(ir, input, -1.0) | ||
| } | ||
|
|
||
| override def clone(ir: Array[Any]): Array[Any] = { | ||
| val arr = new Array[Any](ir.length) | ||
| ir.copyToArray(arr) | ||
| arr | ||
| } | ||
|
|
||
| override def isDeletable: Boolean = true | ||
| } | ||
|
|
||
| // Welford algo for computing variance | ||
| // Traditional sum of squares based formula has serious numerical stability problems | ||
| class WelfordState(ir: Array[Any]) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of a new operator would it make sense to add an argument
prevent_overfloworrunning_averageto the Average operator that defaults to false?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah i like that! defaults to None that gets interpreted as false - to keep the semantic hashes as they were