-
Notifications
You must be signed in to change notification settings - Fork 129
Add Gemini batch processing support #926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
3b4ce95
b56f595
e2ca0b9
d21fd55
bfd6086
a82ee02
fa91d8b
48a0678
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -872,3 +872,370 @@ models_google <- function( | |
| google_location <- function(location) { | ||
| if (location == "global") "" else paste0(location, "-") | ||
| } | ||
|
|
||
| # Batched requests ------------------------------------------------------------- | ||
|
|
||
| # https://ai.google.dev/gemini-api/docs/batch | ||
| method(has_batch_support, ProviderGoogleGemini) <- function(provider) { | ||
| grepl("generativelanguage.googleapis.com", provider@base_url, fixed = TRUE) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this could probably just be
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem is that both
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Have set it to
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it's probably fine. Thanks!
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After reviewing this PR in detail, I'm wondering if it might be worth restoring something in here to only return |
||
| } | ||
|
|
||
| method(batch_submit, ProviderGoogleGemini) <- function( | ||
| provider, | ||
| conversations, | ||
| type = NULL | ||
| ) { | ||
| path <- withr::local_tempfile(fileext = ".jsonl") | ||
|
|
||
| requests <- map(seq_along(conversations), function(i) { | ||
| body <- chat_body( | ||
| provider, | ||
| stream = FALSE, | ||
| turns = conversations[[i]], | ||
| type = type | ||
| ) | ||
|
|
||
| list( | ||
| key = paste0("chat-", i), | ||
| request = gemini_prepare_batch_body(body) | ||
| ) | ||
| }) | ||
|
|
||
| json_lines <- map_chr(requests, to_json) | ||
| writeLines(json_lines, path) | ||
|
|
||
| uploaded <- gemini_upload_file(provider, path) | ||
| if (is.null(uploaded$name) || !nzchar(uploaded$name)) { | ||
| cli::cli_abort("Gemini upload did not return a file resource name.") | ||
|
hadley marked this conversation as resolved.
Outdated
|
||
| } | ||
|
|
||
| req <- base_request(provider) | ||
| req <- req_url_path_append( | ||
| req, | ||
| "models", | ||
| paste0(provider@model, ":batchGenerateContent") | ||
| ) | ||
| req <- req_body_json( | ||
| req, | ||
| list( | ||
| batch = list( | ||
| displayName = paste0("ellmer-", as.integer(Sys.time())), | ||
| model = paste0("models/", provider@model), | ||
| inputConfig = list(fileName = uploaded$name) | ||
| ) | ||
| ) | ||
| ) | ||
|
|
||
| resp <- req_perform(req) | ||
| resp_body_json(resp) | ||
| } | ||
|
|
||
| method(batch_poll, ProviderGoogleGemini) <- function(provider, batch) { | ||
| req <- base_request(provider) | ||
| req <- req_url_path_append(req, batch$name) | ||
| resp <- req_perform(req) | ||
| resp_body_json(resp) | ||
| } | ||
|
|
||
| method(batch_status, ProviderGoogleGemini) <- function(provider, batch) { | ||
| metadata <- batch$metadata %||% list() | ||
| response <- batch$response %||% list() | ||
| state <- metadata$state %||% response$state %||% "BATCH_STATE_UNSPECIFIED" | ||
| stats <- metadata$batchStats %||% response$batchStats %||% list() | ||
|
Comment on lines
+944
to
+947
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure that we need to check |
||
|
|
||
| total <- as.integer(stats$requestCount %||% 0L) | ||
| pending <- as.integer(stats$pendingRequestCount %||% 0L) | ||
| succeeded <- as.integer(stats$successfulRequestCount %||% 0L) | ||
| failed <- as.integer(stats$failedRequestCount %||% 0L) | ||
|
|
||
| if (!is.null(batch$error) && total > 0 && failed == 0L) { | ||
| failed <- total | ||
| } | ||
|
|
||
| terminal_states <- c( | ||
| "BATCH_STATE_SUCCEEDED", | ||
| "BATCH_STATE_FAILED", | ||
| "BATCH_STATE_CANCELLED", | ||
| "BATCH_STATE_EXPIRED" | ||
| ) | ||
|
|
||
| is_terminal <- state %in% terminal_states | ||
|
hadley marked this conversation as resolved.
Outdated
|
||
|
|
||
| # Keep polling if succeeded but output file isn't available yet | ||
| if (state == "BATCH_STATE_SUCCEEDED") { | ||
| batch_resource <- batch$response %||% batch$metadata | ||
| responses_file <- batch_resource$output$responsesFile %||% | ||
| batch_resource$responsesFile %||% | ||
| metadata$output$responsesFile %||% | ||
| NULL | ||
| if (is.null(responses_file) || !nzchar(responses_file)) { | ||
|
hadley marked this conversation as resolved.
Outdated
|
||
| is_terminal <- FALSE | ||
| } | ||
| } | ||
|
|
||
| n_processing <- max(pending, total - succeeded - failed, 0L) | ||
|
|
||
| list( | ||
| working = !is_terminal, | ||
| n_processing = n_processing, | ||
| n_succeeded = max(succeeded, 0L), | ||
| n_failed = max(failed, 0L) | ||
| ) | ||
| } | ||
|
|
||
| method(batch_retrieve, ProviderGoogleGemini) <- function(provider, batch) { | ||
| metadata <- batch$metadata %||% list() | ||
| response <- batch$response %||% list() | ||
| stats <- metadata$batchStats %||% response$batchStats %||% list() | ||
| request_count <- as.integer(stats$requestCount %||% 0L) | ||
|
|
||
| if (!is.null(batch$error)) { | ||
| code <- as.integer(batch$error$code %||% 500L) | ||
| if (request_count <= 0L) { | ||
| return(list(list(status_code = code, body = NULL))) | ||
| } | ||
| return(replicate( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should use
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you could eliminate the branch above with But it might be safer to just error if
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Used rep now. But I left the rest as is;
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops, yes I meant
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done now, using max(0L, request_count) |
||
| request_count, | ||
| list(status_code = code, body = NULL), | ||
| simplify = FALSE | ||
| )) | ||
| } | ||
|
|
||
| batch_resource <- batch$response %||% batch$metadata | ||
| responses_file <- batch_resource$output$responsesFile %||% | ||
|
hadley marked this conversation as resolved.
Outdated
|
||
| batch_resource$responsesFile %||% | ||
| metadata$output$responsesFile %||% | ||
| NULL | ||
|
|
||
| if (is.null(responses_file) || !nzchar(responses_file)) { | ||
| cli::cli_abort("Gemini batch completed but no output file was returned.") | ||
| } | ||
|
|
||
| path_output <- withr::local_tempfile(fileext = ".jsonl") | ||
| gemini_download_file(provider, responses_file, path_output) | ||
|
|
||
| parsed <- read_ndjson(path_output, fallback = gemini_json_fallback) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why the fallback here? Claude might have copied from OpenAI which seems to be flaky. Do you have evidence that gemini is similarly problematic?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed the fallback. I tested with |
||
|
|
||
| normalized <- imap(parsed, function(x, i) { | ||
| gemini_normalize_result(x, index_default = as.integer(i)) | ||
| }) | ||
|
|
||
| ids <- vapply(normalized, function(x) x$index, integer(1)) | ||
| results <- lapply(normalized, function(x) x$result) | ||
| results[order(ids)] | ||
| } | ||
|
|
||
| method(batch_result_turn, ProviderGoogleGemini) <- function( | ||
| provider, | ||
| result, | ||
| has_type = FALSE | ||
| ) { | ||
| if (!is.null(result) && result$status_code == 200L && !is.null(result$body)) { | ||
| value_turn(provider, result$body, has_type = has_type) | ||
| } else { | ||
| NULL | ||
| } | ||
| } | ||
|
|
||
| # Gemini batch helpers --------------------------------------------------------- | ||
|
|
||
| #' @noRd | ||
| gemini_to_snake_case <- function(x) { | ||
| if (is.list(x)) { | ||
| if (!is.null(names(x))) { | ||
| names(x) <- gsub("([a-z])([A-Z])", "\\1_\\2", names(x), perl = TRUE) |> | ||
| tolower() | ||
| } | ||
| lapply(x, gemini_to_snake_case) | ||
| } else { | ||
| x | ||
| } | ||
| } | ||
|
|
||
| #' @noRd | ||
| gemini_prepare_batch_body <- function(body) { | ||
| # Remove empty system instructions (batch parser rejects them) | ||
| si <- body$systemInstruction %||% body$system_instruction | ||
| if (!is.null(si)) { | ||
| parts <- si$parts | ||
| is_empty <- if (is.list(parts) && !is.null(names(parts))) { | ||
| identical(parts$text, "") || is.null(parts$text) | ||
| } else if (is.list(parts) && length(parts) > 0) { | ||
| all(vapply( | ||
| parts, | ||
| function(p) identical(p$text, "") || is.null(p$text), | ||
| logical(1) | ||
| )) | ||
| } else { | ||
| TRUE | ||
| } | ||
|
Comment on lines
+1055
to
+1065
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| if (is_empty) { | ||
| body$systemInstruction <- NULL | ||
| body$system_instruction <- NULL | ||
| } | ||
| } | ||
|
|
||
| # Save user-defined schema before snake_case conversion so property names | ||
| # like "firstName" are not mangled to "first_name" | ||
| gc_pre <- body$generationConfig %||% body$generation_config | ||
| saved_schema <- if (!is.null(gc_pre)) { | ||
| gc_pre$responseSchema %||% gc_pre$response_schema | ||
| } | ||
|
|
||
| body <- gemini_to_snake_case(body) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure this is necessary? Google APIs often seem to take both snake and camel case.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude tested with camelCase and the batch JSONL parser silently ignored the fields — it seems to require protobuf-style snake_case names, unlike the REST API which accepts both. The batch API docs also use snake_case in all their JSONL examples. So this seems to be necessary.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for checking! Can you please add a brief summary as a comment?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a comment. Wrote another full test script to double-check the camelCase/snake_case problem, and it actually errors with HTTP 400 when it encounters camelCase (not silent), so have changed the other comment earlier in the file as well. |
||
|
|
||
| # Rename response_schema -> response_json_schema and restore original schema | ||
| gc <- body$generation_config | ||
| if ( | ||
| !is.null(gc) && (!is.null(gc$response_schema) || !is.null(saved_schema)) | ||
| ) { | ||
| gc$response_json_schema <- saved_schema %||% gc$response_schema | ||
| gc$response_schema <- NULL | ||
| body$generation_config <- gc | ||
| } | ||
|
|
||
| body | ||
| } | ||
|
|
||
| #' @noRd | ||
|
hadley marked this conversation as resolved.
Outdated
|
||
| gemini_upload_file <- function( | ||
| provider, | ||
| path, | ||
| mime_type = "application/jsonl" | ||
| ) { | ||
| upload_base_url <- sub("/v[^/]+/?$", "/", provider@base_url) | ||
|
|
||
| upload_url <- google_upload_init( | ||
| path = path, | ||
| base_url = upload_base_url, | ||
| credentials = provider@credentials, | ||
| mime_type = mime_type | ||
| ) | ||
|
|
||
| status <- google_upload_send( | ||
| upload_url = upload_url, | ||
| path = path, | ||
| credentials = provider@credentials | ||
| ) | ||
| google_upload_wait(status, provider@credentials) | ||
| status | ||
| } | ||
|
|
||
| #' @noRd | ||
| gemini_download_file <- function(provider, name, path) { | ||
| req <- base_request(provider) | ||
| req <- req_url_path_append(req, paste0(name, ":download")) | ||
| req <- req_url_query(req, alt = "media") | ||
| req_perform(req, path = path) | ||
| invisible(path) | ||
| } | ||
|
|
||
| #' @noRd | ||
| gemini_extract_index <- function(x, default = NA_integer_) { | ||
| metadata <- x$metadata %||% list() | ||
| idx <- metadata$request_index %||% metadata$index %||% default | ||
|
|
||
| if (!is.na(idx)) { | ||
| return(as.integer(idx)) | ||
| } | ||
|
|
||
| key <- x$key %||% x$custom_id %||% metadata$custom_id %||% "" | ||
| if (grepl("^chat-[0-9]+$", key)) { | ||
| return(as.integer(sub("^chat-([0-9]+)$", "\\1", key))) | ||
| } | ||
|
|
||
| as.integer(default) | ||
| } | ||
|
Comment on lines
+1096
to
+1114
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tested the batch API live across The |
||
|
|
||
| #' @noRd | ||
| gemini_json_fallback <- function(line) { | ||
| index <- suppressWarnings( | ||
| as.integer(sub( | ||
| '.*"request_index"\\s*:\\s*([0-9]+).*', | ||
| "\\1", | ||
| line, | ||
| perl = TRUE | ||
| )) | ||
| ) | ||
|
|
||
| if (length(index) == 0L || is.na(index)) { | ||
| custom_id <- tryCatch( | ||
| { | ||
| m <- regmatches( | ||
| line, | ||
| regexpr('"custom_id"\\s*:\\s*"chat-[0-9]+"', line, perl = TRUE) | ||
| ) | ||
| if (length(m) == 0L) { | ||
| NA_character_ | ||
| } else { | ||
| sub('.*"chat-([0-9]+)".*', "\\1", m) | ||
| } | ||
| }, | ||
| error = function(e) NA_character_ | ||
| ) | ||
| index <- suppressWarnings(as.integer(custom_id)) | ||
| } | ||
|
|
||
| if (length(index) == 0L || is.na(index)) { | ||
| key_match <- tryCatch( | ||
| { | ||
| m <- regmatches( | ||
| line, | ||
| regexpr('"key"\\s*:\\s*"chat-[0-9]+"', line, perl = TRUE) | ||
| ) | ||
| if (length(m) == 0L) { | ||
| NA_character_ | ||
| } else { | ||
| sub('.*"chat-([0-9]+)".*', "\\1", m) | ||
| } | ||
| }, | ||
| error = function(e) NA_character_ | ||
| ) | ||
| index <- suppressWarnings(as.integer(key_match)) | ||
| } | ||
|
|
||
| list( | ||
| metadata = if (length(index) == 0L || is.na(index)) { | ||
| list() | ||
| } else { | ||
| list(request_index = index) | ||
| }, | ||
| status = list( | ||
| code = 500L, | ||
| message = "Failed to parse Gemini batch output line" | ||
| ) | ||
| ) | ||
| } | ||
|
|
||
| #' @noRd | ||
| gemini_normalize_result <- function(x, index_default) { | ||
| index <- gemini_extract_index(x, default = index_default) | ||
|
|
||
| # Formats where response and error/status are wrapped in one object | ||
| if (!is.null(x$response) || !is.null(x$error) || !is.null(x$status)) { | ||
| if (!is.null(x$response) && is.null(x$error) && is.null(x$status)) { | ||
| return(list( | ||
| index = index, | ||
| result = list(status_code = 200L, body = x$response) | ||
| )) | ||
| } | ||
|
|
||
| status <- x$error %||% x$status %||% list() | ||
| code <- status$code %||% 500L | ||
| return(list( | ||
| index = index, | ||
| result = list(status_code = as.integer(code), body = NULL) | ||
| )) | ||
|
Comment on lines
+1120
to
+1133
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm finding this a little hard to follow on a first read. I'm wondering if we could simplify it with sequential early returns Something like: if (!is.null(x$response) && is.null(x$error) && is.null(x$status)) {
return(list(index = index, result = list(status_code = 200L, body = x$response)))
}
if (!is.null(x$error) || !is.null(x$status)) {
code <- (x$error %||% x$status %||% list())$code %||% 500L
return(list(index = index, result = list(status_code = as.integer(code), body = NULL)))
} |
||
| } | ||
|
|
||
| # Plain GenerateContentResponse lines (current file-mode output) | ||
| if ( | ||
| !is.null(x$candidates) || | ||
| !is.null(x$promptFeedback) || | ||
| !is.null(x$usageMetadata) | ||
| ) { | ||
| return(list(index = index, result = list(status_code = 200L, body = x))) | ||
| } | ||
|
Comment on lines
+1136
to
+1143
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just wanted to check when this branch would get hit as I don't think this quite matches the developer API, though could be wrong |
||
|
|
||
| list(index = index, result = list(status_code = 500L, body = NULL)) | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.