Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions crates/goose/src/providers/canonical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,39 @@ pub fn recommended_models_from_registry(provider: &str) -> Vec<String> {
pub fn maybe_get_canonical_model(provider: &str, model: &str) -> Option<CanonicalModel> {
let registry = CanonicalModelRegistry::bundled().ok()?;

// map_to_canonical_model returns the canonical ID (provider/model)
// Parse it to get provider and model parts for registry lookup
let canonical_id = map_to_canonical_model(provider, model, registry)?;
if let Some((canon_provider, canon_model)) = canonical_id.split_once('/') {
registry.get(canon_provider, canon_model).cloned()
let mut canonical = if let Some((canon_provider, canon_model)) = canonical_id.split_once('/') {
registry.get(canon_provider, canon_model).cloned()?
} else {
None
return None;
};

// TODO: replace with a flag on the provider once we have one
if matches!(provider, "ollama" | "local") {
canonical.cost = Pricing::default();
}

Some(canonical)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn ollama_models_have_zero_cost() {
let canonical = maybe_get_canonical_model("ollama", "mistral-nemo")
.expect("mistral-nemo should resolve via ollama");
assert_eq!(canonical.cost.input, None);
assert_eq!(canonical.cost.output, None);
assert!(canonical.limit.context > 0);
}

#[test]
fn cloud_provider_retains_cost() {
let canonical = maybe_get_canonical_model("anthropic", "claude-3-5-sonnet-20241022")
.expect("claude-3.5-sonnet should resolve");
assert!(canonical.cost.input.is_some());
assert!(canonical.cost.output.is_some());
}
}
Loading