-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathmodel_name_utils.py
More file actions
173 lines (126 loc) · 4.85 KB
/
model_name_utils.py
File metadata and controls
173 lines (126 loc) · 4.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for model name validation and parsing."""
from __future__ import annotations
import re
from typing import Optional
from packaging.version import InvalidVersion
from packaging.version import Version
from .env_utils import is_env_enabled
_DISABLE_GEMINI_MODEL_ID_CHECK_ENV_VAR = 'ADK_DISABLE_GEMINI_MODEL_ID_CHECK'
def is_gemini_model_id_check_disabled() -> bool:
"""Returns True when Gemini model-id validation should be bypassed.
This opt-in environment variable is intended for internal usage where model
ids may not follow the public ``gemini-*`` naming convention.
"""
return is_env_enabled(_DISABLE_GEMINI_MODEL_ID_CHECK_ENV_VAR)
def extract_model_name(model_string: str) -> str:
"""Extract the actual model name from either simple or path-based format.
Args:
model_string: Either a simple model name like "gemini-2.5-pro" or a
path-based model name like "projects/.../models/gemini-2.0-flash-001"
Returns:
The extracted model name (e.g., "gemini-2.5-pro")
"""
# Pattern for path-based model names
# Need to support both Vertex/Gemini and Apigee model paths.
path_patterns = (
r'^projects/[^/]+/locations/[^/]+/publishers/[^/]+/models/(.+)$',
r'^apigee/(?:[^/]+/)?(?:[^/]+/)?(.+)$',
)
# Check against all path-based patterns
for pattern in path_patterns:
match = re.match(pattern, model_string)
if match:
# Return the captured group (the model name)
return match.group(1)
# Handle 'models/' prefixed names like "models/gemini-2.5-pro"
if model_string.startswith('models/'):
return model_string[len('models/') :]
# If it's not a path-based model, return as-is (simple model name)
return model_string
def is_gemini_model(model_string: Optional[str]) -> bool:
"""Check if the model is a Gemini model using regex patterns.
Args:
model_string: Either a simple model name or path-based model name
Returns:
True if it's a Gemini model, False otherwise
"""
if not model_string:
return False
model_name = extract_model_name(model_string)
return re.match(r'^gemini-', model_name) is not None
def is_gemini_1_model(model_string: Optional[str]) -> bool:
"""Check if the model is a Gemini 1.x model using regex patterns.
Args:
model_string: Either a simple model name or path-based model name
Returns:
True if it's a Gemini 1.x model, False otherwise
"""
if not model_string:
return False
model_name = extract_model_name(model_string)
return re.match(r'^gemini-1\.\d+', model_name) is not None
def is_gemini_2_or_above(model_string: Optional[str]) -> bool:
"""Check if the model is a Gemini 2.0 or newer model using semantic versions.
Args:
model_string: Either a simple model name or path-based model name
Returns:
True if it's a Gemini 2.0+ model, False otherwise
"""
if not model_string:
return False
model_name = extract_model_name(model_string)
if not model_name.startswith('gemini-'):
return False
version_string = model_name[len('gemini-') :].split('-', 1)[0]
if not version_string:
return False
try:
parsed_version = Version(version_string)
except InvalidVersion:
return False
return bool(parsed_version.major >= 2)
def is_gemini_3_1_flash_live(model_string: Optional[str]) -> bool:
"""Check if the model is a Gemini 3.1 Flash Live model.
Note: This is a very specific model name for live bidi streaming, so we check
for exact match.
Args:
model_string: The model name
Returns:
True if it's a Gemini 3.1 Flash Live model, False otherwise
"""
if not model_string:
return False
return model_string == 'gemini-3.1-flash-live-preview'
def is_gemini_3_or_above(model_string: Optional[str]) -> bool:
"""Check if the model is a Gemini 3.0 or newer model using semantic versions.
Args:
model_string: Either a simple model name or path-based model name
Returns:
True if it's a Gemini 3.0+ model, False otherwise
"""
if not model_string:
return False
model_name = extract_model_name(model_string)
if not model_name.startswith('gemini-'):
return False
version_string = model_name[len('gemini-') :].split('-', 1)[0]
if not version_string:
return False
try:
parsed_version = Version(version_string)
except InvalidVersion:
return False
return bool(parsed_version.major >= 3)