Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 86 additions & 14 deletions langextract/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import concurrent.futures
import dataclasses
from typing import Any, Iterator, Sequence
import warnings

from langextract.core import base_model
from langextract.core import data
Expand All @@ -28,6 +29,7 @@
from langextract.core import types as core_types
from langextract.providers import patterns
from langextract.providers import router
from langextract.providers import schemas


@router.register(
Expand All @@ -42,6 +44,9 @@ class OpenAILanguageModel(base_model.BaseLanguageModel):
api_key: str | None = None
base_url: str | None = None
organization: str | None = None
openai_schema: schemas.openai.OpenAISchema | None = dataclasses.field(
default=None, repr=False, compare=False
)
format_type: data.FormatType = data.FormatType.JSON
temperature: float | None = None
max_workers: int = 10
Expand All @@ -50,10 +55,46 @@ class OpenAILanguageModel(base_model.BaseLanguageModel):
default_factory=dict, repr=False, compare=False
)

@classmethod
def get_schema_class(cls) -> type[schema.BaseSchema] | None:
"""Return the OpenAISchema class for structured output support."""
return schemas.openai.OpenAISchema

def apply_schema(self, schema_instance: schema.BaseSchema | None) -> None:
"""Applies an OpenAI schema instance to this provider.

Args:
schema_instance: An OpenAISchema to enforce, or None to clear.

Raises:
InferenceConfigError: if schema_instance is a non-OpenAI BaseSchema
subclass, or if applying an OpenAI schema would conflict with a
non-JSON format_type.
"""
if schema_instance is None:
self.openai_schema = None
elif isinstance(schema_instance, schemas.openai.OpenAISchema):
if self.format_type != data.FormatType.JSON:
raise exceptions.InferenceConfigError(
schemas.openai.JSON_SCHEMA_FORMAT_ERROR
)
self.openai_schema = schema_instance
else:
raise exceptions.InferenceConfigError(
'OpenAILanguageModel only accepts OpenAISchema instances; got '
f'{type(schema_instance).__name__}. Use the matching provider '
'for this schema or construct an OpenAISchema via '
'OpenAISchema.from_examples.'
)
super().apply_schema(schema_instance)

@property
def requires_fence_output(self) -> bool:
"""OpenAI JSON mode returns raw JSON without fences."""
if self.format_type == data.FormatType.JSON:
"""OpenAI JSON mode returns raw JSON unless callers override fences."""
if (
self._fence_output_override is None
and self.format_type == data.FormatType.JSON
):
return False
return super().requires_fence_output

Expand All @@ -63,6 +104,7 @@ def __init__(
api_key: str | None = None,
base_url: str | None = None,
organization: str | None = None,
openai_schema: schemas.openai.OpenAISchema | None = None,
format_type: data.FormatType = data.FormatType.JSON,
temperature: float | None = None,
max_workers: int = 10,
Expand All @@ -75,13 +117,13 @@ def __init__(
api_key: API key for OpenAI service.
base_url: Base URL for OpenAI service.
organization: Optional OpenAI organization ID.
openai_schema: Optional schema for structured output.
format_type: Output format (JSON or YAML).
temperature: Sampling temperature.
max_workers: Maximum number of parallel API calls.
**kwargs: Ignored extra parameters so callers can pass a superset of
arguments shared across back-ends without raising ``TypeError``.
arguments shared across back-ends without raising TypeError.
"""
# Lazy import: OpenAI package required
try:
# pylint: disable=import-outside-toplevel
import openai
Expand All @@ -91,33 +133,46 @@ def __init__(
'Install with: pip install langextract[openai]'
) from e

# Constructor-provided schemas use BaseLanguageModel state when applied.
super().__init__(
constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE)
)

self.model_id = model_id
self.api_key = api_key
self.base_url = base_url
self.organization = organization
self.openai_schema = None
self.format_type = format_type
self.temperature = temperature
self.max_workers = max_workers
self._extra_kwargs = kwargs or {}

if not self.api_key:
raise exceptions.InferenceConfigError('API key not provided.')

# Initialize the OpenAI client
if openai_schema is not None:
self.apply_schema(openai_schema)

# Keep SDK initialization after schema validation so LangExtract reports
# configuration errors before any client-side transport checks.
self._client = openai.OpenAI(
api_key=self.api_key,
base_url=self.base_url,
organization=self.organization,
)

super().__init__(
constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE)
)
self._extra_kwargs = kwargs or {}
def _validate_schema_config(self) -> None:
"""Rejects schema settings the OpenAI API cannot honor."""
if self.openai_schema and self.format_type != data.FormatType.JSON:
raise exceptions.InferenceConfigError(
schemas.openai.JSON_SCHEMA_FORMAT_ERROR
)

def _process_single_prompt(
self, prompt: str, config: dict
) -> core_types.ScoredOutput:
"""Process a single prompt and return a ScoredOutput."""
"""Sends one prompt while preserving provider-specific error types."""
try:
normalized_config = config.copy()

Expand Down Expand Up @@ -145,8 +200,23 @@ def _process_single_prompt(
if temp is not None:
api_params['temperature'] = temp

if self.format_type == data.FormatType.JSON:
api_params.setdefault('response_format', {'type': 'json_object'})
runtime_response_format = normalized_config.get('response_format')
if self.openai_schema and runtime_response_format is None:
self._validate_schema_config()
api_params['response_format'] = self.openai_schema.response_format
elif runtime_response_format is not None:
if self.openai_schema:
# Advanced callers may deliberately override response_format at
# runtime; warn because that bypasses the configured schema.
warnings.warn(
'openai_schema is set but a runtime response_format kwarg '
'was provided; the schema is bypassed for this call.',
UserWarning,
stacklevel=3,
)
api_params['response_format'] = runtime_response_format
elif self.format_type == data.FormatType.JSON:
api_params['response_format'] = {'type': 'json_object'}

if (v := normalized_config.get('max_output_tokens')) is not None:
api_params['max_tokens'] = v
Expand All @@ -160,18 +230,18 @@ def _process_single_prompt(
'logprobs',
'top_logprobs',
'reasoning_effort',
'response_format',
]:
if (v := normalized_config.get(key)) is not None:
api_params[key] = v

response = self._client.chat.completions.create(**api_params)

# Extract the response text using the v1.x response format
output_text = response.choices[0].message.content

return core_types.ScoredOutput(score=1.0, output=output_text)

except exceptions.InferenceConfigError:
raise
except Exception as e:
raise exceptions.InferenceRuntimeError(
f'OpenAI API error: {str(e)}', original=e
Expand Down Expand Up @@ -233,6 +303,8 @@ def infer(
index = future_to_index[future]
try:
results[index] = future.result()
except exceptions.InferenceConfigError:
raise
except Exception as e:
raise exceptions.InferenceRuntimeError(
f'Parallel inference error: {str(e)}', original=e
Expand Down
4 changes: 3 additions & 1 deletion langextract/providers/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from __future__ import annotations

from langextract.providers.schemas import gemini
from langextract.providers.schemas import openai

GeminiSchema = gemini.GeminiSchema # Backward compat
OpenAISchema = openai.OpenAISchema

__all__ = ["GeminiSchema"]
__all__ = ["GeminiSchema", "OpenAISchema"]
Loading