diff --git a/crates/goose/src/providers/canonical/mod.rs b/crates/goose/src/providers/canonical/mod.rs index 07ea6bd9bac7..7f2fd75be188 100644 --- a/crates/goose/src/providers/canonical/mod.rs +++ b/crates/goose/src/providers/canonical/mod.rs @@ -67,12 +67,39 @@ pub fn recommended_models_from_registry(provider: &str) -> Vec { pub fn maybe_get_canonical_model(provider: &str, model: &str) -> Option { let registry = CanonicalModelRegistry::bundled().ok()?; - // map_to_canonical_model returns the canonical ID (provider/model) - // Parse it to get provider and model parts for registry lookup let canonical_id = map_to_canonical_model(provider, model, registry)?; - if let Some((canon_provider, canon_model)) = canonical_id.split_once('/') { - registry.get(canon_provider, canon_model).cloned() + let mut canonical = if let Some((canon_provider, canon_model)) = canonical_id.split_once('/') { + registry.get(canon_provider, canon_model).cloned()? } else { - None + return None; + }; + + // TODO: replace with a flag on the provider once we have one + if matches!(provider, "ollama" | "local") { + canonical.cost = Pricing::default(); + } + + Some(canonical) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ollama_models_have_zero_cost() { + let canonical = maybe_get_canonical_model("ollama", "mistral-nemo") + .expect("mistral-nemo should resolve via ollama"); + assert_eq!(canonical.cost.input, None); + assert_eq!(canonical.cost.output, None); + assert!(canonical.limit.context > 0); + } + + #[test] + fn cloud_provider_retains_cost() { + let canonical = maybe_get_canonical_model("anthropic", "claude-3-5-sonnet-20241022") + .expect("claude-3.5-sonnet should resolve"); + assert!(canonical.cost.input.is_some()); + assert!(canonical.cost.output.is_some()); } }