diff --git a/build.sbt b/build.sbt index cf20ce1ae5..fc1dc3cb4a 100644 --- a/build.sbt +++ b/build.sbt @@ -78,7 +78,7 @@ Global / concurrentRestrictions := Seq( val awsSdkVersion = "1.12.470" val awsSdkV2Version = "2.42.25" -val elastic4sVersion = "8.18.2" +val elastic4sVersion = "8.19.1" val okHttpVersion = "3.12.1" val bbcBuildProcess: Boolean = System.getenv().asScala.get("BUILD_ORG").contains("bbc") diff --git a/kahuna/public/js/search/index.js b/kahuna/public/js/search/index.js index 5d57125684..8897e6006e 100644 --- a/kahuna/public/js/search/index.js +++ b/kahuna/public/js/search/index.js @@ -179,6 +179,7 @@ search.config(['$stateProvider', '$urlMatcherFactoryProvider', 'until', 'orderBy', 'useAISearch', + 'vecWeight', 'dateField', 'takenSince', 'takenUntil', diff --git a/kahuna/public/js/search/query.js b/kahuna/public/js/search/query.js index 56d1e255b1..6066e77478 100644 --- a/kahuna/public/js/search/query.js +++ b/kahuna/public/js/search/query.js @@ -82,8 +82,10 @@ query.controller('SearchQueryCtrl', [ ctrl.shouldDisplayAISearchOption = getFeatureSwitchActive("enable-ai-search"); if (!ctrl.shouldDisplayAISearchOption) { ctrl.useAISearch = false; + ctrl.vecWeight = false; } else { ctrl.useAISearch = ($stateParams.useAISearch === 'true' || $stateParams.useAISearch === true) ? true : false; + ctrl.vecWeight = $stateParams.vecWeight; } //--react - angular interop events-- @@ -455,11 +457,12 @@ query.controller('SearchQueryCtrl', [ if (ctrl.useAISearch) { $state.go('search.results', { ...ctrl.filter, - useAISearch: true + useAISearch: true, + vecWeight: ctrl.vecWeight }); } else { $state.go('search.results', {...ctrl.filter, useAISearch: null}); - } + } }); $scope.$watchCollection(() => ctrl.dateFilter, onValChange(({field, since, until}) => { diff --git a/kahuna/public/js/search/results.js b/kahuna/public/js/search/results.js index ce3a30afcb..1268fd9a9e 100644 --- a/kahuna/public/js/search/results.js +++ b/kahuna/public/js/search/results.js @@ -578,6 +578,7 @@ results.controller('SearchResultsCtrl', [ length: length, orderBy: orderBy, useAISearch: $stateParams.useAISearch, + vecWeight: $stateParams.vecWeight, hasRightsAcquired: $stateParams.hasRightsAcquired, hasCrops: $stateParams.hasCrops, syndicationStatus: $stateParams.syndicationStatus, diff --git a/kahuna/public/js/services/api/media-api.js b/kahuna/public/js/services/api/media-api.js index 845f4df31b..4687cb6277 100644 --- a/kahuna/public/js/services/api/media-api.js +++ b/kahuna/public/js/services/api/media-api.js @@ -42,7 +42,7 @@ mediaApi.factory('mediaApi', payType, uploadedBy, offset, length, orderBy, takenSince, takenUntil, modifiedSince, modifiedUntil, hasRightsAcquired, hasCrops, - syndicationStatus, countAll, persisted, useAISearch} = {}) { + syndicationStatus, countAll, persisted, useAISearch, vecWeight} = {}) { return root.follow('search', { q: query, since: since, @@ -65,7 +65,8 @@ mediaApi.factory('mediaApi', syndicationStatus: syndicationStatus, countAll, persisted, - useAISearch: maybeStringToBoolean(useAISearch) + useAISearch: maybeStringToBoolean(useAISearch), + vecWeight: vecWeight }).get(); } diff --git a/media-api/app/controllers/MediaApi.scala b/media-api/app/controllers/MediaApi.scala index c94676cd2c..9fb8455f05 100644 --- a/media-api/app/controllers/MediaApi.scala +++ b/media-api/app/controllers/MediaApi.scala @@ -88,7 +88,8 @@ class MediaApi( "syndicationStatus", "countAll", "persisted", - "useAISearch" + "useAISearch", + "vecWeight" ).mkString(",") private val searchLinkHref = s"${config.rootUri}/images{?$searchParamList}" @@ -618,7 +619,7 @@ class MediaApi( } yield searchResults } - def semanticSearchByText(query: String, k: Int): Future[SearchResults] = { + def semanticSearchByText(query: String, k: Int, vecWeight: Option[Double]): Future[SearchResults] = { // Normalise key so that "Dogs" and "dogs " share a cache entry. val cacheKey = query.trim.toLowerCase @@ -630,21 +631,30 @@ class MediaApi( logger.info(markers, s"AI search embedding cache miss query=$query") } + val weight = vecWeight.getOrElse(0.8) + // cache.get(key) is atomic: if two requests race on the same key, only one // load fires and both callers receive the same Future. val embeddingFuture = embeddingCache.get(cacheKey) + logger.info(markers, s"vecWeight for query '$query' is $weight") for { embedding <- embeddingFuture - searchResults <- elasticSearch.knnSearch(embedding, k = k, numCandidates = Math.max(k * 2, 100)) + searchResults <- elasticSearch.hybridSearch( + query = query, + queryEmbedding = embedding, + k = k, + numCandidates = Math.max(k * 2, 100), + vecWeight = weight, + ) } yield searchResults } - def performAiSearchAndRespond(query: String): Future[Result] = { + def performAiSearchAndRespond(query: String, vecWeight: Option[Double]): Future[Result] = { val k = config.aiSearchResultLimit val searchResultsFuture = parseAiSearchMode(query) match { case SimilarSearch(imageId) => semanticSearchByImage(imageId, k) - case TextSearch(textQuery) => semanticSearchByText(textQuery, k) + case TextSearch(textQuery) => semanticSearchByText(textQuery, k, vecWeight) } searchResultsFuture.map(aiSearchResponseFromResults) @@ -656,7 +666,7 @@ class MediaApi( case _ if _searchParams.length == 0 => emptyAiSearchResponse case Some(q) if !q.isBlank => - performAiSearchAndRespond(q) + performAiSearchAndRespond(q, _searchParams.vecWeight) // Empty queries do not make sense for AI search as we can // only rank results once we have a meaningful vector to compare with. // So return 0 results if the query was empty. diff --git a/media-api/app/lib/elasticsearch/ElasticSearch.scala b/media-api/app/lib/elasticsearch/ElasticSearch.scala index 1aa0f6fc32..63700ffd5d 100644 --- a/media-api/app/lib/elasticsearch/ElasticSearch.scala +++ b/media-api/app/lib/elasticsearch/ElasticSearch.scala @@ -11,6 +11,7 @@ import com.gu.mediaservice.lib.metrics.FutureSyntax import com.gu.mediaservice.model.{Agencies, Agency, AwaitingReviewForSyndication, Image} import com.sksamuel.elastic4s.ElasticDsl import com.sksamuel.elastic4s.ElasticDsl._ +import com.sksamuel.elastic4s.requests.common.Operator.And import com.sksamuel.elastic4s.requests.get.{GetRequest, GetResponse} import com.sksamuel.elastic4s.requests.script.{Script, ScriptField} import com.sksamuel.elastic4s.requests.searches._ @@ -19,6 +20,9 @@ import com.sksamuel.elastic4s.requests.searches.aggs.responses.Aggregations import com.sksamuel.elastic4s.requests.searches.aggs.responses.bucket.{DateHistogram, Terms} import com.sksamuel.elastic4s.requests.searches.queries.Query import com.sksamuel.elastic4s.requests.searches.knn.Knn +import com.sksamuel.elastic4s.requests.searches.queries.compound.BoolQuery +import com.sksamuel.elastic4s.requests.searches.queries.matches.MultiMatchQueryBuilderType.BEST_FIELDS +import com.sksamuel.elastic4s.requests.searches.queries.matches.{FieldWithOptionalBoost, MultiMatchQuery} import lib.querysyntax.{HierarchyField, Match, Parser, Phrase} import lib.{MediaApiConfig, MediaApiMetrics, SupplierUsageSummary} import play.api.libs.json.{JsError, JsObject, JsSuccess, Json} @@ -191,6 +195,86 @@ class ElasticSearch( } } + def hybridSearch( + query: String, + queryEmbedding: List[Float], + k: Int, + numCandidates: Int, + vecWeight: Double, + )( + implicit ex: ExecutionContext, + logMarker: LogMarker + ): Future[SearchResults] = { + +// BM25 scores are unbounded [0,inf] and typically much larger in magnitude +// than cosine similarity (knn). So we get the max BM25 score for the query and use that to calculate +// the scaling factor for the lexical part of the query, so that BM25 and knn scores are both between 0-1 scale +// and can be effectively combined in a hybrid query. + def maxBM25Score(query: String): Future[Double] = { + val maxScore = ElasticDsl.search(imagesCurrentAlias) + .query(BoolQuery().must( + MultiMatchQuery( + text = query, + fields = matchFields.map(field => FieldWithOptionalBoost(field, None)), + `type` = Some(BEST_FIELDS), + fuzziness = Some("AUTO"), + maxExpansions = Some(50), + operator = Some(And), + prefixLength = Some(1), + )) + ) + val maxScoreFuture = executeAndLog(withSearchQueryTimeout(maxScore), "max BM25 score").map { r => + logger.info(logMarker, s"Max BM25 score for query '$query' is ${r.result.hits.maxScore} with total hits ${r.result.totalHits}") + if (r.result.hits.hits.isEmpty) 1.0 else r.result.hits.maxScore + } + maxScoreFuture + } + + val queryEmbeddingDouble: List[Double] = queryEmbedding.map(_.toDouble) + val knn = Knn("embedding.cohereEmbedV4.image") + .queryVector(queryEmbeddingDouble) + .k(k) + .numCandidates(numCandidates) + .boost(1.0) + + val lexicalWeight = 1.0 - vecWeight + + maxBM25Score(query).flatMap { maxScore => +// KNN results are in [0,1], but BM25 scores are unbounded and typically much +// larger than cosine similarity, so we need to apply a scaling factor to the +// BM25 score to bring it to the same range as the cosine similarity + val scalingFactor = if (maxScore > 0.0) 1.0 / maxScore else 1.0 + +// We want to apply only one boost if we can help it, so we scale the +// multi_match boost to be in line with the max_score and the desired +// lexical_weight/vec_weight balance + val multiMatchBoost = if (vecWeight > 0.0) (lexicalWeight/vecWeight) * scalingFactor else 1.0 + + logger.info(logMarker, s"Scaling factor for BM25 score is $scalingFactor, multi-match boost is $multiMatchBoost") + +// TODO make case class for multimatchQuery to avoid repetition + val multiMatchQuery = MultiMatchQuery( + text = query, + fields = matchFields.map(field => FieldWithOptionalBoost(field, None)), + `type` = Some(BEST_FIELDS), + fuzziness = Some("AUTO"), + maxExpansions = Some(50), + operator = Some(And), + prefixLength = Some(1), + boost = Some(multiMatchBoost) + ) + + val searchRequest = ElasticDsl.search(imagesCurrentAlias) + .bool(BoolQuery().should(Seq(multiMatchQuery, knn))) + .size(k) + + executeAndLog(withSearchQueryTimeout(searchRequest), "hybrid search").map { r => + val imageHits = r.result.hits.hits.map(resolveHit).toSeq.flatten.map(i => (i.instance.id, i)) + SearchResults(hits = imageHits, total = r.result.totalHits, extraCounts = None) + } + } + } + def search(params: SearchParams)(implicit ex: ExecutionContext, request: AuthenticatedRequest[AnyContent, Principal], logMarker: LogMarker = MarkerMap()): Future[SearchResults] = { val query: Query = queryBuilder.makeQuery(params.structuredQuery) diff --git a/media-api/app/lib/elasticsearch/ElasticSearchModel.scala b/media-api/app/lib/elasticsearch/ElasticSearchModel.scala index a0cd6a0db7..7ea48b73fc 100644 --- a/media-api/app/lib/elasticsearch/ElasticSearchModel.scala +++ b/media-api/app/lib/elasticsearch/ElasticSearchModel.scala @@ -86,6 +86,7 @@ case class SearchParams( printUsageFilters: Option[PrintUsageFilters] = None, shouldFlagGraphicImages: Boolean = false, useAISearch: Option[Boolean] = None, + vecWeight: Option[Double] = None ) case class InvalidUriParams(message: String) extends Throwable @@ -115,6 +116,7 @@ object SearchParams { // TODO: return descriptive 400 error if invalid def parseIntFromQuery(s: String): Option[Int] = Try(s.toInt).toOption + def parseDoubleFromQuery(s: String): Option[Double] = Try(s.toDouble).toOption def parsePayTypeFromQuery(s: String): Option[PayType.Value] = PayType.create(s) def parseBooleanFromQuery(s: String): Option[Boolean] = Try(s.toBoolean).toOption def parseSyndicationStatus(s: String): Option[SyndicationStatus] = Some(SyndicationStatus(s)) @@ -175,6 +177,7 @@ object SearchParams { printUsageFilters, shouldFlagGraphicImages = false, request.getQueryString("useAISearch") flatMap parseBooleanFromQuery, + request.getQueryString("vecWeight") flatMap parseDoubleFromQuery, ) }