Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Global / concurrentRestrictions := Seq(

val awsSdkVersion = "1.12.470"
val awsSdkV2Version = "2.42.25"
val elastic4sVersion = "8.18.2"
val elastic4sVersion = "8.19.1"
val okHttpVersion = "3.12.1"

val bbcBuildProcess: Boolean = System.getenv().asScala.get("BUILD_ORG").contains("bbc")
Expand Down
1 change: 1 addition & 0 deletions kahuna/public/js/search/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ search.config(['$stateProvider', '$urlMatcherFactoryProvider',
'until',
'orderBy',
'useAISearch',
'vecWeight',
'dateField',
'takenSince',
'takenUntil',
Expand Down
7 changes: 5 additions & 2 deletions kahuna/public/js/search/query.js
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,10 @@ query.controller('SearchQueryCtrl', [
ctrl.shouldDisplayAISearchOption = getFeatureSwitchActive("enable-ai-search");
if (!ctrl.shouldDisplayAISearchOption) {
ctrl.useAISearch = false;
ctrl.vecWeight = false;
} else {
ctrl.useAISearch = ($stateParams.useAISearch === 'true' || $stateParams.useAISearch === true) ? true : false;
ctrl.vecWeight = $stateParams.vecWeight;
}

//--react - angular interop events--
Expand Down Expand Up @@ -455,11 +457,12 @@ query.controller('SearchQueryCtrl', [
if (ctrl.useAISearch) {
$state.go('search.results', {
...ctrl.filter,
useAISearch: true
useAISearch: true,
vecWeight: ctrl.vecWeight
});
} else {
$state.go('search.results', {...ctrl.filter, useAISearch: null});
}
}
});

$scope.$watchCollection(() => ctrl.dateFilter, onValChange(({field, since, until}) => {
Expand Down
1 change: 1 addition & 0 deletions kahuna/public/js/search/results.js
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ results.controller('SearchResultsCtrl', [
length: length,
orderBy: orderBy,
useAISearch: $stateParams.useAISearch,
vecWeight: $stateParams.vecWeight,
hasRightsAcquired: $stateParams.hasRightsAcquired,
hasCrops: $stateParams.hasCrops,
syndicationStatus: $stateParams.syndicationStatus,
Expand Down
5 changes: 3 additions & 2 deletions kahuna/public/js/services/api/media-api.js
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ mediaApi.factory('mediaApi',
payType, uploadedBy, offset, length, orderBy,
takenSince, takenUntil,
modifiedSince, modifiedUntil, hasRightsAcquired, hasCrops,
syndicationStatus, countAll, persisted, useAISearch} = {}) {
syndicationStatus, countAll, persisted, useAISearch, vecWeight} = {}) {
return root.follow('search', {
q: query,
since: since,
Expand All @@ -65,7 +65,8 @@ mediaApi.factory('mediaApi',
syndicationStatus: syndicationStatus,
countAll,
persisted,
useAISearch: maybeStringToBoolean(useAISearch)
useAISearch: maybeStringToBoolean(useAISearch),
vecWeight: vecWeight
}).get();
}

Expand Down
22 changes: 16 additions & 6 deletions media-api/app/controllers/MediaApi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class MediaApi(
"syndicationStatus",
"countAll",
"persisted",
"useAISearch"
"useAISearch",
"vecWeight"
).mkString(",")

private val searchLinkHref = s"${config.rootUri}/images{?$searchParamList}"
Expand Down Expand Up @@ -618,7 +619,7 @@ class MediaApi(
} yield searchResults
}

def semanticSearchByText(query: String, k: Int): Future[SearchResults] = {
def semanticSearchByText(query: String, k: Int, vecWeight: Option[Double]): Future[SearchResults] = {
// Normalise key so that "Dogs" and "dogs " share a cache entry.
val cacheKey = query.trim.toLowerCase

Expand All @@ -630,21 +631,30 @@ class MediaApi(
logger.info(markers, s"AI search embedding cache miss query=$query")
}

val weight = vecWeight.getOrElse(0.8)

// cache.get(key) is atomic: if two requests race on the same key, only one
// load fires and both callers receive the same Future.
val embeddingFuture = embeddingCache.get(cacheKey)

logger.info(markers, s"vecWeight for query '$query' is $weight")
for {
embedding <- embeddingFuture
searchResults <- elasticSearch.knnSearch(embedding, k = k, numCandidates = Math.max(k * 2, 100))
searchResults <- elasticSearch.hybridSearch(
query = query,
queryEmbedding = embedding,
k = k,
numCandidates = Math.max(k * 2, 100),
vecWeight = weight,
)
} yield searchResults
}

def performAiSearchAndRespond(query: String): Future[Result] = {
def performAiSearchAndRespond(query: String, vecWeight: Option[Double]): Future[Result] = {
val k = config.aiSearchResultLimit
val searchResultsFuture = parseAiSearchMode(query) match {
case SimilarSearch(imageId) => semanticSearchByImage(imageId, k)
case TextSearch(textQuery) => semanticSearchByText(textQuery, k)
case TextSearch(textQuery) => semanticSearchByText(textQuery, k, vecWeight)
}

searchResultsFuture.map(aiSearchResponseFromResults)
Expand All @@ -656,7 +666,7 @@ class MediaApi(
case _ if _searchParams.length == 0 =>
emptyAiSearchResponse
case Some(q) if !q.isBlank =>
performAiSearchAndRespond(q)
performAiSearchAndRespond(q, _searchParams.vecWeight)
// Empty queries do not make sense for AI search as we can
// only rank results once we have a meaningful vector to compare with.
// So return 0 results if the query was empty.
Expand Down
84 changes: 84 additions & 0 deletions media-api/app/lib/elasticsearch/ElasticSearch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import com.gu.mediaservice.lib.metrics.FutureSyntax
import com.gu.mediaservice.model.{Agencies, Agency, AwaitingReviewForSyndication, Image}
import com.sksamuel.elastic4s.ElasticDsl
import com.sksamuel.elastic4s.ElasticDsl._
import com.sksamuel.elastic4s.requests.common.Operator.And
import com.sksamuel.elastic4s.requests.get.{GetRequest, GetResponse}
import com.sksamuel.elastic4s.requests.script.{Script, ScriptField}
import com.sksamuel.elastic4s.requests.searches._
Expand All @@ -19,6 +20,9 @@ import com.sksamuel.elastic4s.requests.searches.aggs.responses.Aggregations
import com.sksamuel.elastic4s.requests.searches.aggs.responses.bucket.{DateHistogram, Terms}
import com.sksamuel.elastic4s.requests.searches.queries.Query
import com.sksamuel.elastic4s.requests.searches.knn.Knn
import com.sksamuel.elastic4s.requests.searches.queries.compound.BoolQuery
import com.sksamuel.elastic4s.requests.searches.queries.matches.MultiMatchQueryBuilderType.BEST_FIELDS
import com.sksamuel.elastic4s.requests.searches.queries.matches.{FieldWithOptionalBoost, MultiMatchQuery}
import lib.querysyntax.{HierarchyField, Match, Parser, Phrase}
import lib.{MediaApiConfig, MediaApiMetrics, SupplierUsageSummary}
import play.api.libs.json.{JsError, JsObject, JsSuccess, Json}
Expand Down Expand Up @@ -191,6 +195,86 @@ class ElasticSearch(
}
}

def hybridSearch(
query: String,
queryEmbedding: List[Float],
k: Int,
numCandidates: Int,
vecWeight: Double,
)(
implicit ex: ExecutionContext,
logMarker: LogMarker
): Future[SearchResults] = {

// BM25 scores are unbounded [0,inf] and typically much larger in magnitude
// than cosine similarity (knn). So we get the max BM25 score for the query and use that to calculate
// the scaling factor for the lexical part of the query, so that BM25 and knn scores are both between 0-1 scale
// and can be effectively combined in a hybrid query.
def maxBM25Score(query: String): Future[Double] = {
val maxScore = ElasticDsl.search(imagesCurrentAlias)
.query(BoolQuery().must(
MultiMatchQuery(
text = query,
fields = matchFields.map(field => FieldWithOptionalBoost(field, None)),
`type` = Some(BEST_FIELDS),
fuzziness = Some("AUTO"),
maxExpansions = Some(50),
operator = Some(And),
prefixLength = Some(1),
))
)
val maxScoreFuture = executeAndLog(withSearchQueryTimeout(maxScore), "max BM25 score").map { r =>
logger.info(logMarker, s"Max BM25 score for query '$query' is ${r.result.hits.maxScore} with total hits ${r.result.totalHits}")
if (r.result.hits.hits.isEmpty) 1.0 else r.result.hits.maxScore
}
maxScoreFuture
}

val queryEmbeddingDouble: List[Double] = queryEmbedding.map(_.toDouble)
val knn = Knn("embedding.cohereEmbedV4.image")
.queryVector(queryEmbeddingDouble)
.k(k)
.numCandidates(numCandidates)
.boost(1.0)

val lexicalWeight = 1.0 - vecWeight

maxBM25Score(query).flatMap { maxScore =>
// KNN results are in [0,1], but BM25 scores are unbounded and typically much
// larger than cosine similarity, so we need to apply a scaling factor to the
// BM25 score to bring it to the same range as the cosine similarity
val scalingFactor = if (maxScore > 0.0) 1.0 / maxScore else 1.0

// We want to apply only one boost if we can help it, so we scale the
// multi_match boost to be in line with the max_score and the desired
// lexical_weight/vec_weight balance
val multiMatchBoost = if (vecWeight > 0.0) (lexicalWeight/vecWeight) * scalingFactor else 1.0

logger.info(logMarker, s"Scaling factor for BM25 score is $scalingFactor, multi-match boost is $multiMatchBoost")

// TODO make case class for multimatchQuery to avoid repetition
val multiMatchQuery = MultiMatchQuery(
text = query,
fields = matchFields.map(field => FieldWithOptionalBoost(field, None)),
`type` = Some(BEST_FIELDS),
fuzziness = Some("AUTO"),
maxExpansions = Some(50),
operator = Some(And),
prefixLength = Some(1),
boost = Some(multiMatchBoost)
)

val searchRequest = ElasticDsl.search(imagesCurrentAlias)
.bool(BoolQuery().should(Seq(multiMatchQuery, knn)))
.size(k)

executeAndLog(withSearchQueryTimeout(searchRequest), "hybrid search").map { r =>
val imageHits = r.result.hits.hits.map(resolveHit).toSeq.flatten.map(i => (i.instance.id, i))
SearchResults(hits = imageHits, total = r.result.totalHits, extraCounts = None)
}
}
}

def search(params: SearchParams)(implicit ex: ExecutionContext, request: AuthenticatedRequest[AnyContent, Principal], logMarker: LogMarker = MarkerMap()): Future[SearchResults] = {
val query: Query = queryBuilder.makeQuery(params.structuredQuery)

Expand Down
3 changes: 3 additions & 0 deletions media-api/app/lib/elasticsearch/ElasticSearchModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ case class SearchParams(
printUsageFilters: Option[PrintUsageFilters] = None,
shouldFlagGraphicImages: Boolean = false,
useAISearch: Option[Boolean] = None,
vecWeight: Option[Double] = None
)

case class InvalidUriParams(message: String) extends Throwable
Expand Down Expand Up @@ -115,6 +116,7 @@ object SearchParams {

// TODO: return descriptive 400 error if invalid
def parseIntFromQuery(s: String): Option[Int] = Try(s.toInt).toOption
def parseDoubleFromQuery(s: String): Option[Double] = Try(s.toDouble).toOption
def parsePayTypeFromQuery(s: String): Option[PayType.Value] = PayType.create(s)
def parseBooleanFromQuery(s: String): Option[Boolean] = Try(s.toBoolean).toOption
def parseSyndicationStatus(s: String): Option[SyndicationStatus] = Some(SyndicationStatus(s))
Expand Down Expand Up @@ -175,6 +177,7 @@ object SearchParams {
printUsageFilters,
shouldFlagGraphicImages = false,
request.getQueryString("useAISearch") flatMap parseBooleanFromQuery,
request.getQueryString("vecWeight") flatMap parseDoubleFromQuery,
)
}

Expand Down
Loading