diff --git a/DESCRIPTION b/DESCRIPTION index a816ff5..9a55081 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -36,6 +36,7 @@ Imports: tokenizers, tools, urltools, + vctrs, xml2, yaml Suggests: diff --git a/NAMESPACE b/NAMESPACE index e04da25..bb2bf46 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,8 +1,12 @@ # Generated by roxygen2: do not edit by hand +S3method(as_gpttools_message,default) +S3method(print,chat_list) +S3method(print,gpttools_message) export(add_roxygen_addin) export(addin_run_scrape_pkgs) export(addin_run_select_pkgs) +export(as_gpttools_message) export(chat) export(chat_with_context) export(chat_with_retrieval) @@ -20,6 +24,7 @@ export(get_selection) export(get_transformer_model) export(ghost_chat) export(ghost_writer) +export(ghost_writer_addin) export(gpt_sitrep) export(gpttools_index_all_scraped_data) export(ingest_pdf) @@ -41,11 +46,21 @@ export(script_to_function_addin) export(set_user_config) export(suggest_unit_test_addin) export(transcribe_audio) +export(vec_cast.gpttools_message.data.frame) +export(vec_ptype_abbr.gpttools_message) +export(vec_ptype_full.gpttools_message) import(cli) import(httr2) import(rlang) +import(stringr) importFrom(glue,glue) importFrom(graphics,text) +importFrom(jsonlite,fromJSON) +importFrom(purrr,compact) +importFrom(purrr,map) +importFrom(purrr,map_chr) +importFrom(purrr,map_dfr) +importFrom(purrr,pluck) importFrom(utils,globalVariables) importFrom(utils,head) importFrom(utils,installed.packages) diff --git a/R/addin_copilot.R b/R/addin_copilot.R index 39c01b6..fb37481 100755 --- a/R/addin_copilot.R +++ b/R/addin_copilot.R @@ -3,11 +3,11 @@ #' #' @export copilot_addin <- function() { - cli::cli_alert_info("Attempting to add code suggestions") + cli_alert_info("Attempting to add code suggestions") ghost_chat( service = getOption("gpttools.service", "openai"), stream = TRUE, where = "source" ) - cli::cli_alert_info("Done adding code suggestion") + cli_alert_info("Done adding code suggestion") } diff --git a/R/chat.R b/R/chat.R index 4394dc4..a224be5 100644 --- a/R/chat.R +++ b/R/chat.R @@ -240,7 +240,6 @@ get_cursor_context <- function(context_lines = 20, } else { file_ext <- doc$path |> tools::file_ext() } - list( above = context_above, below = context_below, diff --git a/R/class-message.R b/R/class-message.R new file mode 100644 index 0000000..7898759 --- /dev/null +++ b/R/class-message.R @@ -0,0 +1,106 @@ +new_gpttools_message <- function(role = character(), + content = character(), + datetime = character(), + model = character(), + service = character(), + temperature = double()) { + if (!is_character(role)) cli_abort("`role` must be a character vector.") + if (!is_character(content)) cli_abort("`content` must be a character vector.") + if (!is_character(datetime, n = 1)) { + cli_abort("`datetime` must be a character vector of length 1.") + } + if (!is_character(model, n = 1)) { + cli_abort("`model` must be a character vector of length 1.") + } + if (!is_character(service, n = 1)) { + cli_abort("`service` must be a character vector of length 1") + } + if (!is_double(temperature, n = 1)) { + cli_abort("`temperature` must be a numeric vector of length 1.") + } + + new_rcrd( + fields = + tibble::tibble( + role = role, + content = content, + ), + datetime = datetime, + model = model, + service = service, + temperature = temperature, + class = "gpttools_message" + ) +} + +gpttools_message <- function(role = character(), + content = character(), + datetime = character(), + model = character(), + service = character(), + temperature = double()) { + role <- vec_cast(role, character()) + content <- vec_cast(content, character()) + datetime <- vec_cast(datetime, character()) + model <- vec_cast(model, character()) + service <- vec_cast(service, character()) + temperature <- vec_cast(temperature, double()) + + new_gpttools_message(role, content, datetime, model, service, temperature) +} + +#' @export +print.gpttools_message <- function(x, ...) { + x_valid <- which(!is.na(x)) + + role <- field(x, "role")[x_valid] + content <- field(x, "content")[x_valid] + datetime <- attr(x, "datetime") + model <- attr(x, "model") + service <- attr(x, "service") + temperature <- attr(x, "temperature") + + n <- length(field(x, "role")) + for (i in seq_len(n)) { + writeLines(col_silver(rule(stringr::str_to_title(role[i])))) + writeLines(content[i]) + } + writeLines(rule("Settings", col = "blue")) + writeLines(col_blue(paste0("date: ", unique(datetime)))) + writeLines(col_blue(paste0("model: ", unique(model)))) + writeLines(col_blue(paste0("service: ", unique(service)))) + writeLines(col_blue(paste0("temperature: ", unique(temperature)))) + invisible(x) +} + +#' @export +vec_ptype_full.gpttools_message <- function(x, ...) "gpttools_message" + +#' @export +vec_ptype_abbr.gpttools_message <- function(x, ...) "msg" + +#' @export +as_gpttools_message <- function(x, ...) { + UseMethod("as_gpttools_message") +} + +#' @export +as_gpttools_message.default <- function(x, ...) { + vec_cast(x, "gpttools_message") +} + +#' @export +vec_cast.gpttools_message.data.frame <- function(x, to, ...) { + if (to == "data.frame") { + tibble::tibble( + role = field(x, "role"), + content = field(x, "content"), + datetime = attr(x, "datetime"), + model = attr(x, "model"), + service = attr(x, "service"), + temperature = attr(x, "temperature") + ) + } else { + cli_abort("Can't cast gpttools_message to ", to) + } +} diff --git a/R/document_data.R b/R/document_data.R index e1fd8ba..90baabf 100644 --- a/R/document_data.R +++ b/R/document_data.R @@ -7,13 +7,13 @@ #' @export collect_dataframes <- function() { objects <- names(rlang::global_env()) - purrr::map_chr( + map_chr( .x = objects, .f = \(x) { if (is.data.frame(get(x))) x else NA } ) |> - purrr::compact() |> + compact() |> unlist() } @@ -32,7 +32,7 @@ skim_lite <- function(data) { } collect_column_types <- function(data) { - purrr::map_dfr( + map_dfr( names(data), ~ data.frame( column = .x, diff --git a/R/embedding.R b/R/embedding.R index 705becc..6bcbcfb 100644 --- a/R/embedding.R +++ b/R/embedding.R @@ -34,7 +34,7 @@ prepare_scraped_files <- function(domain) { scraped |> dplyr::mutate( - chunks = purrr::map(text, \(x) { + chunks = map(text, \(x) { chunk_with_overlap(x, chunk_size = 500, overlap_size = 50, @@ -97,7 +97,7 @@ add_embeddings <- function(index, model <- get_transformer_model() index |> dplyr::mutate( - embeddings = purrr::map( + embeddings = map( .x = chunks, .f = \(x) create_text_embeddings(x, model), .progress = "Creating Embeddings Locally" @@ -108,7 +108,7 @@ add_embeddings <- function(index, } else { index |> dplyr::mutate( - embeddings = purrr::map( + embeddings = map( .x = chunks, .f = create_openai_embedding, .progress = "Create Embeddings" @@ -227,7 +227,7 @@ get_top_matches <- function(index, query_embedding, k = 5) { index |> tibble::as_tibble() |> dplyr::mutate( - similarity = purrr::map_dbl(embedding, \(x) { + similarity = map_dbl(embedding, \(x) { lsa::cosine(query_embedding, unlist(x)) }) ) |> @@ -262,6 +262,6 @@ chunk_with_overlap <- function(x, chunk_size, overlap_size, doc_id, ...) { } else { names(chunks) <- NULL } - chunks <- purrr::compact(chunks) - purrr::map(chunks, \(x) stringr::str_c(x, collapse = " ")) + chunks <- compact(chunks) + map(chunks, \(x) stringr::str_c(x, collapse = " ")) } diff --git a/R/gpt-query.R b/R/gpt-query.R index 258fb64..061f8ee 100644 --- a/R/gpt-query.R +++ b/R/gpt-query.R @@ -40,7 +40,7 @@ gpt_chat <- function(instructions, cli_inform("Model: {model}") cli_inform("Sending query... this can take up to 3 minutes.") simple_prompt <- prompt |> - purrr::map_chr(.f = "content") |> + map_chr(.f = "content") |> paste(collapse = "\n\n") answer <- chat( diff --git a/R/gpttools-package.R b/R/gpttools-package.R index daed30d..7f3032f 100644 --- a/R/gpttools-package.R +++ b/R/gpttools-package.R @@ -2,10 +2,12 @@ "_PACKAGE" ## usethis namespace: start -#' @import cli #' @import rlang #' @importFrom glue glue #' @import httr2 +#' @import cli +#' @import stringr +#' @importFrom purrr map_dfr map_chr pluck compact map #' @importFrom utils globalVariables head installed.packages old.packages #' packageDescription packageVersion #' @importFrom graphics text diff --git a/R/harvest-docs.R b/R/harvest-docs.R index 0894642..6e759c8 100644 --- a/R/harvest-docs.R +++ b/R/harvest-docs.R @@ -77,12 +77,12 @@ recursive_hyperlinks <- function(local_domain, expanded_urls <- c(expanded_urls, links) cli_inform(c("i" = "Total urls: {length(expanded_urls)}")) - links_df <- purrr::map(links, get_hyperlinks, + links_df <- map(links, get_hyperlinks, .progress = "Getting more links" ) |> dplyr::bind_rows() |> dplyr::filter(!stringr::str_detect(link, "^\\.$|mailto:|\\#|^\\_$")) |> - dplyr::mutate(link = purrr::map2_chr( + dplyr::mutate(link = map2_chr( .x = parent, .y = link, .f = \(x, y) xml2::url_absolute(y, x) @@ -91,7 +91,7 @@ recursive_hyperlinks <- function(local_domain, cli_inform("Going to check {length(unique(links_df$link))} links") new_links <- - purrr::map(unique(links_df$link), \(x) { + map(unique(links_df$link), \(x) { if (rlang::is_true(stringr::str_detect(x, domain_pattern))) { validate_link(x) } else { @@ -239,7 +239,7 @@ scrape_and_process <- function(url, unique() cli_inform(c("i" = "Scraping validated links")) scraped_data <- - purrr::map(links, \(x) { + map(links, \(x) { if (identical(check_url(x), 200L)) { tibble::tibble( source = local_domain, @@ -322,11 +322,11 @@ extract_text <- function(url, use_html_text2 = TRUE) { ) xpath_tags <- exclude_tags |> - purrr::map_chr(.f = \(x) glue::glue("self::{x}")) |> + map_chr(.f = \(x) glue::glue("self::{x}")) |> stringr::str_c(collapse = " or ") xpath_attributes <- exclude_attributes |> - purrr::map_chr(.f = \(x) { + map_chr(.f = \(x) { glue::glue( "contains(concat(' ', normalize-space(@class), ' '), ' {x}')" ) @@ -336,7 +336,7 @@ extract_text <- function(url, use_html_text2 = TRUE) { # Handling general attribute selectors general_attributes <- c("role", "aria-", "data-", "id", "class", "style") xpath_general_attributes <- general_attributes |> - purrr::map_chr(.f = \(x) glue::glue("@{x}")) |> + map_chr(.f = \(x) glue::glue("@{x}")) |> stringr::str_c(collapse = " or ") xpath_combined <- xpath_combined <- glue::glue( diff --git a/R/history.R b/R/history.R index 8abfad9..35098c3 100644 --- a/R/history.R +++ b/R/history.R @@ -58,7 +58,7 @@ delete_history <- function(local = FALSE) { } history_files <- get_history_path(local = local) - purrr::map(history_files, \(x) { + map(history_files, \(x) { delete_file <- ui_yeah("Do you want to delete {basename(x)}?") if (delete_file) { file.remove(x) @@ -246,7 +246,7 @@ chat_with_context <- function(query, context <- full_context |> dplyr::select(source, link, chunks) |> - purrr::pmap(\(source, link, chunks) { + pmap(\(source, link, chunks) { glue::glue("Source: {source} Link: {link} Text: {chunks}") @@ -331,7 +331,7 @@ chat_with_context <- function(query, ) session_history <- - purrr::map(session_history, \(x) { + map(session_history, \(x) { if (x$role == "system") { NULL } else if (stringr::str_detect( @@ -344,7 +344,7 @@ chat_with_context <- function(query, x } }) |> - purrr::compact() + compact() prompt <- c( session_history, @@ -354,7 +354,7 @@ chat_with_context <- function(query, ) simple_prompt <- prompt |> - purrr::map_chr(.f = "content") |> + map_chr(.f = "content") |> paste(collapse = "\n\n") cli_alert_info("Service: {service}") @@ -390,7 +390,7 @@ chat_with_context <- function(query, } if (save_history) { - purrr::map(prompt, \(x) { + map(prompt, \(x) { save_user_history( file_name = history_name, role = "system", diff --git a/R/index.R b/R/index.R index 6ad6415..f342233 100755 --- a/R/index.R +++ b/R/index.R @@ -37,7 +37,7 @@ gpttools_index_all_scraped_data <- function(overwrite = FALSE, dont_ask = TRUE) { text_files <- list_index("text", full_path = TRUE) - purrr::walk(text_files, function(file_path) { + walk(text_files, function(file_path) { domain <- tools::file_path_sans_ext(basename(file_path)) cli_alert_info(glue("Creating/updating index for domain {domain}...")) create_index( diff --git a/R/site-index.R b/R/site-index.R index a6336c0..b5230c4 100644 --- a/R/site-index.R +++ b/R/site-index.R @@ -56,7 +56,7 @@ get_pkgs_to_scrape <- function(local = TRUE, installed_version = Version ) |> dplyr::filter(name %in% pkgs) |> - dplyr::mutate(url = purrr::map_chr(name, get_pkg_doc_page)) |> + dplyr::mutate(url = map_chr(name, get_pkg_doc_page)) |> tidyr::drop_na(url) |> dplyr::mutate(source = urltools::domain(url)) |> dplyr::left_join(get_outdated_pkgs(), by = "name") |> @@ -140,7 +140,7 @@ scrape_pkg_sites <- function(sites = get_pkgs_to_scrape(local = TRUE), } else { sites |> dplyr::select(url, version, name) |> - purrr::pmap(.f = \(url, version, name) { + pmap(.f = \(url, version, name) { crawl( url = url, index_create = index_create, diff --git a/R/stream-anthropic.R b/R/stream-anthropic.R index 6b69f4a..14cad1d 100644 --- a/R/stream-anthropic.R +++ b/R/stream-anthropic.R @@ -23,7 +23,8 @@ stream_chat_anthropic <- function(prompt, req_error(is_error = function(resp) FALSE) |> req_perform_stream( callback = element_callback, - buffer_kb = 0.01 + buffer_kb = 0.01, + round = "line" ) # error handling @@ -31,9 +32,10 @@ stream_chat_anthropic <- function(prompt, status <- resp_status(response) description <- resp_status_desc(response) - cli::cli_abort(message = c( + cli_abort(message = c( "x" = "Anthropic API request failed. Error {status} - {description}", "i" = "Visit the Anthropic API documentation for more details" )) } + invisible(response) } diff --git a/R/stream-azure-openai.R b/R/stream-azure-openai.R index ab9ec1e..a34af48 100755 --- a/R/stream-azure-openai.R +++ b/R/stream-azure-openai.R @@ -38,7 +38,8 @@ stream_chat_azure_openai <- function(prompt = NULL, req_error(is_error = function(resp) FALSE) |> req_perform_stream( callback = element_callback, - buffer_kb = 0.01 + buffer_kb = 0.01, + round = "line" ) invisible(response) diff --git a/R/stream-chat.R b/R/stream-chat.R index b4de647..8035057 100644 --- a/R/stream-chat.R +++ b/R/stream-chat.R @@ -41,15 +41,16 @@ stream_chat <- function(prompt, ) } ) + invisible(response) } create_handler <- function(service = "openai", r = NULL, output_id = "streaming", - where = "console") { - env <- rlang::env() + where = "console", + env = caller_env()) { env$resp <- NULL - env$full_resp <- NULL + env$full_resp <- "" stream_details <- get_stream_pattern(service) new_pattern <- stream_details$pattern @@ -73,7 +74,7 @@ create_handler <- function(service = "openai", if (stringr::str_detect(env$resp, pattern)) { parsed <- stringr::str_extract(env$resp, pattern) |> jsonlite::fromJSON() |> - purrr::pluck(!!!new_pluck) + pluck(!!!new_pluck) env$full_resp <- paste0(env$full_resp, parsed) if (where == "shiny") { @@ -121,8 +122,8 @@ get_stream_pattern <- function(service) { pluck <- "text" }, "ollama" = { - pattern <- '\\{"model":.*"done":false\\}' - pluck <- "response" + pattern <- '\\{"id":.*?\\}\\]\\}' + pluck <- c("choices", "delta", "content") }, "azure_openai" = { pattern <- '\\{"id":.*?\\}\\]\\}' diff --git a/R/stream-cohere.R b/R/stream-cohere.R index 4b5cd5f..48b8ee8 100644 --- a/R/stream-cohere.R +++ b/R/stream-cohere.R @@ -1,5 +1,5 @@ stream_chat_cohere <- function(prompt, - model = getOption("gpttools.model", "command"), + model = "command", element_callback = create_handler("cohere"), key = Sys.getenv("COHERE_API_KEY")) { request_body <- list( @@ -15,19 +15,23 @@ stream_chat_cohere <- function(prompt, `Authorization` = paste("Bearer", key), `content-type` = "application/json" ) |> - req_method("POST") |> req_body_json(data = request_body) |> req_retry(max_tries = 3) |> req_error(is_error = function(resp) FALSE) |> - req_perform_stream(callback = element_callback, buffer_kb = 0.01) + req_perform_stream( + callback = element_callback, + buffer_kb = 0.01, + round = "line" + ) if (resp_is_error(response)) { status <- resp_status(response) description <- resp_status_desc(response) - cli::cli_abort(message = c( - "x" = glue::glue("Cohere API request failed. Error {status} - {description}"), + cli_abort(c( + "x" = glue("Cohere API request failed. Error {status} - {description}"), "i" = "Visit the Cohere API documentation for more details" )) } + invisible(response) } diff --git a/R/stream-copilot.R b/R/stream-copilot.R new file mode 100644 index 0000000..cb7ac64 --- /dev/null +++ b/R/stream-copilot.R @@ -0,0 +1,29 @@ +get_copilot_oauth <- function(dir = "~/.config/github-copilot") { + hosts <- jsonlite::read_json(file.path(dir, "hosts.json")) + hosts[[1]]$oauth_token +} + +chat_copilot <- function() { + token <- request("https://api.github.com/copilot_internal/v2/token") |> + req_auth_bearer_token(oauth_token) |> + req_perform() |> + resp_body_json() |> + pluck("token") + + request("https://api.githubcopilot.com/chat/completions") |> + req_auth_bearer_token(token = token) |> + req_headers("Editor-Version" = "vscode/9.9.9") |> + req_body_json( + data = list( + messages = list( + list( + role = "user", + content = "Tell me a joke about the R language." + ) + ), + model = "copilot" + ) + ) |> + req_perform() |> + resp_chat_openai(stream = FALSE) +} diff --git a/R/stream-ollama.R b/R/stream-ollama.R index 3fce7ec..d28b2e7 100644 --- a/R/stream-ollama.R +++ b/R/stream-ollama.R @@ -15,17 +15,22 @@ stream_chat_ollama <- function(prompt, req_url_path_append("api") |> req_url_path_append("generate") |> req_body_json(data = body) |> - req_perform_stream(callback = element_callback, buffer_kb = 0.01) + req_perform_stream( + callback = element_callback, + buffer_kb = 0.01, + round = "line" + ) if (resp_is_error(response)) { status <- resp_status(response) description <- resp_status_desc(response) - cli::cli_abort(message = c( + cli_abort(message = c( "x" = glue::glue("Ollama API request failed. Error {status} - {description}"), "i" = "Visit the Ollama API documentation for more details" )) } + invisible(response) } ollama_is_available <- function(verbose = FALSE) { @@ -39,14 +44,14 @@ ollama_is_available <- function(verbose = FALSE) { response <- req_perform(request) |> resp_body_string() - if (verbose) cli::cli_alert_success(response) + if (verbose) cli_alert_success(response) check_value <- TRUE }, error = function(cnd) { if (inherits(cnd, "httr2_failure")) { - if (verbose) cli::cli_alert_danger("Couldn't connect to Ollama in {.url {ollama_api_url()}}. Is it running there?") + if (verbose) cli_alert_danger("Couldn't connect to Ollama in {.url {ollama_api_url()}}. Is it running there?") } else { - if (verbose) cli::cli_alert_danger(cnd) + if (verbose) cli_alert_danger(cnd) } check_value <- FALSE } diff --git a/R/stream-openai.R b/R/stream-openai.R index 231bc6e..66c7d0c 100644 --- a/R/stream-openai.R +++ b/R/stream-openai.R @@ -1,32 +1,171 @@ -stream_chat_openai <- function(prompt = NULL, - element_callback = create_handler("openai"), - model = getOption("gpttools.model", "gpt-4-turbo-preview"), - openai_api_key = Sys.getenv("OPENAI_API_KEY"), - shiny = FALSE) { - messages <- list( +chat_openai <- function(prompt = "Tell me a joke about the R language.", + model = "gpt-3.5-turbo", + history = NULL, + temperature = NULL, + stream = FALSE) { + prompt <- prompt |> add_history(history) + + response <- + req_chat_openai( + prompt = prompt, + model = model, + temperature = temperature, + stream = is_true(stream) + ) |> + resp_chat_openai(stream = is_true(stream)) + + response <- c( + prompt, + list(list(role = response$role, content = response$content)) + ) + + class(response) <- c("chat_list", class(response)) + + response +} + +#' @export +print.chat_list <- function(x, ...) { + n <- length(x) + + writeLines("\n") + + for (i in seq_len(n)) { + print_role <- rule(stringr::str_to_title(x[[i]]$role)) + print_role <- + switch(x[[i]]$role, + "assistant" = col_green(print_role), + "system" = col_silver(print_role), + "user" = col_blue(print_role) + ) + writeLines(print_role) + writeLines(x[[i]]$content) + } + + writeLines("\n") + invisible(x) +} + + +# Make API Request -------------------------------------------------------- + +req_base_openai <- function( + url = getOption("gpttools.url", "https://api.openai.com/")) { + request(url) |> + req_url_path_append("v1", "chat", "completions") +} + +req_auth_openai <- function(request) { + request |> req_auth_bearer_token(token = Sys.getenv("OPENAI_API_KEY")) +} + +req_body_openai <- function(request, + prompt = "Tell me a joke about the R language.", + model = "gpt-4-turbo-preview", + history = NULL, + temperature = 0.7, + stream = FALSE) { + body <- list( - role = "user", - content = prompt + model = model, + messages = prompt, + temperature = temperature, + stream = is_true(stream) ) - ) - # Set the request body - body <- list( - model = model, - stream = TRUE, - messages = messages - ) + request |> + req_body_json(data = body) +} - response <- - request("https://api.openai.com/v1/chat/completions") |> - req_auth_bearer_token(token = openai_api_key) |> - req_body_json(data = body) |> + +req_chat_openai <- function(prompt, model, temperature, stream = FALSE) { + req <- + req_base_openai() |> + req_auth_openai() |> + req_body_openai( + prompt = prompt, + model = model, + temperature = temperature, + stream = is_true(stream) + ) |> req_retry(max_tries = 3) |> - req_error(is_error = function(resp) FALSE) |> - req_perform_stream( - callback = element_callback, - buffer_kb = 0.01 - ) + req_error(is_error = function(resp) FALSE) + + if (is_true(stream)) { + env <- caller_env() + req |> + req_perform_stream( + callback = \(x) stream_callback_openai(x, env), + buffer_kb = 0.01, + round = "line" + ) + tibble::tibble(role = "assistant", content = env$response) + } else { + req |> + req_perform() + } +} + +#' @importFrom jsonlite fromJSON +stream_callback_openai <- function(x, env) { + txt <- rawToChar(x) - invisible(response) + lines <- str_split(txt, "\n")[[1]] + lines <- lines[lines != ""] + lines <- str_replace_all(lines, "^data: ", "") + lines <- lines[!str_detect(lines, "\"finish_reason\":\"stop\"")] + lines <- lines[lines != "[DONE]"] + + tokens <- map_chr(lines, \(line) { + chunk <- jsonlite::fromJSON(line) + chunk$choices$delta$content + }) + + env$response <- paste0(env$response, tokens) + + cat(tokens) + + TRUE +} + +# Process API Response ---------------------------------------------------- + +resp_chat_openai <- function(response, stream) { + resp <- response + + if (is_true(stream)) { + resp + } else { + resp |> + resp_chat_error_openai() |> + resp_body_json(simplifyVector = TRUE) |> + pluck("choices", "message") + } +} + +resp_chat_error_openai <- function(response) { + if (resp_is_error(response)) { + status <- resp_status(response) + description <- resp_status_desc(response) + + cli_abort(c( + "x" = glue("OpenAI API request failed. Error {status} - {description}"), + "i" = "Visit the OpenAI API documentation for more details" + )) + } else { + invisible(response) + } +} + +add_history <- function(prompt, history = NULL) { + c( + history, + list( + list( + role = "user", + content = prompt + ) + ) + ) |> + purrr::compact() } diff --git a/R/stream-perplexity.R b/R/stream-perplexity.R index 8364a3c..8a96fe2 100644 --- a/R/stream-perplexity.R +++ b/R/stream-perplexity.R @@ -1,6 +1,6 @@ stream_chat_perplexity <- function(prompt, element_callback = create_handler("perplexity"), - model = getOption("gpttools.model", "sonar-small-chat"), + model = "sonar-small-chat", api_key = Sys.getenv("PERPLEXITY_API_KEY")) { request_body <- list( model = model, @@ -19,11 +19,19 @@ stream_chat_perplexity <- function(prompt, ) |> req_body_json(data = request_body) |> req_retry(max_tries = 3) |> - req_perform_stream(callback = element_callback, buffer_kb = 0.01) + req_perform_stream( + callback = element_callback, + buffer_kb = 0.01, + round = "line" + ) if (resp_is_error(response)) { status <- resp_status(response) description <- resp_status_desc(response) - stop("Perplexity API request failed with error ", status, ": ", description, call. = FALSE) + + cli_abort(message = c( + "x" = glue::glue("Perplexity API request failed. Error {status} - {description}"), + "i" = "Visit the Perplexity API documentation for more details" + )) } } diff --git a/R/transcribe.R b/R/transcribe.R index 89bbf9e..826c697 100644 --- a/R/transcribe.R +++ b/R/transcribe.R @@ -83,7 +83,7 @@ transcribe_audio <- function(file_path, link = NA, prompt = NA, chunk_size = 120) { audio_chunks <- split_audio(file_path = file_path, duration_secs = chunk_size) - purrr::map(audio_chunks, \(x) { + map(audio_chunks, \(x) { tibble::tibble( source = source, text = transcribe_audio_chunk(audio_file = x, prompt = prompt), @@ -135,7 +135,7 @@ create_index_from_audio <- function(file_path, #' @export create_transcript <- function(file_path, prompt = NULL, chunk_size = 120) { split_audio(file_path = file_path, duration_secs = chunk_size) |> - purrr::map(\(x) { + map(\(x) { transcribed_text <- transcribe_audio_chunk( audio_file = x, prompt = prompt diff --git a/inst/retriever/app.R b/inst/retriever/app.R index 6a29803..7992a03 100644 --- a/inst/retriever/app.R +++ b/inst/retriever/app.R @@ -12,6 +12,7 @@ library(bslib) library(bsicons) library(waiter) library(reprex) +library(purrr) window_height_ui <- function(id) { ns <- NS(id) @@ -45,14 +46,14 @@ window_height_server <- function(id) { make_chat_history <- function(chats) { history <- - purrr::discard(chats, \(x) x$role == "system") |> - purrr::map(\(x) { + discard(chats, \(x) x$role == "system") |> + map(\(x) { list( strong(stringr::str_to_title(x$role)), markdown(x$content) ) }) |> - purrr::list_flatten() + list_flatten() history } @@ -61,7 +62,7 @@ api_services <- stringr::str_remove( pattern = "gptstudio_request_perform.gptstudio_request_" ) |> - purrr::discard(~ .x == "gptstudio_request_perform.default") + discard(~ .x == "gptstudio_request_perform.default") ui <- page_fillable( useWaiter(), @@ -255,7 +256,7 @@ server <- function(input, output, session) { if ("All" %in% input$source) { load_index(domain = "All", local_embeddings = TRUE) } else { - purrr::map(input$source, \(x) { + map(input$source, \(x) { load_index(x, local_embeddings = TRUE) |> tibble::as_tibble() }) |> @@ -264,7 +265,7 @@ server <- function(input, output, session) { } else if ("All" %in% input$source) { load_index(domain = "All", local_embeddings = FALSE) } else { - purrr::map(input$source, \(x) { + map(input$source, \(x) { load_index(x, local_embeddings = FALSE) |> tibble::as_tibble() }) |> diff --git a/inst/scripts/scrape_other_docs.R b/inst/scripts/scrape_other_docs.R index 626010f..289601e 100644 --- a/inst/scripts/scrape_other_docs.R +++ b/inst/scripts/scrape_other_docs.R @@ -51,7 +51,7 @@ scrape_resources <- function(resources) { sites |> dplyr::select(url, name) |> - purrr::pmap(.f = \(url, name) { + pmap(.f = \(url, name) { crawl( url = url, index_create = TRUE, diff --git a/inst/settings/app.R b/inst/settings/app.R index 218896f..7d86167 100755 --- a/inst/settings/app.R +++ b/inst/settings/app.R @@ -9,7 +9,7 @@ api_services <- stringr::str_remove( pattern = "gptstudio_request_perform.gptstudio_request_" ) |> - purrr::discard(~ .x == "gptstudio_request_perform.default") + discard(~ .x == "gptstudio_request_perform.default") # Define UI ui <- page_fillable( diff --git a/man/ghost_writer_addin.Rd b/man/ghost_writer_addin.Rd new file mode 100644 index 0000000..2c206e9 --- /dev/null +++ b/man/ghost_writer_addin.Rd @@ -0,0 +1,11 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/addin_ghost_writer.R +\name{ghost_writer_addin} +\alias{ghost_writer_addin} +\title{Writing suggestions} +\usage{ +ghost_writer_addin() +} +\description{ +Writing suggestions +}