Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions crates/goose/src/providers/canonical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,56 @@ pub fn recommended_models_from_registry(provider: &str) -> Vec<String> {
.collect()
}

/// Providers that run models locally — their cost is always zero regardless
/// of what the canonical registry says for the underlying model architecture.
fn is_local_provider(provider: &str) -> bool {
matches!(provider, "ollama" | "local")
}

pub fn maybe_get_canonical_model(provider: &str, model: &str) -> Option<CanonicalModel> {
let registry = CanonicalModelRegistry::bundled().ok()?;

// map_to_canonical_model returns the canonical ID (provider/model)
// Parse it to get provider and model parts for registry lookup
let canonical_id = map_to_canonical_model(provider, model, registry)?;
if let Some((canon_provider, canon_model)) = canonical_id.split_once('/') {
registry.get(canon_provider, canon_model).cloned()
let mut canonical = if let Some((canon_provider, canon_model)) = canonical_id.split_once('/') {
registry.get(canon_provider, canon_model).cloned()?
} else {
None
return None;
};

// Local providers run models on the user's own hardware — zero out cloud
// pricing so every consumer (CLI, server, etc.) sees the correct cost.
if is_local_provider(provider) {
canonical.cost = Pricing::default();
}

Some(canonical)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn ollama_models_have_zero_cost() {
// "mistral-nemo" resolves to mistralai/mistral-nemo which has non-zero cloud pricing.
// When accessed via ollama, cost must be zeroed out.
let canonical = maybe_get_canonical_model("ollama", "mistral-nemo")
.expect("mistral-nemo should resolve via ollama");
assert_eq!(canonical.cost.input, None);
assert_eq!(canonical.cost.output, None);
assert!(
canonical.limit.context > 0,
"context limit should be preserved"
);
}

#[test]
fn cloud_provider_retains_cost() {
let canonical = maybe_get_canonical_model("anthropic", "claude-3-5-sonnet-20241022")
.expect("claude-3.5-sonnet should resolve");
assert!(canonical.cost.input.is_some());
assert!(canonical.cost.output.is_some());
}
}
Loading