Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions online/src/main/scala/ai/chronon/online/CatalystUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ class PoolMap[Key, Value](createFunc: Key => Value, maxSize: Int = 100, initialS
pool.offer(value)
}
}

def warmup(key: Key, targetSize: Int): Unit = {
val pool = getPool(key)
val toCreate = Math.max(0, targetSize - pool.size())
(0 until toCreate).foreach { _ =>
pool.offer(createFunc(key))
}
}
}

class PooledCatalystUtil(expressions: collection.Seq[(String, String)], inputSchema: StructType) {
Expand All @@ -125,6 +133,8 @@ class PooledCatalystUtil(expressions: collection.Seq[(String, String)], inputSch
}
def outputChrononSchema: Array[(String, DataType)] =
poolMap.performWithValue(poolKey, cuPool) { _.outputChrononSchema }

def warmup(targetSize: Int): Unit = poolMap.warmup(poolKey, targetSize)
}

// This class by itself it not thread safe because of the transformBuffer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
package ai.chronon.online.test

import ai.chronon.api._
import ai.chronon.online.{CatalystUtil, PooledCatalystUtil}
import ai.chronon.online.{CatalystUtil, PoolMap, PooledCatalystUtil}
import ai.chronon.online.CatalystUtil.PoolKey
import junit.framework.TestCase
import org.junit.Assert.{assertArrayEquals, assertEquals, assertTrue}
import org.junit.Test
Expand Down Expand Up @@ -655,6 +656,22 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs {
"value" -> Array(Array(1L, "data1"), Array(2L, "data2"))
)

@Test
def testPoolMapWarmupPrePopulatesPool(): Unit = {
val selects = Seq("int32_x" -> "int32_x")
val key = PoolKey(selects, CommonScalarsStruct)
val poolMap = new PoolMap[PoolKey, CatalystUtil](pi => new CatalystUtil(pi.expressions, pi.inputSchema))
// pool starts with initialSize=2 after first getPool call
poolMap.getPool(key)
assertEquals(2, poolMap.map.get(key).size())
// warmup to targetSize=5 adds 3 more
poolMap.warmup(key, 5)
assertEquals(5, poolMap.map.get(key).size())
// warmup below current size is a no-op
poolMap.warmup(key, 3)
assertEquals(5, poolMap.map.get(key).size())
}

