diff --git a/.gitignore b/.gitignore index 810b28c44..1361a851e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ dist .DS_Store node_modules +target +Cargo.lock data/tokenspeed-monitor.sqlite data/tokenspeed-monitor.sqlite-shm data/tokenspeed-monitor.sqlite-wal diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 000000000..67f9cf33f --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,72 @@ +[workspace] +members = ["crates/*"] +resolver = "2" + +[workspace.package] +edition = "2024" +version = "0.1.0" +rust-version = "1.87.0" +publish = false + +[workspace.dependencies] +serde = { version = "1.0.228", features = ["derive"] } +serde_json = "1.0.149" +thiserror = "2.0.18" +toml = "1.0.4" + +[workspace.lints.rust] +keyword_idents_2024 = "forbid" +missing_unsafe_on_extern = "forbid" +unsafe_code = "deny" +unsafe_op_in_unsafe_fn = "forbid" +unused_results = "forbid" + +[workspace.lints.clippy] +absolute_paths = "deny" +as_conversions = "deny" +as_pointer_underscore = "deny" +as_underscore = "deny" +cast_possible_truncation = "deny" +cast_possible_wrap = "deny" +cast_precision_loss = "deny" +cast_sign_loss = "deny" +clone_on_ref_ptr = "deny" +dbg_macro = "deny" +empty_drop = "deny" +exhaustive_enums = "allow" +exhaustive_structs = "allow" +filetype_is_file = "deny" +fn_to_numeric_cast_any = "deny" +get_unwrap = "deny" +infinite_loop = "deny" +let_underscore_must_use = "deny" +mem_forget = "deny" +missing_asserts_for_indexing = "deny" +module_name_repetitions = "allow" +multiple_crate_versions = "allow" +multiple_unsafe_ops_per_block = "deny" +mutex_atomic = "deny" +panic = "deny" +rc_buffer = "deny" +rc_mutex = "deny" +str_to_string = "deny" +string_slice = "deny" +struct_field_names = "allow" +tests_outside_test_module = "deny" +todo = "deny" +try_err = "deny" +unimplemented = "deny" +unnecessary_safety_comment = "deny" +unnecessary_safety_doc = "deny" +unreachable = "deny" +unwrap_in_result = "deny" +wildcard_imports = "allow" + +cargo = { level = "deny", priority = -1 } +complexity = { level = "deny", priority = -1 } +correctness = { level = "deny", priority = -1 } +nursery = { level = "deny", priority = -1 } +pedantic = { level = "deny", priority = -1 } +perf = { level = "deny", priority = -1 } +style = { level = "deny", priority = -1 } +suspicious = { level = "deny", priority = -1 } diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..af0840fb3 --- /dev/null +++ b/Makefile @@ -0,0 +1,10 @@ +.PHONY: check lint fmt + +check: + cargo check && cargo check --tests + +lint: + cargo clippy && cargo clippy --tests + +fmt: + cargo fmt --all diff --git a/crates/modelsdev/Cargo.toml b/crates/modelsdev/Cargo.toml new file mode 100644 index 000000000..22d8c797e --- /dev/null +++ b/crates/modelsdev/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "modelsdev" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +publish.workspace = true +build = "build.rs" + +[dependencies] +serde = { workspace = true } +serde_json = { workspace = true } +thiserror = { workspace = true } + +[build-dependencies] +serde_json = { workspace = true } +toml = { workspace = true } + +[lints] +workspace = true diff --git a/crates/modelsdev/build.rs b/crates/modelsdev/build.rs new file mode 100644 index 000000000..6d27ca312 --- /dev/null +++ b/crates/modelsdev/build.rs @@ -0,0 +1,112 @@ +use std::collections::BTreeMap; +use std::env; +use std::fs; +use std::io; +use std::path::{Path, PathBuf}; + +fn main() -> Result<(), Box> { + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?); + let repo_root = manifest_dir + .parent() + .and_then(Path::parent) + .ok_or_else(|| io::Error::other("failed to resolve repo root"))?; + let providers_dir = repo_root.join("providers"); + let out_dir = PathBuf::from(env::var("OUT_DIR")?); + let catalog_path = out_dir.join("catalog.json"); + + println!("cargo:rerun-if-changed={}", providers_dir.display()); + + let catalog = serde_json::json!({ + "providers": build_catalog(&providers_dir)?, + }); + let json = serde_json::to_string(&catalog)?; + fs::write(catalog_path, json)?; + Ok(()) +} + +fn build_catalog(providers_dir: &Path) -> Result, io::Error> { + let mut providers = BTreeMap::new(); + for entry in fs::read_dir(providers_dir)? { + let entry = entry?; + if !entry.file_type()?.is_dir() { + continue; + } + let provider_id = entry.file_name().to_string_lossy().to_string(); + let provider_dir = entry.path(); + let provider_toml = provider_dir.join("provider.toml"); + if !provider_toml.is_file() { + continue; + } + + let mut provider = read_toml_value(&provider_toml)?; + let Some(provider_object) = provider.as_object_mut() else { + return Err(io::Error::other(format!( + "provider TOML must be an object: {}", + provider_toml.display() + ))); + }; + let _ = provider_object.insert( + "id".to_owned(), + serde_json::Value::String(provider_id.clone()), + ); + let _ = provider_object.insert( + "models".to_owned(), + serde_json::Value::Object(serde_json::Map::new()), + ); + + let models_dir = provider_dir.join("models"); + if models_dir.is_dir() { + let mut models = BTreeMap::new(); + collect_models(&models_dir, &models_dir, &mut models)?; + let models_object = models.into_iter().collect::>(); + let _ = provider_object.insert( + "models".to_owned(), + serde_json::Value::Object(models_object), + ); + } + + let _ = providers.insert(provider_id, provider); + } + Ok(providers) +} + +fn collect_models( + models_dir: &Path, + current_dir: &Path, + models: &mut BTreeMap, +) -> Result<(), io::Error> { + for entry in fs::read_dir(current_dir)? { + let entry = entry?; + let file_type = entry.file_type()?; + let path = entry.path(); + if file_type.is_dir() { + collect_models(models_dir, &path, models)?; + continue; + } + if path.extension().and_then(|ext| ext.to_str()) != Some("toml") || !path.is_file() { + continue; + } + let relative = path.strip_prefix(models_dir).map_err(io::Error::other)?; + let model_id = relative + .to_string_lossy() + .strip_suffix(".toml") + .unwrap_or_default() + .replace('\\', "/"); + let mut model = read_toml_value(&path)?; + let Some(model_object) = model.as_object_mut() else { + return Err(io::Error::other(format!( + "model TOML must be an object: {}", + path.display() + ))); + }; + let _ = model_object.insert("id".to_owned(), serde_json::Value::String(model_id.clone())); + let _ = models.insert(model_id, model); + } + Ok(()) +} + +fn read_toml_value(path: &Path) -> Result { + let text = fs::read_to_string(path)?; + let toml_value = toml::from_str::(&text).map_err(io::Error::other)?; + serde_json::to_value(toml_value).map_err(io::Error::other) +} diff --git a/crates/modelsdev/src/lib.rs b/crates/modelsdev/src/lib.rs new file mode 100644 index 000000000..cd37bdafe --- /dev/null +++ b/crates/modelsdev/src/lib.rs @@ -0,0 +1,201 @@ +#![forbid(unsafe_code)] + +use std::collections::BTreeMap; +use std::sync::OnceLock; + +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +static CATALOG: OnceLock = OnceLock::new(); + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Catalog { + pub providers: BTreeMap, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Provider { + pub id: String, + pub env: Vec, + pub npm: String, + pub api: Option, + pub name: String, + pub doc: String, + pub models: BTreeMap, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Model { + pub id: String, + pub name: String, + pub family: Option, + pub attachment: bool, + pub reasoning: bool, + pub tool_call: bool, + pub interleaved: Option, + pub structured_output: Option, + pub temperature: Option, + pub knowledge: Option, + pub release_date: String, + pub last_updated: String, + pub modalities: Modalities, + pub open_weights: bool, + pub cost: Option, + pub limit: Limit, + pub status: Option, + pub provider: Option, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum Interleaved { + Enabled(bool), + Field(InterleavedField), +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct InterleavedField { + pub field: String, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Modalities { + pub input: Vec, + pub output: Vec, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Cost { + pub input: f64, + pub output: f64, + pub reasoning: Option, + pub cache_read: Option, + pub cache_write: Option, + pub input_audio: Option, + pub output_audio: Option, + pub context_over_200k: Option>, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Limit { + pub context: i64, + pub input: Option, + pub output: i64, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ModelStatus { + Alpha, + Beta, + Deprecated, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ModelProvider { + pub npm: Option, + pub api: Option, + pub shape: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ProviderTransport { + pub npm: String, + pub api: Option, + pub shape: Option, +} + +#[derive(Debug, Error, Clone, PartialEq, Eq)] +pub enum LookupError { + #[error("unknown provider: {provider_id}")] + UnknownProvider { provider_id: String }, + #[error("unknown model for provider {provider_id}: {model_id}")] + UnknownModel { + provider_id: String, + model_id: String, + }, +} + +#[must_use] +pub fn catalog() -> &'static Catalog { + CATALOG.get_or_init(|| { + serde_json::from_str(CATALOG_JSON).expect("embedded modelsdev catalog must stay valid") + }) +} + +impl Catalog { + pub fn provider(&self, provider_id: &str) -> Result<&Provider, LookupError> { + self.providers + .get(provider_id.trim()) + .ok_or_else(|| LookupError::UnknownProvider { + provider_id: provider_id.trim().to_owned(), + }) + } + + pub fn model(&self, provider_id: &str, model_id: &str) -> Result<&Model, LookupError> { + let provider = self.provider(provider_id)?; + provider.model(model_id) + } + + #[must_use] + pub fn supported_provider_ids(&self) -> Vec<&str> { + self.providers + .values() + .filter(|provider| provider.transport(None).family().is_some()) + .map(|provider| provider.id.as_str()) + .collect() + } +} + +impl Provider { + pub fn model(&self, model_id: &str) -> Result<&Model, LookupError> { + self.models + .get(model_id.trim()) + .ok_or_else(|| LookupError::UnknownModel { + provider_id: self.id.clone(), + model_id: model_id.trim().to_owned(), + }) + } + + #[must_use] + pub fn transport(&self, model: Option<&Model>) -> ProviderTransport { + let model_provider = model.and_then(|value| value.provider.as_ref()); + ProviderTransport { + npm: model_provider + .and_then(|value| value.npm.clone()) + .unwrap_or_else(|| self.npm.clone()), + api: model_provider + .and_then(|value| value.api.clone()) + .or_else(|| self.api.clone()), + shape: model_provider.and_then(|value| value.shape.clone()), + } + } +} + +impl ProviderTransport { + #[must_use] + pub fn family(&self) -> Option<&'static str> { + match self.npm.as_str() { + "@ai-sdk/anthropic" => Some("anthropic"), + "@ai-sdk/openai" | "@ai-sdk/openai-compatible" => Some("openai"), + _ => None, + } + } + + #[must_use] + pub fn base_url(&self) -> Option<&str> { + if let Some(api) = self.api.as_deref() { + return Some(api); + } + match self.npm.as_str() { + "@ai-sdk/anthropic" => Some("https://api.anthropic.com"), + "@ai-sdk/openai" => Some("https://api.openai.com/v1"), + _ => None, + } + } +} + +const CATALOG_JSON: &str = include_str!(concat!(env!("OUT_DIR"), "/catalog.json")); + +#[cfg(test)] +mod tests; diff --git a/crates/modelsdev/src/tests.rs b/crates/modelsdev/src/tests.rs new file mode 100644 index 000000000..e8eec1246 --- /dev/null +++ b/crates/modelsdev/src/tests.rs @@ -0,0 +1,45 @@ +use super::{LookupError, catalog}; + +#[test] +fn catalog_loads_openai_provider() { + let catalog = catalog(); + let provider = catalog.provider("openai").expect("openai should exist"); + assert_eq!(provider.id, "openai"); + assert!(provider.model("gpt-5").is_ok()); + assert_eq!(provider.transport(None).family(), Some("openai")); +} + +#[test] +fn catalog_loads_anthropic_model() { + let catalog = catalog(); + let provider = catalog + .provider("anthropic") + .expect("anthropic should exist"); + let model = provider + .model("claude-sonnet-4-0") + .expect("claude-sonnet-4-0 should exist"); + assert_eq!(model.id, "claude-sonnet-4-0"); +} + +#[test] +fn catalog_loads_opencode_provider() { + let catalog = catalog(); + let provider = catalog.provider("opencode").expect("opencode should exist"); + let transport = provider.transport(None); + assert_eq!(transport.family(), Some("openai")); + assert_eq!(transport.base_url(), Some("https://opencode.ai/zen/v1")); +} + +#[test] +fn unknown_provider_returns_lookup_error() { + let catalog = catalog(); + let error = catalog + .provider("does-not-exist") + .expect_err("unknown provider should fail"); + assert_eq!( + error, + LookupError::UnknownProvider { + provider_id: "does-not-exist".to_owned(), + } + ); +}