-
Notifications
You must be signed in to change notification settings - Fork 129
Add Gemini batch processing support #926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
3b4ce95
b56f595
e2ca0b9
d21fd55
bfd6086
a82ee02
fa91d8b
48a0678
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -872,3 +872,269 @@ models_google <- function( | |
| google_location <- function(location) { | ||
| if (location == "global") "" else paste0(location, "-") | ||
| } | ||
|
|
||
| # Batched requests ------------------------------------------------------------- | ||
|
|
||
| # https://ai.google.dev/gemini-api/docs/batch-api | ||
| method(has_batch_support, ProviderGoogleGemini) <- function(provider) { | ||
| TRUE | ||
| } | ||
|
|
||
| method(batch_submit, ProviderGoogleGemini) <- function( | ||
| provider, | ||
| conversations, | ||
| type = NULL | ||
| ) { | ||
| path <- withr::local_tempfile(fileext = ".jsonl") | ||
|
|
||
| requests <- map(seq_along(conversations), function(i) { | ||
| body <- chat_body( | ||
| provider, | ||
| stream = FALSE, | ||
| turns = conversations[[i]], | ||
| type = type | ||
| ) | ||
|
|
||
| list( | ||
| key = paste0("chat-", i), | ||
| request = gemini_prepare_batch_body(body) | ||
| ) | ||
| }) | ||
|
|
||
| json_lines <- map_chr(requests, to_json) | ||
| writeLines(json_lines, path) | ||
|
|
||
| uploaded <- gemini_upload_file(provider, path) | ||
| if (is.null(uploaded$name) || !nzchar(uploaded$name)) { | ||
| cli::cli_abort( | ||
| "Gemini upload did not return a file resource name.", | ||
| .internal = TRUE | ||
| ) | ||
| } | ||
|
|
||
| req <- base_request(provider) | ||
| req <- req_url_path_append( | ||
| req, | ||
| "models", | ||
| paste0(provider@model, ":batchGenerateContent") | ||
| ) | ||
| req <- req_body_json( | ||
| req, | ||
| list( | ||
| batch = list( | ||
| displayName = paste0("ellmer-", as.integer(Sys.time())), | ||
| model = paste0("models/", provider@model), | ||
| inputConfig = list(fileName = uploaded$name) | ||
| ) | ||
| ) | ||
| ) | ||
|
|
||
| resp <- req_perform(req) | ||
| resp_body_json(resp) | ||
| } | ||
|
|
||
| method(batch_poll, ProviderGoogleGemini) <- function(provider, batch) { | ||
| req <- base_request(provider) | ||
| req <- req_url_path_append(req, batch$name) | ||
| resp <- req_perform(req) | ||
| resp_body_json(resp) | ||
| } | ||
|
|
||
| method(batch_status, ProviderGoogleGemini) <- function(provider, batch) { | ||
| metadata <- batch$metadata %||% list() | ||
| response <- batch$response %||% list() | ||
| state <- metadata$state %||% response$state %||% "BATCH_STATE_UNSPECIFIED" | ||
| stats <- metadata$batchStats %||% response$batchStats %||% list() | ||
|
|
||
| total <- as.integer(stats$requestCount %||% 0L) | ||
| pending <- as.integer(stats$pendingRequestCount %||% 0L) | ||
| succeeded <- as.integer(stats$successfulRequestCount %||% 0L) | ||
| failed <- as.integer(stats$failedRequestCount %||% 0L) | ||
|
|
||
| if (!is.null(batch$error) && total > 0 && failed == 0L) { | ||
| failed <- total | ||
| } | ||
|
|
||
| terminal_states <- c( | ||
| "BATCH_STATE_SUCCEEDED", | ||
| "BATCH_STATE_FAILED", | ||
| "BATCH_STATE_CANCELLED", | ||
| "BATCH_STATE_EXPIRED" | ||
| ) | ||
|
|
||
| is_done <- state %in% terminal_states | ||
|
|
||
| # Keep polling if succeeded but output file isn't available yet. | ||
| # The API can report BATCH_STATE_SUCCEEDED before the responsesFile | ||
| # metadata is populated. | ||
| if (state == "BATCH_STATE_SUCCEEDED") { | ||
| responses_file <- batch$response$responsesFile %||% "" | ||
| if (!nzchar(responses_file)) { | ||
| is_done <- FALSE | ||
| } | ||
| } | ||
|
|
||
| n_processing <- max(pending, total - succeeded - failed, 0L) | ||
|
|
||
| list( | ||
| working = !is_done, | ||
| n_processing = n_processing, | ||
| n_succeeded = max(succeeded, 0L), | ||
| n_failed = max(failed, 0L) | ||
| ) | ||
| } | ||
|
|
||
| method(batch_retrieve, ProviderGoogleGemini) <- function(provider, batch) { | ||
| metadata <- batch$metadata %||% list() | ||
| response <- batch$response %||% list() | ||
| stats <- metadata$batchStats %||% response$batchStats %||% list() | ||
| request_count <- as.integer(stats$requestCount %||% 0L) | ||
|
|
||
| if (!is.null(batch$error)) { | ||
| code <- as.integer(batch$error$code %||% 500L) | ||
| if (request_count <= 0L) { | ||
| return(list(list(status_code = code, body = NULL))) | ||
| } | ||
| return(rep(list(list(status_code = code, body = NULL)), request_count)) | ||
| } | ||
|
|
||
| responses_file <- batch$response$responsesFile %||% "" | ||
|
|
||
| if (!nzchar(responses_file)) { | ||
| cli::cli_abort("Gemini batch completed but no output file was returned.") | ||
| } | ||
|
|
||
| path_output <- withr::local_tempfile(fileext = ".jsonl") | ||
| gemini_download_file(provider, responses_file, path_output) | ||
|
|
||
| parsed <- read_ndjson(path_output) | ||
|
|
||
| normalized <- imap(parsed, function(x, i) { | ||
| gemini_normalize_result(x, index_default = as.integer(i)) | ||
| }) | ||
|
|
||
| ids <- vapply(normalized, function(x) x$index, integer(1)) | ||
| results <- lapply(normalized, function(x) x$result) | ||
| results[order(ids)] | ||
| } | ||
|
|
||
| method(batch_result_turn, ProviderGoogleGemini) <- function( | ||
| provider, | ||
| result, | ||
| has_type = FALSE | ||
| ) { | ||
| if (!is.null(result) && result$status_code == 200L && !is.null(result$body)) { | ||
| value_turn(provider, result$body, has_type = has_type) | ||
| } else { | ||
| NULL | ||
| } | ||
| } | ||
|
|
||
| # Gemini batch helpers --------------------------------------------------------- | ||
|
|
||
| # The Gemini REST API accepts both camelCase and snake_case, but the batch | ||
| # JSONL file parser requires protobuf field names which are always snake_case. | ||
| # Without this conversion, the batch parser silently ignores camelCase fields. | ||
| gemini_to_snake_case <- function(x) { | ||
| if (is.list(x)) { | ||
| if (!is.null(names(x))) { | ||
| names(x) <- gsub("([a-z])([A-Z])", "\\1_\\2", names(x), perl = TRUE) |> | ||
| tolower() | ||
| } | ||
| lapply(x, gemini_to_snake_case) | ||
| } else { | ||
| x | ||
| } | ||
| } | ||
|
|
||
| gemini_prepare_batch_body <- function(body) { | ||
| # Remove empty system instructions (batch parser rejects them) | ||
| si <- body$systemInstruction %||% body$system_instruction | ||
| if (!is.null(si)) { | ||
| parts <- si$parts | ||
| is_empty <- if (is.list(parts) && !is.null(names(parts))) { | ||
| identical(parts$text, "") || is.null(parts$text) | ||
| } else if (is.list(parts) && length(parts) > 0) { | ||
| all(vapply( | ||
| parts, | ||
| function(p) identical(p$text, "") || is.null(p$text), | ||
| logical(1) | ||
| )) | ||
| } else { | ||
| TRUE | ||
| } | ||
| if (is_empty) { | ||
| body$systemInstruction <- NULL | ||
| body$system_instruction <- NULL | ||
| } | ||
| } | ||
|
|
||
| # Save user-defined schema before snake_case conversion so property names | ||
| # like "firstName" are not mangled to "first_name" | ||
| gc_pre <- body$generationConfig %||% body$generation_config | ||
| saved_schema <- if (!is.null(gc_pre)) { | ||
| gc_pre$responseSchema %||% gc_pre$response_schema | ||
| } | ||
|
|
||
| body <- gemini_to_snake_case(body) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure this is necessary? Google APIs often seem to take both snake and camel case.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude tested with camelCase and the batch JSONL parser silently ignored the fields — it seems to require protobuf-style snake_case names, unlike the REST API which accepts both. The batch API docs also use snake_case in all their JSONL examples. So this seems to be necessary.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for checking! Can you please add a brief summary as a comment?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a comment. Wrote another full test script to double-check the camelCase/snake_case problem, and it actually errors with HTTP 400 when it encounters camelCase (not silent), so have changed the other comment earlier in the file as well. |
||
|
|
||
| # Rename response_schema -> response_json_schema and restore original schema | ||
| gc <- body$generation_config | ||
| if ( | ||
| !is.null(gc) && (!is.null(gc$response_schema) || !is.null(saved_schema)) | ||
| ) { | ||
| gc$response_json_schema <- saved_schema %||% gc$response_schema | ||
| gc$response_schema <- NULL | ||
| body$generation_config <- gc | ||
| } | ||
|
|
||
| body | ||
| } | ||
|
|
||
| gemini_extract_index <- function(x, default = NA_integer_) { | ||
| metadata <- x$metadata %||% list() | ||
| idx <- metadata$request_index %||% metadata$index %||% default | ||
|
|
||
| if (!is.na(idx)) { | ||
| return(as.integer(idx)) | ||
| } | ||
|
|
||
| key <- x$key %||% x$custom_id %||% metadata$custom_id %||% "" | ||
| if (grepl("^chat-[0-9]+$", key)) { | ||
| return(as.integer(sub("^chat-([0-9]+)$", "\\1", key))) | ||
| } | ||
|
|
||
| as.integer(default) | ||
| } | ||
|
|
||
| gemini_normalize_result <- function(x, index_default) { | ||
| index <- gemini_extract_index(x, default = index_default) | ||
|
|
||
| # Formats where response and error/status are wrapped in one object | ||
| if (!is.null(x$response) || !is.null(x$error) || !is.null(x$status)) { | ||
| if (!is.null(x$response) && is.null(x$error) && is.null(x$status)) { | ||
| return(list( | ||
| index = index, | ||
| result = list(status_code = 200L, body = x$response) | ||
| )) | ||
| } | ||
|
|
||
| status <- x$error %||% x$status %||% list() | ||
| code <- status$code %||% 500L | ||
| return(list( | ||
| index = index, | ||
| result = list(status_code = as.integer(code), body = NULL) | ||
| )) | ||
| } | ||
|
|
||
| # Plain GenerateContentResponse lines (current file-mode output) | ||
| if ( | ||
| !is.null(x$candidates) || | ||
| !is.null(x$promptFeedback) || | ||
| !is.null(x$usageMetadata) | ||
| ) { | ||
| return(list(index = index, result = list(status_code = 200L, body = x))) | ||
| } | ||
|
|
||
| list(index = index, result = list(status_code = 500L, body = NULL)) | ||
| } | ||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you could reduce the duplication here by making this
google_upload_file(), thengoogle_upload()could callgoogle_upload_file()then create theContentUploadedobject.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done now - extracted
google_upload_file()