def testPooledCatalystUtil(): Unit = {
val selects = Seq(
"ts" -> "ts",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package ai.chronon.spark.streaming

import ai.chronon.api
import ai.chronon.api.Extensions.{GroupByOps, JoinOps, MetadataOps, SourceOps}
import ai.chronon.api.Extensions.{DerivationOps, GroupByOps, JoinOps, MetadataOps, SourceOps}
import ai.chronon.api._
import ai.chronon.online.Fetcher.{Request, ResponseWithContext}
import ai.chronon.online.KVStore.PutRequest
Expand Down Expand Up @@ -48,6 +48,8 @@ import scala.util.{Failure, Success}
object LocalIOCache {
private var fetcher: Fetcher = null
private var kvStore: KVStore = null
@volatile var fetcherWarmedUp: Boolean = false

def getOrSetFetcher(builderFunc: () => Fetcher): Fetcher = {
if (fetcher == null) {
fetcher = builderFunc()
Expand Down Expand Up @@ -121,6 +123,14 @@ class JoinSourceRunner(groupByConf: api.GroupBy, conf: Map[String, String] = Map
// Micro batch repartition size - when set to 0, we won't do the repartition
private val microBatchRepartition: Int = getProp("batch_repartition", "0").toInt

// Warm-up: pre-initialize lazy components and JIT warm-up using real requests before real processing
private val warmupEnabled: Boolean = getProp("warmup.enabled", "true").toBoolean
private val warmupPoolSize: Int = getProp("warmup.pool_size", "4").toInt
private val warmupRequestCount: Int = getProp("warmup.request_count", "10").toInt
// Longer timeout for warm-up: absorbs cold-start cost (KV connections, CatalystUtil.session,
// Janino codegen) which can exceed the normal 5s production timeout on first request.
private val warmupTimeoutSeconds: Int = getProp("warmup.timeout_seconds", "60").toInt

private case class PutRequestHelper(inputSchema: StructType) extends Serializable {
@transient implicit lazy val logger = LoggerFactory.getLogger(getClass)
private val keyIndices: Array[Int] = keyColumns.map(inputSchema.fieldIndex)
Expand Down Expand Up @@ -377,6 +387,54 @@ class JoinSourceRunner(groupByConf: api.GroupBy, conf: Map[String, String] = Map
}
}

private def warmupDriver(schemas: Schemas, joinSource: JoinSource): Unit = {
if (!warmupEnabled) return
val startMs = System.currentTimeMillis()
logger.info("Starting driver-side warm-up for join derivations...")
try {
// 1. Force CatalystUtil.session (JVM singleton SparkSession, most expensive)
CatalystUtil.session

// 2. Initialize Fetcher + KV store
val fetcher = getOrCreateFetcher()
LocalIOCache.getOrSetKvStore { () => apiImpl.genKvStore }

// 3. Pre-populate TTLCache: GroupByServingInfo for all join parts
val joinRequestName = joinSource.join.metaData.getName.replaceFirst("\\.", "/")
joinSource.join.joinPartOps.foreach { part =>
fetcher.getGroupByServingInfo(part.groupBy.metaData.getName)
}

// 4. Pre-populate JoinCodec TTLCache
fetcher.getJoinCodecs(joinRequestName)

// 5. Force JoinCodec.deriveFunc to trigger PooledCatalystUtil creation + Catalyst codegen
schemas.joinCodec.deriveFunc

// 6. Pre-populate CatalystUtil pool beyond default initialSize=2
if (warmupPoolSize > 2) {
val derivationsScala = schemas.joinCodec.conf.derivationsScala
if (
derivationsScala != null && !derivationsScala.isEmpty &&
!derivationsScala.areDerivationsRenameOnly
) {
val pcu = DerivationUtils.buildCatalystUtil(derivationsScala,
schemas.joinCodec.keySchema,
schemas.joinCodec.baseValueSchema)
pcu.warmup(warmupPoolSize)
}
}

val elapsed = System.currentTimeMillis() - startMs
logger.info(s"Driver-side warm-up completed in ${elapsed}ms")
context.distribution("warmup.driver.latency_ms", elapsed)
} catch {
case ex: Throwable =>
logger.warn(s"Driver-side warm-up failed (non-fatal): ${ex.getMessage}", ex)
context.increment("warmup.driver.failure")
}
}

private def arrayToRow(values: Any, schema: StructType): Row = {
val genericRow = SparkConversions
.toSparkRowSparkType(values, schema)
Expand All @@ -398,6 +456,41 @@ class JoinSourceRunner(groupByConf: api.GroupBy, conf: Map[String, String] = Map

// Convert left rows to fetcher requests
val rowsScala = rows.toScala.toArray

// Executor-side JIT warm-up
if (warmupEnabled && !LocalIOCache.fetcherWarmedUp && rowsScala.nonEmpty) {
val warmupRows = rowsScala.take(warmupRequestCount)
val warmupRequests = warmupRows.map { row =>
val keyMap = row.getValuesMap[AnyRef](schemas.leftSourceSchema.fieldNames)
val eventTs = row.getAs[Long](eventTimeColumn)
val ts = if (useEventTimeForQuery) Some(eventTs) else None
Request(joinRequestName, keyMap, atMillis = ts.map(_ + queryShiftMs))
}
val warmupStartMs = System.currentTimeMillis()
try {
val warmupFuture = fetcher.fetchBaseJoin(warmupRequests, Option(joinSource.join))
val warmupResponses = Await.result(warmupFuture, warmupTimeoutSeconds.seconds)

val deriveFunc = schemas.joinCodec.deriveFunc
warmupResponses.foreach { response =>
try { deriveFunc(response.request.keys, response.baseValues) }
catch { case _: Throwable => }
}

logger.info(
s"Executor-side warm-up complete with ${warmupRows.length} requests in " +
s"${System.currentTimeMillis() - warmupStartMs}ms")
context.distribution("warmup.executor.latency_ms", System.currentTimeMillis() - warmupStartMs)
} catch {
case ex: Throwable =>
logger.warn(s"Executor-side warm-up failed (non-fatal): ${ex.getMessage}", ex)
context.increment("warmup.executor.failure")
try { schemas.joinCodec.deriveFunc }
catch { case _: Throwable => }
}
LocalIOCache.fetcherWarmedUp = true
}

val requests = rowsScala.map { row =>
val keyMap = row.getValuesMap[AnyRef](schemas.leftSourceSchema.fieldNames)
val eventTs = row.getAs[Long](eventTimeColumn)
Expand Down Expand Up @@ -645,6 +738,9 @@ class JoinSourceRunner(groupByConf: api.GroupBy, conf: Map[String, String] = Map
// Build schemas for each stages of the chaining transformation
val schemas = buildSchemas(decoded, reqColumns)

// Driver-side warm-up: pre-initialize CatalystUtil, TTLCaches, and derivation codegen
warmupDriver(schemas, joinSource)

// Enrich each left source rows with base columns (pre-derivations) of the join source
val enrichedBase = enrichBaseJoin(
leftSource,
Expand Down