diff --git a/NEWS.md b/NEWS.md index f6e5e582a..acf544ecf 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # ellmer (development version) +* `batch_chat()` now supports `chat_google_gemini()` for batch processing via + the Gemini Developer API (@xmarquez, #914). * ellmer will now distinguish text content from thinking content while streaming, allowing downstream packages like shinychat to provide specific UI for thinking content (@simonpcouch, #909). * `chat_github()` now uses `chat_openai_compatible()` for improved compatibility, and `models_github()` now supports custom `base_url` configuration (@D-M4rk, #877). * `chat_ollama()` now contains a slot for `top_k` within the `params` argument (@frankiethull). diff --git a/R/batch-chat.R b/R/batch-chat.R index 4a4b5084c..d251859fa 100644 --- a/R/batch-chat.R +++ b/R/batch-chat.R @@ -2,10 +2,11 @@ #' #' @description #' `batch_chat()` and `batch_chat_structured()` currently only work with -#' [chat_openai()] and [chat_anthropic()]. They use the -#' [OpenAI](https://platform.openai.com/docs/guides/batch) and -#' [Anthropic](https://docs.claude.com/en/docs/build-with-claude/batch-processing) -#' batch APIs which allow you to submit multiple requests simultaneously. +#' [chat_openai()], [chat_anthropic()], and [chat_google_gemini()]. They use +#' the [OpenAI](https://platform.openai.com/docs/guides/batch), +#' [Anthropic](https://docs.claude.com/en/docs/build-with-claude/batch-processing), +#' and [Google Gemini](https://ai.google.dev/gemini-api/docs/batch-api) batch APIs +#' which allow you to submit multiple requests simultaneously. #' The results can take up to 24 hours to complete, but in return you pay 50% #' less than usual (but note that ellmer doesn't include this discount in #' its pricing metadata). If you want to get results back more quickly, or @@ -90,6 +91,9 @@ batch_chat <- function(chat, prompts, path, wait = TRUE, ignore_hash = FALSE) { ignore_hash = ignore_hash ) job$step_until_done() + if (job$stage != "done") { + return(NULL) + } assistant_turns <- job$result_turns() map2(job$user_turns, assistant_turns, function(user, assistant) { @@ -119,6 +123,9 @@ batch_chat_text <- function( wait = wait, ignore_hash = ignore_hash ) + if (is.null(chats)) { + return(NULL) + } map_chr(chats, \(chat) { if (is.null(chat)) NA_character_ else chat$last_turn()@text }) @@ -151,6 +158,9 @@ batch_chat_structured <- function( ignore_hash = ignore_hash ) job$step_until_done() + if (job$stage != "done") { + return(NULL) + } turns <- job$result_turns() multi_convert( diff --git a/R/provider-google-upload.R b/R/provider-google-upload.R index 586a7f2bc..2b819950b 100644 --- a/R/provider-google-upload.R +++ b/R/provider-google-upload.R @@ -43,6 +43,17 @@ google_upload <- function( mime_type <- mime_type %||% guess_mime_type(path) + status <- google_upload_file( + path = path, + base_url = base_url, + credentials = credentials, + mime_type = mime_type + ) + + ContentUploaded(uri = status$uri, mime_type = status$mimeType) +} + +google_upload_file <- function(path, base_url, credentials, mime_type) { upload_url <- google_upload_init( path = path, base_url = base_url, @@ -56,8 +67,7 @@ google_upload <- function( credentials = credentials ) google_upload_wait(status, credentials) - - ContentUploaded(uri = status$uri, mime_type = status$mimeType) + status } # https://ai.google.dev/api/files#method:-media.upload @@ -124,6 +134,31 @@ google_upload_wait <- function(status, credentials) { invisible() } +# Batch file helpers ----------------------------------------------------------- + +gemini_upload_file <- function( + provider, + path, + mime_type = "application/jsonl" +) { + upload_base_url <- sub("/v[^/]+/?$", "/", provider@base_url) + + google_upload_file( + path = path, + base_url = upload_base_url, + credentials = provider@credentials, + mime_type = mime_type + ) +} + +gemini_download_file <- function(provider, name, path) { + req <- base_request(provider) + req <- req_url_path_append(req, paste0(name, ":download")) + req <- req_url_query(req, alt = "media") + req_perform(req, path = path) + invisible(path) +} + # Helpers ---------------------------------------------------------------------- guess_mime_type <- function(file_path, call = caller_env()) { diff --git a/R/provider-google.R b/R/provider-google.R index c403435ab..1753e418c 100644 --- a/R/provider-google.R +++ b/R/provider-google.R @@ -872,3 +872,275 @@ models_google <- function( google_location <- function(location) { if (location == "global") "" else paste0(location, "-") } + +# Batched requests ------------------------------------------------------------- + +# https://ai.google.dev/gemini-api/docs/batch-api +method(has_batch_support, ProviderGoogleGemini) <- function(provider) { + TRUE +} + +method(batch_submit, ProviderGoogleGemini) <- function( + provider, + conversations, + type = NULL +) { + path <- withr::local_tempfile(fileext = ".jsonl") + + requests <- map(seq_along(conversations), function(i) { + body <- chat_body( + provider, + stream = FALSE, + turns = conversations[[i]], + type = type + ) + + list( + key = paste0("chat-", i), + request = gemini_prepare_batch_body(body) + ) + }) + + json_lines <- map_chr(requests, to_json) + writeLines(json_lines, path) + + uploaded <- gemini_upload_file(provider, path) + if (is.null(uploaded$name) || !nzchar(uploaded$name)) { + cli::cli_abort( + "Gemini upload did not return a file resource name.", + .internal = TRUE + ) + } + + req <- base_request(provider) + req <- req_url_path_append( + req, + "models", + paste0(provider@model, ":batchGenerateContent") + ) + req <- req_body_json( + req, + list( + batch = list( + displayName = paste0("ellmer-", as.integer(Sys.time())), + model = paste0("models/", provider@model), + inputConfig = list(fileName = uploaded$name) + ) + ) + ) + + resp <- req_perform(req) + resp_body_json(resp) +} + +method(batch_poll, ProviderGoogleGemini) <- function(provider, batch) { + req <- base_request(provider) + req <- req_url_path_append(req, batch$name) + resp <- req_perform(req) + resp_body_json(resp) +} + +method(batch_status, ProviderGoogleGemini) <- function(provider, batch) { + metadata <- batch$metadata %||% list() + response <- batch$response %||% list() + state <- metadata$state %||% response$state %||% "BATCH_STATE_UNSPECIFIED" + stats <- metadata$batchStats %||% response$batchStats %||% list() + + total <- as.integer(stats$requestCount %||% 0L) + pending <- as.integer(stats$pendingRequestCount %||% 0L) + succeeded <- as.integer(stats$successfulRequestCount %||% 0L) + failed <- as.integer(stats$failedRequestCount %||% 0L) + + if (!is.null(batch$error) && total > 0 && failed == 0L) { + failed <- total + } + + terminal_states <- c( + "BATCH_STATE_SUCCEEDED", + "BATCH_STATE_FAILED", + "BATCH_STATE_CANCELLED", + "BATCH_STATE_EXPIRED" + ) + + is_done <- state %in% terminal_states + + # Keep polling if succeeded but output file isn't available yet. + # The API can report BATCH_STATE_SUCCEEDED before the responsesFile + # metadata is populated. + if (state == "BATCH_STATE_SUCCEEDED") { + responses_file <- batch$response$responsesFile %||% "" + if (!nzchar(responses_file)) { + is_done <- FALSE + } + } + + n_processing <- max(pending, total - succeeded - failed, 0L) + + list( + working = !is_done, + n_processing = n_processing, + n_succeeded = max(succeeded, 0L), + n_failed = max(failed, 0L) + ) +} + +method(batch_retrieve, ProviderGoogleGemini) <- function(provider, batch) { + metadata <- batch$metadata %||% list() + response <- batch$response %||% list() + stats <- metadata$batchStats %||% response$batchStats %||% list() + request_count <- as.integer(stats$requestCount %||% 0L) + + if (!is.null(batch$error)) { + code <- as.integer(batch$error$code %||% 500L) + return(rep( + list(list(status_code = code, body = NULL)), + max(0L, request_count) + )) + } + + responses_file <- batch$response$responsesFile %||% "" + + if (!nzchar(responses_file)) { + cli::cli_abort("Gemini batch completed but no output file was returned.") + } + + path_output <- withr::local_tempfile(fileext = ".jsonl") + gemini_download_file(provider, responses_file, path_output) + + parsed <- read_ndjson(path_output) + + normalized <- imap(parsed, function(x, i) { + gemini_normalize_result(x, index_default = as.integer(i)) + }) + + ids <- vapply(normalized, function(x) x$index, integer(1)) + results <- lapply(normalized, function(x) x$result) + results[order(ids)] +} + +method(batch_result_turn, ProviderGoogleGemini) <- function( + provider, + result, + has_type = FALSE +) { + if (!is.null(result) && result$status_code == 200L && !is.null(result$body)) { + value_turn(provider, result$body, has_type = has_type) + } else { + NULL + } +} + +# Gemini batch helpers --------------------------------------------------------- + +# The Gemini REST API accepts both camelCase and snake_case, but the batch +# JSONL file parser requires protobuf field names which are always snake_case. +# Without this conversion, the batch API rejects requests with HTTP 400. +gemini_to_snake_case <- function(x) { + if (is.list(x)) { + if (!is.null(names(x))) { + names(x) <- gsub("([a-z])([A-Z])", "\\1_\\2", names(x), perl = TRUE) |> + tolower() + } + lapply(x, gemini_to_snake_case) + } else { + x + } +} + +gemini_prepare_batch_body <- function(body) { + # Remove empty system instructions (batch parser rejects them) + si <- body$systemInstruction %||% body$system_instruction + if (!is.null(si)) { + parts <- si$parts + is_empty <- if (is.list(parts) && !is.null(names(parts))) { + identical(parts$text, "") || is.null(parts$text) + } else if (is.list(parts) && length(parts) > 0) { + all(vapply( + parts, + function(p) identical(p$text, "") || is.null(p$text), + logical(1) + )) + } else { + TRUE + } + if (is_empty) { + body$systemInstruction <- NULL + body$system_instruction <- NULL + } + } + + # Save user-defined schema before snake_case conversion so property names + # like "firstName" are not mangled to "first_name" + gc_pre <- body$generationConfig %||% body$generation_config + saved_schema <- if (!is.null(gc_pre)) { + gc_pre$responseSchema %||% gc_pre$response_schema + } + + # Batch JSONL requires protobuf-style snake_case field names; camelCase causes + # HTTP 400 (unlike the REST API which accepts both) + body <- gemini_to_snake_case(body) + + # Rename response_schema -> response_json_schema and restore original schema + gc <- body$generation_config + if ( + !is.null(gc) && (!is.null(gc$response_schema) || !is.null(saved_schema)) + ) { + gc$response_json_schema <- saved_schema %||% gc$response_schema + gc$response_schema <- NULL + body$generation_config <- gc + } + + body +} + +gemini_extract_index <- function(x, default = NA_integer_) { + metadata <- x$metadata %||% list() + idx <- metadata$request_index %||% metadata$index + + if (!is.null(idx) && !is.na(idx)) { + return(as.integer(idx)) + } + + key <- x$key %||% + x$custom_id %||% + metadata$key %||% + metadata$custom_id %||% + "" + if (grepl("^chat-[0-9]+$", key)) { + return(as.integer(sub("^chat-([0-9]+)$", "\\1", key))) + } + + as.integer(default) +} + +gemini_normalize_result <- function(x, index_default) { + index <- gemini_extract_index(x, default = index_default) + + # Formats where response and error/status are wrapped in one object + if (!is.null(x$response) || !is.null(x$error) || !is.null(x$status)) { + if (!is.null(x$response) && is.null(x$error) && is.null(x$status)) { + return(list( + index = index, + result = list(status_code = 200L, body = x$response) + )) + } + + status <- x$error %||% x$status %||% list() + code <- status$code %||% 500L + return(list( + index = index, + result = list(status_code = as.integer(code), body = NULL) + )) + } + + # Plain GenerateContentResponse lines (current file-mode output) + if ( + !is.null(x$candidates) || + !is.null(x$promptFeedback) || + !is.null(x$usageMetadata) + ) { + return(list(index = index, result = list(status_code = 200L, body = x))) + } + + list(index = index, result = list(status_code = 500L, body = NULL)) +} diff --git a/man/batch_chat.Rd b/man/batch_chat.Rd index e6945267a..654592389 100644 --- a/man/batch_chat.Rd +++ b/man/batch_chat.Rd @@ -75,10 +75,11 @@ is not complete. } \description{ \code{batch_chat()} and \code{batch_chat_structured()} currently only work with -\code{\link[=chat_openai]{chat_openai()}} and \code{\link[=chat_anthropic]{chat_anthropic()}}. They use the -\href{https://platform.openai.com/docs/guides/batch}{OpenAI} and -\href{https://docs.claude.com/en/docs/build-with-claude/batch-processing}{Anthropic} -batch APIs which allow you to submit multiple requests simultaneously. +\code{\link[=chat_openai]{chat_openai()}}, \code{\link[=chat_anthropic]{chat_anthropic()}}, and \code{\link[=chat_google_gemini]{chat_google_gemini()}}. They use +the \href{https://platform.openai.com/docs/guides/batch}{OpenAI}, +\href{https://docs.claude.com/en/docs/build-with-claude/batch-processing}{Anthropic}, +and \href{https://ai.google.dev/gemini-api/docs/batch-api}{Google Gemini} batch APIs +which allow you to submit multiple requests simultaneously. The results can take up to 24 hours to complete, but in return you pay 50\% less than usual (but note that ellmer doesn't include this discount in its pricing metadata). If you want to get results back more quickly, or diff --git a/tests/testthat/batch/state-capitals-gemini.json b/tests/testthat/batch/state-capitals-gemini.json new file mode 100644 index 000000000..c9532c485 --- /dev/null +++ b/tests/testthat/batch/state-capitals-gemini.json @@ -0,0 +1,172 @@ +{ + "version": 1, + "stage": "done", + "batch": { + "name": "batches/rsv6pu5hikanbevp6or89zaoo8bs2uw7vjav", + "metadata": { + "@type": "type.googleapis.com/google.ai.generativelanguage.v1main.GenerateContentBatch", + "model": "models/gemini-2.5-flash", + "displayName": "ellmer-1771051113", + "inputConfig": { + "fileName": "files/1xyylny2fl94" + }, + "output": { + "responsesFile": "files/batch-rsv6pu5hikanbevp6or89zaoo8bs2uw7vjav" + }, + "createTime": "2026-02-14T06:38:35.278957778Z", + "endTime": "2026-02-14T06:39:34.739884130Z", + "updateTime": "2026-02-14T06:39:34.739884083Z", + "batchStats": { + "requestCount": "4", + "successfulRequestCount": "4" + }, + "state": "BATCH_STATE_SUCCEEDED", + "name": "batches/rsv6pu5hikanbevp6or89zaoo8bs2uw7vjav" + }, + "done": true, + "response": { + "@type": "type.googleapis.com/google.ai.generativelanguage.v1main.GenerateContentBatchOutput", + "responsesFile": "files/batch-rsv6pu5hikanbevp6or89zaoo8bs2uw7vjav" + } + }, + "results": [ + { + "status_code": 200, + "body": { + "candidates": [ + { + "index": 0, + "content": { + "role": "model", + "parts": [ + { + "text": "Des Moines" + } + ] + }, + "finishReason": "STOP" + } + ], + "usageMetadata": { + "candidatesTokenCount": 2, + "totalTokenCount": 52, + "promptTokenCount": 15, + "thoughtsTokenCount": 35, + "promptTokensDetails": [ + { + "tokenCount": 15, + "modality": "TEXT" + } + ] + }, + "responseId": "phiQadCDIJGsmtkP_6qC0Ac", + "modelVersion": "gemini-2.5-flash" + } + }, + { + "status_code": 200, + "body": { + "responseId": "phiQac-NJIfP_uMPgebx2Qk", + "modelVersion": "gemini-2.5-flash", + "candidates": [ + { + "index": 0, + "content": { + "role": "model", + "parts": [ + { + "text": "Albany" + } + ] + }, + "finishReason": "STOP" + } + ], + "usageMetadata": { + "thoughtsTokenCount": 43, + "candidatesTokenCount": 2, + "totalTokenCount": 61, + "promptTokenCount": 16, + "promptTokensDetails": [ + { + "tokenCount": 16, + "modality": "TEXT" + } + ] + } + } + }, + { + "status_code": 200, + "body": { + "candidates": [ + { + "index": 0, + "content": { + "role": "model", + "parts": [ + { + "text": "Sacramento" + } + ] + }, + "finishReason": "STOP" + } + ], + "usageMetadata": { + "promptTokensDetails": [ + { + "tokenCount": 15, + "modality": "TEXT" + } + ], + "promptTokenCount": 15, + "candidatesTokenCount": 2, + "totalTokenCount": 43, + "thoughtsTokenCount": 26 + }, + "responseId": "phiQaciRJPDQz7IPi-LjiQk", + "modelVersion": "gemini-2.5-flash" + } + }, + { + "status_code": 200, + "body": { + "candidates": [ + { + "index": 0, + "content": { + "role": "model", + "parts": [ + { + "text": "Austin" + } + ] + }, + "finishReason": "STOP" + } + ], + "usageMetadata": { + "candidatesTokenCount": 1, + "totalTokenCount": 49, + "promptTokenCount": 15, + "thoughtsTokenCount": 33, + "promptTokensDetails": [ + { + "tokenCount": 15, + "modality": "TEXT" + } + ] + }, + "responseId": "phiQaaa0G_ON_PUPzb73gA0", + "modelVersion": "gemini-2.5-flash" + } + } + ], + "started_at": 1771051110, + "hash": { + "provider": "aebac3192349dfc3afe04861751c4aa4", + "prompts": "b8eafe281e3cc5113058d9722be3e295", + "user_turns": "d6990a1b8a9f5db0e97de86c2669de44" + } +} diff --git a/tests/testthat/test-batch-chat.R b/tests/testthat/test-batch-chat.R index be8652957..5e5ad52bc 100644 --- a/tests/testthat/test-batch-chat.R +++ b/tests/testthat/test-batch-chat.R @@ -147,6 +147,32 @@ test_that("errors if wait = FALSE and not complete", { expect_equal(job$step_until_done(), NULL) }) +test_that("public batch helpers return NULL if wait = FALSE and not complete", { + local_mocked_bindings( + batch_submit = function(...) list(id = "123"), + batch_poll = function(...) list(id = "123", results = TRUE), + batch_status = function(...) list(working = TRUE) + ) + + chat <- chat_openai_test() + prompts <- list("What's your name?") + + expect_null(batch_chat(chat, prompts, withr::local_tempfile(), wait = FALSE)) + expect_null(batch_chat_text( + chat, + prompts, + withr::local_tempfile(), + wait = FALSE + )) + expect_null(batch_chat_structured( + chat, + prompts, + withr::local_tempfile(), + type = type_object(name = type_string()), + wait = FALSE + )) +}) + test_that("informative error for bad inputs", { chat_openai <- chat_openai_test() chat_ollama <- chat_openai_test() diff --git a/tests/testthat/test-provider-google-batch.R b/tests/testthat/test-provider-google-batch.R new file mode 100644 index 000000000..48b2c24d4 --- /dev/null +++ b/tests/testthat/test-provider-google-batch.R @@ -0,0 +1,405 @@ +# Gemini batch helper functions ------------------------------------------- + +test_that("gemini_extract_index extracts from metadata.request_index", { + x <- list(metadata = list(request_index = 5L)) + expect_equal(gemini_extract_index(x), 5L) +}) + +test_that("gemini_extract_index extracts from custom_id key", { + x <- list(custom_id = "chat-3") + expect_equal(gemini_extract_index(x), 3L) +}) + +test_that("gemini_extract_index extracts from key field", { + x <- list(key = "chat-7") + expect_equal(gemini_extract_index(x), 7L) +}) + +test_that("gemini_extract_index extracts from metadata key field", { + x <- list(metadata = list(key = "chat-8")) + expect_equal(gemini_extract_index(x), 8L) +}) + +test_that("gemini_extract_index returns default when no index found", { + x <- list(foo = "bar") + expect_equal(gemini_extract_index(x, default = 99L), 99L) +}) + +test_that("gemini_normalize_result handles plain GenerateContentResponse", { + x <- list( + candidates = list(list(content = list(parts = list(list(text = "hello"))))), + usageMetadata = list(totalTokenCount = 10L) + ) + result <- gemini_normalize_result(x, index_default = 1L) + + expect_equal(result$index, 1L) + expect_equal(result$result$status_code, 200L) + expect_equal(result$result$body, x) +}) + +test_that("gemini_normalize_result handles wrapped response", { + x <- list( + metadata = list(request_index = 2L), + response = list(candidates = list()) + ) + result <- gemini_normalize_result(x, index_default = 99L) + + expect_equal(result$index, 2L) + expect_equal(result$result$status_code, 200L) + expect_equal(result$result$body, list(candidates = list())) +}) + +test_that("gemini_normalize_result handles error response", { + x <- list( + metadata = list(request_index = 3L), + error = list(code = 400L, message = "bad request") + ) + result <- gemini_normalize_result(x, index_default = 99L) + + expect_equal(result$index, 3L) + expect_equal(result$result$status_code, 400L) + expect_null(result$result$body) +}) + +test_that("gemini_normalize_result handles unknown format", { + x <- list(unknown_field = "value") + result <- gemini_normalize_result(x, index_default = 5L) + + expect_equal(result$index, 5L) + expect_equal(result$result$status_code, 500L) + expect_null(result$result$body) +}) + +# gemini_prepare_batch_body ----------------------------------------------- + +test_that("gemini_prepare_batch_body converts API keys to snake_case", { + body <- list( + generationConfig = list(responseMimeType = "text/plain"), + contents = list(list(role = "user", parts = list(list(text = "hi")))) + ) + + result <- gemini_prepare_batch_body(body) + + expect_true("generation_config" %in% names(result)) + expect_null(result$generationConfig) + expect_true("response_mime_type" %in% names(result$generation_config)) +}) + +test_that("gemini_prepare_batch_body preserves schema property names", { + body <- list( + generationConfig = list( + responseMimeType = "application/json", + responseSchema = list( + type = "object", + properties = list( + firstName = list(type = "string"), + lastName = list(type = "string") + ), + required = list("firstName", "lastName") + ) + ), + contents = list(list(role = "user", parts = list(list(text = "hi")))) + ) + result <- gemini_prepare_batch_body(body) + + schema <- result$generation_config$response_json_schema + expect_false(is.null(schema)) + expect_true("firstName" %in% names(schema$properties)) + expect_true("lastName" %in% names(schema$properties)) + expect_equal(schema$required, list("firstName", "lastName")) + expect_null(result$generation_config$response_schema) +}) + +test_that("gemini_prepare_batch_body strips empty system instruction", { + body <- list( + systemInstruction = list(parts = list(text = "")), + contents = list(list(role = "user", parts = list(list(text = "hi")))) + ) + result <- gemini_prepare_batch_body(body) + + expect_null(result$system_instruction) + expect_null(result$systemInstruction) +}) + +test_that("gemini_prepare_batch_body keeps non-empty system instruction", { + body <- list( + systemInstruction = list(parts = list(text = "You are helpful.")), + contents = list(list(role = "user", parts = list(list(text = "hi")))) + ) + result <- gemini_prepare_batch_body(body) + + expect_false(is.null(result$system_instruction)) + expect_equal(result$system_instruction$parts$text, "You are helpful.") +}) + +# Batch support ----------------------------------------------------------- + +# Helper to create a dummy provider without needing real credentials +dummy_gemini_provider <- function( + base_url = "https://generativelanguage.googleapis.com/v1beta/" +) { + ProviderGoogleGemini( + name = if (grepl("aiplatform", base_url)) { + "Google/Vertex" + } else { + "Google/Gemini" + }, + base_url = base_url, + model = "gemini-2.5-flash", + params = params(), + extra_args = list(), + extra_headers = character(), + credentials = NULL + ) +} + +test_that("ProviderGoogleGemini has batch support", { + provider <- dummy_gemini_provider() + expect_true(has_batch_support(provider)) +}) + +test_that("Vertex provider also has batch support", { + provider <- dummy_gemini_provider( + base_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test/locations/us-central1/publishers/google/" + ) + expect_true(has_batch_support(provider)) +}) + +test_that("batch_status keeps working when succeeded but no responsesFile", { + provider <- dummy_gemini_provider() + batch <- list( + metadata = list( + state = "BATCH_STATE_SUCCEEDED", + batchStats = list(requestCount = 2L, successfulRequestCount = 2L) + ) + ) + status <- batch_status(provider, batch) + expect_true(status$working) +}) + +test_that("batch_status marks done when succeeded with responsesFile", { + provider <- dummy_gemini_provider() + batch <- list( + metadata = list( + state = "BATCH_STATE_SUCCEEDED", + batchStats = list(requestCount = 2L, successfulRequestCount = 2L) + ), + response = list(responsesFile = "files/abc123") + ) + status <- batch_status(provider, batch) + expect_false(status$working) +}) + +test_that("batch_retrieve reorders out-of-order Gemini results by key", { + provider <- dummy_gemini_provider() + batch <- list( + metadata = list(batchStats = list(requestCount = 3L)), + response = list(responsesFile = "files/abc123") + ) + + local_mocked_bindings( + gemini_download_file = function(provider, name, path) { + lines <- c( + jsonlite::toJSON( + list( + key = "chat-3", + response = list( + responseId = "third", + candidates = list(list( + content = list(parts = list(list(text = "{}"))) + )), + usageMetadata = list(totalTokenCount = 3L) + ) + ), + auto_unbox = TRUE + ), + jsonlite::toJSON( + list( + key = "chat-1", + response = list( + responseId = "first", + candidates = list(list( + content = list(parts = list(list(text = "{}"))) + )), + usageMetadata = list(totalTokenCount = 1L) + ) + ), + auto_unbox = TRUE + ), + jsonlite::toJSON( + list( + key = "chat-2", + response = list( + responseId = "second", + candidates = list(list( + content = list(parts = list(list(text = "{}"))) + )), + usageMetadata = list(totalTokenCount = 2L) + ) + ), + auto_unbox = TRUE + ) + ) + writeLines(lines, path) + invisible(path) + } + ) + + results <- batch_retrieve(provider, batch) + + expect_equal( + vapply(results, \(x) x$body$responseId, character(1)), + c( + "first", + "second", + "third" + ) + ) +}) + +# Fixture-based tests ---------------------------------------------------- + +test_that("batch chat works with Gemini fixture", { + withr::local_envvar(GEMINI_API_KEY = "dummy-key-for-fixture-test") + chat <- chat_google_gemini( + system_prompt = "Answer with just the city name", + model = "gemini-2.5-flash", + params = params(temperature = 0, seed = 1014) + ) + + prompts <- list( + "What's the capital of Iowa?", + "What's the capital of New York?", + "What's the capital of California?", + "What's the capital of Texas?" + ) + + out <- batch_chat_text( + chat, + prompts, + path = test_path("batch/state-capitals-gemini.json"), + ignore_hash = TRUE + ) + expect_equal(out, c("Des Moines", "Albany", "Sacramento", "Austin")) +}) + +# Integration tests ------------------------------------------------------- + +test_that("Gemini batch_chat submits and can be resumed", { + skip_if( + Sys.getenv("GEMINI_API_KEY") == "" && Sys.getenv("GOOGLE_API_KEY") == "", + "No Gemini credentials set" + ) + + chat <- chat_google_gemini_test() + + prompts <- list("Reply with exactly: ok") + results_file <- withr::local_tempfile(fileext = ".json") + + chats <- tryCatch( + batch_chat( + chat, + prompts = prompts, + path = results_file, + wait = FALSE + ), + error = function(e) { + msg <- conditionMessage(e) + if (grepl("unexpected number of responses", msg, fixed = TRUE)) { + NULL + } else { + stop(e) + } + } + ) + + if (is.null(chats)) { + completed <- FALSE + for (i in seq_len(100)) { + Sys.sleep(10) + completed <- isTRUE(batch_chat_completed(chat, prompts, results_file)) + if (completed) break + } + + if (!completed) { + skip("Gemini batch did not complete within test timeout.") + } + + chats <- batch_chat( + chat, + prompts = prompts, + path = results_file, + wait = TRUE + ) + } + + expect_equal(length(chats), 1) + expect_true(inherits(chats[[1]], "Chat")) +}) + +test_that("Gemini batch_chat_structured works", { + skip_if( + Sys.getenv("GEMINI_API_KEY") == "" && Sys.getenv("GOOGLE_API_KEY") == "", + "No Gemini credentials set" + ) + + chat <- chat_google_gemini_test() + + type_answer <- type_object( + answer = type_string() + ) + + prompts <- list("What is 2+2? Reply with just the number.") + results_file <- withr::local_tempfile(fileext = ".json") + + result <- tryCatch( + batch_chat_structured( + chat, + prompts = prompts, + path = results_file, + type = type_answer, + wait = FALSE + ), + error = function(e) { + msg <- conditionMessage(e) + if (grepl("unexpected number of responses", msg, fixed = TRUE)) { + NULL + } else if ( + grepl( + "HTTP 40[04]|invalid argument|not found|not supported", + msg, + ignore.case = TRUE + ) + ) { + skip(paste0("Gemini batch API rejected request: ", msg)) + } else { + stop(e) + } + } + ) + + if (is.null(result)) { + completed <- FALSE + for (i in seq_len(12)) { + Sys.sleep(10) + completed <- isTRUE(batch_chat_completed(chat, prompts, results_file)) + if (completed) break + } + + if (!completed) { + skip("Gemini batch did not complete within test timeout.") + } + + result <- batch_chat_structured( + chat, + prompts = prompts, + path = results_file, + type = type_answer, + wait = TRUE + ) + } + + expect_true(is.data.frame(result)) + expect_equal(nrow(result), 1) + expect_true("answer" %in% names(result)) +})