From e848f48f7505399dff641ab7e24edbb02e65a562 Mon Sep 17 00:00:00 2001 From: Pengyu Hou <3771747+pengyu-hou@users.noreply.github.com> Date: Wed, 1 Apr 2026 14:38:26 -0700 Subject: [PATCH] feat: add derivation warm-up for JoinSourceRunner to reduce cold-start latency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pre-initializes lazy components before streaming starts and uses first N real requests to warm up KV store connections, CatalystUtil.session, Janino codegen, and derivation UDF on executor JVMs — eliminating the 5-10 minute timeout spike on new deploys. - Add PoolMap.warmup() and PooledCatalystUtil.warmup() to pre-populate the CatalystUtil pool beyond the default initialSize=2 - Add driver-side warmupDriver(): forces CatalystUtil.session, TTLCaches (GroupByServingInfo, JoinCodec), deriveFunc + CatalystUtil pool - Add executor-side warm-up in enrichBaseJoin: runs first N real rows through fetchBaseJoin (60s timeout) then invokes deriveFunc with real base values to warm up UDF lazy state and JIT; results discarded and all rows re-processed normally - Fallback: if fetchBaseJoin times out, still force deriveFunc init so enrichModelTransforms avoids CatalystUtil cold-start timeout Config (spark.chronon.stream.chain.*): warmup.enabled=true, warmup.request_count=10, warmup.timeout_seconds=60, warmup.pool_size=4 --- .../ai/chronon/online/CatalystUtil.scala | 10 ++ .../online/test/CatalystUtilTest.scala | 19 +++- .../spark/streaming/JoinSourceRunner.scala | 98 ++++++++++++++++++- 3 files changed, 125 insertions(+), 2 deletions(-) diff --git a/online/src/main/scala/ai/chronon/online/CatalystUtil.scala b/online/src/main/scala/ai/chronon/online/CatalystUtil.scala index 861c6e70c..2be72c90a 100644 --- a/online/src/main/scala/ai/chronon/online/CatalystUtil.scala +++ b/online/src/main/scala/ai/chronon/online/CatalystUtil.scala @@ -100,6 +100,14 @@ class PoolMap[Key, Value](createFunc: Key => Value, maxSize: Int = 100, initialS pool.offer(value) } } + + def warmup(key: Key, targetSize: Int): Unit = { + val pool = getPool(key) + val toCreate = Math.max(0, targetSize - pool.size()) + (0 until toCreate).foreach { _ => + pool.offer(createFunc(key)) + } + } } class PooledCatalystUtil(expressions: collection.Seq[(String, String)], inputSchema: StructType) { @@ -125,6 +133,8 @@ class PooledCatalystUtil(expressions: collection.Seq[(String, String)], inputSch } def outputChrononSchema: Array[(String, DataType)] = poolMap.performWithValue(poolKey, cuPool) { _.outputChrononSchema } + + def warmup(targetSize: Int): Unit = poolMap.warmup(poolKey, targetSize) } // This class by itself it not thread safe because of the transformBuffer diff --git a/online/src/test/scala/ai/chronon/online/test/CatalystUtilTest.scala b/online/src/test/scala/ai/chronon/online/test/CatalystUtilTest.scala index 50e47b9be..2a9b7498a 100644 --- a/online/src/test/scala/ai/chronon/online/test/CatalystUtilTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/CatalystUtilTest.scala @@ -17,7 +17,8 @@ package ai.chronon.online.test import ai.chronon.api._ -import ai.chronon.online.{CatalystUtil, PooledCatalystUtil} +import ai.chronon.online.{CatalystUtil, PoolMap, PooledCatalystUtil} +import ai.chronon.online.CatalystUtil.PoolKey import junit.framework.TestCase import org.junit.Assert.{assertArrayEquals, assertEquals, assertTrue} import org.junit.Test @@ -655,6 +656,22 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { "value" -> Array(Array(1L, "data1"), Array(2L, "data2")) ) + @Test + def testPoolMapWarmupPrePopulatesPool(): Unit = { + val selects = Seq("int32_x" -> "int32_x") + val key = PoolKey(selects, CommonScalarsStruct) + val poolMap = new PoolMap[PoolKey, CatalystUtil](pi => new CatalystUtil(pi.expressions, pi.inputSchema)) + // pool starts with initialSize=2 after first getPool call + poolMap.getPool(key) + assertEquals(2, poolMap.map.get(key).size()) + // warmup to targetSize=5 adds 3 more + poolMap.warmup(key, 5) + assertEquals(5, poolMap.map.get(key).size()) + // warmup below current size is a no-op + poolMap.warmup(key, 3) + assertEquals(5, poolMap.map.get(key).size()) + } + def testPooledCatalystUtil(): Unit = { val selects = Seq( "ts" -> "ts", diff --git a/spark/src/main/scala/ai/chronon/spark/streaming/JoinSourceRunner.scala b/spark/src/main/scala/ai/chronon/spark/streaming/JoinSourceRunner.scala index 4345de655..c1b21e0b4 100644 --- a/spark/src/main/scala/ai/chronon/spark/streaming/JoinSourceRunner.scala +++ b/spark/src/main/scala/ai/chronon/spark/streaming/JoinSourceRunner.scala @@ -17,7 +17,7 @@ package ai.chronon.spark.streaming import ai.chronon.api -import ai.chronon.api.Extensions.{GroupByOps, JoinOps, MetadataOps, SourceOps} +import ai.chronon.api.Extensions.{DerivationOps, GroupByOps, JoinOps, MetadataOps, SourceOps} import ai.chronon.api._ import ai.chronon.online.Fetcher.{Request, ResponseWithContext} import ai.chronon.online.KVStore.PutRequest @@ -48,6 +48,8 @@ import scala.util.{Failure, Success} object LocalIOCache { private var fetcher: Fetcher = null private var kvStore: KVStore = null + @volatile var fetcherWarmedUp: Boolean = false + def getOrSetFetcher(builderFunc: () => Fetcher): Fetcher = { if (fetcher == null) { fetcher = builderFunc() @@ -121,6 +123,14 @@ class JoinSourceRunner(groupByConf: api.GroupBy, conf: Map[String, String] = Map // Micro batch repartition size - when set to 0, we won't do the repartition private val microBatchRepartition: Int = getProp("batch_repartition", "0").toInt + // Warm-up: pre-initialize lazy components and JIT warm-up using real requests before real processing + private val warmupEnabled: Boolean = getProp("warmup.enabled", "true").toBoolean + private val warmupPoolSize: Int = getProp("warmup.pool_size", "4").toInt + private val warmupRequestCount: Int = getProp("warmup.request_count", "10").toInt + // Longer timeout for warm-up: absorbs cold-start cost (KV connections, CatalystUtil.session, + // Janino codegen) which can exceed the normal 5s production timeout on first request. + private val warmupTimeoutSeconds: Int = getProp("warmup.timeout_seconds", "60").toInt + private case class PutRequestHelper(inputSchema: StructType) extends Serializable { @transient implicit lazy val logger = LoggerFactory.getLogger(getClass) private val keyIndices: Array[Int] = keyColumns.map(inputSchema.fieldIndex) @@ -377,6 +387,54 @@ class JoinSourceRunner(groupByConf: api.GroupBy, conf: Map[String, String] = Map } } + private def warmupDriver(schemas: Schemas, joinSource: JoinSource): Unit = { + if (!warmupEnabled) return + val startMs = System.currentTimeMillis() + logger.info("Starting driver-side warm-up for join derivations...") + try { + // 1. Force CatalystUtil.session (JVM singleton SparkSession, most expensive) + CatalystUtil.session + + // 2. Initialize Fetcher + KV store + val fetcher = getOrCreateFetcher() + LocalIOCache.getOrSetKvStore { () => apiImpl.genKvStore } + + // 3. Pre-populate TTLCache: GroupByServingInfo for all join parts + val joinRequestName = joinSource.join.metaData.getName.replaceFirst("\\.", "/") + joinSource.join.joinPartOps.foreach { part => + fetcher.getGroupByServingInfo(part.groupBy.metaData.getName) + } + + // 4. Pre-populate JoinCodec TTLCache + fetcher.getJoinCodecs(joinRequestName) + + // 5. Force JoinCodec.deriveFunc to trigger PooledCatalystUtil creation + Catalyst codegen + schemas.joinCodec.deriveFunc + + // 6. Pre-populate CatalystUtil pool beyond default initialSize=2 + if (warmupPoolSize > 2) { + val derivationsScala = schemas.joinCodec.conf.derivationsScala + if ( + derivationsScala != null && !derivationsScala.isEmpty && + !derivationsScala.areDerivationsRenameOnly + ) { + val pcu = DerivationUtils.buildCatalystUtil(derivationsScala, + schemas.joinCodec.keySchema, + schemas.joinCodec.baseValueSchema) + pcu.warmup(warmupPoolSize) + } + } + + val elapsed = System.currentTimeMillis() - startMs + logger.info(s"Driver-side warm-up completed in ${elapsed}ms") + context.distribution("warmup.driver.latency_ms", elapsed) + } catch { + case ex: Throwable => + logger.warn(s"Driver-side warm-up failed (non-fatal): ${ex.getMessage}", ex) + context.increment("warmup.driver.failure") + } + } + private def arrayToRow(values: Any, schema: StructType): Row = { val genericRow = SparkConversions .toSparkRowSparkType(values, schema) @@ -398,6 +456,41 @@ class JoinSourceRunner(groupByConf: api.GroupBy, conf: Map[String, String] = Map // Convert left rows to fetcher requests val rowsScala = rows.toScala.toArray + + // Executor-side JIT warm-up + if (warmupEnabled && !LocalIOCache.fetcherWarmedUp && rowsScala.nonEmpty) { + val warmupRows = rowsScala.take(warmupRequestCount) + val warmupRequests = warmupRows.map { row => + val keyMap = row.getValuesMap[AnyRef](schemas.leftSourceSchema.fieldNames) + val eventTs = row.getAs[Long](eventTimeColumn) + val ts = if (useEventTimeForQuery) Some(eventTs) else None + Request(joinRequestName, keyMap, atMillis = ts.map(_ + queryShiftMs)) + } + val warmupStartMs = System.currentTimeMillis() + try { + val warmupFuture = fetcher.fetchBaseJoin(warmupRequests, Option(joinSource.join)) + val warmupResponses = Await.result(warmupFuture, warmupTimeoutSeconds.seconds) + + val deriveFunc = schemas.joinCodec.deriveFunc + warmupResponses.foreach { response => + try { deriveFunc(response.request.keys, response.baseValues) } + catch { case _: Throwable => } + } + + logger.info( + s"Executor-side warm-up complete with ${warmupRows.length} requests in " + + s"${System.currentTimeMillis() - warmupStartMs}ms") + context.distribution("warmup.executor.latency_ms", System.currentTimeMillis() - warmupStartMs) + } catch { + case ex: Throwable => + logger.warn(s"Executor-side warm-up failed (non-fatal): ${ex.getMessage}", ex) + context.increment("warmup.executor.failure") + try { schemas.joinCodec.deriveFunc } + catch { case _: Throwable => } + } + LocalIOCache.fetcherWarmedUp = true + } + val requests = rowsScala.map { row => val keyMap = row.getValuesMap[AnyRef](schemas.leftSourceSchema.fieldNames) val eventTs = row.getAs[Long](eventTimeColumn) @@ -645,6 +738,9 @@ class JoinSourceRunner(groupByConf: api.GroupBy, conf: Map[String, String] = Map // Build schemas for each stages of the chaining transformation val schemas = buildSchemas(decoded, reqColumns) + // Driver-side warm-up: pre-initialize CatalystUtil, TTLCaches, and derivation codegen + warmupDriver(schemas, joinSource) + // Enrich each left source rows with base columns (pre-derivations) of the join source val enrichedBase = enrichBaseJoin( leftSource,