Skip to content
Draft
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 157 additions & 10 deletions sentry_sdk/integrations/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,17 @@
return 0


def _calculate_token_usage(
def _calculate_streaming_completions_token_usage(
messages: "Optional[Iterable[ChatCompletionMessageParam]]",
response: "Any",
span: "Span",
streaming_message_responses: "Optional[List[str]]",
count_tokens: "Callable[..., Any]",
) -> None:
"""
Sets token attribute for streaming Completions calls.
Sets token usage if available in the response object, or computes the usage for textual input and output with tiktoken.
"""
input_tokens: "Optional[int]" = 0
input_tokens_cached: "Optional[int]" = 0
output_tokens: "Optional[int]" = 0
Expand Down Expand Up @@ -219,6 +223,151 @@
)


def _calculate_streaming_responses_token_usage(
messages: "Optional[Optional[Union[str, ResponseInputParam]]]",
response: "Any",
span: "Span",
streaming_message_responses: "Optional[List[str]]",
count_tokens: "Callable[..., Any]",
):
"""
Sets token attribute for streaming Responses calls.
Sets token usage if available in the response object, or computes the usage for textual input and output with tiktoken.
"""
input_tokens: "Optional[int]" = 0
input_tokens_cached: "Optional[int]" = 0
output_tokens: "Optional[int]" = 0
output_tokens_reasoning: "Optional[int]" = 0
total_tokens: "Optional[int]" = 0

if hasattr(response, "usage"):
input_tokens = _get_usage(response.usage, ["input_tokens", "prompt_tokens"])
if hasattr(response.usage, "input_tokens_details"):
input_tokens_cached = _get_usage(
response.usage.input_tokens_details, ["cached_tokens"]
)

output_tokens = _get_usage(
response.usage, ["output_tokens", "completion_tokens"]
)
if hasattr(response.usage, "output_tokens_details"):
output_tokens_reasoning = _get_usage(
response.usage.output_tokens_details, ["reasoning_tokens"]
)

total_tokens = _get_usage(response.usage, ["total_tokens"])

# Manually count tokens
if input_tokens == 0:
for message in messages or []:
if isinstance(message, str):
input_tokens += count_tokens(message)
continue
elif isinstance(message, dict):
message_content = message.get("content")
if message_content is None:
continue
# Deliberate use of Completions function for both Completions and Responses input format.
text_items = _get_text_items(message_content)
input_tokens += sum(count_tokens(text) for text in text_items)
continue

if output_tokens == 0:
if streaming_message_responses is not None:
for message in streaming_message_responses:
output_tokens += count_tokens(message)
elif hasattr(response, "choices"):
for choice in response.choices:
if hasattr(choice, "message") and hasattr(choice.message, "content"):
output_tokens += count_tokens(choice.message.content)

# Do not set token data if it is 0
input_tokens = input_tokens or None
input_tokens_cached = input_tokens_cached or None
output_tokens = output_tokens or None
output_tokens_reasoning = output_tokens_reasoning or None
total_tokens = total_tokens or None

record_token_usage(
span,
input_tokens=input_tokens,
input_tokens_cached=input_tokens_cached,
output_tokens=output_tokens,
output_tokens_reasoning=output_tokens_reasoning,
total_tokens=total_tokens,
)


def _calculate_token_usage(
messages: "Optional[Iterable[ChatCompletionMessageParam]]",
response: "Any",
span: "Span",
count_tokens: "Callable[..., Any]",
) -> None:
"""
Sets token attribute for non-streaming Completions and Responses API calls.
Sets token usage if available in the response object, or computes the usage for textual input and output with tiktoken.
"""
input_tokens: "Optional[int]" = 0
input_tokens_cached: "Optional[int]" = 0
output_tokens: "Optional[int]" = 0
output_tokens_reasoning: "Optional[int]" = 0
total_tokens: "Optional[int]" = 0

if hasattr(response, "usage"):
input_tokens = _get_usage(response.usage, ["input_tokens", "prompt_tokens"])
if hasattr(response.usage, "input_tokens_details"):
input_tokens_cached = _get_usage(
response.usage.input_tokens_details, ["cached_tokens"]
)

