diff --git a/check_env.py b/check_env.py new file mode 100644 index 00000000..f3885468 --- /dev/null +++ b/check_env.py @@ -0,0 +1,2 @@ +import os +print('MATRIXAI_API_KEY:', repr(os.environ.get('MATRIXAI_API_KEY'))) \ No newline at end of file diff --git a/langextract-doubao/.gitignore b/langextract-doubao/.gitignore new file mode 100644 index 00000000..37e861dd --- /dev/null +++ b/langextract-doubao/.gitignore @@ -0,0 +1,50 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so + +# Distribution / packaging +build/ +dist/ +*.egg-info/ +.eggs/ +*.egg + +# Virtual environments +.env +.venv +env/ +venv/ +ENV/ + +# Testing & coverage +.pytest_cache/ +.tox/ +htmlcov/ +.coverage +.coverage.* + +# Type checking +.mypy_cache/ +.dmypy.json +dmypy.json +.pytype/ + +# IDEs +.idea/ +.vscode/ +*.swp +*.swo + +# OS-specific +.DS_Store +Thumbs.db + +# Logs +*.log + +# Temp files +*.tmp +*.bak +*.backup diff --git a/langextract-doubao/LICENSE b/langextract-doubao/LICENSE new file mode 100644 index 00000000..6460eee0 --- /dev/null +++ b/langextract-doubao/LICENSE @@ -0,0 +1,13 @@ +# LICENSE + +TODO: Add your license here. + +This is a placeholder license file for your provider plugin. +Please replace this with your actual license before distribution. + +Common options include: +- Apache License 2.0 +- MIT License +- BSD License +- GPL License +- Proprietary/Commercial License diff --git a/langextract-doubao/README.md b/langextract-doubao/README.md new file mode 100644 index 00000000..265eedb8 --- /dev/null +++ b/langextract-doubao/README.md @@ -0,0 +1,41 @@ + # LangExtract doubao Provider + +A provider plugin for LangExtract that supports doubao models. + +## Installation + +```bash +pip install -e . +``` + +## Supported Model IDs + +- `doubao*`: Models matching pattern ^doubao + +## Environment Variables + +- `DOUBAO_API_KEY`: API key for authentication + +## Usage + +```python +import langextract as lx + +result = lx.extract( + text="Your document here", + model_id="doubao-model", + prompt_description="Extract entities", + examples=[...] +) +``` + +## Development + +1. Install in development mode: `pip install -e .` +2. Run tests: `python test_plugin.py` +3. Build package: `python -m build` +4. Publish to PyPI: `twine upload dist/*` + +## License + +Apache License 2.0 diff --git a/langextract-doubao/langextract_doubao/__init__.py b/langextract-doubao/langextract_doubao/__init__.py new file mode 100644 index 00000000..e46a0dd9 --- /dev/null +++ b/langextract-doubao/langextract_doubao/__init__.py @@ -0,0 +1,6 @@ +"""LangExtract provider plugin for doubao.""" + +from langextract_doubao.provider import doubaoLanguageModel + +__all__ = ['doubaoLanguageModel'] +__version__ = "0.1.0" diff --git a/langextract-doubao/langextract_doubao/provider.py b/langextract-doubao/langextract_doubao/provider.py new file mode 100644 index 00000000..b90c6008 --- /dev/null +++ b/langextract-doubao/langextract_doubao/provider.py @@ -0,0 +1,80 @@ +"""Provider implementation for doubao.""" + +import os +import langextract as lx +from langextract_doubao.schema import doubaoSchema +from langextract.core.base_model import BaseLanguageModel +from langextract.core.types import ScoredOutput +from volcenginesdkarkruntime import Ark + + +@lx.providers.registry.register(r'^doubao', priority=10) +class doubaoLanguageModel(BaseLanguageModel): + """LangExtract provider for doubao. + + This provider handles model IDs matching: ['^doubao'] + """ + + def __init__(self, model_id: str, api_key: str = None, **kwargs): + """Initialize the doubao provider. + + Args: + model_id: The model identifier. + api_key: API key for authentication. + **kwargs: Additional provider-specific parameters. + """ + super().__init__() + self.model_id = model_id + self.api_key = api_key or os.environ.get('ARK_API_KEY') + self.response_schema = kwargs.get('response_schema') + self.structured_output = kwargs.get('structured_output', False) + + self.client = Ark( + base_url="https://ark.cn-beijing.volces.com/api/v3", + api_key=self.api_key + ) + self._extra_kwargs = kwargs + + @classmethod + def get_schema_class(cls): + """Tell LangExtract about our schema support.""" + from langextract_doubao.schema import doubaoSchema + return doubaoSchema + + def apply_schema(self, schema_instance): + """Apply or clear schema configuration.""" + super().apply_schema(schema_instance) + if schema_instance: + config = schema_instance.to_provider_config() + self.response_schema = config.get('response_schema') + self.structured_output = config.get('structured_output', False) + else: + self.response_schema = None + self.structured_output = False + + def infer(self, batch_prompts, **kwargs): + """Run inference on a batch of prompts. + + Args: + batch_prompts: List of prompts to process. + **kwargs: Additional inference parameters. + + Yields: + Lists of ScoredOutput objects, one per prompt. + """ + for prompt in batch_prompts: + api_params = { + "model": self.model_id, + "messages": [ + {"role": "user", "content": prompt}, + {"role": "system", "content": "You are an ai assistant"} + ] + } + + completion = self.client.chat.completions.create(**api_params) + text = getattr(completion.choices[0].message, "content", "") + # 调试:打印原始输出 + print("[DEBUG] Doubao raw output:", repr(text)) + if not text: + raise RuntimeError("Doubao returned empty output") + yield [ScoredOutput(score=1.0, output=text)] diff --git a/langextract-doubao/langextract_doubao/schema.py b/langextract-doubao/langextract_doubao/schema.py new file mode 100644 index 00000000..a04243bb --- /dev/null +++ b/langextract-doubao/langextract_doubao/schema.py @@ -0,0 +1,75 @@ +"""Schema implementation for doubao provider.""" + +import langextract as lx +from langextract.core.schema import BaseSchema + + +class doubaoSchema(BaseSchema): + """Schema implementation for doubao structured output.""" + + def __init__(self, schema_dict: dict): + """Initialize the schema with a dictionary.""" + self._schema_dict = schema_dict + + @property + def schema_dict(self) -> dict: + """Return the schema dictionary.""" + return self._schema_dict + + @classmethod + def from_examples(cls, examples_data, attribute_suffix="_attributes"): + """Build schema from example extractions. + + Args: + examples_data: Sequence of ExampleData objects. + attribute_suffix: Suffix for attribute fields. + + Returns: + A configured doubaoSchema instance. + """ + extraction_types = {} + for example in examples_data: + for extraction in example.extractions: + class_name = extraction.extraction_class + if class_name not in extraction_types: + extraction_types[class_name] = set() + if extraction.attributes: + extraction_types[class_name].update(extraction.attributes.keys()) + + schema_dict = { + "type": "object", + "properties": { + "extractions": { + "type": "array", + "items": {"type": "object"} + } + }, + "required": ["extractions"] + } + + return cls(schema_dict) + + def to_provider_config(self) -> dict: + """Convert to provider-specific configuration. + + Returns: + Dictionary of provider-specific configuration. + """ + return { + "response_schema": self._schema_dict, + "structured_output": True + } + + @property + def supports_strict_mode(self) -> bool: + """Whether this schema guarantees valid structured output. + + Returns: + True if the provider enforces valid JSON output. + """ + return False # Set to True only if your provider guarantees valid JSON + + @property + def requires_raw_output(self) -> bool: + """返回 True 表示模型输出原生 JSON(无围栏)。""" + return True # 或 False,根据豆包 API 行为调整 \ No newline at end of file diff --git a/langextract-doubao/pyproject.toml b/langextract-doubao/pyproject.toml new file mode 100644 index 00000000..77a192e8 --- /dev/null +++ b/langextract-doubao/pyproject.toml @@ -0,0 +1,22 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "langextract-doubao" +version = "0.1.0" +description = "LangExtract provider plugin for doubao" +readme = "README.md" +requires-python = ">=3.10" +license = {text = "Apache-2.0"} +dependencies = [ + "langextract>=1.0.0", + # Add your provider's SDK dependencies here +] + +[project.entry-points."langextract.providers"] +doubao = "langextract_doubao.provider:doubaoLanguageModel" + +[tool.setuptools.packages.find] +where = ["."] +include = ["langextract_doubao*"] diff --git a/langextract-doubao/test_plugin.py b/langextract-doubao/test_plugin.py new file mode 100644 index 00000000..47b75bf3 --- /dev/null +++ b/langextract-doubao/test_plugin.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +"""Test script for doubao provider (Step 5 checklist).""" + +import re +import sys +import langextract as lx +from langextract.providers import registry + +try: + from langextract_doubao import doubaoLanguageModel +except ImportError: + print("ERROR: Plugin not installed. Run: pip install -e .") + sys.exit(1) + +lx.providers.load_plugins_once() + +PROVIDER_CLS_NAME = "doubaoLanguageModel" +PATTERNS = ['^doubao'] + +def _example_id(pattern: str) -> str: + """Generate test model ID from pattern.""" + base = re.sub(r'^\^', '', pattern) + m = re.match(r"[A-Za-z0-9._-]+", base) + base = m.group(0) if m else (base or "model") + return f"{base}-test" + +sample_ids = [_example_id(p) for p in PATTERNS] +sample_ids.append("unknown-model") + +print("Testing doubao Provider - Step 5 Checklist:") +print("-" * 50) + +# 1 & 2. Provider registration + pattern matching via resolve() +print("1–2. Provider registration & pattern matching") +for model_id in sample_ids: + try: + provider_class = registry.resolve(model_id) + ok = provider_class.__name__ == PROVIDER_CLS_NAME + status = "✓" if (ok or model_id == "unknown-model") else "✗" + note = "expected" if ok else ("expected (no provider)" if model_id == "unknown-model" else "unexpected provider") + print(f" {status} {model_id} -> {provider_class.__name__ if ok else 'resolved'} {note}") + except Exception as e: + if model_id == "unknown-model": + print(f" ✓ {model_id}: No provider found (expected)") + else: + print(f" ✗ {model_id}: resolve() failed: {e}") + +# 3. Inference sanity check +print("\n3. Test inference with sample prompts") +try: + model_id = sample_ids[0] if sample_ids[0] != "unknown-model" else (_example_id(PATTERNS[0]) if PATTERNS else "test-model") + provider = doubaoLanguageModel(model_id=model_id) + prompts = ["Test prompt 1", "Test prompt 2"] + results = list(provider.infer(prompts)) + print(f" ✓ Inference returned {len(results)} results") + for i, result in enumerate(results): + try: + out = result[0].output if result and result[0] else None + print(f" ✓ Result {i+1}: {(out or '')[:60]}...") + except Exception: + print(f" ✗ Result {i+1}: Unexpected result shape: {result}") +except Exception as e: + print(f" ✗ ERROR: {e}") + +# 4. Test schema creation and application +print("\n4. Test schema creation and application") +try: + from langextract_doubao.schema import doubaoSchema + from langextract import data + + examples = [ + data.ExampleData( + text="Test text", + extractions=[ + data.Extraction( + extraction_class="entity", + extraction_text="test", + attributes={"type": "example"} + ) + ] + ) + ] + + schema = doubaoSchema.from_examples(examples) + print(f" ✓ Schema created (keys={list(schema.schema_dict.keys())})") + + schema_class = doubaoLanguageModel.get_schema_class() + print(f" ✓ Provider schema class: {schema_class.__name__}") + + provider = doubaoLanguageModel(model_id=_example_id(PATTERNS[0]) if PATTERNS else "test-model") + provider.apply_schema(schema) + print(f" ✓ Schema applied: response_schema={provider.response_schema is not None} structured={getattr(provider, 'structured_output', False)}") +except Exception as e: + print(f" ✗ ERROR: {e}") + +# 5. Test factory integration +print("\n5. Test factory integration") +try: + from langextract import factory + config = factory.ModelConfig( + model_id=_example_id(PATTERNS[0]) if PATTERNS else "test-model", + provider="doubaoLanguageModel" + ) + model = factory.create_model(config) + print(f" ✓ Factory created: {type(model).__name__}") +except Exception as e: + print(f" ✗ ERROR: {e}") + +print("\n" + "-" * 50) +print("✅ Testing complete!") diff --git a/langextract-matrixai/langextract_matrixai/__init__.py b/langextract-matrixai/langextract_matrixai/__init__.py new file mode 100644 index 00000000..a0d654b6 --- /dev/null +++ b/langextract-matrixai/langextract_matrixai/__init__.py @@ -0,0 +1,6 @@ +"""LangExtract provider plugin for matrixai.""" + +from langextract_matrixai.provider import matrixaiLanguageModel + +__all__ = ['matrixaiLanguageModel'] +__version__ = "0.1.0" diff --git a/langextract-matrixai/langextract_matrixai/provider.py b/langextract-matrixai/langextract_matrixai/provider.py new file mode 100644 index 00000000..a28b7449 --- /dev/null +++ b/langextract-matrixai/langextract_matrixai/provider.py @@ -0,0 +1,111 @@ +"""Provider implementation for matrixai.""" + +import os +import langextract as lx +from langextract_matrixai.schema import matrixaiSchema + +from langextract.core.base_model import BaseLanguageModel +from langextract.core.types import ScoredOutput +from openai import OpenAI + +@lx.providers.registry.register( + r'^matrixai', + r'^deepseek', # Also register for deepseek model IDs + priority=30 # Higher priority than Ollama (which has priority=10) +) +class matrixaiLanguageModel(BaseLanguageModel): + """LangExtract provider for matrixai. + + This provider handles model IDs matching: ['^matrixai'] + """ + + def __init__(self, model_id: str, api_key: str = None, **kwargs): + """Initialize the matrixai provider. + + Args: + model_id: The model identifier. + api_key: API key for authentication. + **kwargs: Additional provider-specific parameters. + """ + super().__init__() + self.model_id = model_id + + # Debug: Print which method is used to get the API key + if api_key: + self.api_key = api_key + print("DEBUG: Using API key passed as parameter") + elif os.environ.get('MATRIXAI_API_KEY'): + self.api_key = os.environ.get('MATRIXAI_API_KEY') + print("DEBUG: Using MATRIXAI_API_KEY environment variable") + elif os.environ.get('DEEPSEEK_API_KEY'): + self.api_key = os.environ.get('DEEPSEEK_API_KEY') + print("DEBUG: Using DEEPSEEK_API_KEY environment variable") + else: + raise ValueError( + "API key is required for matrixai provider. " + "Please set MATRIXAI_API_KEY or DEEPSEEK_API_KEY environment variable, " + "or pass api_key parameter." + ) + + # Check if API key is available, raise informative error if not + if not self.api_key: + raise ValueError( + "API key is required for matrixai provider. " + "Please set MATRIXAI_API_KEY or DEEPSEEK_API_KEY environment variable, " + "or pass api_key parameter." + ) + + self.response_schema = kwargs.get('response_schema') + self.structured_output = kwargs.get('structured_output', True) + self.base_url = os.environ.get("MATRIXAI_BASE_URL", "https://api.deepseek.com") + self.client = OpenAI( + base_url=self.base_url, + api_key=self.api_key + ) + self._extra_kwargs = kwargs + + @classmethod + def get_schema_class(cls): + """Tell LangExtract about our schema support.""" + from langextract_matrixai.schema import matrixaiSchema + return matrixaiSchema + + def apply_schema(self, schema_instance): + """Apply or clear schema configuration.""" + super().apply_schema(schema_instance) + if schema_instance: + config = schema_instance.to_provider_config() + self.response_schema = config.get('response_schema') + self.structured_output = config.get('structured_output', False) + else: + self.response_schema = None + self.structured_output = False + + def infer(self, batch_prompts, **kwargs): + """Run inference on a batch of prompts. + + Args: + batch_prompts: List of prompts to process. + **kwargs: Additional inference parameters. + + Yields: + Lists of ScoredOutput objects, one per prompt. + """ + for prompt in batch_prompts: + api_params = { + "model": self.model_id, + "messages": [ + {"role": "user", "content": prompt}, + {"role": "system", "content": "你是教研分析助手"} + ], + + "stream": False, + } + + completion = self.client.chat.completions.create(**api_params) + text = getattr(completion.choices[0].message, "content", "") + + if not text: + raise RuntimeError("MatrixAI returned empty output") + yield [ScoredOutput(score=1.0, output=text)] + \ No newline at end of file diff --git a/langextract-matrixai/langextract_matrixai/schema.py b/langextract-matrixai/langextract_matrixai/schema.py new file mode 100644 index 00000000..5e1b598d --- /dev/null +++ b/langextract-matrixai/langextract_matrixai/schema.py @@ -0,0 +1,74 @@ +"""Schema implementation for matrixai provider.""" + +import langextract as lx +from langextract.core.schema import BaseSchema + +class matrixaiSchema(BaseSchema): + """Schema implementation for matrixai structured output.""" + + def __init__(self, schema_dict: dict): + """Initialize the schema with a dictionary.""" + self._schema_dict = schema_dict + + @property + def schema_dict(self) -> dict: + """Return the schema dictionary.""" + return self._schema_dict + + @classmethod + def from_examples(cls, examples_data, attribute_suffix="_attributes"): + """Build schema from example extractions. + + Args: + examples_data: Sequence of ExampleData objects. + attribute_suffix: Suffix for attribute fields. + + Returns: + A configured matrixaiSchema instance. + """ + extraction_types = {} + for example in examples_data: + for extraction in example.extractions: + class_name = extraction.extraction_class + if class_name not in extraction_types: + extraction_types[class_name] = set() + if extraction.attributes: + extraction_types[class_name].update(extraction.attributes.keys()) + + schema_dict = { + "type": "object", + "properties": { + "extractions": { + "type": "array", + "items": {"type": "object"} + } + }, + "required": ["extractions"] + } + + return cls(schema_dict) + + def to_provider_config(self) -> dict: + """Convert to provider-specific configuration. + + Returns: + Dictionary of provider-specific configuration. + """ + return { + "response_schema": self._schema_dict, + "structured_output": True + } + + @property + def supports_strict_mode(self) -> bool: + """Whether this schema guarantees valid structured output. + + Returns: + True if the provider enforces valid JSON output. + """ + return False # Set to True only if your provider guarantees valid JSON + + @property + def requires_raw_output(self) -> bool: + """返回 True 表示模型输出原生 JSON(无围栏)。""" + return True \ No newline at end of file diff --git a/langextract-matrixai/test_plugin.py b/langextract-matrixai/test_plugin.py new file mode 100644 index 00000000..09879612 --- /dev/null +++ b/langextract-matrixai/test_plugin.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +"""Test script for matrixai provider (Step 5 checklist).""" + +import re +import sys +import langextract as lx +from langextract.providers import registry + +try: + from langextract_matrixai import matrixaiLanguageModel +except ImportError: + print("ERROR: Plugin not installed. Run: pip install -e .") + sys.exit(1) + +lx.providers.load_plugins_once() + +PROVIDER_CLS_NAME = "matrixaiLanguageModel" +PATTERNS = ['^matrixai'] + +def _example_id(pattern: str) -> str: + """Generate test model ID from pattern.""" + base = re.sub(r'^\^', '', pattern) + m = re.match(r"[A-Za-z0-9._-]+", base) + base = m.group(0) if m else (base or "model") + return f"{base}-test" + +sample_ids = [_example_id(p) for p in PATTERNS] +sample_ids.append("unknown-model") + +print("Testing matrixai Provider - Step 5 Checklist:") +print("-" * 50) + +# 1 & 2. Provider registration + pattern matching via resolve() +print("1–2. Provider registration & pattern matching") +for model_id in sample_ids: + try: + provider_class = registry.resolve(model_id) + ok = provider_class.__name__ == PROVIDER_CLS_NAME + status = "✓" if (ok or model_id == "unknown-model") else "✗" + note = "expected" if ok else ("expected (no provider)" if model_id == "unknown-model" else "unexpected provider") + print(f" {status} {model_id} -> {provider_class.__name__ if ok else 'resolved'} {note}") + except Exception as e: + if model_id == "unknown-model": + print(f" ✓ {model_id}: No provider found (expected)") + else: + print(f" ✗ {model_id}: resolve() failed: {e}") + +# 3. Inference sanity check +print("\n3. Test inference with sample prompts") +try: + model_id = sample_ids[0] if sample_ids[0] != "unknown-model" else (_example_id(PATTERNS[0]) if PATTERNS else "test-model") + # Use a dummy API key for testing initialization only + provider = matrixaiLanguageModel(model_id=model_id, api_key="dummy-key-for-test") + # Only test initialization, skip actual API call for unit testing + print(f" ✓ Provider initialized successfully with model: {provider.model_id}") + # Optionally, we could mock the API call or skip it during testing +except Exception as e: + print(f" ✗ ERROR: {e}") + +# 4. Test schema creation and application +print("\n4. Test schema creation and application") +try: + from langextract_matrixai.schema import matrixaiSchema + from langextract import data + + examples = [ + data.ExampleData( + text="Test text", + extractions=[ + data.Extraction( + extraction_class="entity", + extraction_text="test", + attributes={"type": "example"} + ) + ] + ) + ] + + schema = matrixaiSchema.from_examples(examples) + print(f" ✓ Schema created (keys={list(schema.schema_dict.keys())})") + + schema_class = matrixaiLanguageModel.get_schema_class() + print(f" ✓ Provider schema class: {schema_class.__name__}") + + # Use a dummy API key for testing initialization only + provider = matrixaiLanguageModel(model_id=_example_id(PATTERNS[0]) if PATTERNS else "test-model", api_key="dummy-key-for-test") + provider.apply_schema(schema) + print(f" ✓ Schema applied: response_schema={provider.response_schema is not None} structured={getattr(provider, 'structured_output', False)}") +except Exception as e: + print(f" ✗ ERROR: {e}") + +# 5. Test factory integration +print("\n5. Test factory integration") +try: + from langextract import factory + # Include api_key in provider_kwargs for proper initialization + config = factory.ModelConfig( + model_id=_example_id(PATTERNS[0]) if PATTERNS else "test-model", + provider="matrixaiLanguageModel", + provider_kwargs={"api_key": "dummy-key-for-test"} # Provide dummy key for testing + ) + model = factory.create_model(config) + print(f" ✓ Factory created: {type(model).__name__}") +except Exception as e: + print(f" ✗ ERROR: {e}") + +print("\n" + "-" * 50) +print("✅ Testing complete!") diff --git a/scripts/create_provider_plugin.py b/scripts/create_provider_plugin.py index e7075a4f..74d8be36 100755 --- a/scripts/create_provider_plugin.py +++ b/scripts/create_provider_plugin.py @@ -129,36 +129,36 @@ def create_provider( schema_init = ( """ - self.response_schema = kwargs.get('response_schema') - self.structured_output = kwargs.get('structured_output', False)""" + self.response_schema = kwargs.get('response_schema') + self.structured_output = kwargs.get('structured_output', False)""" if with_schema else "" ) schema_methods = f""" - @classmethod - def get_schema_class(cls): - \"\"\"Tell LangExtract about our schema support.\"\"\" - from langextract_{package_name}.schema import {provider_name}Schema - return {provider_name}Schema - - def apply_schema(self, schema_instance): - \"\"\"Apply or clear schema configuration.\"\"\" - super().apply_schema(schema_instance) - if schema_instance: - config = schema_instance.to_provider_config() - self.response_schema = config.get('response_schema') - self.structured_output = config.get('structured_output', False) - else: - self.response_schema = None - self.structured_output = False""" if with_schema else "" + @classmethod + def get_schema_class(cls): + \"\"\"Tell LangExtract about our schema support.\"\"\" + from langextract_{package_name}.schema import {provider_name}Schema + return {provider_name}Schema + + def apply_schema(self, schema_instance): + \"\"\"Apply or clear schema configuration.\"\"\" + super().apply_schema(schema_instance) + if schema_instance: + config = schema_instance.to_provider_config() + self.response_schema = config.get('response_schema') + self.structured_output = config.get('structured_output', False) + else: + self.response_schema = None + self.structured_output = False""" if with_schema else "" schema_infer = ( """ - api_params = {} - if self.response_schema: - api_params['response_schema'] = self.response_schema + api_params = {} + if self.response_schema: + api_params['response_schema'] = self.response_schema # result = self.client.generate(prompt, **api_params)""" if with_schema else """ @@ -166,47 +166,47 @@ def apply_schema(self, schema_instance): ) provider_content = textwrap.dedent(f'''\ - """Provider implementation for {provider_name}.""" +"""Provider implementation for {provider_name}.""" - import os - import langextract as lx{schema_imports} +import os +import langextract as lx{schema_imports} - @lx.providers.registry.register({patterns_str}, priority=10) - class {provider_name}LanguageModel(lx.inference.BaseLanguageModel): - """LangExtract provider for {provider_name}. +@lx.providers.registry.register({patterns_str}, priority=10) +class {provider_name}LanguageModel(lx.inference.BaseLanguageModel): + """LangExtract provider for {provider_name}. - This provider handles model IDs matching: {patterns} - """ + This provider handles model IDs matching: {patterns} + """ - def __init__(self, model_id: str, api_key: str = None, **kwargs): - """Initialize the {provider_name} provider. + def __init__(self, model_id: str, api_key: str = None, **kwargs): + """Initialize the {provider_name} provider. - Args: - model_id: The model identifier. - api_key: API key for authentication. - **kwargs: Additional provider-specific parameters. - """ - super().__init__() - self.model_id = model_id - self.api_key = api_key or os.environ.get('{env_var_safe}'){schema_init} + Args: + model_id: The model identifier. + api_key: API key for authentication. + **kwargs: Additional provider-specific parameters. + """ + super().__init__() + self.model_id = model_id + self.api_key = api_key or os.environ.get('{env_var_safe}'){schema_init} - # self.client = YourClient(api_key=self.api_key) - self._extra_kwargs = kwargs{schema_methods} + # self.client = YourClient(api_key=self.api_key) + self._extra_kwargs = kwargs{schema_methods} - def infer(self, batch_prompts, **kwargs): - """Run inference on a batch of prompts. + def infer(self, batch_prompts, **kwargs): + """Run inference on a batch of prompts. - Args: - batch_prompts: List of prompts to process. - **kwargs: Additional inference parameters. + Args: + batch_prompts: List of prompts to process. + **kwargs: Additional inference parameters. - Yields: - Lists of ScoredOutput objects, one per prompt. - """ - for prompt in batch_prompts:{schema_infer} - result = f"Mock response for: {{prompt[:50]}}..." - yield [lx.inference.ScoredOutput(score=1.0, output=result)] + Yields: + Lists of ScoredOutput objects, one per prompt. + """ + for prompt in batch_prompts:{schema_infer} + result = f"Mock response for: {{prompt[:50]}}..." + yield [lx.inference.ScoredOutput(score=1.0, output=result)] ''') (package_dir / "provider.py").write_text(provider_content, encoding="utf-8")