-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathgemini.py
More file actions
431 lines (372 loc) · 15.1 KB
/
gemini.py
File metadata and controls
431 lines (372 loc) · 15.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
import os
import time
import uuid
from typing import Any, Dict, Optional
from posthog.ai.types import TokenUsage, StreamingEventData
from posthog.ai.utils import merge_system_prompt
try:
from google import genai
except ImportError:
raise ModuleNotFoundError(
"Please install the Google Gemini SDK to use this feature: 'pip install google-genai'"
)
from posthog import setup
from posthog.ai.utils import (
call_llm_and_track_usage,
capture_streaming_event,
merge_usage_stats,
)
from posthog.ai.gemini.gemini_converter import (
extract_gemini_usage_from_chunk,
extract_gemini_content_from_chunk,
extract_gemini_stop_reason_from_chunk,
format_gemini_streaming_output,
)
from posthog.ai.sanitization import sanitize_gemini
from posthog.client import Client as PostHogClient
class Client:
"""
A drop-in replacement for genai.Client that automatically sends LLM usage events to PostHog.
Usage:
client = Client(
api_key="your_api_key",
posthog_client=posthog_client,
posthog_distinct_id="default_user", # Optional defaults
posthog_properties={"team": "ai"} # Optional defaults
)
response = client.models.generate_content(
model="gemini-2.0-flash",
contents=["Hello world"],
posthog_distinct_id="specific_user" # Override default
)
"""
_ph_client: PostHogClient
def __init__(
self,
api_key: Optional[str] = None,
vertexai: Optional[bool] = None,
credentials: Optional[Any] = None,
project: Optional[str] = None,
location: Optional[str] = None,
debug_config: Optional[Any] = None,
http_options: Optional[Any] = None,
posthog_client: Optional[PostHogClient] = None,
posthog_distinct_id: Optional[str] = None,
posthog_properties: Optional[Dict[str, Any]] = None,
posthog_privacy_mode: bool = False,
posthog_groups: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""
Args:
api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable (not required for Vertex AI)
vertexai: Whether to use Vertex AI authentication
credentials: Vertex AI credentials object
project: GCP project ID for Vertex AI
location: GCP location for Vertex AI
debug_config: Debug configuration for the client
http_options: HTTP options for the client
posthog_client: PostHog client for tracking usage
posthog_distinct_id: Default distinct ID for all calls (can be overridden per call)
posthog_properties: Default properties for all calls (can be overridden per call)
posthog_privacy_mode: Default privacy mode for all calls (can be overridden per call)
posthog_groups: Default groups for all calls (can be overridden per call)
**kwargs: Additional arguments (for future compatibility)
"""
self._ph_client = posthog_client or setup()
if self._ph_client is None:
raise ValueError("posthog_client is required for PostHog tracking")
self.models = Models(
api_key=api_key,
vertexai=vertexai,
credentials=credentials,
project=project,
location=location,
debug_config=debug_config,
http_options=http_options,
posthog_client=self._ph_client,
posthog_distinct_id=posthog_distinct_id,
posthog_properties=posthog_properties,
posthog_privacy_mode=posthog_privacy_mode,
posthog_groups=posthog_groups,
**kwargs,
)
class Models:
"""
Models interface that mimics genai.Client().models with PostHog tracking.
"""
_ph_client: PostHogClient # Not None after __init__ validation
def __init__(
self,
api_key: Optional[str] = None,
vertexai: Optional[bool] = None,
credentials: Optional[Any] = None,
project: Optional[str] = None,
location: Optional[str] = None,
debug_config: Optional[Any] = None,
http_options: Optional[Any] = None,
posthog_client: Optional[PostHogClient] = None,
posthog_distinct_id: Optional[str] = None,
posthog_properties: Optional[Dict[str, Any]] = None,
posthog_privacy_mode: bool = False,
posthog_groups: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""
Args:
api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable (not required for Vertex AI)
vertexai: Whether to use Vertex AI authentication
credentials: Vertex AI credentials object
project: GCP project ID for Vertex AI
location: GCP location for Vertex AI
debug_config: Debug configuration for the client
http_options: HTTP options for the client
posthog_client: PostHog client for tracking usage
posthog_distinct_id: Default distinct ID for all calls
posthog_properties: Default properties for all calls
posthog_privacy_mode: Default privacy mode for all calls
posthog_groups: Default groups for all calls
**kwargs: Additional arguments (for future compatibility)
"""
self._ph_client = posthog_client or setup()
if self._ph_client is None:
raise ValueError("posthog_client is required for PostHog tracking")
# Store default PostHog settings
self._default_distinct_id = posthog_distinct_id
self._default_properties = posthog_properties or {}
self._default_privacy_mode = posthog_privacy_mode
self._default_groups = posthog_groups
# Build genai.Client arguments
client_args: Dict[str, Any] = {}
# Add Vertex AI parameters if provided
if vertexai is not None:
client_args["vertexai"] = vertexai
if credentials is not None:
client_args["credentials"] = credentials
if project is not None:
client_args["project"] = project
if location is not None:
client_args["location"] = location
if debug_config is not None:
client_args["debug_config"] = debug_config
if http_options is not None:
client_args["http_options"] = http_options
# Handle API key authentication
if vertexai:
# For Vertex AI, api_key is optional
if api_key is not None:
client_args["api_key"] = api_key
else:
# For non-Vertex AI mode, api_key is required (backwards compatibility)
if api_key is None:
api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("API_KEY")
if api_key is None:
raise ValueError(
"API key must be provided either as parameter or via GOOGLE_API_KEY/API_KEY environment variable"
)
client_args["api_key"] = api_key
self._client = genai.Client(**client_args)
self._base_url = "https://generativelanguage.googleapis.com"
def _merge_posthog_params(
self,
call_distinct_id: Optional[str],
call_trace_id: Optional[str],
call_properties: Optional[Dict[str, Any]],
call_privacy_mode: Optional[bool],
call_groups: Optional[Dict[str, Any]],
):
"""Merge call-level PostHog parameters with client defaults."""
# Use call-level values if provided, otherwise fall back to defaults
distinct_id = (
call_distinct_id
if call_distinct_id is not None
else self._default_distinct_id
)
privacy_mode = (
call_privacy_mode
if call_privacy_mode is not None
else self._default_privacy_mode
)
groups = call_groups if call_groups is not None else self._default_groups
# Merge properties: default properties + call properties (call properties override)
properties = dict(self._default_properties)
if call_properties:
properties.update(call_properties)
if call_trace_id is None:
call_trace_id = str(uuid.uuid4())
return distinct_id, call_trace_id, properties, privacy_mode, groups
def generate_content(
self,
model: str,
contents,
posthog_distinct_id: Optional[str] = None,
posthog_trace_id: Optional[str] = None,
posthog_properties: Optional[Dict[str, Any]] = None,
posthog_privacy_mode: Optional[bool] = None,
posthog_groups: Optional[Dict[str, Any]] = None,
**kwargs: Any,
):
"""
Generate content using Gemini's API while tracking usage in PostHog.
This method signature exactly matches genai.Client().models.generate_content()
with additional PostHog tracking parameters.
Args:
model: The model to use (e.g., 'gemini-2.0-flash')
contents: The input content for generation
posthog_distinct_id: ID to associate with the usage event (overrides client default)
posthog_trace_id: Trace UUID for linking events (auto-generated if not provided)
posthog_properties: Extra properties to include in the event (merged with client defaults)
posthog_privacy_mode: Whether to redact sensitive information (overrides client default)
posthog_groups: Group analytics properties (overrides client default)
**kwargs: Arguments passed to Gemini's generate_content
"""
# Merge PostHog parameters
distinct_id, trace_id, properties, privacy_mode, groups = (
self._merge_posthog_params(
posthog_distinct_id,
posthog_trace_id,
posthog_properties,
posthog_privacy_mode,
posthog_groups,
)
)
kwargs_with_contents = {"model": model, "contents": contents, **kwargs}
return call_llm_and_track_usage(
distinct_id,
self._ph_client,
"gemini",
trace_id,
properties,
privacy_mode,
groups,
self._base_url,
self._client.models.generate_content,
**kwargs_with_contents,
)
def _generate_content_streaming(
self,
model: str,
contents,
distinct_id: Optional[str],
trace_id: Optional[str],
properties: Optional[Dict[str, Any]],
privacy_mode: bool,
groups: Optional[Dict[str, Any]],
**kwargs: Any,
):
start_time = time.time()
usage_stats: TokenUsage = TokenUsage(input_tokens=0, output_tokens=0)
accumulated_content = []
stop_reason: Optional[str] = None
kwargs_without_stream = {"model": model, "contents": contents, **kwargs}
response = self._client.models.generate_content_stream(**kwargs_without_stream)
def generator():
nonlocal usage_stats
nonlocal accumulated_content
nonlocal stop_reason
try:
for chunk in response:
# Extract usage stats from chunk
chunk_usage = extract_gemini_usage_from_chunk(chunk)
if chunk_usage:
# Gemini reports cumulative totals, not incremental values
merge_usage_stats(usage_stats, chunk_usage, mode="cumulative")
# Extract content from chunk (now returns content blocks)
content_block = extract_gemini_content_from_chunk(chunk)
if content_block is not None:
accumulated_content.append(content_block)
# Extract stop reason from chunk
chunk_stop_reason = extract_gemini_stop_reason_from_chunk(chunk)
if chunk_stop_reason is not None:
stop_reason = chunk_stop_reason
yield chunk
finally:
end_time = time.time()
latency = end_time - start_time
self._capture_streaming_event(
model,
contents,
distinct_id,
trace_id,
properties,
privacy_mode,
groups,
kwargs,
usage_stats,
latency,
accumulated_content,
stop_reason=stop_reason,
)
return generator()
def _capture_streaming_event(
self,
model: str,
contents,
distinct_id: Optional[str],
trace_id: Optional[str],
properties: Optional[Dict[str, Any]],
privacy_mode: bool,
groups: Optional[Dict[str, Any]],
kwargs: Dict[str, Any],
usage_stats: TokenUsage,
latency: float,
output: Any,
stop_reason: Optional[str] = None,
):
# Prepare standardized event data
formatted_input = self._format_input(contents, **kwargs)
sanitized_input = sanitize_gemini(formatted_input)
event_data = StreamingEventData(
provider="gemini",
model=model,
base_url=self._base_url,
kwargs=kwargs,
formatted_input=sanitized_input,
formatted_output=format_gemini_streaming_output(output),
usage_stats=usage_stats,
latency=latency,
distinct_id=distinct_id,
trace_id=trace_id,
properties=properties,
privacy_mode=privacy_mode,
groups=groups,
stop_reason=stop_reason,
)
# Use the common capture function
capture_streaming_event(self._ph_client, event_data)
def _format_input(self, contents, **kwargs):
"""Format input contents for PostHog tracking"""
# Create kwargs dict with contents for merge_system_prompt
input_kwargs = {"contents": contents, **kwargs}
return merge_system_prompt(input_kwargs, "gemini")
def generate_content_stream(
self,
model: str,
contents,
posthog_distinct_id: Optional[str] = None,
posthog_trace_id: Optional[str] = None,
posthog_properties: Optional[Dict[str, Any]] = None,
posthog_privacy_mode: Optional[bool] = None,
posthog_groups: Optional[Dict[str, Any]] = None,
**kwargs: Any,
):
# Merge PostHog parameters
distinct_id, trace_id, properties, privacy_mode, groups = (
self._merge_posthog_params(
posthog_distinct_id,
posthog_trace_id,
posthog_properties,
posthog_privacy_mode,
posthog_groups,
)
)
return self._generate_content_streaming(
model,
contents,
distinct_id,
trace_id,
properties,
privacy_mode,
groups,
**kwargs,
)