output_tokens = _get_usage(
response.usage, ["output_tokens", "completion_tokens"]
)
if hasattr(response.usage, "output_tokens_details"):
output_tokens_reasoning = _get_usage(
response.usage.output_tokens_details, ["reasoning_tokens"]
)

total_tokens = _get_usage(response.usage, ["total_tokens"])

# Manually count tokens
if input_tokens == 0:
for message in messages or []:
if isinstance(message, str):
input_tokens += count_tokens(message)
continue
elif isinstance(message, dict):
message_content = message.get("content")
if message_content is None:
continue
# Deliberate use of Completions function for both Completions and Responses input format.
text_items = _get_text_items(message_content)
input_tokens += sum(count_tokens(text) for text in text_items)
continue

if output_tokens == 0 and hasattr(response, "choices"):
for choice in response.choices:
if hasattr(choice, "message") and hasattr(choice.message, "content"):
output_tokens += count_tokens(choice.message.content)

# Do not set token data if it is 0
input_tokens = input_tokens or None
input_tokens_cached = input_tokens_cached or None
output_tokens = output_tokens or None
output_tokens_reasoning = output_tokens_reasoning or None
total_tokens = total_tokens or None

record_token_usage(
span,
input_tokens=input_tokens,
input_tokens_cached=input_tokens_cached,
output_tokens=output_tokens,
output_tokens_reasoning=output_tokens_reasoning,
total_tokens=total_tokens,
)


def _set_responses_api_input_data(
span: "Span",
kwargs: "dict[str, Any]",
Expand Down Expand Up @@ -493,7 +642,7 @@
if len(response_text) > 0:
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, response_text)

_calculate_token_usage(input, response, span, None, integration.count_tokens)
_calculate_token_usage(input, response, span, integration.count_tokens)

Check failure on line 645 in sentry_sdk/integrations/openai.py

View check run for this annotation

@sentry/warden / warden: find-bugs

Tests call _calculate_token_usage with wrong number of arguments

The `_calculate_token_usage` function signature was changed from 5 parameters to 4 parameters (removing `streaming_message_responses`). However, the tests at lines 1800, 1833, 1866, 1900, and 1926 still pass 5 arguments including `streaming_message_responses`. This will cause test failures since the 4th positional argument will be interpreted as `count_tokens` instead of as the callable function.

if finish_span:
span.__exit__(None, None, None)
Expand Down Expand Up @@ -529,12 +678,12 @@
span, SPANDATA.GEN_AI_RESPONSE_TEXT, output_messages["response"]
)

_calculate_token_usage(input, response, span, None, integration.count_tokens)
_calculate_token_usage(input, response, span, integration.count_tokens)

if finish_span:
span.__exit__(None, None, None)
else:
_calculate_token_usage(input, response, span, None, integration.count_tokens)
_calculate_token_usage(input, response, span, integration.count_tokens)
if finish_span:
span.__exit__(None, None, None)

Expand Down Expand Up @@ -658,7 +807,7 @@
set_data_normalized(
span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses
)
_calculate_token_usage(
_calculate_streaming_completions_token_usage(
messages,
response,
span,
Expand Down Expand Up @@ -702,7 +851,7 @@
set_data_normalized(
span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses
)
_calculate_token_usage(
_calculate_streaming_completions_token_usage(
messages,
response,
span,
Expand Down Expand Up @@ -777,7 +926,6 @@
input,
x.response,
span,
None,
integration.count_tokens,
)
count_tokens_manually = False
Expand All @@ -796,7 +944,7 @@
span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses
)
if count_tokens_manually:
_calculate_token_usage(
_calculate_streaming_responses_token_usage(
input,
response,
span,
Expand Down Expand Up @@ -826,7 +974,6 @@
input,
x.response,
span,
None,
integration.count_tokens,
)
count_tokens_manually = False
Expand All @@ -845,7 +992,7 @@
span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses
)
if count_tokens_manually:
_calculate_token_usage(
_calculate_streaming_responses_token_usage(
input,
response,
span,
Expand Down
Loading