Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# ellmer (development version)

* `batch_chat()` now supports `chat_google_gemini()` for batch processing via
the Gemini Developer API (@xmarquez, #914).
* ellmer will now distinguish text content from thinking content while streaming, allowing downstream packages like shinychat to provide specific UI for thinking content (@simonpcouch, #909).
* `chat_github()` now uses `chat_openai_compatible()` for improved compatibility, and `models_github()` now supports custom `base_url` configuration (@D-M4rk, #877).
* `chat_ollama()` now contains a slot for `top_k` within the `params` argument (@frankiethull).
Expand Down
18 changes: 14 additions & 4 deletions R/batch-chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
#'
#' @description
#' `batch_chat()` and `batch_chat_structured()` currently only work with
#' [chat_openai()] and [chat_anthropic()]. They use the
#' [OpenAI](https://platform.openai.com/docs/guides/batch) and
#' [Anthropic](https://docs.claude.com/en/docs/build-with-claude/batch-processing)
#' batch APIs which allow you to submit multiple requests simultaneously.
#' [chat_openai()], [chat_anthropic()], and [chat_google_gemini()]. They use
#' the [OpenAI](https://platform.openai.com/docs/guides/batch),
#' [Anthropic](https://docs.claude.com/en/docs/build-with-claude/batch-processing),
#' and [Google Gemini](https://ai.google.dev/gemini-api/docs/batch-api) batch APIs
#' which allow you to submit multiple requests simultaneously.
#' The results can take up to 24 hours to complete, but in return you pay 50%
#' less than usual (but note that ellmer doesn't include this discount in
#' its pricing metadata). If you want to get results back more quickly, or
Expand Down Expand Up @@ -90,6 +91,9 @@ batch_chat <- function(chat, prompts, path, wait = TRUE, ignore_hash = FALSE) {
ignore_hash = ignore_hash
)
job$step_until_done()
if (job$stage != "done") {
return(NULL)
}

assistant_turns <- job$result_turns()
map2(job$user_turns, assistant_turns, function(user, assistant) {
Expand Down Expand Up @@ -119,6 +123,9 @@ batch_chat_text <- function(
wait = wait,
ignore_hash = ignore_hash
)
if (is.null(chats)) {
return(NULL)
}
map_chr(chats, \(chat) {
if (is.null(chat)) NA_character_ else chat$last_turn()@text
})
Expand Down Expand Up @@ -151,6 +158,9 @@ batch_chat_structured <- function(
ignore_hash = ignore_hash
)
job$step_until_done()
if (job$stage != "done") {
return(NULL)
}
turns <- job$result_turns()

multi_convert(
Expand Down
39 changes: 37 additions & 2 deletions R/provider-google-upload.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ google_upload <- function(

mime_type <- mime_type %||% guess_mime_type(path)

status <- google_upload_file(
path = path,
base_url = base_url,
credentials = credentials,
mime_type = mime_type
)

ContentUploaded(uri = status$uri, mime_type = status$mimeType)
}

google_upload_file <- function(path, base_url, credentials, mime_type) {
upload_url <- google_upload_init(
path = path,
base_url = base_url,
Expand All @@ -56,8 +67,7 @@ google_upload <- function(
credentials = credentials
)
google_upload_wait(status, credentials)

ContentUploaded(uri = status$uri, mime_type = status$mimeType)
status
}

# https://ai.google.dev/api/files#method:-media.upload
Expand Down Expand Up @@ -124,6 +134,31 @@ google_upload_wait <- function(status, credentials) {
invisible()
}

# Batch file helpers -----------------------------------------------------------

gemini_upload_file <- function(
provider,
path,
mime_type = "application/jsonl"
) {
upload_base_url <- sub("/v[^/]+/?$", "/", provider@base_url)

google_upload_file(
path = path,
base_url = upload_base_url,
credentials = provider@credentials,
mime_type = mime_type
)
}

gemini_download_file <- function(provider, name, path) {
req <- base_request(provider)
req <- req_url_path_append(req, paste0(name, ":download"))
req <- req_url_query(req, alt = "media")
req_perform(req, path = path)
invisible(path)
}

# Helpers ----------------------------------------------------------------------

guess_mime_type <- function(file_path, call = caller_env()) {
Expand Down
272 changes: 272 additions & 0 deletions R/provider-google.R
Original file line number Diff line number Diff line change
Expand Up @@ -872,3 +872,275 @@ models_google <- function(
google_location <- function(location) {
if (location == "global") "" else paste0(location, "-")
}

# Batched requests -------------------------------------------------------------

# https://ai.google.dev/gemini-api/docs/batch-api
method(has_batch_support, ProviderGoogleGemini) <- function(provider) {
TRUE
}

method(batch_submit, ProviderGoogleGemini) <- function(
provider,
conversations,
type = NULL
) {
path <- withr::local_tempfile(fileext = ".jsonl")

requests <- map(seq_along(conversations), function(i) {
body <- chat_body(
provider,
stream = FALSE,
turns = conversations[[i]],
type = type
)

list(
key = paste0("chat-", i),
request = gemini_prepare_batch_body(body)
)
})

json_lines <- map_chr(requests, to_json)
writeLines(json_lines, path)

uploaded <- gemini_upload_file(provider, path)
if (is.null(uploaded$name) || !nzchar(uploaded$name)) {
cli::cli_abort(
"Gemini upload did not return a file resource name.",
.internal = TRUE
)
}

req <- base_request(provider)
req <- req_url_path_append(
req,
"models",
paste0(provider@model, ":batchGenerateContent")
)
req <- req_body_json(
req,
list(
batch = list(
displayName = paste0("ellmer-", as.integer(Sys.time())),
model = paste0("models/", provider@model),
inputConfig = list(fileName = uploaded$name)
)
)
)

resp <- req_perform(req)
resp_body_json(resp)
}

method(batch_poll, ProviderGoogleGemini) <- function(provider, batch) {
req <- base_request(provider)
req <- req_url_path_append(req, batch$name)
resp <- req_perform(req)
resp_body_json(resp)
}

method(batch_status, ProviderGoogleGemini) <- function(provider, batch) {
metadata <- batch$metadata %||% list()
response <- batch$response %||% list()
state <- metadata$state %||% response$state %||% "BATCH_STATE_UNSPECIFIED"
stats <- metadata$batchStats %||% response$batchStats %||% list()

total <- as.integer(stats$requestCount %||% 0L)
pending <- as.integer(stats$pendingRequestCount %||% 0L)
succeeded <- as.integer(stats$successfulRequestCount %||% 0L)
failed <- as.integer(stats$failedRequestCount %||% 0L)

if (!is.null(batch$error) && total > 0 && failed == 0L) {
failed <- total
}

terminal_states <- c(
"BATCH_STATE_SUCCEEDED",
"BATCH_STATE_FAILED",
"BATCH_STATE_CANCELLED",
"BATCH_STATE_EXPIRED"
)

is_done <- state %in% terminal_states

# Keep polling if succeeded but output file isn't available yet.
# The API can report BATCH_STATE_SUCCEEDED before the responsesFile
# metadata is populated.
if (state == "BATCH_STATE_SUCCEEDED") {
responses_file <- batch$response$responsesFile %||% ""
if (!nzchar(responses_file)) {
is_done <- FALSE
}
}

n_processing <- max(pending, total - succeeded - failed, 0L)

list(
working = !is_done,
n_processing = n_processing,
n_succeeded = max(succeeded, 0L),
n_failed = max(failed, 0L)
)
}

method(batch_retrieve, ProviderGoogleGemini) <- function(provider, batch) {
metadata <- batch$metadata %||% list()
response <- batch$response %||% list()
stats <- metadata$batchStats %||% response$batchStats %||% list()
request_count <- as.integer(stats$requestCount %||% 0L)

if (!is.null(batch$error)) {
code <- as.integer(batch$error$code %||% 500L)
return(rep(
list(list(status_code = code, body = NULL)),
max(0L, request_count)
))
}

responses_file <- batch$response$responsesFile %||% ""

if (!nzchar(responses_file)) {
cli::cli_abort("Gemini batch completed but no output file was returned.")
}

path_output <- withr::local_tempfile(fileext = ".jsonl")
gemini_download_file(provider, responses_file, path_output)

parsed <- read_ndjson(path_output)

normalized <- imap(parsed, function(x, i) {
gemini_normalize_result(x, index_default = as.integer(i))
})

ids <- vapply(normalized, function(x) x$index, integer(1))
results <- lapply(normalized, function(x) x$result)
results[order(ids)]
}

method(batch_result_turn, ProviderGoogleGemini) <- function(
provider,
result,
has_type = FALSE
) {
if (!is.null(result) && result$status_code == 200L && !is.null(result$body)) {
value_turn(provider, result$body, has_type = has_type)
} else {
NULL
}
}

# Gemini batch helpers ---------------------------------------------------------

# The Gemini REST API accepts both camelCase and snake_case, but the batch
# JSONL file parser requires protobuf field names which are always snake_case.
# Without this conversion, the batch API rejects requests with HTTP 400.
gemini_to_snake_case <- function(x) {
if (is.list(x)) {
if (!is.null(names(x))) {
names(x) <- gsub("([a-z])([A-Z])", "\\1_\\2", names(x), perl = TRUE) |>
tolower()
}
lapply(x, gemini_to_snake_case)
} else {
x
}
}

gemini_prepare_batch_body <- function(body) {
# Remove empty system instructions (batch parser rejects them)
si <- body$systemInstruction %||% body$system_instruction
if (!is.null(si)) {
parts <- si$parts
is_empty <- if (is.list(parts) && !is.null(names(parts))) {
identical(parts$text, "") || is.null(parts$text)
} else if (is.list(parts) && length(parts) > 0) {
all(vapply(
parts,
function(p) identical(p$text, "") || is.null(p$text),
logical(1)
))
} else {
TRUE
}
if (is_empty) {
body$systemInstruction <- NULL
body$system_instruction <- NULL
}
}

# Save user-defined schema before snake_case conversion so property names
# like "firstName" are not mangled to "first_name"
gc_pre <- body$generationConfig %||% body$generation_config
saved_schema <- if (!is.null(gc_pre)) {
gc_pre$responseSchema %||% gc_pre$response_schema
}

# Batch JSONL requires protobuf-style snake_case field names; camelCase causes
# HTTP 400 (unlike the REST API which accepts both)
body <- gemini_to_snake_case(body)

# Rename response_schema -> response_json_schema and restore original schema
gc <- body$generation_config
if (
!is.null(gc) && (!is.null(gc$response_schema) || !is.null(saved_schema))
) {
gc$response_json_schema <- saved_schema %||% gc$response_schema
gc$response_schema <- NULL
body$generation_config <- gc
}

body
}

gemini_extract_index <- function(x, default = NA_integer_) {
metadata <- x$metadata %||% list()
idx <- metadata$request_index %||% metadata$index

if (!is.null(idx) && !is.na(idx)) {
return(as.integer(idx))
}

key <- x$key %||%
x$custom_id %||%
metadata$key %||%
metadata$custom_id %||%
""
if (grepl("^chat-[0-9]+$", key)) {
return(as.integer(sub("^chat-([0-9]+)$", "\\1", key)))
}

as.integer(default)
}

gemini_normalize_result <- function(x, index_default) {
index <- gemini_extract_index(x, default = index_default)

# Formats where response and error/status are wrapped in one object
if (!is.null(x$response) || !is.null(x$error) || !is.null(x$status)) {
if (!is.null(x$response) && is.null(x$error) && is.null(x$status)) {
return(list(
index = index,
result = list(status_code = 200L, body = x$response)
))
}

status <- x$error %||% x$status %||% list()
code <- status$code %||% 500L
return(list(
index = index,
result = list(status_code = as.integer(code), body = NULL)
))
}

# Plain GenerateContentResponse lines (current file-mode output)
if (
!is.null(x$candidates) ||
!is.null(x$promptFeedback) ||
!is.null(x$usageMetadata)
) {
return(list(index = index, result = list(status_code = 200L, body = x)))
}

list(index = index, result = list(status_code = 500L, body = NULL))
}
Loading
Loading