diff --git a/langextract/core/base_model.py b/langextract/core/base_model.py index eda41836..796a78fb 100644 --- a/langextract/core/base_model.py +++ b/langextract/core/base_model.py @@ -13,6 +13,7 @@ # limitations under the License. """Base interfaces for language models.""" + from __future__ import annotations import abc @@ -135,7 +136,7 @@ def infer( """ def infer_batch( - self, prompts: Sequence[str], batch_size: int = 32 # pylint: disable=unused-argument + self, prompts: Sequence[str], batch_size: int = 32 ) -> list[list[types.ScoredOutput]]: """Batch inference with configurable batch size. @@ -143,13 +144,20 @@ def infer_batch( Args: prompts: List of prompts to process. - batch_size: Batch size (currently unused, for future optimization). + batch_size: Batch size hint for providers. + + This method passes `batch_size` through to `infer()` as a keyword + argument. Providers may interpret it to control true server-side + batching (e.g., a batch job size), concurrency, or throttling. Returns: List of lists of ScoredOutput objects. """ + if batch_size <= 0: + raise ValueError('batch_size must be > 0') + results = [] - for output in self.infer(prompts): + for output in self.infer(prompts, batch_size=batch_size): results.append(list(output)) return results diff --git a/langextract/providers/gemini.py b/langextract/providers/gemini.py index a82afe1e..bf99b01a 100644 --- a/langextract/providers/gemini.py +++ b/langextract/providers/gemini.py @@ -13,6 +13,7 @@ # limitations under the License. """Gemini provider for LangExtract.""" + # pylint: disable=duplicate-code from __future__ import annotations @@ -237,6 +238,7 @@ def infer( Yields: Lists of ScoredOutputs. """ + kwargs.pop('batch_size', None) merged_kwargs = self.merge_kwargs(kwargs) config = { diff --git a/langextract/providers/ollama.py b/langextract/providers/ollama.py index c6be9379..b556f7ff 100644 --- a/langextract/providers/ollama.py +++ b/langextract/providers/ollama.py @@ -79,6 +79,7 @@ 2. Pull the model: ollama pull gemma2:2b 3. Ollama server will start automatically when you use extract() """ + # pylint: disable=duplicate-code from __future__ import annotations @@ -256,6 +257,7 @@ def infer( Yields: Lists of ScoredOutputs. """ + kwargs.pop('batch_size', None) combined_kwargs = self.merge_kwargs(kwargs) for prompt in batch_prompts: diff --git a/langextract/providers/openai.py b/langextract/providers/openai.py index 8f45c77f..08a20467 100644 --- a/langextract/providers/openai.py +++ b/langextract/providers/openai.py @@ -13,6 +13,7 @@ # limitations under the License. """OpenAI provider for LangExtract.""" + # pylint: disable=duplicate-code from __future__ import annotations @@ -26,6 +27,7 @@ from langextract.core import exceptions from langextract.core import schema from langextract.core import types as core_types +from langextract.providers import openai_batch from langextract.providers import patterns from langextract.providers import router @@ -46,6 +48,9 @@ class OpenAILanguageModel(base_model.BaseLanguageModel): temperature: float | None = None max_workers: int = 10 _client: Any = dataclasses.field(default=None, repr=False, compare=False) + _batch_cfg: openai_batch.BatchConfig = dataclasses.field( + default_factory=openai_batch.BatchConfig, repr=False, compare=False + ) _extra_kwargs: dict[str, Any] = dataclasses.field( default_factory=dict, repr=False, compare=False ) @@ -99,6 +104,10 @@ def __init__( self.temperature = temperature self.max_workers = max_workers + # Extract batch config before storing remaining kwargs. + batch_cfg_dict = kwargs.pop('batch', None) + self._batch_cfg = openai_batch.BatchConfig.from_dict(batch_cfg_dict) + if not self.api_key: raise exceptions.InferenceConfigError('API key not provided.') @@ -114,6 +123,57 @@ def __init__( ) self._extra_kwargs = kwargs or {} + def _build_chat_completions_body(self, prompt: str, config: dict) -> dict: + """Build a /v1/chat/completions request body for a single prompt.""" + normalized_config = self._normalize_reasoning_params(config) + + system_message = '' + if self.format_type == data.FormatType.JSON: + system_message = ( + 'You are a helpful assistant that responds in JSON format.' + ) + elif self.format_type == data.FormatType.YAML: + system_message = ( + 'You are a helpful assistant that responds in YAML format.' + ) + + messages = [{'role': 'user', 'content': prompt}] + if system_message: + messages.insert(0, {'role': 'system', 'content': system_message}) + + api_params: dict[str, Any] = { + 'model': self.model_id, + 'messages': messages, + 'n': 1, + } + + temp = normalized_config.get('temperature', self.temperature) + if temp is not None: + api_params['temperature'] = temp + + if self.format_type == data.FormatType.JSON: + api_params.setdefault('response_format', {'type': 'json_object'}) + + if (v := normalized_config.get('max_output_tokens')) is not None: + api_params['max_tokens'] = v + if (v := normalized_config.get('top_p')) is not None: + api_params['top_p'] = v + + for key in [ + 'frequency_penalty', + 'presence_penalty', + 'seed', + 'stop', + 'logprobs', + 'top_logprobs', + 'reasoning', + 'response_format', + ]: + if (v := normalized_config.get(key)) is not None: + api_params[key] = v + + return api_params + def _normalize_reasoning_params(self, config: dict) -> dict: """Normalize reasoning parameters for API compatibility. @@ -135,52 +195,7 @@ def _process_single_prompt( ) -> core_types.ScoredOutput: """Process a single prompt and return a ScoredOutput.""" try: - normalized_config = self._normalize_reasoning_params(config) - - system_message = '' - if self.format_type == data.FormatType.JSON: - system_message = ( - 'You are a helpful assistant that responds in JSON format.' - ) - elif self.format_type == data.FormatType.YAML: - system_message = ( - 'You are a helpful assistant that responds in YAML format.' - ) - - messages = [{'role': 'user', 'content': prompt}] - if system_message: - messages.insert(0, {'role': 'system', 'content': system_message}) - - api_params = { - 'model': self.model_id, - 'messages': messages, - 'n': 1, - } - - temp = normalized_config.get('temperature', self.temperature) - if temp is not None: - api_params['temperature'] = temp - - if self.format_type == data.FormatType.JSON: - api_params.setdefault('response_format', {'type': 'json_object'}) - - if (v := normalized_config.get('max_output_tokens')) is not None: - api_params['max_tokens'] = v - if (v := normalized_config.get('top_p')) is not None: - api_params['top_p'] = v - for key in [ - 'frequency_penalty', - 'presence_penalty', - 'seed', - 'stop', - 'logprobs', - 'top_logprobs', - 'reasoning', - 'response_format', - ]: - if (v := normalized_config.get(key)) is not None: - api_params[key] = v - + api_params = self._build_chat_completions_body(prompt, config) response = self._client.chat.completions.create(**api_params) # Extract the response text using the v1.x response format @@ -205,6 +220,7 @@ def infer( Yields: Lists of ScoredOutputs. """ + batch_size = kwargs.pop('batch_size', None) merged_kwargs = self.merge_kwargs(kwargs) config = {} @@ -231,6 +247,33 @@ def infer( if key in merged_kwargs: config[key] = merged_kwargs[key] + # OpenAI Batch API mode (async job + polling) when enabled and threshold met. + if ( + self._batch_cfg.enabled + and len(batch_prompts) >= self._batch_cfg.threshold + ): + try: + texts = openai_batch.infer_batch( + client=self._client, + model_id=self.model_id, + prompts=batch_prompts, + cfg=self._batch_cfg, + request_builder=lambda p: self._build_chat_completions_body( + p, config + ), + batch_size=batch_size, + ) + except exceptions.InferenceError: + raise + except Exception as e: + raise exceptions.InferenceRuntimeError( + f'OpenAI Batch API error: {str(e)}', original=e, provider='OpenAI' + ) from e + + for text in texts: + yield [core_types.ScoredOutput(score=1.0, output=text)] + return + # Use parallel processing for batches larger than 1 if len(batch_prompts) > 1 and self.max_workers > 1: with concurrent.futures.ThreadPoolExecutor( diff --git a/langextract/providers/openai_batch.py b/langextract/providers/openai_batch.py new file mode 100644 index 00000000..3f1df2ff --- /dev/null +++ b/langextract/providers/openai_batch.py @@ -0,0 +1,381 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenAI Batch API helper module for LangExtract. + +This module is intentionally written to be testable without importing the +`openai` package: it accepts a generic client object with the expected +`files.*` and `batches.*` methods. +""" + +from __future__ import annotations + +from collections.abc import Callable, Mapping, Sequence +import dataclasses +import io +import json +import time +from typing import Any + +from langextract.core import exceptions + +_DEFAULT_ENDPOINT = "/v1/chat/completions" + + +@dataclasses.dataclass(slots=True, frozen=True) +class BatchConfig: + """Define and validate OpenAI Batch API configuration. + + Attributes: + enabled: Whether batch mode is enabled. + threshold: Minimum prompts to trigger batch processing. + completion_window: Optional OpenAI completion window string (e.g., "24h"). + If unset, LangExtract will omit it from the batch-create call. + poll_interval: Seconds between status checks. + timeout: Maximum seconds to wait for completion. + max_requests_per_job: Safety cap on the number of requests per batch job. + metadata: Optional metadata dict attached to the batch job. + on_job_create: Optional hook invoked with the created job object. + """ + + enabled: bool = False + threshold: int = 50 + completion_window: str | None = None + poll_interval: int = 10 + timeout: int = 3600 + max_requests_per_job: int = 50000 + metadata: Mapping[str, Any] | None = None + on_job_create: Callable[[Any], None] | None = None + + def __post_init__(self): + validations = [ + (self.threshold >= 1, "batch.threshold must be >= 1"), + (self.poll_interval > 0, "batch.poll_interval must be > 0"), + (self.timeout > 0, "batch.timeout must be > 0"), + ( + self.max_requests_per_job > 0, + "batch.max_requests_per_job must be > 0", + ), + ] + for is_valid, msg in validations: + if not is_valid: + raise ValueError(msg) + + if self.completion_window is not None and not self.completion_window: + raise ValueError( + "batch.completion_window must be a non-empty string when set" + ) + + @classmethod + def from_dict(cls, d: dict | None) -> BatchConfig: + """Create BatchConfig from dictionary, using defaults for missing keys.""" + if not d: + return cls(enabled=False) + + # Allow either {enabled: true, ...} or a truthy dict without enabled. + enabled = bool(d.get("enabled", True)) + return cls( + enabled=enabled, + threshold=int(d.get("threshold", cls.threshold)), + completion_window=d.get("completion_window"), + poll_interval=int(d.get("poll_interval", cls.poll_interval)), + timeout=int(d.get("timeout", cls.timeout)), + max_requests_per_job=int( + d.get("max_requests_per_job", cls.max_requests_per_job) + ), + metadata=d.get("metadata"), + ) + + +def _custom_id(idx: int) -> str: + return f"idx-{idx:06d}" + + +def _extract_text_from_response_body(body: Mapping[str, Any]) -> str: + try: + choices = body.get("choices") + if not choices: + raise KeyError("choices") + message = choices[0].get("message") or {} + content = message.get("content") + if content is None: + raise KeyError("message.content") + return content + except Exception as e: + raise exceptions.InferenceRuntimeError( + f"Failed to extract text from OpenAI batch response body: {e}", + original=e, + provider="OpenAI", + ) from e + + +def _content_to_text(content: Any) -> str: + """Best-effort conversion of OpenAI SDK file content responses to text.""" + if content is None: + result = "" + elif isinstance(content, str): + result = content + elif isinstance(content, bytes): + result = content.decode("utf-8") + else: + text = getattr(content, "text", None) + if isinstance(text, str): + result = text + else: + read = getattr(content, "read", None) + if callable(read): + data = read() + if isinstance(data, bytes): + result = data.decode("utf-8") + elif isinstance(data, str): + result = data + else: + result = str(content) + else: + result = str(content) + + return result + + +def infer_batch( + *, + client: Any, + model_id: str, + prompts: Sequence[str], + cfg: BatchConfig, + request_builder: Callable[[str], Mapping[str, Any]], + endpoint: str = _DEFAULT_ENDPOINT, + batch_size: int | None = None, +) -> list[str]: + """Execute batch inference on multiple prompts using OpenAI Batch API. + + Args: + client: OpenAI client instance (or compatible fake for testing). + model_id: OpenAI model id. + prompts: Prompt strings. + cfg: Batch configuration. + request_builder: Callable that produces the request body for one prompt. + endpoint: The OpenAI endpoint string for the batch (default chat completions). + batch_size: Optional per-call limit that caps requests per batch job. + + Returns: + List of output texts aligned with prompts. + + Raises: + InferenceRuntimeError: On job failure, timeout, or per-item errors. + """ + if not prompts: + return [] + + if not cfg.enabled: + raise exceptions.InferenceConfigError( + "OpenAI batch mode is not enabled (cfg.enabled=False)" + ) + + if batch_size is not None and batch_size <= 0: + raise ValueError("batch_size must be > 0") + + per_job_limit = cfg.max_requests_per_job + if batch_size is not None: + per_job_limit = min(per_job_limit, batch_size) + + outputs: list[str] = [""] * len(prompts) + + # Submit in chunks to avoid huge jobs and to honor batch_size. + for offset in range(0, len(prompts), per_job_limit): + chunk = list(prompts[offset : offset + per_job_limit]) + chunk_outputs = _infer_batch_one_job( + client=client, + model_id=model_id, + prompts=chunk, + cfg=cfg, + request_builder=request_builder, + endpoint=endpoint, + base_index=offset, + ) + outputs[offset : offset + len(chunk_outputs)] = chunk_outputs + + return outputs + + +def _infer_batch_one_job( + *, + client: Any, + model_id: str, + prompts: Sequence[str], + cfg: BatchConfig, + request_builder: Callable[[str], Mapping[str, Any]], + endpoint: str, + base_index: int, +) -> list[str]: + lines: list[str] = [] + for i, prompt in enumerate(prompts): + idx = base_index + i + body = dict(request_builder(prompt)) + body.setdefault("model", model_id) + + req = { + "custom_id": _custom_id(idx), + "method": "POST", + "url": endpoint, + "body": body, + } + lines.append(json.dumps(req, ensure_ascii=False)) + + jsonl = "\n".join(lines) + "\n" + + # Use an in-memory buffer with a name attribute for broad compatibility. + buf = io.BytesIO(jsonl.encode("utf-8")) + buf.name = "langextract_openai_batch_input.jsonl" # type: ignore[attr-defined] + + try: + input_file = client.files.create(file=buf, purpose="batch") + input_file_id = getattr(input_file, "id", None) or input_file.get("id") + except Exception as e: + raise exceptions.InferenceRuntimeError( + f"OpenAI Batch API input file upload failed: {e}", + original=e, + provider="OpenAI", + ) from e + + try: + create_kwargs: dict[str, Any] = { + "input_file_id": input_file_id, + "endpoint": endpoint, + "metadata": dict(cfg.metadata or {}), + } + if cfg.completion_window: + create_kwargs["completion_window"] = cfg.completion_window + + job = client.batches.create(**create_kwargs) + if cfg.on_job_create: + cfg.on_job_create(job) + batch_id = getattr(job, "id", None) or job.get("id") + except Exception as e: + raise exceptions.InferenceRuntimeError( + f"OpenAI Batch API job create failed: {e}", + original=e, + provider="OpenAI", + ) from e + + start = time.time() + last_status = None + while True: + if time.time() - start > cfg.timeout: + raise exceptions.InferenceRuntimeError( + f"OpenAI Batch API job timed out after {cfg.timeout}s" + f" (last_status={last_status})", + provider="OpenAI", + ) + + try: + job = client.batches.retrieve(batch_id) + except Exception as e: + raise exceptions.InferenceRuntimeError( + f"OpenAI Batch API job retrieve failed: {e}", + original=e, + provider="OpenAI", + ) from e + + status = getattr(job, "status", None) or job.get("status") + last_status = status + + if status in ("completed", "failed", "expired", "cancelled"): + break + + time.sleep(cfg.poll_interval) + + if status != "completed": + err = getattr(job, "error", None) or job.get("error") + raise exceptions.InferenceRuntimeError( + f"OpenAI Batch API job did not complete (status={status}, error={err})", + provider="OpenAI", + ) + + output_file_id = ( + getattr(job, "output_file_id", None) + or job.get("output_file_id") + or getattr(job, "output_file", None) + or job.get("output_file") + ) + if not output_file_id: + raise exceptions.InferenceRuntimeError( + "OpenAI Batch API job completed but has no output_file_id", + provider="OpenAI", + ) + + try: + content = client.files.content(output_file_id) + text = _content_to_text(content) + except Exception as e: + raise exceptions.InferenceRuntimeError( + f"OpenAI Batch API output download failed: {e}", + original=e, + provider="OpenAI", + ) from e + + # Parse output JSONL. + outputs_by_idx: dict[int, str] = {} + errors: list[str] = [] + for raw_line in text.splitlines(): + line = raw_line.strip() + if not line: + continue + try: + obj = json.loads(line) + except Exception as e: + raise exceptions.InferenceRuntimeError( + f"OpenAI Batch API output JSONL parse error: {e}", + original=e, + provider="OpenAI", + ) from e + + cid = obj.get("custom_id") + if not cid or not isinstance(cid, str) or not cid.startswith("idx-"): + continue + + try: + idx = int(cid.split("-", 1)[1]) + except ValueError: + continue + + item_error = obj.get("error") + if item_error: + errors.append(f"{cid}: {item_error}") + continue + + response = obj.get("response") or {} + body = response.get("body") or {} + try: + outputs_by_idx[idx] = _extract_text_from_response_body(body) + except exceptions.InferenceRuntimeError as e: + errors.append(f"{cid}: {e}") + + if errors: + raise exceptions.InferenceRuntimeError( + "OpenAI Batch API per-item errors: " + "; ".join(errors), + provider="OpenAI", + ) + + # Ensure we have every prompt. + chunk_outputs: list[str] = [] + for i in range(base_index, base_index + len(prompts)): + if i not in outputs_by_idx: + raise exceptions.InferenceRuntimeError( + f"OpenAI Batch API missing output for custom_id={_custom_id(i)}", + provider="OpenAI", + ) + chunk_outputs.append(outputs_by_idx[i]) + + return chunk_outputs diff --git a/tests/openai_batch_test.py b/tests/openai_batch_test.py new file mode 100644 index 00000000..ae27db7d --- /dev/null +++ b/tests/openai_batch_test.py @@ -0,0 +1,263 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for OpenAI Batch API helper.""" + +# pylint: disable=too-few-public-methods + +from __future__ import annotations + +import json +import types as py_types +import unittest +from unittest import mock + +from langextract.core import exceptions +from langextract.providers import openai_batch + + +class _FakeFiles: + + def __init__(self): + self.created = [] + self._content_by_id = {} + + def create(self, *, file, purpose): + self.created.append({"file": file, "purpose": purpose}) + return py_types.SimpleNamespace(id=f"file-{len(self.created)}") + + def content(self, file_id): + return py_types.SimpleNamespace(text=self._content_by_id[file_id]) + + def set_content(self, file_id: str, text: str) -> None: + self._content_by_id[file_id] = text + + +class _FakeBatches: + + def __init__(self): + self.created = [] + self._retrieve_queue = [] + + def create(self, **kwargs): + self.created.append(kwargs) + return py_types.SimpleNamespace(id=f"batch-{len(self.created)}") + + def retrieve(self, _batch_id): + if not self._retrieve_queue: + raise RuntimeError("retrieve queue empty") + return self._retrieve_queue.pop(0) + + def push_retrieve(self, obj): + self._retrieve_queue.append(obj) + + +class _FakeClient: + + def __init__(self): + self.files = _FakeFiles() + self.batches = _FakeBatches() + + +def _make_output_line(custom_id: str, content: str) -> str: + obj = { + "custom_id": custom_id, + "response": { + "body": { + "choices": [ + {"message": {"content": content}}, + ] + } + }, + "error": None, + } + return json.dumps(obj) + + +class OpenAIBatchHelperTest(unittest.TestCase): + + @mock.patch( + "langextract.providers.openai_batch.time.sleep", return_value=None + ) + def test_orders_results_by_custom_id(self, _mock_sleep): + client = _FakeClient() + + # Job status: in_progress -> completed. + client.batches.push_retrieve(py_types.SimpleNamespace(status="in_progress")) + client.batches.push_retrieve( + py_types.SimpleNamespace(status="completed", output_file_id="out-1") + ) + + # Output JSONL is intentionally out of order. + out = "\n".join([ + _make_output_line("idx-000001", "B"), + _make_output_line("idx-000000", "A"), + ]) + client.files.set_content("out-1", out) + + cfg = openai_batch.BatchConfig( + enabled=True, + threshold=1, + completion_window="24h", + poll_interval=1, + timeout=5, + ) + + res = openai_batch.infer_batch( + client=client, + model_id="gpt-test", + prompts=["p0", "p1"], + cfg=cfg, + request_builder=lambda p: { + "model": "gpt-test", + "messages": [{"role": "user", "content": p}], + }, + ) + + self.assertEqual(res, ["A", "B"]) + + @mock.patch( + "langextract.providers.openai_batch.time.sleep", return_value=None + ) + def test_splits_jobs_by_batch_size(self, _mock_sleep): + client = _FakeClient() + + # For 3 jobs we will do: [in_progress, completed] x 3 + for job_idx in range(3): + client.batches.push_retrieve( + py_types.SimpleNamespace(status="in_progress") + ) + client.batches.push_retrieve( + py_types.SimpleNamespace( + status="completed", output_file_id=f"out-{job_idx}" + ) + ) + + # Each job returns exactly the lines for its indices. + client.files.set_content( + "out-0", + "\n".join([ + _make_output_line("idx-000000", "0"), + _make_output_line("idx-000001", "1"), + ]), + ) + client.files.set_content( + "out-1", + "\n".join([ + _make_output_line("idx-000002", "2"), + _make_output_line("idx-000003", "3"), + ]), + ) + client.files.set_content( + "out-2", + _make_output_line("idx-000004", "4"), + ) + + cfg = openai_batch.BatchConfig( + enabled=True, + threshold=1, + completion_window="24h", + poll_interval=1, + timeout=5, + max_requests_per_job=100, + ) + + prompts = ["p0", "p1", "p2", "p3", "p4"] + res = openai_batch.infer_batch( + client=client, + model_id="gpt-test", + prompts=prompts, + cfg=cfg, + request_builder=lambda p: { + "model": "gpt-test", + "messages": [{"role": "user", "content": p}], + }, + batch_size=2, + ) + + self.assertEqual(res, ["0", "1", "2", "3", "4"]) + self.assertEqual(len(client.batches.created), 3) + self.assertEqual(len(client.files.created), 3) + + def test_item_error_raises(self): + client = _FakeClient() + + client.batches.push_retrieve( + py_types.SimpleNamespace(status="completed", output_file_id="out-1") + ) + + obj = { + "custom_id": "idx-000000", + "error": {"message": "boom"}, + "response": None, + } + client.files.set_content("out-1", json.dumps(obj)) + + cfg = openai_batch.BatchConfig( + enabled=True, + threshold=1, + completion_window="24h", + poll_interval=1, + timeout=5, + ) + + with self.assertRaises(exceptions.InferenceRuntimeError): + _ = openai_batch.infer_batch( + client=client, + model_id="gpt-test", + prompts=["p0"], + cfg=cfg, + request_builder=lambda p: { + "model": "gpt-test", + "messages": [{"role": "user", "content": p}], + }, + ) + + @mock.patch( + "langextract.providers.openai_batch.time.sleep", return_value=None + ) + def test_completion_window_is_optional(self, _mock_sleep): + client = _FakeClient() + + client.batches.push_retrieve( + py_types.SimpleNamespace(status="completed", output_file_id="out-1") + ) + client.files.set_content("out-1", _make_output_line("idx-000000", "ok")) + + cfg = openai_batch.BatchConfig( + enabled=True, + threshold=1, + completion_window=None, + poll_interval=1, + timeout=5, + ) + + res = openai_batch.infer_batch( + client=client, + model_id="gpt-test", + prompts=["p0"], + cfg=cfg, + request_builder=lambda p: { + "model": "gpt-test", + "messages": [{"role": "user", "content": p}], + }, + ) + + self.assertEqual(res, ["ok"]) + self.assertEqual(len(client.batches.created), 1) + self.assertNotIn("completion_window", client.batches.created[0]) + + +if __name__ == "__main__": + unittest.main()