diff --git a/README.md b/README.md index 1aa1fbbe..9d7e6b18 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ - [*Romeo and Juliet* Full Text Extraction](#romeo-and-juliet-full-text-extraction) - [Medication Extraction](#medication-extraction) - [Radiology Report Structuring: RadExtract](#radiology-report-structuring-radextract) +- [Architecture Documentation](#architecture-documentation) - [Community Providers](#community-providers) - [Contributing](#contributing) - [Testing](#testing) @@ -206,6 +207,213 @@ docker build -t langextract . docker run --rm -e LANGEXTRACT_API_KEY="your-api-key" langextract python your_script.py ``` +## Configuration & Logging + +LangExtract provides a flexible configuration system and unified logging interface. + +### Configuring Log Levels + +By default, LangExtract logs at `WARNING` level. You can enable more detailed logging: + +**Option 1: Using `configure()` (global setting)** + +```python +import langextract as lx + +# Enable DEBUG level logging globally +lx.configure(log_level="DEBUG") + +# Or enable INFO level +lx.configure(log_level="INFO") +``` + +**Option 2: Using context manager (temporary setting)** + +```python +import langextract as lx + +# Temporarily enable DEBUG logging for a specific code block +with lx.config(log_level="DEBUG"): + result = lx.extract(...) # Debug logs will be shown here + +# Back to default WARNING level outside the context +``` + +**Option 3: Using environment variables** + +```bash +# Set log level via environment variable +export LANGEXTRACT_LOG_LEVEL="DEBUG" + +# Other configuration options +export LANGEXTRACT_REQUEST_TIMEOUT="120.0" +export LANGEXTRACT_MAX_RETRIES="5" +export LANGEXTRACT_DEFAULT_MODEL="gemini-2.5-flash" +export LANGEXTRACT_CACHE_ENABLED="true" +``` + +### Available Configuration Options + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `log_level` | str | `"WARNING"` | Logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL | +| `request_timeout` | float | `60.0` | Request timeout in seconds | +| `max_retries` | int | `3` | Maximum number of retries for failed requests | +| `default_model` | str | `None` | Default model ID to use | +| `default_max_tokens` | int | `None` | Default maximum tokens for generation | +| `cache_enabled` | bool | `True` | Whether to enable caching | +| `cache_dir` | str | `None` | Directory for cache files | +| `progress_enabled` | bool | `True` | Whether to show terminal progress bars and completion messages | + +### Configuration Priority + +Configuration values are applied in the following priority order (highest to lowest): + +1. **Explicit parameters** - Passed directly to `configure()` or `config()` +2. **Environment variables** - `LANGEXTRACT_*` prefixed variables +3. **Default values** - Built-in defaults + +### Advanced Logging Configuration + +For production use, you may want to configure file-based logging with rotation, +or add custom handlers like JSON formatting. + +#### Example 1: Rotating File Handler (Log Persistence) + +```python +import logging +from logging.handlers import RotatingFileHandler +import langextract as lx + +# Enable INFO level logging +lx.configure(log_level="INFO") + +# Get the root langextract logger +root_logger = logging.getLogger("langextract") + +# Create a RotatingFileHandler: max 10MB per file, keep 5 backup files +file_handler = RotatingFileHandler( + "langextract.log", + maxBytes=10 * 1024 * 1024, # 10MB + backupCount=5, + encoding="utf-8", +) + +# Set the format +formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +file_handler.setFormatter(formatter) + +# Add the handler to the root logger +root_logger.addHandler(file_handler) + +# Optional: Also log to console (stderr) +console_handler = logging.StreamHandler() +console_handler.setFormatter(formatter) +root_logger.addHandler(console_handler) +``` + +#### Example 2: JSON Formatted Logs + +For structured logging (e.g., for log aggregation systems like ELK Stack): + +```python +import logging +import json +from datetime import datetime +import langextract as lx + +lx.configure(log_level="DEBUG") + +class JSONFormatter(logging.Formatter): + """Custom JSON formatter for structured logging.""" + + def format(self, record): + log_record = { + "timestamp": datetime.utcnow().isoformat(), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + "module": record.module, + "function": record.funcName, + "line": record.lineno, + } + + # Include exception info if present + if record.exc_info: + log_record["exception"] = self.formatException(record.exc_info) + + return json.dumps(log_record) + +# Get the root logger +root_logger = logging.getLogger("langextract") + +# Create handler with JSON formatter +handler = logging.StreamHandler() +handler.setFormatter(JSONFormatter()) +root_logger.addHandler(handler) +``` + +#### Example 3: Disabling Progress Display + +In non-interactive environments (e.g., production servers, cron jobs), +you may want to disable the terminal progress bars: + +```python +import langextract as lx + +# Disable progress bars and completion messages +lx.configure(progress_enabled=False, log_level="INFO") + +# Now extractions will run without terminal progress display +# but all log messages will still go through the logging system +result = lx.extract(...) +``` + +Or via environment variable: + +```bash +export LANGEXTRACT_PROGRESS_ENABLED="0" +``` + +### Thread Safety Note + +The `config()` context manager uses Python's `contextvars` module for +thread-safe configuration. However, note that: + +1. **`contextvars` are NOT automatically inherited by new threads** in Python. + If you use `threading.Thread`, the context configuration will not be + automatically propagated. + +2. **Recommendations for multi-threaded code:** + - Use `configure()` for global configuration + - Or use `contextvars.copy_context()` to explicitly propagate context + - Or pass Config objects directly to functions that need them + +Example with explicit context propagation: + +```python +import threading +import contextvars +import langextract as lx + +def worker(): + with lx.config(log_level="DEBUG"): + # This code will have DEBUG logging + pass + +# Run in the current context +with lx.config(log_level="DEBUG"): + # Copy the current context + ctx = contextvars.copy_context() + + # Run the worker with the copied context + t = threading.Thread(target=lambda: ctx.run(worker)) + t.start() + t.join() +``` + ## API Key Setup for Cloud Models When using LangExtract with cloud-hosted models (like Gemini or OpenAI), you'll need to diff --git a/docs/architecture/TODO.md b/docs/architecture/TODO.md new file mode 100644 index 00000000..9d5254e9 --- /dev/null +++ b/docs/architecture/TODO.md @@ -0,0 +1,347 @@ +# 架构文档 TODO + +本文档记录在编写架构文档过程中发现的代码注释或文档改进点。 + +--- + +## 架构导航 + +- **← [返回概览](overview.md)** +- **→ [Schema 设计](schema.md)** +- **→ [Prompt 组装](prompt.md)** +- **→ [输出解析与实体对齐](alignment.md)** +- **→ [长文档分块](chunking.md)** + +--- + +## 目录 + +- [待办事项列表](#待办事项列表) +- [Issue 模板](#issue-模板) + +--- + +## 待办事项列表 + +| ID | 优先级 | 状态 | 标题 | 关联文件 | +|----|--------|------|------|----------| +| 1 | P1 | open | `BaseSchema.from_examples()` 方法缺少详细 docstring | `langextract/core/schema.py` | +| 2 | P2 | open | `FormatModeSchema.requires_raw_output` 行为差异说明 | `langextract/core/schema.py` | +| 3 | P1 | open | `WordAligner.align_extractions()` 中 `delim` 选择理由注释 | `langextract/resolver.py` | +| 4 | P2 | open | `_accept_lcs_match` 双重门控设计 rationale | `langextract/resolver.py` | +| 5 | P2 | open | `ChunkIterator` 分块策略设计选择高层文档 | `langextract/chunking.py` | +| 6 | P1 | open | `parse_output` 兼容路径决策树注释 | `langextract/core/format_handler.py` | +| 7 | P2 | open | `_THINK_TAG_RE` 存在理由补充 | `langextract/core/format_handler.py` | +| 8 | P1 | open | `resolver_params` 对齐参数详细说明 | `langextract/extraction.py` | + +--- + +### 详细说明 + +#### TODO 1: `BaseSchema.from_examples()` 方法缺少详细 docstring + +| 字段 | 内容 | +|------|------| +| **优先级** | P1 | +| **状态** | open | +| **关联文件** | `langextract/core/schema.py` | +| **当前问题** | `from_examples` 方法缺少详细 docstring,说明如何从 examples 推断 schema | +| **建议改进** | 补充 docstring,说明:1) 如何提取 extraction_class;2) 如何推断值类型;3) 如何处理 attributes | +| **影响范围** | 新开发者理解 schema 推断机制的入口点 | + +**Issue 模板**: + +``` +### 标题: 补充 BaseSchema.from_examples() 方法的 docstring + +### 背景 +在编写架构文档时,发现 `langextract/core/schema.py` 中的 `BaseSchema.from_examples()` 抽象方法缺少详细的 docstring。这个方法是理解 LangExtract example-driven schema 推断机制的关键入口点。 + +### 问题描述 +- 当前 `from_examples` 只有一个简单的签名说明 +- 没有说明从 examples 推断 schema 的具体逻辑: + - 如何从 `Extraction.extraction_class` 提取实体类型? + - 如何分析 `extraction_text` 的值类型? + - 如何处理 `attributes` 字段? + +### 验收标准 +- [ ] `BaseSchema.from_examples()` 方法有完整的 docstring +- [ ] docstring 包含参数说明(`examples_data`、`attribute_suffix`) +- [ ] docstring 包含返回值说明(`BaseSchema` 实例) +- [ ] docstring 包含 schema 推断的逻辑说明 +``` + +--- + +#### TODO 2: `FormatModeSchema.requires_raw_output` 行为差异说明 + +| 字段 | 内容 | +|------|------| +| **优先级** | P2 | +| **状态** | open | +| **关联文件** | `langextract/core/schema.py` | +| **当前问题** | `requires_raw_output` 属性的行为在不同 provider 之间的差异需要更清晰的说明 | +| **建议改进** | 补充说明:1) JSON 模式 vs YAML 模式的差异;2) 不同 provider 如何处理 `requires_raw_output` | +| **影响范围** | 开发者理解不同 provider 的行为差异 | + +**Issue 模板**: + +``` +### 标题: 补充 FormatModeSchema.requires_raw_output 行为说明 + +### 背景 +`FormatModeSchema.requires_raw_output` 属性决定了输出是否需要围栏标记。但这个行为在 JSON 模式和 YAML 模式下有所不同,且与具体 provider 的实现有关。 + +### 问题描述 +- 当前 `requires_raw_output` 的实现是:`self._format == "json"` +- 这意味着 JSON 格式输出原始 JSON(无围栏),YAML 格式需要围栏 +- 但不同 provider(Gemini、OpenAI、Ollama)对这个属性的处理可能不同 + +### 验收标准 +- [ ] 在 `FormatModeSchema` 类的 docstring 中补充 `requires_raw_output` 的行为说明 +- [ ] 说明 JSON 模式和 YAML 模式的差异 +- [ ] 说明不同 provider 可能的实现差异 +``` + +--- + +#### TODO 3: `WordAligner.align_extractions()` 中 `delim` 选择理由注释 + +| 字段 | 内容 | +|------|------| +| **优先级** | P1 | +| **状态** | open | +| **关联文件** | `langextract/resolver.py` | +| **当前问题** | `WordAligner.align_extractions` 方法中使用 `\u241F` 作为分隔符,但缺少为什么选择这个字符的注释 | +| **建议改进** | 补充注释说明:1) 为什么选择 `\u241F`(Unicode 单元分隔符);2) 为什么不使用其他分隔符 | +| **影响范围** | 理解精确匹配机制的关键设计选择 | + +**Issue 模板**: + +``` +### 标题: 补充 WordAligner.align_extractions 中 delim 选择理由的注释 + +### 背景 +在 `langextract/resolver.py` 的 `WordAligner.align_extractions()` 方法中,使用 `\u241F`(Unicode 单元分隔符)作为分隔符连接多个 extraction_text。这个选择是精确匹配机制中的关键设计。 + +### 问题描述 +- 分隔符选择:`delim = "\u241F"` +- 代码中没有注释说明为什么选择这个字符 +- 新开发者可能不理解: + - 为什么是 `\u241F` 而不是空格或其他字符? + - 这个字符有什么特殊属性? + +### 验收标准 +- [ ] 在 `delim = "\u241F"` 语句前添加注释 +- [ ] 说明这是 Unicode 单元分隔符(Unit Separator) +- [ ] 说明选择理由:不会出现在正常文本中,确保精确匹配的准确性 +``` + +--- + +#### TODO 4: `_accept_lcs_match` 双重门控设计 rationale + +| 字段 | 内容 | +|------|------| +| **优先级** | P2 | +| **状态** | open | +| **关联文件** | `langextract/resolver.py` | +| **当前问题** | `_accept_lcs_match` 中的双重门控(coverage + density)的设计 rationale 可以补充说明 | +| **建议改进** | 补充注释说明:1) 为什么需要双重门控;2) 每个门控解决什么问题;3) 默认值(0.75, 1/3)的选择理由 | +| **影响范围** | 理解模糊匹配算法的核心逻辑 | + +**Issue 模板**: + +``` +### 标题: 补充 _accept_lcs_match 双重门控的设计 rationale + +### 背景 +`_accept_lcs_match` 函数实现了 LCS 模糊匹配的双重门控验证: +1. Coverage Gate: 匹配的 token 数 >= 阈值比例 +2. Density Gate: 匹配的 token 数 / 匹配区间长度 >= min_density + +### 问题描述 +- 当前代码只有简单的实现逻辑 +- 没有说明为什么需要双重门控 +- 没有说明默认值(0.75, 1/3)的选择理由 + +### 验收标准 +- [ ] 在 `_accept_lcs_match` 函数前添加或补充 docstring +- [ ] 说明 Coverage Gate 解决的问题(防止匹配过少 tokens) +- [ ] 说明 Density Gate 解决的问题(防止匹配的 tokens 分散在太长区间) +- [ ] 补充默认值选择的 rationale(如果有) +``` + +--- + +#### TODO 5: `ChunkIterator` 分块策略设计选择高层文档 + +| 字段 | 内容 | +|------|------| +| **优先级** | P2 | +| **状态** | open | +| **关联文件** | `langextract/chunking.py` | +| **当前问题** | `ChunkIterator` 中分块策略的设计选择(为什么优先换行 > 句子 > token)缺少高层文档 | +| **建议改进** | 在模块级或类级 docstring 中补充:1) 分块策略的优先级;2) 每个优先级的设计理由;3) `broken_sentence` 标志的使用场景 | +| **影响范围** | 理解长文档分块机制 | + +**Issue 模板**: + +``` +### 标题: 补充 ChunkIterator 分块策略的高层文档 + +### 背景 +`ChunkIterator` 实现了 LangExtract 的长文档分块逻辑,策略优先级为: +1. 优先按换行符分割(保持格式) +2. 然后按句子边界分割(保持语义) +3. 最后按 token 分割(超长句处理) + +### 问题描述 +- 当前代码中,分块逻辑分散在 `__next__` 方法中 +- 没有统一的高层文档说明策略设计选择 +- `broken_sentence` 标志的使用场景不够清晰 + +### 验收标准 +- [ ] 在 `ChunkIterator` 类的 docstring 中补充分块策略说明 +- [ ] 说明优先级:换行 > 句子 > token +- [ ] 说明每个优先级的设计理由 +- [ ] 说明 `broken_sentence` 标志的作用和使用场景 +``` + +--- + +#### TODO 6: `parse_output` 兼容路径决策树注释 + +| 字段 | 内容 | +|------|------| +| **优先级** | P1 | +| **状态** | open | +| **关联文件** | `langextract/core/format_handler.py` | +| **当前问题** | `parse_output` 方法中各种兼容路径(wrapper vs 非 wrapper, strict vs 非 strict)的决策树可以用注释说明 | +| **建议改进** | 用注释或流程图形式说明:1) strict 模式 vs 非 strict 模式的行为差异;2) wrapper 模式 vs 非 wrapper 模式的处理流程 | +| **影响范围** | 理解输出解析的容错机制 | + +**Issue 模板**: + +``` +### 标题: 补充 parse_output 兼容路径的决策树注释 + +### 背景 +`FormatHandler.parse_output()` 方法处理多种兼容情况: +- strict vs 非 strict 模式 +- wrapper vs 非 wrapper 模式 +- 顶级列表支持 + +### 问题描述 +- 代码逻辑复杂,包含多个条件分支 +- 没有统一的决策树说明 +- 新开发者难以理解不同参数组合的行为 + +### 验收标准 +- [ ] 在 `parse_output` 方法中添加决策树注释 +- [ ] 说明 `strict` 参数的影响 +- [ ] 说明 `use_wrapper`、`allow_top_level_list` 的交互 +- [ ] 用 ASCII 图或表格形式展示参数组合的行为 +``` + +--- + +#### TODO 7: `_THINK_TAG_RE` 存在理由补充 + +| 字段 | 内容 | +|------|------| +| **优先级** | P2 | +| **状态** | open | +| **关联文件** | `langextract/core/format_handler.py` | +| **当前问题** | `_THINK_TAG_RE` 的存在理由(支持哪些模型)可以补充 | +| **建议改进** | 补充注释说明:1) 正则表达式的用途;2) 哪些模型会输出 `` 标签;3) 这个容错机制的设计背景 | +| **影响范围** | 理解推理模型的输出格式处理 | + +**Issue 模板**: + +``` +### 标题: 补充 _THINK_TAG_RE 存在理由的注释 + +### 背景 +`_THINK_TAG_RE` 正则表达式用于处理推理模型(如 DeepSeek-R1、QwQ)的输出,这些模型会在 JSON 输出前先输出思考过程。 + +### 问题描述 +- 当前定义:`_THINK_TAG_RE = re.compile(r"[\s\S]*?\s*")` +- 没有注释说明: + - 这个正则的用途是什么? + - 哪些模型会输出 `` 标签? + - 这个容错机制是何时、为何引入的? + +### 验收标准 +- [ ] 在 `_THINK_TAG_RE` 定义前添加注释 +- [ ] 说明支持的模型(DeepSeek-R1、QwQ 等推理模型) +- [ ] 说明这些模型的输出特点(先思考后输出) +- [ ] 说明 `_parse_with_fallback` 中如何使用这个正则 +``` + +--- + +#### TODO 8: `resolver_params` 对齐参数详细说明 + +| 字段 | 内容 | +|------|------| +| **优先级** | P1 | +| **状态** | open | +| **关联文件** | `langextract/extraction.py` | +| **当前问题** | `lx.extract()` 的 docstring 很详细,但 `resolver_params` 中的各个对齐参数可以增加更详细的说明 | +| **建议改进** | 在 docstring 中补充:1) 对齐参数的默认值;2) 推荐范围;3) 调整建议(何时调高调低) | +| **影响范围** | Public API 文档,用户快速上手 | + +**Issue 模板**: + +``` +### 标题: 补充 resolver_params 中对齐参数的详细说明 + +### 背景 +`extract()` 函数的 `resolver_params` 参数允许用户配置对齐行为,包括: +- `enable_fuzzy_alignment` +- `fuzzy_alignment_threshold` +- `fuzzy_alignment_min_density` +- `fuzzy_alignment_algorithm` +- `accept_match_lesser` +- `suppress_parse_errors` + +### 问题描述 +- 当前 docstring 中缺少这些参数的详细说明 +- 用户可能不知道: + - 默认值是什么? + - 推荐调整范围是什么? + - 何时应该调整这些参数? + +### 验收标准 +- [ ] 在 `extract()` 函数的 docstring 中补充 `resolver_params` 的详细说明 +- [ ] 列出所有支持的对齐参数及其默认值 +- [ ] 给出推荐范围和调整建议 +- [ ] 可以考虑添加一个配置示例表格 +``` + +--- + +## Issue 模板 + +为方便提交 Issue,以下是标准模板: + +```markdown +### 标题: [简短描述问题] + +### 背景 +[问题的背景信息,为什么需要改进] + +### 问题描述 +[具体问题是什么,当前代码的状态] + +### 验收标准 +- [ ] 可验证的标准 1 +- [ ] 可验证的标准 2 +- [ ] 可验证的标准 3 +``` + +--- + +**本文档基于代码版本**: langextract (docs/schema-design 分支) +**最后更新**: 2026-04-19 diff --git a/docs/architecture/alignment.md b/docs/architecture/alignment.md new file mode 100644 index 00000000..bd6590a4 --- /dev/null +++ b/docs/architecture/alignment.md @@ -0,0 +1,655 @@ +# 输出解析与实体对齐 + +LLM 返回的原始文本需要经过解析才能转换为结构化的 `Extraction` 对象。解析过程由 `FormatHandler` 和 `Resolver` 协同完成。 + +实体对齐是 LangExtract 的核心能力之一——它将 LLM 抽取出的文本片段回溯到原文中的精确字符位置。这使得抽取结果可验证、可可视化。 + +--- + +## 架构导航 + +- **← [返回概览](overview.md)** +- **→ [Schema 设计](schema.md)** +- **→ [Prompt 组装](prompt.md)** +- **→ [长文档分块](chunking.md)** + +--- + +## 目录 + +- [输出解析](#输出解析) + - [解析流程](#解析流程) + - [关键代码解析](#关键代码解析) + - [格式错误时的 Fallback 策略](#格式错误时的-fallback-策略) +- [实体对齐](#实体对齐) + - [对齐流程](#对齐流程) + - [关键代码解析](#关键代码解析-1) + - [对齐状态说明](#对齐状态说明) + - [对齐失败时的处理](#对齐失败时的处理) + - [对齐参数配置](#对齐参数配置) + +--- + +## 输出解析 + +### 解析流程 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ 输出解析流程 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ LLM 原始输出 (Raw Output) │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 某些推理模型可能先输出思考过程: │ │ +│ │ Let me analyze this text. I see Dr. Smith mentioned... │ │ +│ │ │ │ +│ │ 然后是结构化输出: │ │ +│ │ ```json │ │ +│ │ { │ │ +│ │ "extractions": [ │ │ +│ │ {"person": "Dr. Smith"}, │ │ +│ │ {"medication": "Aspirin"} │ │ +│ │ ] │ │ +│ │ } │ │ +│ │ ``` │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 1: 标签过滤 (可选) │ │ +│ │ │ │ +│ │ 正则: [\s\S]*?\s* │ │ +│ │ 位置: langextract/core/format_handler.py:46 │ │ +│ │ │ │ +│ │ 原因: DeepSeek-R1, QwQ 等推理模型会先输出思考过程 │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 2: 围栏提取 (Fence Extraction) │ │ +│ │ │ │ +│ │ 正则: ```(?P[A-Za-z0-9_+-]+)?\s*\n(?P[\s\S]*?)``` │ │ +│ │ 位置: langextract/core/format_handler.py:41-44 │ │ +│ │ │ │ +│ │ 规则: │ │ +│ │ - strict_fences=True: 必须恰好一个 ```json 或 ```yaml 块 │ │ +│ │ - strict_fences=False: 宽松模式,支持无语言标签或无围栏 │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 3: 格式解析 (Format Parsing) │ │ +│ │ │ │ +│ │ JSON: json.loads(content) │ │ +│ │ YAML: yaml.safe_load(content) │ │ +│ │ │ │ +│ │ 容错: 如果第一次解析失败且有 标签,尝试去除后再解析 │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 4: 结构提取 (Structure Extraction) │ │ +│ │ │ │ +│ │ 期望结构 (wrapper 模式): │ │ +│ │ {"extractions": [{"key1": "value1"}, {"key2": "value2"}]} │ │ +│ │ │ │ +│ │ 兼容结构 (非 wrapper 模式): │ │ +│ │ [{"key1": "value1"}, {"key2": "value2"}] │ │ +│ │ │ │ +│ │ 位置: langextract/core/format_handler.py:151-245 (parse_output) │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 5: 转换为 Extraction 对象 │ │ +│ │ │ │ +│ │ 每个字典项: │ │ +│ │ {"person": "John", "person_attributes": {"age": "30"}} │ │ +│ │ ↓ │ │ +│ │ Extraction( │ │ +│ │ extraction_class="person", │ │ +│ │ extraction_text="John", │ │ +│ │ attributes={"age": "30"} │ │ +│ │ ) │ │ +│ │ │ │ +│ │ 位置: langextract/resolver.py:424-523 (extract_ordered_extractions)│ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 关键代码解析 + +#### 1. 围栏提取 (`_extract_content`) + +```python +# langextract/core/format_handler.py:278-333 +def _extract_content(self, text: str) -> str: + """从文本中提取内容,处理围栏""" + if not self.use_fences: + return text.strip() # 无围栏模式,直接返回 + + matches = list(_FENCE_RE.finditer(text)) + + # 验证语言标签 (json/yaml/yml) + valid_tags = { + data.FormatType.YAML: {"yaml", "yml"}, + data.FormatType.JSON: {"json"}, + } + candidates = [m for m in matches if self._is_valid_language_tag(...)] + + if self.strict_fences: + # 严格模式: 必须恰好一个有效围栏块 + if len(candidates) != 1: + raise exceptions.FormatParseError("...") + return candidates[0].group("body").strip() + + # 宽松模式 + if len(candidates) == 1: + return candidates[0].group("body").strip() + elif len(candidates) > 1: + raise exceptions.FormatParseError("Multiple fenced blocks found") + + # 最后尝试: 任意围栏或无围栏 + if matches and len(matches) == 1: + return matches[0].group("body").strip() + + return text.strip() # 无围栏,直接使用 +``` + +#### 2. 解析输出 (`parse_output`) + +```python +# langextract/core/format_handler.py:151-245 +def parse_output( + self, text: str, *, strict: bool | None = None +) -> Sequence[Mapping[str, ExtractionValueType]]: + """解析模型输出为提取数据""" + if not text: + raise exceptions.FormatParseError("Empty or invalid input string.") + + # Step 1: 提取内容 (围栏处理) + content = self._extract_content(text) + + # Step 2: 解析 JSON/YAML (含 标签容错) + try: + parsed = self._parse_with_fallback(content, strict) + except (yaml.YAMLError, json.JSONDecodeError) as e: + raise exceptions.FormatParseError(...) from e + + # Step 3: 提取 extractions 列表 + require_wrapper = self.wrapper_key is not None and ( + self.use_wrapper or bool(strict) + ) + + if isinstance(parsed, dict): + # Wrapper 模式: {"extractions": [...]} + if require_wrapper: + if self.wrapper_key not in parsed: + raise exceptions.FormatParseError( + f"Content must contain an '{self.wrapper_key}' key." + ) + items = parsed[self.wrapper_key] + else: + # 兼容: 尝试已知的 wrapper key + if data.EXTRACTIONS_KEY in parsed: + items = parsed[data.EXTRACTIONS_KEY] + elif self.wrapper_key and self.wrapper_key in parsed: + items = parsed[self.wrapper_key] + else: + items = [parsed] # 单个对象作为单元素列表 + elif isinstance(parsed, list): + # 非 wrapper 模式: [...] + if require_wrapper and (strict or not self.allow_top_level_list): + raise exceptions.FormatParseError(...) + items = parsed + else: + raise exceptions.FormatParseError( + f"Expected list or dict, got {type(parsed)}" + ) + + # Step 4: 验证每个 item 是字典 + for item in items: + if not isinstance(item, dict): + raise exceptions.FormatParseError( + "Each item in the sequence must be a mapping." + ) + + return items +``` + +#### 3. `` 标签容错 (`_parse_with_fallback`) + +```python +# langextract/core/format_handler.py:261-276 +def _parse_with_fallback(self, content: str, strict: bool): + """解析内容,失败时尝试去除 标签""" + try: + if self.format_type == data.FormatType.YAML: + return yaml.safe_load(content) + return json.loads(content) + except (yaml.YAMLError, json.JSONDecodeError): + if strict: + raise + # 推理模型 (DeepSeek-R1, QwQ) 会在 JSON 前输出 + if _THINK_TAG_RE.search(content): + stripped = _THINK_TAG_RE.sub("", content).strip() + if self.format_type == data.FormatType.YAML: + return yaml.safe_load(stripped) + return json.loads(stripped) + raise +``` + +#### 4. 转换为 Extraction 对象 (`extract_ordered_extractions`) + +```python +# langextract/resolver.py:424-523 +def extract_ordered_extractions( + self, + extraction_data: Sequence[Mapping[str, fh.ExtractionValueType]], +) -> Sequence[data.Extraction]: + """将解析后的数据转换为 Extraction 对象列表""" + processed_extractions = [] + extraction_index = 0 + index_suffix = self.extraction_index_suffix # 可选: "_index" + attributes_suffix = self.format_handler.attribute_suffix # "_attributes" + + for group_index, group in enumerate(extraction_data): + for extraction_class, extraction_value in group.items(): + # 跳过索引字段 (如果使用 index_suffix) + if index_suffix and extraction_class.endswith(index_suffix): + continue + + # 跳过属性字段 (单独处理) + if attributes_suffix and extraction_class.endswith(attributes_suffix): + continue + + # 值类型验证: 必须是 str/int/float + if not isinstance(extraction_value, (str, int, float)): + raise ValueError( + "Extraction text must be a string, integer, or float." + ) + + # 统一转为字符串 + if not isinstance(extraction_value, str): + extraction_value = str(extraction_value) + + # 查找对应的索引 (如果有) + if index_suffix: + index_key = extraction_class + index_suffix + extraction_index = group.get(index_key, None) + if extraction_index is None: + continue # 无索引则跳过 + else: + extraction_index += 1 + + # 查找对应的属性 + attributes = None + if attributes_suffix: + attributes_key = extraction_class + attributes_suffix + attributes = group.get(attributes_key, None) + + # 创建 Extraction 对象 + processed_extractions.append( + data.Extraction( + extraction_class=extraction_class, + extraction_text=extraction_value, + extraction_index=extraction_index, + group_index=group_index, + attributes=attributes, + ) + ) + + # 按索引排序 (如果使用 index_suffix) + processed_extractions.sort(key=operator.attrgetter("extraction_index")) + return processed_extractions +``` + +### 格式错误时的 Fallback 策略 + +| 场景 | 处理方式 | 控制参数 | +|------|----------|----------| +| 解析失败 (JSON/YAML 语法错误) | `suppress_parse_errors=True` 时返回空列表,否则抛异常 | `resolver_params={"suppress_parse_errors": True}` | +| 多个围栏块 | 严格模式抛异常,宽松模式取第一个 | `strict_fences` | +| 无围栏标签 | 宽松模式尝试直接解析整段文本 | `strict_fences=False` | +| 包含 `` 标签 | 自动去除后重试解析 | 内置 (非 strict 模式) | +| 缺少 `extractions` wrapper | 宽松模式接受顶级列表 | `use_wrapper=False` 或 `allow_top_level_list=True` | + +**注意**: `suppress_parse_errors` 在 `extract()` 中默认为 `True`,这意味着单个 chunk 的解析失败不会导致整个文档处理失败。 + +--- + +## 实体对齐 + +### 对齐流程 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ 实体对齐流程 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 输入 │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ resolved_extractions: [ │ │ +│ │ Extraction(extraction_text="Dr. Smith"), │ │ +│ │ Extraction(extraction_text="Aspirin 10mg") │ │ +│ │ ] │ │ +│ │ │ │ +│ │ source_text: "Dr. Smith prescribed Aspirin 10mg to the patient." │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 0: Tokenization & 归一化 │ │ +│ │ │ │ +│ │ 原文 token 化: │ │ +│ │ ["dr", "smith", "prescribed", "aspirin", "10mg", "to", ...] │ │ +│ │ │ │ +│ │ 提取文本 token 化 + 归一化: │ │ +│ │ - 小写: "Dr. Smith" → "dr. smith" │ │ +│ │ - 轻量词干化: "patients" → "patient" (去除 s 后缀) │ │ +│ │ │ │ +│ │ 位置: langextract/resolver.py:1034-1069 (_tokenize_with_lowercase) │ │ +│ │ langextract/resolver.py:1063-1069 (_normalize_token) │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 1: 精确匹配 (Exact Match) │ │ +│ │ │ │ +│ │ 算法: difflib.SequenceMatcher (Python 标准库) │ │ +│ │ 位置: langextract/resolver.py:921-977 │ │ +│ │ │ │ +│ │ 策略: │ │ +│ │ 1. 将所有 extraction_text 用特殊分隔符连接 │ │ +│ │ 2. 与 source_text 进行全局序列匹配 │ │ +│ │ 3. 对每个匹配块,判断是完全匹配还是部分匹配 │ │ +│ │ │ │ +│ │ 匹配状态: │ │ +│ │ - MATCH_EXACT: extraction_text 与原文完全一致 │ │ +│ │ - MATCH_LESSER: 匹配的文本比 extraction_text 短 │ │ +│ │ (extraction 更长,只匹配到一部分) │ │ +│ │ - 不匹配: 进入模糊匹配阶段 │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 2: 模糊匹配 (Fuzzy Match) - 仅当精确匹配失败时 │ │ +│ │ │ │ +│ │ 有两种算法: │ │ +│ │ │ │ +│ │ A) Legacy 算法 (deprecated) │ │ +│ │ - difflib.SequenceMatcher.ratio() │ │ +│ │ - 滑动窗口遍历所有可能的匹配位置 │ │ +│ │ - 位置: langextract/resolver.py:578-702 (_fuzzy_align_extraction)│ │ +│ │ │ │ +│ │ B) LCS 算法 (默认,推荐) │ │ +│ │ - 最长公共子序列 (Longest Common Subsequence) │ │ +│ │ - 动态规划 O(n*m²) 时间复杂度 │ │ +│ │ - 双重门控: coverage + density │ │ +│ │ - 位置: langextract/resolver.py:704-774 (_lcs_fuzzy_align_extraction)│ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 3: 计算偏移量 & 设置状态 │ │ +│ │ │ │ +│ │ 计算: │ │ +│ │ - token_interval: 在 chunk 内的 token 索引 + token_offset │ │ +│ │ - char_interval: 通过 token 的 char_interval 计算字符偏移 │ │ +│ │ - alignment_status: MATCH_EXACT / MATCH_FUZZY / MATCH_LESSER │ │ +│ │ │ │ +│ │ 对齐失败: │ │ +│ │ - char_interval = None │ │ +│ │ - token_interval = None │ │ +│ │ - alignment_status = None │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 关键代码解析 + +#### 1. 精确匹配主流程 (`align_extractions`) + +```python +# langextract/resolver.py:776-1031 +def align_extractions( + self, + extraction_groups: Sequence[Sequence[data.Extraction]], + source_text: str, + token_offset: int = 0, + char_offset: int = 0, + enable_fuzzy_alignment: bool = True, + fuzzy_alignment_threshold: float = 0.75, + ... +) -> Sequence[Sequence[data.Extraction]]: + """将 extractions 对齐到原文""" + # Step 1: 准备 tokens + source_tokens = list(_tokenize_with_lowercase(source_text, ...)) + + # Step 2: 用特殊分隔符连接所有 extraction_text + # 分隔符: "\u241F" (Unicode 单元分隔符),确保不会出现在正常文本中 + delim = "\u241F" + extraction_tokens = list(_tokenize_with_lowercase( + f" {delim} ".join( + extraction.extraction_text + for extraction in itertools.chain(*extraction_groups) + ), + tokenizer_impl=tokenizer_impl, + )) + + # Step 3: 精确匹配 (difflib.SequenceMatcher) + self._set_seqs(source_tokens, extraction_tokens) + + # 遍历匹配块 + for i, j, n in self._get_matching_blocks()[:-1]: + # i: source 中的起始 token 索引 + # j: extraction 中的起始 token 索引 + # n: 匹配的 token 数量 + + # 查找对应的 extraction + extraction, _ = index_to_extraction_group.get(j, (None, None)) + + # 设置 token_interval + extraction.token_interval = tokenizer_lib.TokenInterval( + start_index=i + token_offset, + end_index=i + n + token_offset, + ) + + # 通过 token 计算 char_interval + start_token = tokenized_text.tokens[i] + end_token = tokenized_text.tokens[i + n - 1] + extraction.char_interval = data.CharInterval( + start_pos=char_offset + start_token.char_interval.start_pos, + end_pos=char_offset + end_token.char_interval.end_pos, + ) + + # 判断匹配类型 + extraction_text_len = len(extraction_tokens_for_this_extraction) + if extraction_text_len == n: + extraction.alignment_status = data.AlignmentStatus.MATCH_EXACT + exact_matches += 1 + else: + # 部分匹配 (extraction 更长,只匹配到一部分) + if accept_match_lesser: + extraction.alignment_status = data.AlignmentStatus.MATCH_LESSER + lesser_matches += 1 + else: + # 不接受部分匹配,重置 + extraction.token_interval = None + extraction.char_interval = None + extraction.alignment_status = None + + # Step 4: 模糊匹配 (对精确匹配失败的 extractions) + if enable_fuzzy_alignment and unaligned_extractions: + for extraction in unaligned_extractions: + if fuzzy_alignment_algorithm == "lcs": + aligned = self._lcs_fuzzy_align_extraction(...) + else: + aligned = self._fuzzy_align_extraction(...) + + if aligned: + aligned_extractions.append(aligned) + + return aligned_extraction_groups +``` + +#### 2. LCS 模糊匹配算法 (`_lcs_fuzzy_align_extraction`) + +```python +# langextract/resolver.py:704-774 +def _lcs_fuzzy_align_extraction( + self, + extraction: data.Extraction, + source_tokens_norm: list[str], # 已归一化的原文 tokens + tokenized_text: tokenizer_lib.TokenizedText, + token_offset: int, + char_offset: int, + fuzzy_alignment_threshold: float = 0.75, + fuzzy_alignment_min_density: float = 1/3, + ... +) -> data.Extraction | None: + """使用 LCS 算法进行模糊对齐""" + # Step 1: Tokenize 和归一化 extraction_text + extraction_tokens = list(_tokenize_with_lowercase(extraction.extraction_text, ...)) + extraction_tokens_norm = [_normalize_token(t) for t in extraction_tokens] + + # Step 2: 计算所有可能的 LCS 匹配 + # 返回: {match_count: LcsSpan(matches, start, end)} + spans = _best_lcs_spans(source_tokens_norm, extraction_tokens_norm) + + # Step 3: 按匹配数量从高到低尝试,找到第一个通过双重门控的 + for k in sorted(spans.keys(), reverse=True): + candidate = spans[k] + if _accept_lcs_match( + candidate, + len(extraction_tokens_norm), + threshold=fuzzy_alignment_threshold, + min_density=fuzzy_alignment_min_density, + ): + accepted = candidate + break + + if accepted is None: + return None + + # Step 4: 设置 intervals 和状态 + extraction.token_interval = tokenizer_lib.TokenInterval( + start_index=accepted.start + token_offset, + end_index=accepted.end + 1 + token_offset, + ) + + start_token = tokenized_text.tokens[accepted.start] + end_token = tokenized_text.tokens[accepted.end] + extraction.char_interval = data.CharInterval( + start_pos=char_offset + start_token.char_interval.start_pos, + end_pos=char_offset + end_token.char_interval.end_pos, + ) + + extraction.alignment_status = data.AlignmentStatus.MATCH_FUZZY + return extraction +``` + +#### 3. LCS 双重门控 (`_accept_lcs_match`) + +```python +# langextract/resolver.py:1165-1192 +def _accept_lcs_match( + span: LcsSpan, + extraction_len: int, + threshold: float = 0.75, + min_density: float = 1/3, +) -> bool: + """应用覆盖度和密度双重门控""" + if span.matches == 0 or extraction_len == 0: + return False + + # Coverage Gate (覆盖度): 匹配的 token 数 >= 阈值比例 + # 例如: extraction 有 4 个 tokens,threshold=0.75,需要至少匹配 3 个 + needed = math.ceil(extraction_len * threshold) + if span.matches < needed: + return False + + # Density Gate (密度): 匹配的 token 数 / 匹配区间长度 >= min_density + # 防止匹配的 tokens 分散在太长的区间中 + # 例如: 匹配 2 个 tokens,但分散在 10 个 token 的区间中 → 密度 0.2 < 1/3 → 拒绝 + if span.span_len <= 0: + return False + density = span.matches / span.span_len + return density >= min_density +``` + +#### 4. Token 归一化 (`_normalize_token`) + +```python +# langextract/resolver.py:1063-1069 +@functools.lru_cache(maxsize=10000) +def _normalize_token(token: str) -> str: + """小写 + 轻量词干化 (去除复数 s)""" + token = token.lower() + # 长度 > 3 且以 s 结尾且不以 ss 结尾 → 去除 s + if len(token) > 3 and token.endswith("s") and not token.endswith("ss"): + token = token[:-1] + return token +``` + +### 对齐状态说明 + +| 状态 | 值 | 含义 | 示例 | +|------|-----|------|------| +| `MATCH_EXACT` | `"match_exact"` | 精确匹配 | extraction_text="John",原文中恰好有 "John" | +| `MATCH_LESSER` | `"match_lesser"` | 部分匹配 (匹配文本更短) | extraction_text="John Smith",只匹配到 "John" | +| `MATCH_FUZZY` | `"match_fuzzy"` | 模糊匹配 | extraction_text="Jon",匹配到原文的 "John" | +| `None` | - | 对齐失败 | 无法在原文中找到对应片段 | + +### 对齐失败时的处理 + +对齐失败的 extraction 会保留,但 `char_interval` 和 `token_interval` 为 `None`。用户可以通过过滤来只保留成功对齐的结果: + +```python +# 只保留成功对齐的 extractions +grounded_extractions = [ + e for e in result.extractions + if e.char_interval is not None +] +``` + +**原因**: LLM 可能从 few-shot examples 中"幻觉"出内容,或者提取的文本与原文表述不完全一致。LangExtract 不会丢弃这些结果,而是让用户决定如何处理。 + +### 对齐参数配置 + +对齐参数通过 `resolver_params` 传递给 `extract()`: + +```python +result = lx.extract( + text_or_documents=text, + prompt_description=prompt, + examples=examples, + resolver_params={ + # 模糊匹配开关 + "enable_fuzzy_alignment": True, + + # 覆盖度阈值: 至少匹配 75% 的 tokens + "fuzzy_alignment_threshold": 0.75, + + # 密度阈值: 匹配 tokens / 区间长度 >= 1/3 + "fuzzy_alignment_min_density": 1/3, + + # 算法选择: "lcs" (默认) 或 "legacy" (deprecated) + "fuzzy_alignment_algorithm": "lcs", + + # 是否接受部分匹配 (MATCH_LESSER) + "accept_match_lesser": True, + + # 解析错误时是否抑制异常 + "suppress_parse_errors": True, + } +) +``` + +--- + +**本文档基于代码版本**: langextract (docs/schema-design 分支) +**最后更新**: 2026-04-19 diff --git a/docs/architecture/chunking.md b/docs/architecture/chunking.md new file mode 100644 index 00000000..1fef7370 --- /dev/null +++ b/docs/architecture/chunking.md @@ -0,0 +1,381 @@ +# 长文档分块 + +当输入文本超过 LLM 的上下文窗口或 `max_char_buffer` 限制时,LangExtract 会将文档分割成多个 chunks 分别处理。 + +--- + +## 架构导航 + +- **← [返回概览](overview.md)** +- **→ [Schema 设计](schema.md)** +- **→ [Prompt 组装](prompt.md)** +- **→ [输出解析与实体对齐](alignment.md)** + +--- + +## 目录 + +- [分块策略](#分块策略) +- [关键代码解析](#关键代码解析) +- [Overlap 与上下文窗口](#overlap-与上下文窗口) +- [跨 Chunk 实体合并与去重](#跨-chunk-实体合并与去重) +- [分块参数配置](#分块参数配置) + +--- + +## 分块策略 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ 长文档分块策略 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 核心原则 │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 1. 优先按句子边界分割 (保持语义完整性) │ │ +│ │ 2. 尊重换行符 (诗歌、列表等格式) │ │ +│ │ 3. 单句过长时按 token 分割 │ │ +│ │ 4. 单个 token 超过 buffer 时单独成块 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 分块场景 │ +│ │ +│ 场景 A: 单句超长,需要在句内分割 │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 原文 (诗歌): │ │ +│ │ "No man is an island, │ │ +│ │ Entire of itself, │ │ +│ │ Every man is a piece of the continent, │ │ +│ │ A part of the main." │ │ +│ │ │ │ +│ │ max_char_buffer=40 │ │ +│ │ │ │ +│ │ 分块结果: │ │ +│ │ Chunk 1: "No man is an island,\nEntire of itself," (38 chars)│ │ +│ │ Chunk 2: "Every man is a piece of the continent," (38 chars)│ │ +│ │ Chunk 3: "A part of the main." (19 chars) │ │ +│ │ │ │ +│ │ 特点: 尊重换行符,在换行处优先分割 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 场景 B: 单个 token 超长 │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 原文: "This is antidisestablishmentarianism." │ │ +│ │ max_char_buffer=20 │ │ +│ │ │ │ +│ │ 分块结果: │ │ +│ │ Chunk 1: "This is" (7 chars) │ │ +│ │ Chunk 2: "antidisestablishmentarianism" (28 chars) │ │ +│ │ Chunk 3: "." (1 char) │ │ +│ │ │ │ +│ │ 特点: 超长 token 即使超过 buffer 也单独成块 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 场景 C: 多短句可合并 │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 原文: "Roses are red. Violets are blue. Flowers are nice. And so │ │ +│ │ are you." │ │ +│ │ max_char_buffer=60 │ │ +│ │ │ │ +│ │ 分块结果: │ │ +│ │ Chunk 1: "Roses are red. Violets are blue. Flowers are nice." │ │ +│ │ (50 chars) │ │ +│ │ Chunk 2: "And so are you." (15 chars) │ │ +│ │ │ │ +│ │ 特点: 多个完整句子可合并到一个 chunk (不超过 buffer) │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## 关键代码解析 + +### 1. ChunkIterator 主逻辑 (`__next__`) + +```python +# langextract/chunking.py:441-506 +def __next__(self) -> TextChunk: + # 获取下一个句子 (或句子的剩余部分) + sentence = next(self.sentence_iter) + + # 策略 1: 如果第一个 token 就超过 buffer,单独成块 + curr_chunk = create_token_interval( + sentence.start_index, sentence.start_index + 1 + ) + if self._tokens_exceed_buffer(curr_chunk): + self.sentence_iter = SentenceIterator( + self.tokenized_text, curr_token_pos=sentence.start_index + 1 + ) + self.broken_sentence = True + return TextChunk(token_interval=curr_chunk, document=self.document) + + # 策略 2: 在句子内追加 tokens,直到接近 buffer + start_of_new_line = -1 + for token_index in range(curr_chunk.start_index, sentence.end_index): + # 记录换行位置 (用于优先在换行处分割) + if self.tokenized_text.tokens[token_index].first_token_after_newline: + start_of_new_line = token_index + + test_chunk = create_token_interval( + curr_chunk.start_index, token_index + 1 + ) + + if self._tokens_exceed_buffer(test_chunk): + # 超过 buffer 了 + # 优先在最近的换行处分割 (如果有) + if start_of_new_line > 0 and start_of_new_line > curr_chunk.start_index: + curr_chunk = create_token_interval( + curr_chunk.start_index, start_of_new_line + ) + # 更新句子迭代器,下次从这里继续 + self.sentence_iter = SentenceIterator( + self.tokenized_text, curr_token_pos=curr_chunk.end_index + ) + self.broken_sentence = True + return TextChunk(token_interval=curr_chunk, document=self.document) + else: + curr_chunk = test_chunk # 继续追加 + + # 策略 3: 整句没超过 buffer,尝试合并更多句子 + if self.broken_sentence: + self.broken_sentence = False + else: + for sentence in self.sentence_iter: + test_chunk = create_token_interval( + curr_chunk.start_index, sentence.end_index + ) + if self._tokens_exceed_buffer(test_chunk): + self.sentence_iter = SentenceIterator( + self.tokenized_text, curr_token_pos=curr_chunk.end_index + ) + return TextChunk(token_interval=curr_chunk, document=self.document) + else: + curr_chunk = test_chunk # 合并整句 + + return TextChunk(token_interval=curr_chunk, document=self.document) +``` + +### 2. 句子边界检测 (`SentenceIterator`) + +```python +# langextract/chunking.py:282-340 +class SentenceIterator: + """迭代 tokenized 文本的句子""" + + def __next__(self) -> tokenizer_lib.TokenInterval: + # 找到包含当前 token 的句子范围 + sentence_range = tokenizer_lib.find_sentence_range( + self.tokenized_text.text, + self.tokenized_text.tokens, + self.curr_token_pos, + ) + # 从当前位置开始,而不是句子开头 + # (如果我们在句子中间,从这里继续) + sentence_range = create_token_interval( + self.curr_token_pos, sentence_range.end_index + ) + self.curr_token_pos = sentence_range.end_index + return sentence_range +``` + +--- + +## Overlap 与上下文窗口 + +LangExtract **没有使用传统的 chunk overlap 机制**,而是提供了 **`context_window_chars`** 参数来解决跨 chunk 的指代消解问题。 + +| 机制 | 说明 | 示例 | +|------|------|------| +| 传统 overlap | 相邻 chunks 共享部分文本 | Chunk1: [0-100], Chunk2: [80-180] | +| LangExtract context_window | 前一个 chunk 的尾部文本作为 prompt 上下文 | Chunk2 的 prompt 包含 Chunk1 的最后 N 个字符 | + +**ContextAwarePromptBuilder 实现**: + +```python +# langextract/prompting.py:179-276 +class ContextAwarePromptBuilder(PromptBuilder): + """支持跨 chunk 上下文追踪的 prompt builder""" + + _CONTEXT_PREFIX = "[Previous text]: ..." + + def __init__( + self, + generator: QAPromptGenerator, + context_window_chars: int | None = None, # 例如: 100 + ): + super().__init__(generator) + self._context_window_chars = context_window_chars + self._prev_chunk_by_doc_id: dict[str, str] = {} # 按文档追踪 + + def build_prompt( + self, + chunk_text: str, + document_id: str, + additional_context: str | None = None, + ) -> str: + # 构建有效上下文 (前一个 chunk + 额外上下文) + effective_context = self._build_effective_context( + document_id, additional_context + ) + + prompt = self._generator.render( + question=chunk_text, + additional_context=effective_context, + ) + + # 更新状态: 保存当前 chunk 供下一个使用 + self._update_state(document_id, chunk_text) + return prompt + + def _build_effective_context( + self, document_id: str, additional_context: str | None + ) -> str | None: + context_parts: list[str] = [] + + # 注入前一个 chunk 的尾部 + if self._context_window_chars and document_id in self._prev_chunk_by_doc_id: + prev_text = self._prev_chunk_by_doc_id[document_id] + window = prev_text[-self._context_window_chars :] # 取尾部 + context_parts.append(f"{self._CONTEXT_PREFIX}{window}") + + if additional_context: + context_parts.append(additional_context) + + return "\n\n".join(context_parts) if context_parts else None +``` + +**使用示例**: + +```python +result = lx.extract( + text_or_documents=long_text, + prompt_description=prompt, + examples=examples, + context_window_chars=100, # 每个 chunk 包含前一个 chunk 的最后 100 字符 +) +``` + +**效果**: + +假设文档被分为两个 chunks: +- Chunk1: "Dr. Sarah Johnson is a cardiologist at the hospital. She" +- Chunk2: " specializes in heart disease and hypertension." + +没有 context_window 时,Chunk2 的 "She" 可能无法正确解析。 + +有 `context_window_chars=50` 时,Chunk2 的 prompt 会包含: +``` +[Previous text]: ...cardiologist at the hospital. She + +Q: specializes in heart disease and hypertension. +A: +``` + +这样 LLM 就能知道 "She" 指的是 "Dr. Sarah Johnson"。 + +--- + +## 跨 Chunk 实体合并与去重 + +LangExtract 目前 **没有自动的跨 chunk 实体去重机制**。每个 chunk 的处理是独立的,结果累积到 `per_doc` 字典中。 + +```python +# langextract/annotation.py:307-332 (Annotator._annotate_documents_single_pass) +def _annotate_documents_single_pass(...): + per_doc: DefaultDict[str, list[data.Extraction]] = collections.defaultdict(list) + + for batch in batch_iter: + # ... 推理、解析、对齐 ... + + for text_chunk, scored_outputs in zip(batch, outputs): + # ... + + aligned_extractions = resolver.align(...) + + for extraction in aligned_extractions: + # 直接追加,没有去重 + per_doc[text_chunk.document_id].append(extraction) +``` + +**用户需要自己处理去重**,可以基于: +1. `char_interval` 重叠检测 +2. `extraction_text` + `extraction_class` 相似度 + +**例外: Sequential Extraction Passes** + +当使用 `extraction_passes > 1` 时,多次抽取的结果会进行非重叠合并: + +```python +# langextract/annotation.py:46-84 +def _merge_non_overlapping_extractions( + all_extractions: list[Iterable[data.Extraction]], +) -> list[data.Extraction]: + """合并多次抽取的结果,重叠时保留较早的抽取""" + if not all_extractions: + return [] + if len(all_extractions) == 1: + return list(all_extractions[0]) + + merged_extractions = list(all_extractions[0]) # 第一次抽取的结果 + + for pass_extractions in all_extractions[1:]: + for extraction in pass_extractions: + # 检查是否与已合并的结果重叠 + overlaps = False + if extraction.char_interval is not None: + for existing_extraction in merged_extractions: + if existing_extraction.char_interval is not None: + if _extractions_overlap(extraction, existing_extraction): + overlaps = True + break + + # 只有不重叠时才添加 + if not overlaps: + merged_extractions.append(extraction) + + return merged_extractions + +def _extractions_overlap( + extraction1: data.Extraction, extraction2: data.Extraction +) -> bool: + """检查两个 extraction 的字符区间是否重叠""" + # [start1, end1) 与 [start2, end2) 重叠 + return start1 < end2 and start2 < end1 +``` + +**注意**: 这是同一文档多次抽取的合并策略,不是跨 chunk 去重。 + +--- + +## 分块参数配置 + +```python +result = lx.extract( + text_or_documents=text, + prompt_description=prompt, + examples=examples, + # 分块相关参数 + max_char_buffer=1000, # 每个 chunk 的最大字符数 + batch_length=10, # 每批处理的 chunk 数量 + max_workers=10, # 并行 worker 数 + context_window_chars=100, # 前一个 chunk 的上下文字符数 (可选) + extraction_passes=1, # 抽取次数 (可选,多次抽取时合并非重叠结果) +) +``` + +**参数说明**: + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `max_char_buffer` | 1000 | 每个 chunk 的最大字符数。调小可提高准确率但增加 API 调用。 | +| `batch_length` | 10 | 每批处理的 chunk 数量。与 `max_workers` 共同决定并行度。 | +| `max_workers` | 10 | 最大并行 worker 数。有效并行度受限于 `min(batch_length, max_workers)`。 | +| `context_window_chars` | `None` | 前一个 chunk 的上下文字符数。用于指代消解。 | +| `extraction_passes` | 1 | 抽取次数。> 1 时执行多次抽取并合并非重叠结果。 | + +--- + +**本文档基于代码版本**: langextract (docs/schema-design 分支) +**最后更新**: 2026-04-19 diff --git a/docs/architecture/overview.md b/docs/architecture/overview.md new file mode 100644 index 00000000..de4e2745 --- /dev/null +++ b/docs/architecture/overview.md @@ -0,0 +1,389 @@ +# LangExtract 架构概览 + +本文档系统梳理 LangExtract 的核心内部机制,帮助开发者理解信息抽取的完整流程。 + +--- + +## 快速导航 + +| 文档 | 内容概述 | 适用对象 | +|------|----------|----------| +| [Schema 设计](schema.md) | 字段类型、实体定义、JSON 与 Python 类对应关系 | 需要理解如何定义抽取任务的开发者 | +| [Prompt 组装](prompt.md) | 从 schema 到最终 prompt 的完整流程、Q:A 格式 | 需要定制 prompt 或理解 LLM 交互的开发者 | +| [输出解析与实体对齐](alignment.md) | LLM 输出解析、容错策略、精确匹配 + LCS 模糊匹配算法 | 需要处理对齐失败或调试抽取结果的开发者 | +| [长文档分块](chunking.md) | Chunk 策略、context_window_chars、跨 chunk 合并 | 需要处理长文档或优化分块参数的开发者 | +| [文档 TODO](TODO.md) | 代码注释改进计划、优先级、Issue 模板 | 贡献者、维护者 | + +--- + +## 架构导航 + +- **← [返回概览](overview.md)** +- **→ [Schema 设计](schema.md)** +- **→ [Prompt 组装](prompt.md)** +- **→ [输出解析与实体对齐](alignment.md)** +- **→ [长文档分块](chunking.md)** + +--- + +## Overview: 抽取流程一页纸 + +LangExtract 的信息抽取流程是一个典型的 **schema-driven** 流水线,从用户输入到最终返回结构化数据,经过以下阶段: + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ LangExtract 抽取流程 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 用户输入 (User Input) │ +│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │ +│ │ text_or_documents │ │ prompt_description │ │ examples │ │ +│ │ (待抽取文本) │ │ (抽取指令) │ │ (少量示例) │ │ +│ └────────┬────────┘ └────────┬────────┘ └────────┬────────┘ │ +│ │ │ │ │ +│ └────────────────────┼────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Schema 推断 (Schema Inference) │ │ +│ │ - 从 examples 中提取 extraction_class (实体类型) │ │ +│ │ - 分析 extraction_text 的值类型 (string/number/dict/list) │ │ +│ │ - 构建 BaseSchema 或 FormatModeSchema 实例 │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Prompt 组装 (Prompt Assembly) │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ 1. System Instruction: prompt_description │ │ │ +│ │ │ 2. Few-shot Examples: 格式化 examples 为 JSON/YAML │ │ │ +│ │ │ 3. Question: 当前 chunk 文本 (Q: ...) │ │ │ +│ │ │ 4. Answer Prefix: 引导模型输出 (A: ) │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ 关键类: QAPromptGenerator, ContextAwarePromptBuilder │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ LLM 推理 (Inference) │ │ +│ │ - 支持 Gemini, OpenAI, Ollama 等多种 provider │ │ +│ │ - 部分模型支持 schema constraints (结构化输出约束) │ │ +│ │ - 批量处理 (batch) 提高吞吐量 │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 输出解析 (Output Parsing) │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ 1. 围栏提取: 从 ```json / ```yaml 中提取内容 │ │ │ +│ │ │ 2. 格式解析: JSON.parse / yaml.safe_load │ │ │ +│ │ │ 3. 容错处理: 标签过滤、宽松解析模式 │ │ │ +│ │ │ 4. 结构转换: 转为 Extraction 对象序列 │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ 关键类: FormatHandler, Resolver │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 实体对齐 (Entity Alignment) │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ 1. 精确匹配: difflib.SequenceMatcher 逐 token 匹配 │ │ │ +│ │ │ 2. 模糊匹配: LCS (最长公共子序列) 算法 │ │ │ +│ │ │ 3. 归一化: 小写 + 轻量词干化 (去除 s 后缀) │ │ │ +│ │ │ 4. 状态标记: MATCH_EXACT / MATCH_FUZZY / MATCH_LESSER │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ 关键类: WordAligner │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 结果返回 (Result Return) │ │ +│ │ - AnnotatedDocument: 包含 extractions 列表 │ │ +│ │ - 每个 Extraction 包含: char_interval, alignment_status, attributes │ │ +│ │ - 支持 visualization 可视化 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 核心数据流向伪代码 + +```python +# langextract/extraction.py:37-377 (extract 函数) +def extract(text_or_documents, prompt_description, examples, ...): + # 1. 验证 examples 的对齐质量 (可选但推荐) + if prompt_validation_level != OFF: + alignment_report = validate_prompt_alignment(examples) + + # 2. 创建 Prompt 模板 + prompt_template = PromptTemplateStructured( + description=prompt_description, + examples=examples + ) + + # 3. 初始化 LLM (自动选择 provider) + language_model = factory.create_model( + model_id=model_id, + examples=examples if use_schema_constraints else None + ) + + # 4. 创建 Annotator + annotator = Annotator( + language_model=language_model, + prompt_template=prompt_template, + format_handler=format_handler + ) + + # 5. 执行抽取 (内部包含分块、推理、解析、对齐) + result = annotator.annotate_text( + text=text_or_documents, + resolver=Resolver(...), + max_char_buffer=max_char_buffer, + ... + ) + + return result +``` + +--- + +## 已知限制与 FAQ + +### Q1: 为什么我的实体对齐失败了? + +**常见原因**: + +1. **LLM 提取的文本与原文不一致** + - LLM 可能 paraphrase(转述)原文,例如原文是 "John Smith",但 LLM 返回 "Mr. Smith" + - 解决方案:在 prompt_description 中强调 "Use exact text from the source. Do not paraphrase." + +2. **提取文本跨越 chunk 边界** + - 如果一个实体被分割在两个 chunks 中,对齐可能失败 + - 解决方案:使用 `context_window_chars` 参数,或调整 `max_char_buffer` + +3. **模糊匹配阈值设置过高** + - 默认 `fuzzy_alignment_threshold=0.75`,如果提取文本与原文差异较大,可能无法匹配 + - 解决方案:调低阈值 `resolver_params={"fuzzy_alignment_threshold": 0.6}` + +4. **特殊字符或大小写问题** + - 虽然有归一化处理,但某些特殊字符可能导致问题 + - 检查 `extraction_text` 中是否有不可见字符 + +**调试方法**: +```python +# 查看对齐失败的 extractions +failed = [e for e in result.extractions if e.alignment_status is None] +for e in failed: + print(f"Failed: class={e.extraction_class}, text={e.extraction_text}") +``` + +--- + +### Q2: 为什么 schema 里的 Optional 字段没被抽取? + +**LangExtract 没有传统的 "Optional" 概念**。 + +LangExtract 的 schema 是 **example-driven** 的,不是类型驱动的。这意味着: + +1. **schema 从 examples 推断** + - 如果你在 examples 中定义了某个 extraction_class,LLM 会被引导去抽取这类实体 + - 但这不是强制的——LLM 可能抽取也可能不抽取 + +2. **没有 "必填/可选" 标记** + - 传统 Pydantic 模型有 `Optional[]` 或 `required=True/False` + - LangExtract 没有这个机制 + +3. **如何控制抽取行为** + - 通过 `prompt_description` 描述应该抽取什么 + - 通过 `examples` 展示抽取模式 + - 如果某些实体经常被遗漏,增加更多相关 examples + +**注意**: 如果 LLM 没有抽取某个实体,结果中不会有对应的 Extraction 对象(值为 null 或空字符串也不会被表示)。 + +--- + +### Q3: 为什么我的输出解析失败了? + +**常见场景**: + +1. **LLM 没有返回 JSON/YAML 格式** + - 某些模型可能忽略格式指令,返回自然语言 + - 解决方案: + - 确保 examples 格式正确 + - 使用支持 schema constraints 的模型(如 Gemini) + - 检查 `use_schema_constraints=True`(默认) + +2. **多个围栏块冲突** + - LLM 可能返回多个 ```json 块 + - 检查 `strict_fences` 设置(默认 False,取第一个有效块) + +3. **推理模型的 `` 标签** + - DeepSeek-R1, QwQ 等模型会在 JSON 前输出思考过程 + - LangExtract 会自动处理(非 strict 模式),但如果格式太复杂可能失败 + +4. **缺少 `extractions` wrapper** + - 某些模型可能直接返回 `[...]` 而不是 `{"extractions": [...]}` + - 默认 `allow_top_level_list=True` 会处理这种情况 + +**调试方法**: +```python +# 使用 debug=True 查看原始输出 +result = lx.extract( + ..., + debug=True, # 启用详细日志 +) +``` + +--- + +### Q4: 为什么同一个实体会被多次抽取? + +**原因**: + +1. **跨 chunk 边界** + - 一个实体可能出现在多个 chunks 中(如果 `context_window_chars` 包含了它) + - LangExtract 目前没有自动去重 + +2. **多次抽取 (`extraction_passes > 1`)** + - 虽然多次抽取会合并非重叠结果,但如果同一个实体在不同位置有相似文本,可能被多次抽取 + +3. **LLM 自身的不稳定性** + - 即使是相同的 prompt,LLM 也可能返回略有不同的结果 + +**解决方案**(用户自行处理): +```python +# 基于 char_interval 去重 +def deduplicate(extractions): + seen = set() + result = [] + for e in extractions: + if e.char_interval is None: + continue + key = (e.extraction_class, e.char_interval.start_pos, e.char_interval.end_pos) + if key not in seen: + seen.add(key) + result.append(e) + return result +``` + +--- + +### Q5: `max_char_buffer` 应该设多大? + +**考虑因素**: + +1. **模型上下文窗口** + - `max_char_buffer` 应该远小于模型的最大 token 限制 + - 因为 prompt 本身(description + examples)也占用 tokens + +2. **抽取精度 vs API 成本** + - 较小的 `max_char_buffer` → 更多 chunks → 更多 API 调用 → 更高成本,但可能更准确 + - 较大的 `max_char_buffer` → 更少 chunks → 更低成本,但可能遗漏信息 + +3. **经验建议** + - 默认 `1000` 是一个平衡值 + - 简单任务(如抽取人名)可以用较大值(如 `2000-3000`) + - 复杂任务(如关系抽取)建议用较小值(如 `500-1000`) + +4. **与 token 数量的关系** + - `max_char_buffer` 是字符数,不是 token 数 + - 粗略估计:英文 ~1 token = 4 chars,中文 ~1 token = 2 chars + +**配置示例**: +```python +# 高精度模式 +result = lx.extract( + ..., + max_char_buffer=500, # 较小 chunk + extraction_passes=3, # 多次抽取提高召回 +) + +# 低成本模式 +result = lx.extract( + ..., + max_char_buffer=2000, # 较大 chunk + extraction_passes=1, # 单次抽取 +) +``` + +**参数说明**: + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `max_char_buffer` | 1000 | 每个 chunk 的最大字符数。调小可提高准确率但增加 API 调用。 | +| `batch_length` | 10 | 每批处理的 chunk 数量。与 `max_workers` 共同决定并行度。 | +| `max_workers` | 10 | 最大并行 worker 数。有效并行度受限于 `min(batch_length, max_workers)`。 | +| `context_window_chars` | `None` | 前一个 chunk 的上下文字符数。用于指代消解。 | +| `extraction_passes` | 1 | 抽取次数。> 1 时执行多次抽取并合并非重叠结果。 | + +--- + +## 附录:核心类关系图 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ 核心类关系图 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ ExampleData │────▶│ Extraction │────▶│ CharInterval │ │ +│ │ (示例数据) │ │ (抽取结果) │ │ (字符区间) │ │ +│ └──────────────┘ └──────────────┘ └──────────────┘ │ +│ │ │ │ +│ │ ▼ │ +│ │ ┌──────────────┐ │ +│ │ │TokenInterval │ │ +│ │ │ (token 区间) │ │ +│ │ └──────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌──────────────────────────────────────────────────────────────────────┐ │ +│ │ extract() 入口函数 │ │ +│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ +│ │ │ Annotator │ │ Resolver │ │FormatHandler│ │ │ +│ │ │ (协调器) │ │ (解析对齐) │ │ (格式处理) │ │ │ +│ │ └──────┬──────┘ └──────┬──────┘ └─────────────┘ │ │ +│ │ │ │ │ │ +│ │ ▼ ▼ │ │ +│ │ ┌─────────────┐ ┌─────────────┐ │ │ +│ │ │ChunkIterator│ │ WordAligner │ │ │ +│ │ │ (分块器) │ │ (对齐器) │ │ │ +│ │ └─────────────┘ └─────────────┘ │ │ +│ └──────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌──────────────────────────────────────────────────────────────────────┐ │ +│ │ Prompt 组装层 │ │ +│ │ ┌─────────────────────┐ ┌─────────────────────────┐ │ │ +│ │ │PromptTemplateStruct │ │ QAPromptGenerator │ │ │ +│ │ │ (模板数据) │ │ (prompt 生成器) │ │ │ +│ │ └──────────┬──────────┘ └────────────┬────────────┘ │ │ +│ │ │ │ │ │ +│ │ ▼ ▼ │ │ +│ │ ┌─────────────────────────────────────────────────────────┐ │ │ +│ │ │ ContextAwarePromptBuilder │ │ │ +│ │ │ (支持跨 chunk 上下文的 prompt builder) │ │ │ +│ │ └─────────────────────────────────────────────────────────┘ │ │ +│ └──────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌──────────────────────────────────────────────────────────────────────┐ │ +│ │ Schema 层 │ │ +│ │ ┌──────────────┐ │ │ +│ │ │ BaseSchema │ (抽象基类) │ │ +│ │ └──────┬───────┘ │ │ +│ │ │ │ │ +│ │ ▼ │ │ +│ │ ┌──────────────────┐ │ │ +│ │ │ FormatModeSchema │ (当前主要实现: JSON/YAML 格式约束) │ │ +│ │ └──────────────────┘ │ │ +│ └──────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +**本文档基于代码版本**: langextract (docs/schema-design 分支) +**最后更新**: 2026-04-19 diff --git a/docs/architecture/prompt.md b/docs/architecture/prompt.md new file mode 100644 index 00000000..81310337 --- /dev/null +++ b/docs/architecture/prompt.md @@ -0,0 +1,255 @@ +# Prompt 组装 + +Prompt 组装是 LangExtract 的核心环节,它将用户的 `prompt_description` 和 `examples` 转换为 LLM 可理解的指令格式。 + +--- + +## 架构导航 + +- **← [返回概览](overview.md)** +- **→ [Schema 设计](schema.md)** +- **→ [输出解析与实体对齐](alignment.md)** +- **→ [长文档分块](chunking.md)** + +--- + +## 目录 + +- [完整流程](#完整流程) +- [关键代码路径](#关键代码路径) +- [完整 Prompt 示例](#完整-prompt-示例) + +--- + +## 完整流程 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Prompt 组装流程 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 输入 │ +│ ┌─────────────────┐ ┌─────────────────────────────────────────────┐ │ +│ │ prompt_description │ │ examples │ │ +│ │ "Extract persons..." │ │ [ExampleData(text=..., extractions=...)] │ │ +│ └────────┬────────┘ └─────────────────────┬───────────────────────┘ │ +│ │ │ │ +│ └────────────────┬────────────────┘ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ PromptTemplateStructured (数据容器) │ │ +│ │ ┌───────────────────────────────────────────────────────────────┐ │ │ +│ │ │ description: "Extract persons and medications from text..." │ │ │ +│ │ │ examples: [ExampleData, ExampleData, ...] │ │ │ +│ │ └───────────────────────────────────────────────────────────────┘ │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ FormatHandler (示例格式化) │ │ +│ │ │ │ +│ │ 每个 ExampleData.extractions 被格式化为: │ │ +│ │ { │ │ +│ │ "extractions": [ │ │ +│ │ {"person": "John", "person_attributes": {"age": "30"}}, │ │ +│ │ {"medication": "Aspirin", "medication_attributes": {...}} │ │ +│ │ ] │ │ +│ │ } │ │ +│ │ │ │ +│ │ 输出格式: JSON 或 YAML,带或不带 ``` 围栏 │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ QAPromptGenerator (最终组装) │ │ +│ │ │ │ +│ │ render(question=chunk_text) 生成: │ │ +│ │ ┌───────────────────────────────────────────────────────────────┐ │ │ +│ │ │ [description] │ │ │ +│ │ │ │ │ │ +│ │ │ Examples │ │ │ +│ │ │ Q: [example_1.text] │ │ │ +│ │ │ A: [formatted_extractions_1] │ │ │ +│ │ │ │ │ │ +│ │ │ Q: [example_2.text] │ │ │ +│ │ │ A: [formatted_extractions_2] │ │ │ +│ │ │ │ │ │ +│ │ │ Q: [current_chunk_text] │ │ │ +│ │ │ A: │ │ │ +│ │ └───────────────────────────────────────────────────────────────┘ │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ ContextAwarePromptBuilder (跨 chunk 上下文) │ │ +│ │ │ │ +│ │ 可选功能: 注入前一个 chunk 的尾部文本作为上下文 │ │ +│ │ │ │ +│ │ [Previous text]: ...the patient was prescribed │ │ │ +│ │ [additional_context] │ │ │ +│ │ │ │ +│ │ 帮助解决指代消解问题: "She" → "Dr. Sarah Johnson" │ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## 关键代码路径 + +### 1. 示例格式化 (`format_extraction_example`) + +```python +# langextract/core/format_handler.py:116-149 +def format_extraction_example( + self, extractions: list[data.Extraction] +) -> str: + """将 extractions 格式化为 prompt 中的示例""" + items = [ + { + ext.extraction_class: ext.extraction_text, + f"{ext.extraction_class}{self.attribute_suffix}": ( + ext.attributes or {} + ), + } + for ext in extractions + ] + + if self.use_wrapper and self.wrapper_key: + payload = {self.wrapper_key: items} # {"extractions": [...]} + else: + payload = items + + if self.format_type == data.FormatType.YAML: + formatted = yaml.safe_dump(payload, ...) + else: + formatted = json.dumps(payload, indent=2, ensure_ascii=False) + + return self._add_fences(formatted) if self.use_fences else formatted +``` + +### 2. Prompt 渲染 (`QAPromptGenerator.render`) + +```python +# langextract/prompting.py:115-138 +def render(self, question: str, additional_context: str | None = None) -> str: + """生成完整的 prompt 文本""" + prompt_lines: list[str] = [f"{self.template.description}\n"] + + if additional_context: + prompt_lines.append(f"{additional_context}\n") + + if self.template.examples: + prompt_lines.append(self.examples_heading) # "Examples" + for ex in self.template.examples: + prompt_lines.append(self.format_example_as_text(ex)) + + prompt_lines.append(f"{self.question_prefix}{question}") # "Q: ..." + prompt_lines.append(self.answer_prefix) # "A: " + return "\n".join(prompt_lines) +``` + +### 3. 示例格式化 (`format_example_as_text`) + +```python +# langextract/prompting.py:98-113 +def format_example_as_text(self, example: data.ExampleData) -> str: + """将单个 example 格式化为 Q:A 对""" + question = example.text + answer = self.format_handler.format_extraction_example(example.extractions) + + return "\n".join([ + f"{self.question_prefix}{question}", + f"{self.answer_prefix}{answer}\n", + ]) +``` + +### 4. 跨 chunk 上下文 (`ContextAwarePromptBuilder`) + +```python +# langextract/prompting.py:242-266 +def _build_effective_context( + self, + document_id: str, + additional_context: str | None, +) -> str | None: + """组合前一个 chunk 的上下文和额外上下文""" + context_parts: list[str] = [] + + if self._context_window_chars and document_id in self._prev_chunk_by_doc_id: + prev_text = self._prev_chunk_by_doc_id[document_id] + window = prev_text[-self._context_window_chars :] + context_parts.append(f"{self._CONTEXT_PREFIX}{window}") + # 例如: "[Previous text]: ...the patient visited the clinic" + + if additional_context: + context_parts.append(additional_context) + + return "\n\n".join(context_parts) if context_parts else None +``` + +--- + +## 完整 Prompt 示例 + +假设用户提供: + +```python +prompt_description = """Extract all persons and their roles from the text. +Use exact text from the source. Do not paraphrase.""" + +examples = [ + lx.data.ExampleData( + text="Dr. Smith, the chief surgeon, operated on patient Johnson.", + extractions=[ + lx.data.Extraction( + extraction_class="person", + extraction_text="Dr. Smith", + attributes={"role": "chief surgeon"} + ), + lx.data.Extraction( + extraction_class="person", + extraction_text="Johnson", + attributes={"role": "patient"} + ), + ] + ) +] +``` + +生成的 prompt 将是: + +``` +Extract all persons and their roles from the text. +Use exact text from the source. Do not paraphrase. + +Examples +Q: Dr. Smith, the chief surgeon, operated on patient Johnson. +A: ```json +{ + "extractions": [ + { + "person": "Dr. Smith", + "person_attributes": { + "role": "chief surgeon" + } + }, + { + "person": "Johnson", + "person_attributes": { + "role": "patient" + } + } + ] +} +``` + +Q: [当前待处理的 chunk 文本] +A: +``` + +--- + +**本文档基于代码版本**: langextract (docs/schema-design 分支) +**最后更新**: 2026-04-19 diff --git a/docs/architecture/provider_layer.md b/docs/architecture/provider_layer.md new file mode 100644 index 00000000..2b01883d --- /dev/null +++ b/docs/architecture/provider_layer.md @@ -0,0 +1,639 @@ +# Provider 层设计 + +LangExtract 的 provider 层提供了统一的 LLM 抽象接口,支持多种模型后端(Gemini、OpenAI、Ollama 等),并通过插件机制支持第三方扩展。 + +--- + +## 架构导航 + +- **← [返回概览](overview.md)** +- **→ [Schema 设计](schema.md)** +- **→ [Prompt 组装](prompt.md)** + +--- + +## 目录 + +- [LLMProvider 契约](#llmprovider-契约) +- [内置 Provider 对比](#内置-provider-对比) +- [ProviderRegistry 与模型解析](#providerregistry-与模型解析) +- [Entry Points 插件机制](#entry-points-插件机制) +- [MockProvider 使用指南](#mockprovider-使用指南) +- [Context Manager 资源管理](#context-manager-资源管理) +- [高层 API 与 Registry 关系](#高层-api-与-registry-关系) + +--- + +## LLMProvider 契约 + +`LLMProvider` 是所有 LLM provider 必须实现的抽象基类,定义了统一的接口契约。 + +### 核心接口 + +```python +# langextract/core/base_model.py:64-193 +class LLMProvider(abc.ABC): + @property + @abc.abstractmethod + def name(self) -> str: + """Provider 名称标识符,如 "gemini", "openai", "mock"。""" + + @property + @abc.abstractmethod + def supported_models(self) -> TypingSequence[str]: + """支持的模型 ID 正则表达式模式列表。""" + + @abc.abstractmethod + def generate( + self, prompt: str, model: str | None = None, **kwargs: Any + ) -> GenerateResult: + """同步生成文本。""" + + @abc.abstractmethod + async def agenerate( + self, prompt: str, model: str | None = None, **kwargs: Any + ) -> GenerateResult: + """异步生成文本。""" + + @abc.abstractmethod + def close(self) -> None: + """清理资源(关闭 HTTP 客户端、连接池等)。""" +``` + +### 返回类型 + +**GenerateResult**: 封装生成结果的数据类 + +```python +# langextract/core/base_model.py:49-62 +@dataclasses.dataclass +class GenerateResult: + text: str # 生成的文本输出 + usage: Usage | None # 可选的 token 使用量信息 + raw_response: Any # 原始 API 响应 +``` + +**Usage**: Token 使用量信息 + +```python +# langextract/core/base_model.py:34-47 +@dataclasses.dataclass +class Usage: + input_tokens: int | None # 输入 prompt 的 token 数 + output_tokens: int | None # 输出的 token 数 + total_tokens: int | None # 总 token 数 +``` + +### BaseLanguageModel 默认实现 + +`BaseLanguageModel` 继承 `LLMProvider`,提供了默认实现以保持向后兼容: + +| 方法 | 默认实现 | 说明 | +|------|----------|------| +| `name` | 从类名推导 | 去掉 "LanguageModel" 后缀并小写 | +| `supported_models` | 返回空列表 | Provider 应覆盖此属性 | +| `generate` | 包装 `infer()` | 基于旧的 `infer()` API | +| `agenerate` | 线程池包装 | 在 `ThreadPoolExecutor` 中运行 `generate()` | +| `close` | 空操作 | Provider 应覆盖以清理资源 | + +--- + +## 内置 Provider 对比 + +| Provider | 名称 | 模型模式 | Schema 支持 | 异步支持 | 依赖包 | +|----------|------|----------|-------------|----------|--------| +| Gemini | `gemini` | `^gemini`, `^models/gemini` | 完整支持 (response_schema) | 线程池包装 | `google-genai` | +| OpenAI | `openai` | `^gpt`, `^o1`, `^o3`, `^o4` | 格式支持 (JSON mode) | 线程池包装 | `openai` | +| Ollama | `ollama` | `^ollama:` | 格式支持 | 线程池包装 | 无 (HTTP) | +| Mock | `mock` | `^mock$`, `^mock-` | 无 (测试用) | 原生同步 | 无 | + +### Gemini Provider + +```python +# langextract/providers/gemini.py +@router.register( + *patterns.GEMINI_PATTERNS, + priority=patterns.GEMINI_PRIORITY, +) +class GeminiLanguageModel(BaseLanguageModel): + def __init__( + self, + model_id: str = 'gemini-2.5-flash', + api_key: str | None = None, + vertexai: bool = False, + project: str | None = None, + location: str | None = None, + temperature: float = 0.0, + max_workers: int = 10, + **kwargs, + ): ... +``` + +**Key Features**: +- 完整的 `GeminiSchema` 结构化输出支持 +- 支持 Vertex AI (Enterprise) 模式 +- 内置 Batch API 支持 +- `google-genai` 官方 SDK + +### OpenAI Provider + +```python +# langextract/providers/openai.py +@router.register( + *patterns.OPENAI_PATTERNS, + priority=patterns.OPENAI_PRIORITY, +) +class OpenAILanguageModel(BaseLanguageModel): + def __init__( + self, + model_id: str = 'gpt-4o-mini', + api_key: str | None = None, + base_url: str | None = None, + organization: str | None = None, + temperature: float | None = None, + max_workers: int = 10, + **kwargs, + ): ... +``` + +**Key Features**: +- JSON mode 格式约束 +- 支持 `base_url` 自定义端点(兼容 Azure OpenAI、本地兼容 API) +- `openai` 官方 SDK + +--- + +## ProviderRegistry 与模型解析 + +`ProviderRegistry` 提供了中心化的 provider 注册和查找机制。 + +### 基本用法 + +```python +from langextract.providers import ProviderRegistry + +# 获取全局 registry 实例 +registry = ProviderRegistry.get_global() + +# 按名称获取 provider 类 +gemini_cls = registry.get("gemini") +mock_cls = registry.get("mock") + +# 按模型 ID 解析 provider +provider_cls = registry.get_for_model("gemini-2.5-flash") + +# 注册自定义 provider +registry.register(MyCustomProvider, patterns=[r"^my-model"]) + +# 清空 registry (用于测试) +registry.clear() +``` + +### 模型解析机制 + +模型解析基于正则表达式模式匹配,**高优先级**的 provider 优先匹配: + +``` +模型 ID: "gemini-2.5-flash" + │ + ▼ +┌───────────────────────────────────────────────┐ +│ Router Pattern Matching │ +├───────────────────────────────────────────────┤ +│ 1. 遍历所有注册的 pattern (按优先级排序) │ +│ 2. 检查 `re.match(pattern, model_id)` │ +│ 3. 返回第一个匹配的 provider │ +└───────────────────────────────────────────────┘ + │ + ▼ +返回: GeminiLanguageModel +``` + +### Priority 机制 + +```python +# 高优先级会覆盖低优先级 +@router.register(r"^gemini", priority=0) # 低优先级 +class DefaultGeminiProvider(...): ... + +@router.register(r"^gemini", priority=100) # 高优先级 +class CustomGeminiProvider(...): ... + +# 解析 "gemini-pro" → 返回 CustomGeminiProvider +``` + +**默认优先级值**: + +| Provider | 默认 Priority | +|----------|---------------| +| 内置 Provider (Gemini/OpenAI) | 0 | +| MockProvider | 100 | +| 第三方插件 (默认) | 20 | + +--- + +## Entry Points 插件机制 + +LangExtract 使用 Python 的 `entry_points` 机制实现第三方 provider 的自动发现和注册。 + +### 工作原理 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Entry Points 自动发现流程 │ +├─────────────────────────────────────────────────────────────┤ +│ │ +│ 1. 用户安装第三方包: pip install langextract-myprovider │ +│ │ +│ 2. ProviderRegistry 初始化时调用 load_plugins_once() │ +│ │ +│ 3. 查询 importlib.metadata.entry_points() │ +│ ── 筛选 group="langextract.providers" ── │ +│ │ +│ 4. 加载 entry point 指向的类 │ +│ ── 自动调用 @router.register 注册 ── │ +│ │ +│ 5. 完成!用户可通过 registry.get("myprovider") 使用 │ +│ │ +└─────────────────────────────────────────────────────────────┘ +``` + +### 插件开发步骤 + +#### 1. 创建 Provider 类 + +```python +# my_package/provider.py +from langextract.core import base_model +from langextract.core import types as core_types +from langextract.providers import router + +@router.register( + r"^my-model", r"^myprovider:", + priority=50, +) +class MyProviderLanguageModel(base_model.BaseLanguageModel): + model_id: str = "my-model-default" + + @property + def name(self) -> str: + return "myprovider" + + @property + def supported_models(self) -> list[str]: + return [r"^my-model", r"^myprovider:"] + + @classmethod + def get_model_patterns(cls) -> list[str]: + """可选:静态方法返回支持的模型模式""" + return [r"^my-model", r"^myprovider:"] + + def __init__(self, model_id: str, api_key: str | None = None, **kwargs): + super().__init__() + self.model_id = model_id + self._client = self._init_client(api_key) + + def _init_client(self, api_key: str | None): + # 初始化你的 API 客户端 + ... + + def infer(self, batch_prompts: Sequence[str], **kwargs): + # 实现推理逻辑 + for prompt in batch_prompts: + response = self._call_api(prompt, **kwargs) + yield [core_types.ScoredOutput(score=1.0, output=response)] + + def close(self) -> None: + # 清理资源 + if hasattr(self, '_client'): + self._client.close() +``` + +#### 2. 配置 pyproject.toml + +```toml +[project] +name = "langextract-myprovider" +version = "0.1.0" +description = "My custom provider for LangExtract" +requires-python = ">=3.10" +dependencies = [ + "langextract>=1.2.0", + # 你的其他依赖 +] + +[project.entry-points."langextract.providers"] +myprovider = "my_package.provider:MyProviderLanguageModel" + +# 带 priority 后缀的语法 (可选) +# myprovider = "my_package.provider:MyProviderLanguageModel:priority=100" +``` + +### Priority 后缀语法 + +Entry point value 支持 `:priority=N` 后缀来指定注册优先级: + +```toml +[project.entry-points."langextract.providers"] +# 基本语法 +basic_provider = "my_pkg:BasicProvider" + +# 带优先级的语法 +high_priority_provider = "my_pkg:HighPriorityProvider:priority=100" +low_priority_provider = "my_pkg:LowPriorityProvider:priority=0" +``` + +**优先级解析规则**: +1. 如果 entry point value 有 `:priority=N` 后缀,使用该值 +2. 否则使用类的 `pattern_priority` 属性(默认 20) +3. 最后使用默认插件优先级 20 + +### 禁用插件加载 + +```python +# 环境变量禁用 +export LANGEXTRACT_DISABLE_PLUGINS=1 + +# 或在代码中设置 +import os +os.environ["LANGEXTRACT_DISABLE_PLUGINS"] = "1" +``` + +--- + +## MockProvider 使用指南 + +`MockProvider` 是专为测试设计的 provider,不需要真实 API 调用。 + +### 基本用法 + +```python +from langextract.providers import MockProvider, ProviderRegistry + +# 方式 1: 直接实例化 +mock = MockProvider(fixed_response='{"result": "test"}') +result = mock.generate("Hello") +print(result.text) # '{"result": "test"}' + +# 方式 2: 通过 registry 获取 (推荐) +registry = ProviderRegistry() +mock_cls = registry.get("mock") +mock = mock_cls(fixed_response="test response") + +# 方式 3: 使用响应函数 +def custom_response(prompt: str, **kwargs) -> str: + if "name" in prompt: + return '{"name": "Alice"}' + elif "value" in prompt: + return '{"value": 42}' + return '{"default": true}' + +mock = MockProvider(response_fn=custom_response) +mock.generate("What is your name?") # '{"name": "Alice"}' +mock.generate("What is the value?") # '{"value": 42}' +``` + +### 与 Usage 信息 + +```python +from langextract.providers import MockProvider, Usage + +mock = MockProvider( + fixed_response='{"result": "ok"}', + usage=Usage(input_tokens=100, output_tokens=50, total_tokens=150) +) + +result = mock.generate("Test") +assert result.usage.input_tokens == 100 +assert result.usage.output_tokens == 50 +``` + +### 作为 pytest fixture + +```python +import pytest +from langextract.providers import MockProvider, ProviderRegistry + +@pytest.fixture +def mock_provider(): + """Fixture 提供已配置的 MockProvider""" + return MockProvider(fixed_response='{"test": true}') + +@pytest.fixture +def registry_mock(): + """Fixture 提供带有 MockProvider 的 registry""" + registry = ProviderRegistry() + registry.clear() # 重置状态 + # MockProvider 会自动注册 + return registry + +def test_with_mock(mock_provider): + result = mock_provider.generate("Hello") + assert '"test": true' in result.text + +def test_via_registry(registry_mock): + mock_cls = registry_mock.get("mock") + mock = mock_cls(fixed_response="custom") + result = mock.generate("test") + assert result.text == "custom" +``` + +--- + +## Context Manager 资源管理 + +`LLMProvider` 支持同步和异步两种上下文管理器,确保资源自动释放。 + +### 同步用法 + +```python +from langextract.providers import ProviderRegistry + +# 方式 1: 手动管理 (容易忘记 close) +provider = ProviderRegistry.get("gemini")(api_key="...") +try: + result = provider.generate("Hello") +finally: + provider.close() + +# 方式 2: with 语句 (推荐) +with ProviderRegistry.get("gemini")(api_key="...") as provider: + result = provider.generate("Hello") +# provider.close() 已自动调用 +``` + +### 异步用法 + +```python +import asyncio +from langextract.providers import MockProvider + +async def async_usage(): + # async with 语句 + async with MockProvider(fixed_response="async test") as mock: + result = await mock.agenerate("Hello") + print(result.text) + + # mock.close() 已自动调用 + +asyncio.run(async_usage()) +``` + +### 测试 close() 调用 + +`MockProvider.close_called` 计数器可用于验证资源清理: + +```python +from langextract.providers import MockProvider + +def test_context_manager_calls_close(): + mock = MockProvider() + assert mock.close_called == 0 + + with mock: + result = mock.generate("Test") + assert mock.close_called == 0 # 还在上下文中 + + assert mock.close_called == 1 # 已退出上下文 + +def test_async_context_manager(): + import asyncio + + async def test(): + mock = MockProvider() + async with mock: + result = await mock.agenerate("Test") + assert mock.close_called == 0 + assert mock.close_called == 1 + + asyncio.run(test()) +``` + +### 异常时的资源清理 + +上下文管理器保证即使发生异常,`close()` 也会被调用: + +```python +from langextract.providers import MockProvider + +mock = MockProvider() + +try: + with mock: + raise ValueError("Something went wrong") +except ValueError: + pass + +assert mock.close_called == 1 # close() 仍然被调用 +``` + +--- + +## 高层 API 与 Registry 关系 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ 高层 API → Registry → Provider 调用链 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 用户调用: lx.extract() │ │ +│ │ (langextract/extraction.py) │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ factory.create_model() │ │ +│ │ (langextract/factory.py) │ │ +│ │ - 解析 model_id │ │ +│ │ - 查找对应的 provider 类 │ │ +│ │ - 实例化 provider │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ ProviderRegistry / router │ │ +│ │ (langextract/providers/registry.py, router.py) │ │ +│ │ - 管理已注册的 provider 模式 │ │ +│ │ - 按优先级匹配模型 ID │ │ +│ │ - 自动加载 entry points 插件 │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ LLMProvider 实例 │ │ +│ │ (GeminiLanguageModel, OpenAILanguageModel, MockProvider, ...) │ │ +│ │ - 封装真实 API 调用 │ │ +│ │ - 管理 HTTP 客户端资源 │ │ +│ │ - 支持 generate() / agenerate() │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 完整调用流程示例 + +```python +import langextract as lx + +# 用户调用高层 API +result = lx.extract( + text_or_documents="ROMEO. But soft! What light...", + prompt_description="Extract character names", + examples=[...], + config=lx.factory.ModelConfig( + model_id="gemini-2.5-flash", # 决定使用哪个 provider + provider_kwargs={"api_key": "your-api-key"}, + ), +) + +# 内部流程: +# 1. extract() → factory.create_model() +# 2. create_model() → registry.get_for_model("gemini-2.5-flash") +# 3. registry 匹配 pattern "^gemini" → GeminiLanguageModel +# 4. 实例化 GeminiLanguageModel(api_key="...") +# 5. 调用 provider.generate() / infer() 执行推理 +``` + +### 直接使用 Provider + +有时你可能想直接使用 provider 而不经过完整的 extract 流程: + +```python +from langextract.providers import ProviderRegistry + +# 获取 provider 类 +gemini_cls = ProviderRegistry.get("gemini") + +# 实例化 +with gemini_cls( + model_id="gemini-2.5-flash", + api_key="your-api-key", + temperature=0.0, +) as provider: + # 直接调用 generate + result = provider.generate( + prompt="Translate 'Hello' to French", + model="gemini-2.5-flash", + ) + print(result.text) # "Bonjour" +``` + +--- + +## 附录: 关键文件位置 + +| 文件 | 说明 | +|------|------| +| `langextract/core/base_model.py` | `LLMProvider`, `BaseLanguageModel`, `GenerateResult`, `Usage` | +| `langextract/providers/registry.py` | `ProviderRegistry`, `ProviderInfo`, `MockProvider` | +| `langextract/providers/router.py` | 底层模式匹配与注册机制 | +| `langextract/providers/__init__.py` | 导出入口、插件加载 `load_plugins_once()` | +| `langextract/providers/gemini.py` | Gemini provider 实现 | +| `langextract/providers/openai.py` | OpenAI provider 实现 | +| `tests/provider_abstraction_test.py` | Provider 层单元测试 | + +--- + +**本文档基于代码版本**: langextract (main) +**最后更新**: 2026-06-18 diff --git a/docs/architecture/schema.md b/docs/architecture/schema.md new file mode 100644 index 00000000..58478d33 --- /dev/null +++ b/docs/architecture/schema.md @@ -0,0 +1,188 @@ +# Schema 设计 + +LangExtract 的 schema 设计采用 **example-driven** 模式——schema 从用户提供的 `examples` 中自动推断,而非显式定义 Pydantic 模型。 + +--- + +## 架构导航 + +- **← [返回概览](overview.md)** +- **→ [Prompt 组装](prompt.md)** +- **→ [输出解析与实体对齐](alignment.md)** +- **→ [长文档分块](chunking.md)** + +--- + +## 目录 + +- [支持的字段类型](#支持的字段类型) +- [实体定义方式](#实体定义方式) +- [Schema 的 JSON 表达与 Python 类对应关系](#schema-的-json-表达与-python-类对应关系) +- [Schema 约束如何应用到 LLM](#schema-约束如何应用到-llm) + +--- + +## 支持的字段类型 + +| 类型 | 说明 | 示例值 | +|------|------|--------| +| `string` | 字符串 (主要类型) | `"John Smith"`, `"2024-01-15"` | +| `number` | 数值 (自动转为 string) | `42`, `3.14` | +| `dict` | 属性字典 (通过后缀识别) | `{"dosage": "10mg", "route": "oral"}` | +| `list` | 列表 (在 attributes 中) | `["symptom1", "symptom2"]` | + +--- + +## 实体定义方式 + +实体通过 `Extraction` 类定义,每个实体包含三个核心属性: + +```python +# langextract/core/data.py:64-118 +@dataclasses.dataclass +class Extraction: + extraction_class: str # 实体类型 (如 "person", "medication") + extraction_text: str # 实体文本 (从原文提取的内容) + char_interval: CharInterval | None = None # 原文中的字符偏移 + alignment_status: AlignmentStatus | None = None # 对齐状态 + extraction_index: int | None = None # 排序索引 + group_index: int | None = None # 分组索引 + attributes: dict[str, str | list[str]] | None = None # 附加属性 +``` + +--- + +## Schema 的 JSON 表达与 Python 类对应关系 + +LangExtract 的 schema 系统有两层抽象: + +### 1. 数据层 (Data Layer) + +用户通过 `ExampleData` 和 `Extraction` 定义示例: + +```python +# 来自 README.md 的示例 +examples = [ + lx.data.ExampleData( + text="ROMEO. But soft! What light through yonder window breaks?", + extractions=[ + lx.data.Extraction( + extraction_class="character", # 对应 JSON 键 + extraction_text="ROMEO", # 对应 JSON 值 + attributes={"emotional_state": "wonder"} # 附加属性 + ), + lx.data.Extraction( + extraction_class="emotion", + extraction_text="But soft!", + attributes={"feeling": "gentle awe"} + ), + ] + ) +] +``` + +对应的 JSON 表达 (在 prompt 中): + +```json +{ + "extractions": [ + { + "character": "ROMEO", + "character_attributes": { + "emotional_state": "wonder" + } + }, + { + "emotion": "But soft!", + "emotion_attributes": { + "feeling": "gentle awe" + } + } + ] +} +``` + +### 2. Schema 层 (Schema Layer) + +`BaseSchema` 是抽象基类,定义了从 examples 生成 provider 配置的接口: + +```python +# langextract/core/schema.py:38-91 +class BaseSchema(abc.ABC): + @classmethod + @abc.abstractmethod + def from_examples( + cls, + examples_data: Sequence[data.ExampleData], + attribute_suffix: str = data.ATTRIBUTE_SUFFIX, + ) -> BaseSchema: + """从示例数据构建 schema 实例""" + + @abc.abstractmethod + def to_provider_config(self) -> dict[str, Any]: + """转换为 provider 特定的配置 (如 Gemini 的 response_schema)""" + + @property + @abc.abstractmethod + def requires_raw_output(self) -> bool: + """是否输出原始 JSON/YAML (无围栏标记)""" +``` + +### 3. FormatModeSchema: 通用格式约束 + +`FormatModeSchema` 是当前主要使用的 schema 实现,它不强制字段级结构,只保证输出格式: + +```python +# langextract/core/schema.py:93-139 +class FormatModeSchema(BaseSchema): + def __init__(self, format_type: types.FormatType = types.FormatType.JSON): + self.format_type = format_type + self._format = "json" if format_type == types.FormatType.JSON else "yaml" + + @classmethod + def from_examples( + cls, + examples_data: Sequence[data.ExampleData], + attribute_suffix: str = data.ATTRIBUTE_SUFFIX, + ) -> FormatModeSchema: + """从 examples 构建 schema (当前默认使用 JSON 格式)""" + return cls(format_type=types.FormatType.JSON) + + def to_provider_config(self) -> dict[str, Any]: + """返回 provider 配置""" + return {"format": self._format} + + @property + def requires_raw_output(self) -> bool: + """JSON 格式输出原始 JSON (无围栏),YAML 则需要围栏""" + return self._format == "json" +``` + +--- + +## Schema 约束如何应用到 LLM + +在 `extract()` 函数中,schema 被传递给 model factory: + +```python +# langextract/extraction.py:298-303 +language_model = factory.create_model( + config=config, + examples=prompt_template.examples if use_schema_constraints else None, + use_schema_constraints=use_schema_constraints, + fence_output=fence_output, +) +``` + +不同 provider 对 schema 的支持程度不同: + +| Provider | Schema 支持 | 实现方式 | +|----------|-------------|----------| +| Gemini | 完整支持 | `response_schema` + 结构化输出 | +| OpenAI | 格式支持 | JSON mode | +| Ollama | 格式支持 | `format: "json"` 参数 | + +--- + +**本文档基于代码版本**: langextract (docs/schema-design 分支) +**最后更新**: 2026-04-19 diff --git a/docs/architecture/schema_and_prompt_design.md b/docs/architecture/schema_and_prompt_design.md new file mode 100644 index 00000000..b6f80925 --- /dev/null +++ b/docs/architecture/schema_and_prompt_design.md @@ -0,0 +1,1766 @@ +# LangExtract Schema & Prompt 架构设计 + +本文档系统梳理 LangExtract 的核心内部机制,包括 schema 设计方式、prompt 组装逻辑、实体对齐机制、输出解析策略和长文档分块方法。 + +--- + +## Table of Contents + +- [Overview: 抽取流程一页纸](#overview-抽取流程一页纸) +- [Schema 设计](#schema-设计) +- [Prompt 组装](#prompt-组装) +- [输出解析](#输出解析) +- [实体对齐 (Alignment)](#实体对齐-alignment) +- [长文档分块](#长文档分块) +- [已知限制与 FAQ](#已知限制与-faq) +- [文档 TODO](#文档-todo) + +--- + +## Overview: 抽取流程一页纸 + +LangExtract 的信息抽取流程是一个典型的 **schema-driven** 流水线,从用户输入到最终返回结构化数据,经过以下阶段: + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ LangExtract 抽取流程 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 用户输入 (User Input) │ +│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │ +│ │ text_or_documents │ │ prompt_description │ │ examples │ │ +│ │ (待抽取文本) │ │ (抽取指令) │ │ (少量示例) │ │ +│ └────────┬────────┘ └────────┬────────┘ └────────┬────────┘ │ +│ │ │ │ │ +│ └────────────────────┼────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Schema 推断 (Schema Inference) │ │ +│ │ - 从 examples 中提取 extraction_class (实体类型) │ │ +│ │ - 分析 extraction_text 的值类型 (string/number/dict/list) │ │ +│ │ - 构建 BaseSchema 或 FormatModeSchema 实例 │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Prompt 组装 (Prompt Assembly) │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ 1. System Instruction: prompt_description │ │ │ +│ │ │ 2. Few-shot Examples: 格式化 examples 为 JSON/YAML │ │ │ +│ │ │ 3. Question: 当前 chunk 文本 (Q: ...) │ │ │ +│ │ │ 4. Answer Prefix: 引导模型输出 (A: ) │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ 关键类: QAPromptGenerator, ContextAwarePromptBuilder │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ LLM 推理 (Inference) │ │ +│ │ - 支持 Gemini, OpenAI, Ollama 等多种 provider │ │ +│ │ - 部分模型支持 schema constraints (结构化输出约束) │ │ +│ │ - 批量处理 (batch) 提高吞吐量 │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 输出解析 (Output Parsing) │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ 1. 围栏提取: 从 ```json / ```yaml 中提取内容 │ │ │ +│ │ │ 2. 格式解析: JSON.parse / yaml.safe_load │ │ │ +│ │ │ 3. 容错处理: 标签过滤、宽松解析模式 │ │ │ +│ │ │ 4. 结构转换: 转为 Extraction 对象序列 │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ 关键类: FormatHandler, Resolver │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 实体对齐 (Entity Alignment) │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ 1. 精确匹配: difflib.SequenceMatcher 逐 token 匹配 │ │ │ +│ │ │ 2. 模糊匹配: LCS (最长公共子序列) 算法 │ │ │ +│ │ │ 3. 归一化: 小写 + 轻量词干化 (去除 s 后缀) │ │ │ +│ │ │ 4. 状态标记: MATCH_EXACT / MATCH_FUZZY / MATCH_LESSER │ │ │ +│ │ └─────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ 关键类: WordAligner │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 结果返回 (Result Return) │ │ +│ │ - AnnotatedDocument: 包含 extractions 列表 │ │ +│ │ - 每个 Extraction 包含: char_interval, alignment_status, attributes │ │ +│ │ - 支持 visualization 可视化 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 核心数据流向伪代码 + +```python +# langextract/extraction.py:37-377 (extract 函数) +def extract(text_or_documents, prompt_description, examples, ...): + # 1. 验证 examples 的对齐质量 (可选但推荐) + if prompt_validation_level != OFF: + alignment_report = validate_prompt_alignment(examples) + + # 2. 创建 Prompt 模板 + prompt_template = PromptTemplateStructured( + description=prompt_description, + examples=examples + ) + + # 3. 初始化 LLM (自动选择 provider) + language_model = factory.create_model( + model_id=model_id, + examples=examples if use_schema_constraints else None + ) + + # 4. 创建 Annotator + annotator = Annotator( + language_model=language_model, + prompt_template=prompt_template, + format_handler=format_handler + ) + + # 5. 执行抽取 (内部包含分块、推理、解析、对齐) + result = annotator.annotate_text( + text=text_or_documents, + resolver=Resolver(...), + max_char_buffer=max_char_buffer, + ... + ) + + return result +``` + +--- + +## Schema 设计 + +LangExtract 的 schema 设计采用 **example-driven** 模式——schema 从用户提供的 `examples` 中自动推断,而非显式定义 Pydantic 模型。 + +### 支持的字段类型 + +| 类型 | 说明 | 示例值 | +|------|------|--------| +| `string` | 字符串 (主要类型) | `"John Smith"`, `"2024-01-15"` | +| `number` | 数值 (自动转为 string) | `42`, `3.14` | +| `dict` | 属性字典 (通过后缀识别) | `{"dosage": "10mg", "route": "oral"}` | +| `list` | 列表 (在 attributes 中) | `["symptom1", "symptom2"]` | + +### 实体定义方式 + +实体通过 `Extraction` 类定义,每个实体包含三个核心属性: + +```python +# langextract/core/data.py:64-118 +@dataclasses.dataclass +class Extraction: + extraction_class: str # 实体类型 (如 "person", "medication") + extraction_text: str # 实体文本 (从原文提取的内容) + char_interval: CharInterval | None = None # 原文中的字符偏移 + alignment_status: AlignmentStatus | None = None # 对齐状态 + extraction_index: int | None = None # 排序索引 + group_index: int | None = None # 分组索引 + attributes: dict[str, str | list[str]] | None = None # 附加属性 +``` + +### Schema 的 JSON 表达与 Python 类对应关系 + +LangExtract 的 schema 系统有两层抽象: + +#### 1. 数据层 (Data Layer) + +用户通过 `ExampleData` 和 `Extraction` 定义示例: + +```python +# 来自 README.md 的示例 +examples = [ + lx.data.ExampleData( + text="ROMEO. But soft! What light through yonder window breaks?", + extractions=[ + lx.data.Extraction( + extraction_class="character", # 对应 JSON 键 + extraction_text="ROMEO", # 对应 JSON 值 + attributes={"emotional_state": "wonder"} # 附加属性 + ), + lx.data.Extraction( + extraction_class="emotion", + extraction_text="But soft!", + attributes={"feeling": "gentle awe"} + ), + ] + ) +] +``` + +对应的 JSON 表达 (在 prompt 中): + +```json +{ + "extractions": [ + { + "character": "ROMEO", + "character_attributes": { + "emotional_state": "wonder" + } + }, + { + "emotion": "But soft!", + "emotion_attributes": { + "feeling": "gentle awe" + } + } + ] +} +``` + +#### 2. Schema 层 (Schema Layer) + +`BaseSchema` 是抽象基类,定义了从 examples 生成 provider 配置的接口: + +```python +# langextract/core/schema.py:38-91 +class BaseSchema(abc.ABC): + @classmethod + @abc.abstractmethod + def from_examples( + cls, + examples_data: Sequence[data.ExampleData], + attribute_suffix: str = data.ATTRIBUTE_SUFFIX, + ) -> BaseSchema: + """从示例数据构建 schema 实例""" + + @abc.abstractmethod + def to_provider_config(self) -> dict[str, Any]: + """转换为 provider 特定的配置 (如 Gemini 的 response_schema)""" + + @property + @abc.abstractmethod + def requires_raw_output(self) -> bool: + """是否输出原始 JSON/YAML (无围栏标记)""" +``` + +#### 3. FormatModeSchema: 通用格式约束 + +`FormatModeSchema` 是当前主要使用的 schema 实现,它不强制字段级结构,只保证输出格式: + +```python +# langextract/core/schema.py:93-139 +class FormatModeSchema(BaseSchema): + def __init__(self, format_type: types.FormatType = types.FormatType.JSON): + self.format_type = format_type + self._format = "json" if format_type == types.FormatType.JSON else "yaml" + + @classmethod + def from_examples( + cls, + examples_data: Sequence[data.ExampleData], + attribute_suffix: str = data.ATTRIBUTE_SUFFIX, + ) -> FormatModeSchema: + """从 examples 构建 schema (当前默认使用 JSON 格式)""" + return cls(format_type=types.FormatType.JSON) + + def to_provider_config(self) -> dict[str, Any]: + """返回 provider 配置""" + return {"format": self._format} + + @property + def requires_raw_output(self) -> bool: + """JSON 格式输出原始 JSON (无围栏),YAML 则需要围栏""" + return self._format == "json" +``` + +### Schema 约束如何应用到 LLM + +在 `extract()` 函数中,schema 被传递给 model factory: + +```python +# langextract/extraction.py:298-303 +language_model = factory.create_model( + config=config, + examples=prompt_template.examples if use_schema_constraints else None, + use_schema_constraints=use_schema_constraints, + fence_output=fence_output, +) +``` + +不同 provider 对 schema 的支持程度不同: + +| Provider | Schema 支持 | 实现方式 | +|----------|-------------|----------| +| Gemini | 完整支持 | `response_schema` + 结构化输出 | +| OpenAI | 格式支持 | JSON mode | +| Ollama | 格式支持 | `format: "json"` 参数 | + +--- + +## Prompt 组装 + +Prompt 组装是 LangExtract 的核心环节,它将用户的 `prompt_description` 和 `examples` 转换为 LLM 可理解的指令格式。 + +### 完整流程 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Prompt 组装流程 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 输入 │ +│ ┌─────────────────┐ ┌─────────────────────────────────────────────┐ │ +│ │ prompt_description │ │ examples │ │ +│ │ "Extract persons..." │ │ [ExampleData(text=..., extractions=...)] │ │ +│ └────────┬────────┘ └─────────────────────┬───────────────────────┘ │ +│ │ │ │ +│ └────────────────┬────────────────┘ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ PromptTemplateStructured (数据容器) │ │ +│ │ ┌───────────────────────────────────────────────────────────────┐ │ │ +│ │ │ description: "Extract persons and medications from text..." │ │ │ +│ │ │ examples: [ExampleData, ExampleData, ...] │ │ │ +│ │ └───────────────────────────────────────────────────────────────┘ │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ FormatHandler (示例格式化) │ │ +│ │ │ │ +│ │ 每个 ExampleData.extractions 被格式化为: │ │ +│ │ { │ │ +│ │ "extractions": [ │ │ +│ │ {"person": "John", "person_attributes": {"age": "30"}}, │ │ +│ │ {"medication": "Aspirin", "medication_attributes": {...}} │ │ +│ │ ] │ │ +│ │ } │ │ +│ │ │ │ +│ │ 输出格式: JSON 或 YAML,带或不带 ``` 围栏 │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ QAPromptGenerator (最终组装) │ │ +│ │ │ │ +│ │ render(question=chunk_text) 生成: │ │ +│ │ ┌───────────────────────────────────────────────────────────────┐ │ │ +│ │ │ [description] │ │ │ +│ │ │ │ │ │ +│ │ │ Examples │ │ │ +│ │ │ Q: [example_1.text] │ │ │ +│ │ │ A: [formatted_extractions_1] │ │ │ +│ │ │ │ │ │ +│ │ │ Q: [example_2.text] │ │ │ +│ │ │ A: [formatted_extractions_2] │ │ │ +│ │ │ │ │ │ +│ │ │ Q: [current_chunk_text] │ │ │ +│ │ │ A: │ │ │ +│ │ └───────────────────────────────────────────────────────────────┘ │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ ContextAwarePromptBuilder (跨 chunk 上下文) │ │ +│ │ │ │ +│ │ 可选功能: 注入前一个 chunk 的尾部文本作为上下文 │ │ +│ │ │ │ +│ │ [Previous text]: ...the patient was prescribed │ │ │ +│ │ [additional_context] │ │ │ +│ │ │ │ +│ │ 帮助解决指代消解问题: "She" → "Dr. Sarah Johnson" │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 关键代码路径 + +#### 1. 示例格式化 (`format_extraction_example`) + +```python +# langextract/core/format_handler.py:116-149 +def format_extraction_example( + self, extractions: list[data.Extraction] +) -> str: + """将 extractions 格式化为 prompt 中的示例""" + items = [ + { + ext.extraction_class: ext.extraction_text, + f"{ext.extraction_class}{self.attribute_suffix}": ( + ext.attributes or {} + ), + } + for ext in extractions + ] + + if self.use_wrapper and self.wrapper_key: + payload = {self.wrapper_key: items} # {"extractions": [...]} + else: + payload = items + + if self.format_type == data.FormatType.YAML: + formatted = yaml.safe_dump(payload, ...) + else: + formatted = json.dumps(payload, indent=2, ensure_ascii=False) + + return self._add_fences(formatted) if self.use_fences else formatted +``` + +#### 2. Prompt 渲染 (`QAPromptGenerator.render`) + +```python +# langextract/prompting.py:115-138 +def render(self, question: str, additional_context: str | None = None) -> str: + """生成完整的 prompt 文本""" + prompt_lines: list[str] = [f"{self.template.description}\n"] + + if additional_context: + prompt_lines.append(f"{additional_context}\n") + + if self.template.examples: + prompt_lines.append(self.examples_heading) # "Examples" + for ex in self.template.examples: + prompt_lines.append(self.format_example_as_text(ex)) + + prompt_lines.append(f"{self.question_prefix}{question}") # "Q: ..." + prompt_lines.append(self.answer_prefix) # "A: " + return "\n".join(prompt_lines) +``` + +#### 3. 示例格式化 (`format_example_as_text`) + +```python +# langextract/prompting.py:98-113 +def format_example_as_text(self, example: data.ExampleData) -> str: + """将单个 example 格式化为 Q:A 对""" + question = example.text + answer = self.format_handler.format_extraction_example(example.extractions) + + return "\n".join([ + f"{self.question_prefix}{question}", + f"{self.answer_prefix}{answer}\n", + ]) +``` + +#### 4. 跨 chunk 上下文 (`ContextAwarePromptBuilder`) + +```python +# langextract/prompting.py:242-266 +def _build_effective_context( + self, + document_id: str, + additional_context: str | None, +) -> str | None: + """组合前一个 chunk 的上下文和额外上下文""" + context_parts: list[str] = [] + + if self._context_window_chars and document_id in self._prev_chunk_by_doc_id: + prev_text = self._prev_chunk_by_doc_id[document_id] + window = prev_text[-self._context_window_chars :] + context_parts.append(f"{self._CONTEXT_PREFIX}{window}") + # 例如: "[Previous text]: ...the patient visited the clinic" + + if additional_context: + context_parts.append(additional_context) + + return "\n\n".join(context_parts) if context_parts else None +``` + +### 完整 Prompt 示例 + +假设用户提供: + +```python +prompt_description = """Extract all persons and their roles from the text. +Use exact text from the source. Do not paraphrase.""" + +examples = [ + lx.data.ExampleData( + text="Dr. Smith, the chief surgeon, operated on patient Johnson.", + extractions=[ + lx.data.Extraction( + extraction_class="person", + extraction_text="Dr. Smith", + attributes={"role": "chief surgeon"} + ), + lx.data.Extraction( + extraction_class="person", + extraction_text="Johnson", + attributes={"role": "patient"} + ), + ] + ) +] +``` + +生成的 prompt 将是: + +``` +Extract all persons and their roles from the text. +Use exact text from the source. Do not paraphrase. + +Examples +Q: Dr. Smith, the chief surgeon, operated on patient Johnson. +A: ```json +{ + "extractions": [ + { + "person": "Dr. Smith", + "person_attributes": { + "role": "chief surgeon" + } + }, + { + "person": "Johnson", + "person_attributes": { + "role": "patient" + } + } + ] +} +``` + +Q: [当前待处理的 chunk 文本] +A: +``` + +--- + +## 输出解析 + +LLM 返回的原始文本需要经过解析才能转换为结构化的 `Extraction` 对象。解析过程由 `FormatHandler` 和 `Resolver` 协同完成。 + +### 解析流程 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ 输出解析流程 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ LLM 原始输出 (Raw Output) │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 某些推理模型可能先输出思考过程: │ │ +│ │ Let me analyze this text. I see Dr. Smith mentioned... │ │ +│ │ │ │ +│ │ 然后是结构化输出: │ │ +│ │ ```json │ │ +│ │ { │ │ +│ │ "extractions": [ │ │ +│ │ {"person": "Dr. Smith"}, │ │ +│ │ {"medication": "Aspirin"} │ │ +│ │ ] │ │ +│ │ } │ │ +│ │ ``` │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 1: 标签过滤 (可选) │ │ +│ │ │ │ +│ │ 正则: [\s\S]*?\s* │ │ +│ │ 位置: langextract/core/format_handler.py:46 │ │ +│ │ │ │ +│ │ 原因: DeepSeek-R1, QwQ 等推理模型会先输出思考过程 │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 2: 围栏提取 (Fence Extraction) │ │ +│ │ │ │ +│ │ 正则: ```(?P[A-Za-z0-9_+-]+)?\s*\n(?P[\s\S]*?)``` │ │ +│ │ 位置: langextract/core/format_handler.py:41-44 │ │ +│ │ │ │ +│ │ 规则: │ │ +│ │ - strict_fences=True: 必须恰好一个 ```json 或 ```yaml 块 │ │ +│ │ - strict_fences=False: 宽松模式,支持无语言标签或无围栏 │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 3: 格式解析 (Format Parsing) │ │ +│ │ │ │ +│ │ JSON: json.loads(content) │ │ +│ │ YAML: yaml.safe_load(content) │ │ +│ │ │ │ +│ │ 容错: 如果第一次解析失败且有 标签,尝试去除后再解析 │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 4: 结构提取 (Structure Extraction) │ │ +│ │ │ │ +│ │ 期望结构 (wrapper 模式): │ │ +│ │ {"extractions": [{"key1": "value1"}, {"key2": "value2"}]} │ │ +│ │ │ │ +│ │ 兼容结构 (非 wrapper 模式): │ │ +│ │ [{"key1": "value1"}, {"key2": "value2"}] │ │ +│ │ │ │ +│ │ 位置: langextract/core/format_handler.py:151-245 (parse_output) │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 5: 转换为 Extraction 对象 │ │ +│ │ │ │ +│ │ 每个字典项: │ │ +│ │ {"person": "John", "person_attributes": {"age": "30"}} │ │ +│ │ ↓ │ │ +│ │ Extraction( │ │ +│ │ extraction_class="person", │ │ +│ │ extraction_text="John", │ │ +│ │ attributes={"age": "30"} │ │ +│ │ ) │ │ +│ │ │ │ +│ │ 位置: langextract/resolver.py:424-523 (extract_ordered_extractions)│ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 关键代码解析 + +#### 1. 围栏提取 (`_extract_content`) + +```python +# langextract/core/format_handler.py:278-333 +def _extract_content(self, text: str) -> str: + """从文本中提取内容,处理围栏""" + if not self.use_fences: + return text.strip() # 无围栏模式,直接返回 + + matches = list(_FENCE_RE.finditer(text)) + + # 验证语言标签 (json/yaml/yml) + valid_tags = { + data.FormatType.YAML: {"yaml", "yml"}, + data.FormatType.JSON: {"json"}, + } + candidates = [m for m in matches if self._is_valid_language_tag(...)] + + if self.strict_fences: + # 严格模式: 必须恰好一个有效围栏块 + if len(candidates) != 1: + raise exceptions.FormatParseError("...") + return candidates[0].group("body").strip() + + # 宽松模式 + if len(candidates) == 1: + return candidates[0].group("body").strip() + elif len(candidates) > 1: + raise exceptions.FormatParseError("Multiple fenced blocks found") + + # 最后尝试: 任意围栏或无围栏 + if matches and len(matches) == 1: + return matches[0].group("body").strip() + + return text.strip() # 无围栏,直接使用 +``` + +#### 2. 解析输出 (`parse_output`) + +```python +# langextract/core/format_handler.py:151-245 +def parse_output( + self, text: str, *, strict: bool | None = None +) -> Sequence[Mapping[str, ExtractionValueType]]: + """解析模型输出为提取数据""" + if not text: + raise exceptions.FormatParseError("Empty or invalid input string.") + + # Step 1: 提取内容 (围栏处理) + content = self._extract_content(text) + + # Step 2: 解析 JSON/YAML (含 标签容错) + try: + parsed = self._parse_with_fallback(content, strict) + except (yaml.YAMLError, json.JSONDecodeError) as e: + raise exceptions.FormatParseError(...) from e + + # Step 3: 提取 extractions 列表 + require_wrapper = self.wrapper_key is not None and ( + self.use_wrapper or bool(strict) + ) + + if isinstance(parsed, dict): + # Wrapper 模式: {"extractions": [...]} + if require_wrapper: + if self.wrapper_key not in parsed: + raise exceptions.FormatParseError( + f"Content must contain an '{self.wrapper_key}' key." + ) + items = parsed[self.wrapper_key] + else: + # 兼容: 尝试已知的 wrapper key + if data.EXTRACTIONS_KEY in parsed: + items = parsed[data.EXTRACTIONS_KEY] + elif self.wrapper_key and self.wrapper_key in parsed: + items = parsed[self.wrapper_key] + else: + items = [parsed] # 单个对象作为单元素列表 + elif isinstance(parsed, list): + # 非 wrapper 模式: [...] + if require_wrapper and (strict or not self.allow_top_level_list): + raise exceptions.FormatParseError(...) + items = parsed + else: + raise exceptions.FormatParseError( + f"Expected list or dict, got {type(parsed)}" + ) + + # Step 4: 验证每个 item 是字典 + for item in items: + if not isinstance(item, dict): + raise exceptions.FormatParseError( + "Each item in the sequence must be a mapping." + ) + + return items +``` + +#### 3. 标签容错 (`_parse_with_fallback`) + +```python +# langextract/core/format_handler.py:261-276 +def _parse_with_fallback(self, content: str, strict: bool): + """解析内容,失败时尝试去除 标签""" + try: + if self.format_type == data.FormatType.YAML: + return yaml.safe_load(content) + return json.loads(content) + except (yaml.YAMLError, json.JSONDecodeError): + if strict: + raise + # 推理模型 (DeepSeek-R1, QwQ) 会在 JSON 前输出 + if _THINK_TAG_RE.search(content): + stripped = _THINK_TAG_RE.sub("", content).strip() + if self.format_type == data.FormatType.YAML: + return yaml.safe_load(stripped) + return json.loads(stripped) + raise +``` + +#### 4. 转换为 Extraction 对象 (`extract_ordered_extractions`) + +```python +# langextract/resolver.py:424-523 +def extract_ordered_extractions( + self, + extraction_data: Sequence[Mapping[str, fh.ExtractionValueType]], +) -> Sequence[data.Extraction]: + """将解析后的数据转换为 Extraction 对象列表""" + processed_extractions = [] + extraction_index = 0 + index_suffix = self.extraction_index_suffix # 可选: "_index" + attributes_suffix = self.format_handler.attribute_suffix # "_attributes" + + for group_index, group in enumerate(extraction_data): + for extraction_class, extraction_value in group.items(): + # 跳过索引字段 (如果使用 index_suffix) + if index_suffix and extraction_class.endswith(index_suffix): + continue + + # 跳过属性字段 (单独处理) + if attributes_suffix and extraction_class.endswith(attributes_suffix): + continue + + # 值类型验证: 必须是 str/int/float + if not isinstance(extraction_value, (str, int, float)): + raise ValueError( + "Extraction text must be a string, integer, or float." + ) + + # 统一转为字符串 + if not isinstance(extraction_value, str): + extraction_value = str(extraction_value) + + # 查找对应的索引 (如果有) + if index_suffix: + index_key = extraction_class + index_suffix + extraction_index = group.get(index_key, None) + if extraction_index is None: + continue # 无索引则跳过 + else: + extraction_index += 1 + + # 查找对应的属性 + attributes = None + if attributes_suffix: + attributes_key = extraction_class + attributes_suffix + attributes = group.get(attributes_key, None) + + # 创建 Extraction 对象 + processed_extractions.append( + data.Extraction( + extraction_class=extraction_class, + extraction_text=extraction_value, + extraction_index=extraction_index, + group_index=group_index, + attributes=attributes, + ) + ) + + # 按索引排序 (如果使用 index_suffix) + processed_extractions.sort(key=operator.attrgetter("extraction_index")) + return processed_extractions +``` + +### 格式错误时的 Fallback 策略 + +| 场景 | 处理方式 | 控制参数 | +|------|----------|----------| +| 解析失败 (JSON/YAML 语法错误) | `suppress_parse_errors=True` 时返回空列表,否则抛异常 | `resolver_params={"suppress_parse_errors": True}` | +| 多个围栏块 | 严格模式抛异常,宽松模式取第一个 | `strict_fences` | +| 无围栏标签 | 宽松模式尝试直接解析整段文本 | `strict_fences=False` | +| 包含 `` 标签 | 自动去除后重试解析 | 内置 (非 strict 模式) | +| 缺少 `extractions` wrapper | 宽松模式接受顶级列表 | `use_wrapper=False` 或 `allow_top_level_list=True` | + +**注意**: `suppress_parse_errors` 在 `extract()` 中默认为 `True`,这意味着单个 chunk 的解析失败不会导致整个文档处理失败。 + +--- + +## 实体对齐 (Alignment) + +实体对齐是 LangExtract 的核心能力之一——它将 LLM 抽取出的文本片段回溯到原文中的精确字符位置。这使得抽取结果可验证、可可视化。 + +### 对齐流程 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ 实体对齐流程 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 输入 │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ resolved_extractions: [ │ │ +│ │ Extraction(extraction_text="Dr. Smith"), │ │ +│ │ Extraction(extraction_text="Aspirin 10mg") │ │ +│ │ ] │ │ +│ │ │ │ +│ │ source_text: "Dr. Smith prescribed Aspirin 10mg to the patient." │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 0: Tokenization & 归一化 │ │ +│ │ │ │ +│ │ 原文 token 化: │ │ +│ │ ["dr", "smith", "prescribed", "aspirin", "10mg", "to", ...] │ │ +│ │ │ │ +│ │ 提取文本 token 化 + 归一化: │ │ +│ │ - 小写: "Dr. Smith" → "dr. smith" │ │ +│ │ - 轻量词干化: "patients" → "patient" (去除 s 后缀) │ │ +│ │ │ │ +│ │ 位置: langextract/resolver.py:1034-1069 (_tokenize_with_lowercase) │ │ +│ │ langextract/resolver.py:1063-1069 (_normalize_token) │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 1: 精确匹配 (Exact Match) │ │ +│ │ │ │ +│ │ 算法: difflib.SequenceMatcher (Python 标准库) │ │ +│ │ 位置: langextract/resolver.py:921-977 │ │ +│ │ │ │ +│ │ 策略: │ │ +│ │ 1. 将所有 extraction_text 用特殊分隔符连接 │ │ +│ │ 2. 与 source_text 进行全局序列匹配 │ │ +│ │ 3. 对每个匹配块,判断是完全匹配还是部分匹配 │ │ +│ │ │ │ +│ │ 匹配状态: │ │ +│ │ - MATCH_EXACT: extraction_text 与原文完全一致 │ │ +│ │ - MATCH_LESSER: 匹配的文本比 extraction_text 短 │ │ +│ │ (extraction 更长,只匹配到一部分) │ │ +│ │ - 不匹配: 进入模糊匹配阶段 │ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 2: 模糊匹配 (Fuzzy Match) - 仅当精确匹配失败时 │ │ +│ │ │ │ +│ │ 有两种算法: │ │ +│ │ │ │ +│ │ A) Legacy 算法 (deprecated) │ │ +│ │ - difflib.SequenceMatcher.ratio() │ │ +│ │ - 滑动窗口遍历所有可能的匹配位置 │ │ +│ │ - 位置: langextract/resolver.py:578-702 (_fuzzy_align_extraction)│ │ +│ │ │ │ +│ │ B) LCS 算法 (默认,推荐) │ │ +│ │ - 最长公共子序列 (Longest Common Subsequence) │ │ +│ │ - 动态规划 O(n*m²) 时间复杂度 │ │ +│ │ - 双重门控: coverage + density │ │ +│ │ - 位置: langextract/resolver.py:704-774 (_lcs_fuzzy_align_extraction)│ │ +│ └─────────────────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 3: 计算偏移量 & 设置状态 │ │ +│ │ │ │ +│ │ 计算: │ │ +│ │ - token_interval: 在 chunk 内的 token 索引 + token_offset │ │ +│ │ - char_interval: 通过 token 的 char_interval 计算字符偏移 │ │ +│ │ - alignment_status: MATCH_EXACT / MATCH_FUZZY / MATCH_LESSER │ │ +│ │ │ │ +│ │ 对齐失败: │ │ +│ │ - char_interval = None │ │ +│ │ - token_interval = None │ │ +│ │ - alignment_status = None │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 关键代码解析 + +#### 1. 精确匹配主流程 (`align_extractions`) + +```python +# langextract/resolver.py:776-1031 +def align_extractions( + self, + extraction_groups: Sequence[Sequence[data.Extraction]], + source_text: str, + token_offset: int = 0, + char_offset: int = 0, + enable_fuzzy_alignment: bool = True, + fuzzy_alignment_threshold: float = 0.75, + ... +) -> Sequence[Sequence[data.Extraction]]: + """将 extractions 对齐到原文""" + # Step 1: 准备 tokens + source_tokens = list(_tokenize_with_lowercase(source_text, ...)) + + # Step 2: 用特殊分隔符连接所有 extraction_text + # 分隔符: "\u241F" (Unicode 单元分隔符),确保不会出现在正常文本中 + delim = "\u241F" + extraction_tokens = list(_tokenize_with_lowercase( + f" {delim} ".join( + extraction.extraction_text + for extraction in itertools.chain(*extraction_groups) + ), + tokenizer_impl=tokenizer_impl, + )) + + # Step 3: 精确匹配 (difflib.SequenceMatcher) + self._set_seqs(source_tokens, extraction_tokens) + + # 遍历匹配块 + for i, j, n in self._get_matching_blocks()[:-1]: + # i: source 中的起始 token 索引 + # j: extraction 中的起始 token 索引 + # n: 匹配的 token 数量 + + # 查找对应的 extraction + extraction, _ = index_to_extraction_group.get(j, (None, None)) + + # 设置 token_interval + extraction.token_interval = tokenizer_lib.TokenInterval( + start_index=i + token_offset, + end_index=i + n + token_offset, + ) + + # 通过 token 计算 char_interval + start_token = tokenized_text.tokens[i] + end_token = tokenized_text.tokens[i + n - 1] + extraction.char_interval = data.CharInterval( + start_pos=char_offset + start_token.char_interval.start_pos, + end_pos=char_offset + end_token.char_interval.end_pos, + ) + + # 判断匹配类型 + extraction_text_len = len(extraction_tokens_for_this_extraction) + if extraction_text_len == n: + extraction.alignment_status = data.AlignmentStatus.MATCH_EXACT + exact_matches += 1 + else: + # 部分匹配 (extraction 更长,只匹配到一部分) + if accept_match_lesser: + extraction.alignment_status = data.AlignmentStatus.MATCH_LESSER + lesser_matches += 1 + else: + # 不接受部分匹配,重置 + extraction.token_interval = None + extraction.char_interval = None + extraction.alignment_status = None + + # Step 4: 模糊匹配 (对精确匹配失败的 extractions) + if enable_fuzzy_alignment and unaligned_extractions: + for extraction in unaligned_extractions: + if fuzzy_alignment_algorithm == "lcs": + aligned = self._lcs_fuzzy_align_extraction(...) + else: + aligned = self._fuzzy_align_extraction(...) + + if aligned: + aligned_extractions.append(aligned) + + return aligned_extraction_groups +``` + +#### 2. LCS 模糊匹配算法 (`_lcs_fuzzy_align_extraction`) + +```python +# langextract/resolver.py:704-774 +def _lcs_fuzzy_align_extraction( + self, + extraction: data.Extraction, + source_tokens_norm: list[str], # 已归一化的原文 tokens + tokenized_text: tokenizer_lib.TokenizedText, + token_offset: int, + char_offset: int, + fuzzy_alignment_threshold: float = 0.75, + fuzzy_alignment_min_density: float = 1/3, + ... +) -> data.Extraction | None: + """使用 LCS 算法进行模糊对齐""" + # Step 1: Tokenize 和归一化 extraction_text + extraction_tokens = list(_tokenize_with_lowercase(extraction.extraction_text, ...)) + extraction_tokens_norm = [_normalize_token(t) for t in extraction_tokens] + + # Step 2: 计算所有可能的 LCS 匹配 + # 返回: {match_count: LcsSpan(matches, start, end)} + spans = _best_lcs_spans(source_tokens_norm, extraction_tokens_norm) + + # Step 3: 按匹配数量从高到低尝试,找到第一个通过双重门控的 + for k in sorted(spans.keys(), reverse=True): + candidate = spans[k] + if _accept_lcs_match( + candidate, + len(extraction_tokens_norm), + threshold=fuzzy_alignment_threshold, + min_density=fuzzy_alignment_min_density, + ): + accepted = candidate + break + + if accepted is None: + return None + + # Step 4: 设置 intervals 和状态 + extraction.token_interval = tokenizer_lib.TokenInterval( + start_index=accepted.start + token_offset, + end_index=accepted.end + 1 + token_offset, + ) + + start_token = tokenized_text.tokens[accepted.start] + end_token = tokenized_text.tokens[accepted.end] + extraction.char_interval = data.CharInterval( + start_pos=char_offset + start_token.char_interval.start_pos, + end_pos=char_offset + end_token.char_interval.end_pos, + ) + + extraction.alignment_status = data.AlignmentStatus.MATCH_FUZZY + return extraction +``` + +#### 3. LCS 双重门控 (`_accept_lcs_match`) + +```python +# langextract/resolver.py:1165-1192 +def _accept_lcs_match( + span: LcsSpan, + extraction_len: int, + threshold: float = 0.75, + min_density: float = 1/3, +) -> bool: + """应用覆盖度和密度双重门控""" + if span.matches == 0 or extraction_len == 0: + return False + + # Coverage Gate (覆盖度): 匹配的 token 数 >= 阈值比例 + # 例如: extraction 有 4 个 tokens,threshold=0.75,需要至少匹配 3 个 + needed = math.ceil(extraction_len * threshold) + if span.matches < needed: + return False + + # Density Gate (密度): 匹配的 token 数 / 匹配区间长度 >= min_density + # 防止匹配的 tokens 分散在太长的区间中 + # 例如: 匹配 2 个 tokens,但分散在 10 个 token 的区间中 → 密度 0.2 < 1/3 → 拒绝 + if span.span_len <= 0: + return False + density = span.matches / span.span_len + return density >= min_density +``` + +#### 4. Token 归一化 (`_normalize_token`) + +```python +# langextract/resolver.py:1063-1069 +@functools.lru_cache(maxsize=10000) +def _normalize_token(token: str) -> str: + """小写 + 轻量词干化 (去除复数 s)""" + token = token.lower() + # 长度 > 3 且以 s 结尾且不以 ss 结尾 → 去除 s + if len(token) > 3 and token.endswith("s") and not token.endswith("ss"): + token = token[:-1] + return token +``` + +### 对齐状态说明 + +| 状态 | 值 | 含义 | 示例 | +|------|-----|------|------| +| `MATCH_EXACT` | `"match_exact"` | 精确匹配 | extraction_text="John",原文中恰好有 "John" | +| `MATCH_LESSER` | `"match_lesser"` | 部分匹配 (匹配文本更短) | extraction_text="John Smith",只匹配到 "John" | +| `MATCH_FUZZY` | `"match_fuzzy"` | 模糊匹配 | extraction_text="Jon",匹配到原文的 "John" | +| `None` | - | 对齐失败 | 无法在原文中找到对应片段 | + +### 对齐失败时的处理 + +对齐失败的 extraction 会保留,但 `char_interval` 和 `token_interval` 为 `None`。用户可以通过过滤来只保留成功对齐的结果: + +```python +# 只保留成功对齐的 extractions +grounded_extractions = [ + e for e in result.extractions + if e.char_interval is not None +] +``` + +**原因**: LLM 可能从 few-shot examples 中"幻觉"出内容,或者提取的文本与原文表述不完全一致。LangExtract 不会丢弃这些结果,而是让用户决定如何处理。 + +### 对齐参数配置 + +对齐参数通过 `resolver_params` 传递给 `extract()`: + +```python +result = lx.extract( + text_or_documents=text, + prompt_description=prompt, + examples=examples, + resolver_params={ + # 模糊匹配开关 + "enable_fuzzy_alignment": True, + + # 覆盖度阈值: 至少匹配 75% 的 tokens + "fuzzy_alignment_threshold": 0.75, + + # 密度阈值: 匹配 tokens / 区间长度 >= 1/3 + "fuzzy_alignment_min_density": 1/3, + + # 算法选择: "lcs" (默认) 或 "legacy" (deprecated) + "fuzzy_alignment_algorithm": "lcs", + + # 是否接受部分匹配 (MATCH_LESSER) + "accept_match_lesser": True, + + # 解析错误时是否抑制异常 + "suppress_parse_errors": True, + } +) +``` + +--- + +## 长文档分块 + +当输入文本超过 LLM 的上下文窗口或 `max_char_buffer` 限制时,LangExtract 会将文档分割成多个 chunks 分别处理。 + +### 分块策略 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ 长文档分块策略 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 核心原则 │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 1. 优先按句子边界分割 (保持语义完整性) │ │ +│ │ 2. 尊重换行符 (诗歌、列表等格式) │ │ +│ │ 3. 单句过长时按 token 分割 │ │ +│ │ 4. 单个 token 超过 buffer 时单独成块 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 分块场景 │ +│ │ +│ 场景 A: 单句超长,需要在句内分割 │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 原文 (诗歌): │ │ +│ │ "No man is an island, │ │ +│ │ Entire of itself, │ │ +│ │ Every man is a piece of the continent, │ │ +│ │ A part of the main." │ │ +│ │ │ │ +│ │ max_char_buffer=40 │ │ +│ │ │ │ +│ │ 分块结果: │ │ +│ │ Chunk 1: "No man is an island,\nEntire of itself," (38 chars)│ │ +│ │ Chunk 2: "Every man is a piece of the continent," (38 chars)│ │ +│ │ Chunk 3: "A part of the main." (19 chars) │ │ +│ │ │ │ +│ │ 特点: 尊重换行符,在换行处优先分割 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 场景 B: 单个 token 超长 │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 原文: "This is antidisestablishmentarianism." │ │ +│ │ max_char_buffer=20 │ │ +│ │ │ │ +│ │ 分块结果: │ │ +│ │ Chunk 1: "This is" (7 chars) │ │ +│ │ Chunk 2: "antidisestablishmentarianism" (28 chars) │ │ +│ │ Chunk 3: "." (1 char) │ │ +│ │ │ │ +│ │ 特点: 超长 token 即使超过 buffer 也单独成块 │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 场景 C: 多短句可合并 │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 原文: "Roses are red. Violets are blue. Flowers are nice. And so │ │ +│ │ are you." │ │ +│ │ max_char_buffer=60 │ │ +│ │ │ │ +│ │ 分块结果: │ │ +│ │ Chunk 1: "Roses are red. Violets are blue. Flowers are nice." │ │ +│ │ (50 chars) │ │ +│ │ Chunk 2: "And so are you." (15 chars) │ │ +│ │ │ │ +│ │ 特点: 多个完整句子可合并到一个 chunk (不超过 buffer) │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### 关键代码解析 + +#### 1. ChunkIterator 主逻辑 (`__next__`) + +```python +# langextract/chunking.py:441-506 +def __next__(self) -> TextChunk: + # 获取下一个句子 (或句子的剩余部分) + sentence = next(self.sentence_iter) + + # 策略 1: 如果第一个 token 就超过 buffer,单独成块 + curr_chunk = create_token_interval( + sentence.start_index, sentence.start_index + 1 + ) + if self._tokens_exceed_buffer(curr_chunk): + self.sentence_iter = SentenceIterator( + self.tokenized_text, curr_token_pos=sentence.start_index + 1 + ) + self.broken_sentence = True + return TextChunk(token_interval=curr_chunk, document=self.document) + + # 策略 2: 在句子内追加 tokens,直到接近 buffer + start_of_new_line = -1 + for token_index in range(curr_chunk.start_index, sentence.end_index): + # 记录换行位置 (用于优先在换行处分割) + if self.tokenized_text.tokens[token_index].first_token_after_newline: + start_of_new_line = token_index + + test_chunk = create_token_interval( + curr_chunk.start_index, token_index + 1 + ) + + if self._tokens_exceed_buffer(test_chunk): + # 超过 buffer 了 + # 优先在最近的换行处分割 (如果有) + if start_of_new_line > 0 and start_of_new_line > curr_chunk.start_index: + curr_chunk = create_token_interval( + curr_chunk.start_index, start_of_new_line + ) + # 更新句子迭代器,下次从这里继续 + self.sentence_iter = SentenceIterator( + self.tokenized_text, curr_token_pos=curr_chunk.end_index + ) + self.broken_sentence = True + return TextChunk(token_interval=curr_chunk, document=self.document) + else: + curr_chunk = test_chunk # 继续追加 + + # 策略 3: 整句没超过 buffer,尝试合并更多句子 + if self.broken_sentence: + self.broken_sentence = False + else: + for sentence in self.sentence_iter: + test_chunk = create_token_interval( + curr_chunk.start_index, sentence.end_index + ) + if self._tokens_exceed_buffer(test_chunk): + self.sentence_iter = SentenceIterator( + self.tokenized_text, curr_token_pos=curr_chunk.end_index + ) + return TextChunk(token_interval=curr_chunk, document=self.document) + else: + curr_chunk = test_chunk # 合并整句 + + return TextChunk(token_interval=curr_chunk, document=self.document) +``` + +#### 2. 句子边界检测 (`SentenceIterator`) + +```python +# langextract/chunking.py:282-340 +class SentenceIterator: + """迭代 tokenized 文本的句子""" + + def __next__(self) -> tokenizer_lib.TokenInterval: + # 找到包含当前 token 的句子范围 + sentence_range = tokenizer_lib.find_sentence_range( + self.tokenized_text.text, + self.tokenized_text.tokens, + self.curr_token_pos, + ) + # 从当前位置开始,而不是句子开头 + # (如果我们在句子中间,从这里继续) + sentence_range = create_token_interval( + self.curr_token_pos, sentence_range.end_index + ) + self.curr_token_pos = sentence_range.end_index + return sentence_range +``` + +### Overlap 与上下文窗口 + +LangExtract **没有使用传统的 chunk overlap 机制**,而是提供了 **`context_window_chars`** 参数来解决跨 chunk 的指代消解问题。 + +| 机制 | 说明 | 示例 | +|------|------|------| +| 传统 overlap | 相邻 chunks 共享部分文本 | Chunk1: [0-100], Chunk2: [80-180] | +| LangExtract context_window | 前一个 chunk 的尾部文本作为 prompt 上下文 | Chunk2 的 prompt 包含 Chunk1 的最后 N 个字符 | + +**ContextAwarePromptBuilder 实现**: + +```python +# langextract/prompting.py:179-276 +class ContextAwarePromptBuilder(PromptBuilder): + """支持跨 chunk 上下文追踪的 prompt builder""" + + _CONTEXT_PREFIX = "[Previous text]: ..." + + def __init__( + self, + generator: QAPromptGenerator, + context_window_chars: int | None = None, # 例如: 100 + ): + super().__init__(generator) + self._context_window_chars = context_window_chars + self._prev_chunk_by_doc_id: dict[str, str] = {} # 按文档追踪 + + def build_prompt( + self, + chunk_text: str, + document_id: str, + additional_context: str | None = None, + ) -> str: + # 构建有效上下文 (前一个 chunk + 额外上下文) + effective_context = self._build_effective_context( + document_id, additional_context + ) + + prompt = self._generator.render( + question=chunk_text, + additional_context=effective_context, + ) + + # 更新状态: 保存当前 chunk 供下一个使用 + self._update_state(document_id, chunk_text) + return prompt + + def _build_effective_context( + self, document_id: str, additional_context: str | None + ) -> str | None: + context_parts: list[str] = [] + + # 注入前一个 chunk 的尾部 + if self._context_window_chars and document_id in self._prev_chunk_by_doc_id: + prev_text = self._prev_chunk_by_doc_id[document_id] + window = prev_text[-self._context_window_chars :] # 取尾部 + context_parts.append(f"{self._CONTEXT_PREFIX}{window}") + + if additional_context: + context_parts.append(additional_context) + + return "\n\n".join(context_parts) if context_parts else None +``` + +**使用示例**: + +```python +result = lx.extract( + text_or_documents=long_text, + prompt_description=prompt, + examples=examples, + context_window_chars=100, # 每个 chunk 包含前一个 chunk 的最后 100 字符 +) +``` + +**效果**: + +假设文档被分为两个 chunks: +- Chunk1: "Dr. Sarah Johnson is a cardiologist at the hospital. She" +- Chunk2: " specializes in heart disease and hypertension." + +没有 context_window 时,Chunk2 的 "She" 可能无法正确解析。 + +有 `context_window_chars=50` 时,Chunk2 的 prompt 会包含: +``` +[Previous text]: ...cardiologist at the hospital. She + +Q: specializes in heart disease and hypertension. +A: +``` + +这样 LLM 就能知道 "She" 指的是 "Dr. Sarah Johnson"。 + +### 跨 Chunk 实体合并与去重 + +LangExtract 目前 **没有自动的跨 chunk 实体去重机制**。每个 chunk 的处理是独立的,结果累积到 `per_doc` 字典中。 + +```python +# langextract/annotation.py:307-332 (Annotator._annotate_documents_single_pass) +def _annotate_documents_single_pass(...): + per_doc: DefaultDict[str, list[data.Extraction]] = collections.defaultdict(list) + + for batch in batch_iter: + # ... 推理、解析、对齐 ... + + for text_chunk, scored_outputs in zip(batch, outputs): + # ... + + aligned_extractions = resolver.align(...) + + for extraction in aligned_extractions: + # 直接追加,没有去重 + per_doc[text_chunk.document_id].append(extraction) +``` + +**用户需要自己处理去重**,可以基于: +1. `char_interval` 重叠检测 +2. `extraction_text` + `extraction_class` 相似度 + +**例外: Sequential Extraction Passes** + +当使用 `extraction_passes > 1` 时,多次抽取的结果会进行非重叠合并: + +```python +# langextract/annotation.py:46-84 +def _merge_non_overlapping_extractions( + all_extractions: list[Iterable[data.Extraction]], +) -> list[data.Extraction]: + """合并多次抽取的结果,重叠时保留较早的抽取""" + if not all_extractions: + return [] + if len(all_extractions) == 1: + return list(all_extractions[0]) + + merged_extractions = list(all_extractions[0]) # 第一次抽取的结果 + + for pass_extractions in all_extractions[1:]: + for extraction in pass_extractions: + # 检查是否与已合并的结果重叠 + overlaps = False + if extraction.char_interval is not None: + for existing_extraction in merged_extractions: + if existing_extraction.char_interval is not None: + if _extractions_overlap(extraction, existing_extraction): + overlaps = True + break + + # 只有不重叠时才添加 + if not overlaps: + merged_extractions.append(extraction) + + return merged_extractions + +def _extractions_overlap( + extraction1: data.Extraction, extraction2: data.Extraction +) -> bool: + """检查两个 extraction 的字符区间是否重叠""" + # [start1, end1) 与 [start2, end2) 重叠 + return start1 < end2 and start2 < end1 +``` + +**注意**: 这是同一文档多次抽取的合并策略,不是跨 chunk 去重。 + +### 分块参数配置 + +```python +result = lx.extract( + text_or_documents=text, + prompt_description=prompt, + examples=examples, + # 分块相关参数 + max_char_buffer=1000, # 每个 chunk 的最大字符数 + batch_length=10, # 每批处理的 chunk 数量 + max_workers=10, # 并行 worker 数 + context_window_chars=100, # 前一个 chunk 的上下文字符数 (可选) + extraction_passes=1, # 抽取次数 (可选,多次抽取时合并非重叠结果) +) +``` + +**参数说明**: + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `max_char_buffer` | 1000 | 每个 chunk 的最大字符数。调小可提高准确率但增加 API 调用。 | +| `batch_length` | 10 | 每批处理的 chunk 数量。与 `max_workers` 共同决定并行度。 | +| `max_workers` | 10 | 最大并行 worker 数。有效并行度受限于 `min(batch_length, max_workers)`。 | +| `context_window_chars` | `None` | 前一个 chunk 的上下文字符数。用于指代消解。 | +| `extraction_passes` | 1 | 抽取次数。> 1 时执行多次抽取并合并非重叠结果。 | + +--- + +## 已知限制与 FAQ + +### Q1: 为什么我的实体对齐失败了? + +**常见原因**: + +1. **LLM 提取的文本与原文不一致** + - LLM 可能 paraphrase(转述)原文,例如原文是 "John Smith",但 LLM 返回 "Mr. Smith" + - 解决方案:在 prompt_description 中强调 "Use exact text from the source. Do not paraphrase." + +2. **提取文本跨越 chunk 边界** + - 如果一个实体被分割在两个 chunks 中,对齐可能失败 + - 解决方案:使用 `context_window_chars` 参数,或调整 `max_char_buffer` + +3. **模糊匹配阈值设置过高** + - 默认 `fuzzy_alignment_threshold=0.75`,如果提取文本与原文差异较大,可能无法匹配 + - 解决方案:调低阈值 `resolver_params={"fuzzy_alignment_threshold": 0.6}` + +4. **特殊字符或大小写问题** + - 虽然有归一化处理,但某些特殊字符可能导致问题 + - 检查 `extraction_text` 中是否有不可见字符 + +**调试方法**: +```python +# 查看对齐失败的 extractions +failed = [e for e in result.extractions if e.alignment_status is None] +for e in failed: + print(f"Failed: class={e.extraction_class}, text={e.extraction_text}") +``` + +--- + +### Q2: 为什么 schema 里的 Optional 字段没被抽取? + +**LangExtract 没有传统的 "Optional" 概念**。 + +LangExtract 的 schema 是 **example-driven** 的,不是类型驱动的。这意味着: + +1. **schema 从 examples 推断** + - 如果你在 examples 中定义了某个 extraction_class,LLM 会被引导去抽取这类实体 + - 但这不是强制的——LLM 可能抽取也可能不抽取 + +2. **没有 "必填/可选" 标记** + - 传统 Pydantic 模型有 `Optional[]` 或 `required=True/False` + - LangExtract 没有这个机制 + +3. **如何控制抽取行为** + - 通过 `prompt_description` 描述应该抽取什么 + - 通过 `examples` 展示抽取模式 + - 如果某些实体经常被遗漏,增加更多相关 examples + +**注意**: 如果 LLM 没有抽取某个实体,结果中不会有对应的 Extraction 对象(值为 null 或空字符串也不会被表示)。 + +--- + +### Q3: 为什么我的输出解析失败了? + +**常见场景**: + +1. **LLM 没有返回 JSON/YAML 格式** + - 某些模型可能忽略格式指令,返回自然语言 + - 解决方案: + - 确保 examples 格式正确 + - 使用支持 schema constraints 的模型(如 Gemini) + - 检查 `use_schema_constraints=True`(默认) + +2. **多个围栏块冲突** + - LLM 可能返回多个 ```json 块 + - 检查 `strict_fences` 设置(默认 False,取第一个有效块) + +3. **推理模型的 标签** + - DeepSeek-R1, QwQ 等模型会在 JSON 前输出思考过程 + - LangExtract 会自动处理(非 strict 模式),但如果格式太复杂可能失败 + +4. **缺少 `extractions` wrapper** + - 某些模型可能直接返回 `[...]` 而不是 `{"extractions": [...]}` + - 默认 `allow_top_level_list=True` 会处理这种情况 + +**调试方法**: +```python +# 使用 debug=True 查看原始输出 +result = lx.extract( + ..., + debug=True, # 启用详细日志 +) +``` + +--- + +### Q4: 为什么同一个实体会被多次抽取? + +**原因**: + +1. **跨 chunk 边界** + - 一个实体可能出现在多个 chunks 中(如果 `context_window_chars` 包含了它) + - LangExtract 目前没有自动去重 + +2. **多次抽取 (`extraction_passes > 1`)** + - 虽然多次抽取会合并非重叠结果,但如果同一个实体在不同位置有相似文本,可能被多次抽取 + +3. **LLM 自身的不稳定性** + - 即使是相同的 prompt,LLM 也可能返回略有不同的结果 + +**解决方案**(用户自行处理): +```python +# 基于 char_interval 去重 +def deduplicate(extractions): + seen = set() + result = [] + for e in extractions: + if e.char_interval is None: + continue + key = (e.extraction_class, e.char_interval.start_pos, e.char_interval.end_pos) + if key not in seen: + seen.add(key) + result.append(e) + return result +``` + +--- + +### Q5: `max_char_buffer` 应该设多大? + +**考虑因素**: + +1. **模型上下文窗口** + - `max_char_buffer` 应该远小于模型的最大 token 限制 + - 因为 prompt 本身(description + examples)也占用 tokens + +2. **抽取精度 vs API 成本** + - 较小的 `max_char_buffer` → 更多 chunks → 更多 API 调用 → 更高成本,但可能更准确 + - 较大的 `max_char_buffer` → 更少 chunks → 更低成本,但可能遗漏信息 + +3. **经验建议** + - 默认 `1000` 是一个平衡值 + - 简单任务(如抽取人名)可以用较大值(如 `2000-3000`) + - 复杂任务(如关系抽取)建议用较小值(如 `500-1000`) + +4. **与 token 数量的关系** + - `max_char_buffer` 是字符数,不是 token 数 + - 粗略估计:英文 ~1 token = 4 chars,中文 ~1 token = 2 chars + +**配置示例**: +```python +# 高精度模式 +result = lx.extract( + ..., + max_char_buffer=500, # 较小 chunk + extraction_passes=3, # 多次抽取提高召回 +) + +# 低成本模式 +result = lx.extract( + ..., + max_char_buffer=2000, # 较大 chunk + extraction_passes=1, # 单次抽取 +) +``` + +--- + +## 文档 TODO + +在编写本文档过程中,发现以下代码注释或文档可能需要改进: + +### 1. `langextract/core/schema.py` + +- `BaseSchema` 的 `from_examples` 方法缺少详细 docstring,说明如何从 examples 推断 schema +- `FormatModeSchema` 的 `requires_raw_output` 属性的行为在不同 provider 之间的差异需要更清晰的说明 + +### 2. `langextract/resolver.py` + +- `WordAligner.align_extractions` 方法的参数 `delim` 的选择理由(为什么是 `\u241F`)缺少注释 +- `_accept_lcs_match` 中的双重门控(coverage + density)的设计 rationale 可以补充说明 + +### 3. `langextract/chunking.py` + +- `ChunkIterator` 中分块策略的设计选择(为什么优先换行 > 句子 > token)缺少高层文档 +- `broken_sentence` 标志的使用场景需要更清晰的注释 + +### 4. `langextract/core/format_handler.py` + +- `parse_output` 方法中各种兼容路径(wrapper vs 非 wrapper, strict vs 非 strict)的决策树可以用注释说明 +- `_THINK_TAG_RE` 的存在理由(支持哪些模型)可以补充 + +### 5. Public API docstring + +- `lx.extract()` 的 docstring 很详细,但 `resolver_params` 中的各个对齐参数可以增加更详细的说明 +- 建议在 docstring 中添加对齐参数的默认值和推荐范围 + +--- + +## 附录:核心类关系图 + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ 核心类关系图 │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ ExampleData │────▶│ Extraction │────▶│ CharInterval │ │ +│ │ (示例数据) │ │ (抽取结果) │ │ (字符区间) │ │ +│ └──────────────┘ └──────────────┘ └──────────────┘ │ +│ │ │ │ +│ │ ▼ │ +│ │ ┌──────────────┐ │ +│ │ │TokenInterval │ │ +│ │ │ (token 区间) │ │ +│ │ └──────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌──────────────────────────────────────────────────────────────────────┐ │ +│ │ extract() 入口函数 │ │ +│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ +│ │ │ Annotator │ │ Resolver │ │FormatHandler│ │ │ +│ │ │ (协调器) │ │ (解析对齐) │ │ (格式处理) │ │ │ +│ │ └──────┬──────┘ └──────┬──────┘ └─────────────┘ │ │ +│ │ │ │ │ │ +│ │ ▼ ▼ │ │ +│ │ ┌─────────────┐ ┌─────────────┐ │ │ +│ │ │ChunkIterator│ │ WordAligner │ │ │ +│ │ │ (分块器) │ │ (对齐器) │ │ │ +│ │ └─────────────┘ └─────────────┘ │ │ +│ └──────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌──────────────────────────────────────────────────────────────────────┐ │ +│ │ Prompt 组装层 │ │ +│ │ ┌─────────────────────┐ ┌─────────────────────────┐ │ │ +│ │ │PromptTemplateStruct │ │ QAPromptGenerator │ │ │ +│ │ │ (模板数据) │ │ (prompt 生成器) │ │ │ +│ │ └──────────┬──────────┘ └────────────┬────────────┘ │ │ +│ │ │ │ │ │ +│ │ ▼ ▼ │ │ +│ │ ┌─────────────────────────────────────────────────────────┐ │ │ +│ │ │ ContextAwarePromptBuilder │ │ │ +│ │ │ (支持跨 chunk 上下文的 prompt builder) │ │ │ +│ │ └─────────────────────────────────────────────────────────┘ │ │ +│ └──────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌──────────────────────────────────────────────────────────────────────┐ │ +│ │ Schema 层 │ │ +│ │ ┌──────────────┐ │ │ +│ │ │ BaseSchema │ (抽象基类) │ │ +│ │ └──────┬───────┘ │ │ +│ │ │ │ │ +│ │ ▼ │ │ +│ │ ┌──────────────────┐ │ │ +│ │ │ FormatModeSchema │ (当前主要实现: JSON/YAML 格式约束) │ │ +│ │ └──────────────────┘ │ │ +│ └──────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +**本文档基于代码版本**: langextract (docs/schema-design 分支) +**最后更新**: 2026-04-19 \ No newline at end of file diff --git a/langextract/_config.py b/langextract/_config.py new file mode 100644 index 00000000..0e1ad56a --- /dev/null +++ b/langextract/_config.py @@ -0,0 +1,230 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration system for LangExtract. + +This module provides the Config class and global configuration management +with support for: +- Constructor parameters (highest priority) +- Environment variables (LANGEXTRACT_* prefix) +- Default values (lowest priority) + +Thread Safety Note: + - The global configuration (_global_config) is shared across all threads. + Use configure() to modify it safely. + + - Context-local configuration (via contextvars in _logging.py) is + thread-local but NOT automatically inherited by new threads. + + For multi-threaded code: + 1. Use configure() for global settings (affects all threads) + 2. Use contextvars.copy_context() to propagate context-local config + 3. Or pass Config objects directly to worker functions +""" +from __future__ import annotations + +import dataclasses +import os +from dataclasses import dataclass, field +from typing import Any, Optional + + +ENV_PREFIX = "LANGEXTRACT_" + + +@dataclass +class Config: + """Configuration for LangExtract. + + Supports three levels of configuration, from highest to lowest priority: + 1. Constructor parameters passed explicitly + 2. Environment variables with LANGEXTRACT_ prefix + 3. Default values defined in the dataclass + + Environment variable mapping: + - LANGEXTRACT_LOG_LEVEL -> log_level + - LANGEXTRACT_REQUEST_TIMEOUT -> request_timeout + - LANGEXTRACT_MAX_RETRIES -> max_retries + - LANGEXTRACT_DEFAULT_MODEL -> default_model + - LANGEXTRACT_DEFAULT_MAX_TOKENS -> default_max_tokens + - LANGEXTRACT_CACHE_ENABLED -> cache_enabled + - LANGEXTRACT_CACHE_DIR -> cache_dir + """ + + log_level: str = field(default="WARNING") + request_timeout: float = field(default=60.0) + max_retries: int = field(default=3) + default_model: Optional[str] = field(default=None) + default_max_tokens: Optional[int] = field(default=None) + cache_enabled: bool = field(default=True) + cache_dir: Optional[str] = field(default=None) + progress_enabled: bool = field(default=True) + + def __post_init__(self): + """Apply environment variable overrides for fields not explicitly set.""" + env_overrides = self._parse_env_vars() + + for field_name, env_value in env_overrides.items(): + if self._is_default_value(field_name): + self._apply_env_value(field_name, env_value) + + self._validate() + + @classmethod + def _parse_env_vars(cls) -> dict[str, str]: + """Parse all relevant environment variables. + + Returns: + A dict mapping field names to their string values from environment. + """ + result = {} + field_to_env = cls._get_field_env_mapping() + + for field_name, env_var in field_to_env.items(): + if env_var in os.environ: + result[field_name] = os.environ[env_var] + + return result + + @staticmethod + def _get_field_env_mapping() -> dict[str, str]: + """Get the mapping from field names to environment variable names.""" + return { + "log_level": f"{ENV_PREFIX}LOG_LEVEL", + "request_timeout": f"{ENV_PREFIX}REQUEST_TIMEOUT", + "max_retries": f"{ENV_PREFIX}MAX_RETRIES", + "default_model": f"{ENV_PREFIX}DEFAULT_MODEL", + "default_max_tokens": f"{ENV_PREFIX}DEFAULT_MAX_TOKENS", + "cache_enabled": f"{ENV_PREFIX}CACHE_ENABLED", + "cache_dir": f"{ENV_PREFIX}CACHE_DIR", + "progress_enabled": f"{ENV_PREFIX}PROGRESS_ENABLED", + } + + def _is_default_value(self, field_name: str) -> bool: + """Check if a field is still at its default value. + + This is used to determine if environment variables should override it. + """ + default_value = self._get_field_default(field_name) + return getattr(self, field_name) == default_value + + @classmethod + def _get_field_default(cls, field_name: str) -> Any: + """Get the default value for a field from dataclass metadata. + + This avoids creating a new Config instance (which would cause recursion). + """ + for f in dataclasses.fields(cls): + if f.name == field_name: + if f.default is not dataclasses.MISSING: + return f.default + elif f.default_factory is not dataclasses.MISSING: + return f.default_factory() + else: + return None + return None + + def _apply_env_value(self, field_name: str, env_value: str) -> None: + """Apply an environment variable value to a field with proper type conversion.""" + if field_name == "log_level": + self.log_level = env_value + elif field_name == "request_timeout": + self.request_timeout = float(env_value) + elif field_name == "max_retries": + self.max_retries = int(env_value) + elif field_name == "default_model": + self.default_model = env_value if env_value else None + elif field_name == "default_max_tokens": + self.default_max_tokens = int(env_value) if env_value else None + elif field_name == "cache_enabled": + self.cache_enabled = env_value.lower() in ("1", "true", "yes", "on") + elif field_name == "cache_dir": + self.cache_dir = env_value if env_value else None + elif field_name == "progress_enabled": + self.progress_enabled = env_value.lower() not in ("0", "false", "no", "off") + + def _validate(self) -> None: + """Validate the configuration values.""" + valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + if self.log_level.upper() not in valid_levels: + raise ValueError( + f"Invalid log_level: {self.log_level}. " + f"Must be one of: {valid_levels}" + ) + + if self.request_timeout <= 0: + raise ValueError( + f"request_timeout must be positive, got {self.request_timeout}" + ) + + if self.max_retries < 0: + raise ValueError( + f"max_retries must be non-negative, got {self.max_retries}" + ) + + if self.default_max_tokens is not None and self.default_max_tokens <= 0: + raise ValueError( + f"default_max_tokens must be positive, got {self.default_max_tokens}" + ) + + def model_copy(self, update: dict | None = None) -> Config: + """Create a copy of this Config with optional updates. + + Args: + update: Dictionary of field values to update in the new instance. + + Returns: + A new Config instance with the updated values. + """ + import copy + new_config = copy.deepcopy(self) + + if update: + for key, value in update.items(): + if hasattr(new_config, key): + setattr(new_config, key, value) + + new_config._validate() + return new_config + + +_global_config: Config | None = None + + +def get_global_config() -> Config: + """Get the global configuration. + + If no global config has been set, creates a new one using defaults + and environment variables. + + Returns: + The global Config instance. + """ + global _global_config + if _global_config is None: + _global_config = Config() + return _global_config + + +def set_global_config(config: Config) -> None: + """Set the global configuration. + + Args: + config: The Config instance to use as the global configuration. + """ + global _global_config + _global_config = config + + + diff --git a/langextract/_logging.py b/langextract/_logging.py new file mode 100644 index 00000000..4626b55f --- /dev/null +++ b/langextract/_logging.py @@ -0,0 +1,248 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified logging system for LangExtract. + +This module provides a centralized logging configuration that integrates with +the Config system. All modules should use get_logger(__name__) to obtain +their logger instances. + +Thread Safety Note: + The config() context manager uses Python's contextvars module for + thread-safe and async-safe configuration scoping within the same + execution context. However, note that: + + 1. contextvars are NOT automatically inherited by new threads in Python. + If you use threading.Thread without explicit context propagation, + the context configuration will not be propagated. + + 2. For multi-threaded code: + - Use configure() for global configuration (affects all threads) + - Or use contextvars.copy_context() to explicitly propagate context + - Or pass Config objects directly to functions that need them + + Example with explicit context propagation: + import threading + import contextvars + import langextract as lx + + def worker(): + with lx.config(log_level="DEBUG"): + # This code will have DEBUG logging + pass + + with lx.config(log_level="DEBUG"): + ctx = contextvars.copy_context() + t = threading.Thread(target=lambda: ctx.run(worker)) + t.start() + t.join() +""" +from __future__ import annotations + +import contextvars +import logging +from typing import Optional + +from langextract._config import Config + +_LOGGER_CACHE: dict[str, logging.Logger] = {} + +_ROOT_LOGGER_NAME = "langextract" + +_config_var: contextvars.ContextVar[Optional[Config]] = contextvars.ContextVar( + "langextract_config", default=None +) + + +def _ensure_root_logger() -> logging.Logger: + """Ensure the root langextract logger is properly initialized.""" + root_logger = logging.getLogger(_ROOT_LOGGER_NAME) + + if not root_logger.handlers: + root_logger.addHandler(logging.NullHandler()) + + root_logger.setLevel(logging.WARNING) + root_logger.propagate = False + + return root_logger + + +def get_logger(name: str) -> logging.Logger: + """Get a logger for the given module name. + + The logger name will be prefixed with "langextract." to ensure proper + namespace isolation. The logger's level is determined by the current + configuration (context-local, global, or default). + + Args: + name: The module name, typically __name__. + + Returns: + A configured logging.Logger instance. + """ + if name.startswith(_ROOT_LOGGER_NAME + ".") or name == _ROOT_LOGGER_NAME: + full_name = name + else: + full_name = f"{_ROOT_LOGGER_NAME}.{name}" + + if full_name in _LOGGER_CACHE: + return _LOGGER_CACHE[full_name] + + _ensure_root_logger() + + logger = logging.getLogger(full_name) + logger.propagate = True + + _LOGGER_CACHE[full_name] = logger + + return logger + + +def _get_effective_config() -> Config: + """Get the current effective configuration. + + Checks context-local configuration first, then falls back to global. + + Returns: + The current effective Config instance. + """ + context_config = _config_var.get() + if context_config is not None: + return context_config + + from langextract._config import get_global_config + return get_global_config() + + +def _apply_config_to_loggers(config: Config) -> None: + """Apply log level from config to all langextract loggers. + + Args: + config: The Config instance containing log_level. + """ + root_logger = _ensure_root_logger() + + level_name = config.log_level.upper() + level = getattr(logging, level_name, logging.WARNING) + + root_logger.setLevel(level) + + for logger in _LOGGER_CACHE.values(): + logger.setLevel(level) + + +def configure(**kwargs) -> None: + """Configure global settings for LangExtract. + + This function updates the global configuration that affects all loggers + and other configurable components. + + Args: + **kwargs: Configuration options to update. Valid keys include: + - log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + - request_timeout: Timeout for API requests in seconds + - max_retries: Maximum number of retries for failed requests + - default_model: Default model to use for extraction + - default_max_tokens: Default maximum tokens for generation + - cache_enabled: Whether to enable caching + - cache_dir: Directory to use for caching + """ + from langextract._config import get_global_config, set_global_config + + current = get_global_config() + updated = current.model_copy(update=kwargs) + set_global_config(updated) + _apply_config_to_loggers(updated) + + +class _ConfigContext: + """Context manager for temporary configuration overrides. + + Uses contextvars to ensure thread-safe and async-safe configuration + scoping. + """ + + def __init__(self, **kwargs): + self._kwargs = kwargs + self._token: Optional[contextvars.Token] = None + self._prev_config: Optional[Config] = None + + def __enter__(self): + base = _get_effective_config() + temp_config = base.model_copy(update=self._kwargs) + + self._prev_config = _config_var.get() + self._token = _config_var.set(temp_config) + + _apply_config_to_loggers(temp_config) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._token is not None: + _config_var.reset(self._token) + + active_config = _get_effective_config() + _apply_config_to_loggers(active_config) + + return False + + +def config(**kwargs): + """Create a context manager for temporary configuration overrides. + + This allows you to temporarily change settings (like log level) within + a specific code block, after which the previous settings are restored. + + Example: + with langextract.config(log_level="DEBUG"): + # Debug logging enabled here + result = langextract.extract(...) + # Back to previous log level + + Thread Safety Note: + This context manager uses contextvars, which are thread-local but NOT + automatically inherited by new threads. If you need to propagate the + configuration to a new thread, use contextvars.copy_context(). + + Example with explicit context propagation: + import threading + import contextvars + + with lx.config(log_level="DEBUG"): + ctx = contextvars.copy_context() + t = threading.Thread(target=lambda: ctx.run(my_function)) + t.start() + t.join() + + Args: + **kwargs: Configuration options to override temporarily. + + Returns: + A context manager that applies the configuration on enter and + restores the previous configuration on exit. + """ + return _ConfigContext(**kwargs) + + +def get_context_config() -> Optional[Config]: + """Get the current context-local configuration, if any. + + Returns: + The Config instance set by the current context manager, or None. + """ + return _config_var.get() + + +_ensure_root_logger() diff --git a/langextract/annotation.py b/langextract/annotation.py index d77ab178..7268e1be 100644 --- a/langextract/annotation.py +++ b/langextract/annotation.py @@ -30,9 +30,10 @@ import time from typing import DefaultDict -from absl import logging - +from langextract._logging import get_logger from langextract import chunking + +logger = get_logger(__name__) from langextract import progress from langextract import prompting from langextract import resolver as resolver_lib @@ -202,7 +203,7 @@ def __init__( format_handler=format_handler, ) - logging.debug( + logger.debug( "Annotator initialized with format_handler: %s", format_handler ) @@ -459,7 +460,7 @@ def _annotate_documents_sequential_passes( ) -> Iterator[data.AnnotatedDocument]: """Sequential extraction passes logic for improved recall.""" - logging.info( + logger.info( "Starting sequential extraction passes for improved recall with %d" " passes.", extraction_passes, @@ -475,7 +476,7 @@ def _annotate_documents_sequential_passes( document_texts[_doc.document_id] = _doc.text or "" for pass_num in range(extraction_passes): - logging.info( + logger.info( "Starting extraction pass %d of %d", pass_num + 1, extraction_passes ) @@ -512,7 +513,7 @@ def _annotate_documents_sequential_passes( total_extractions = sum( len(extractions) for extractions in all_pass_extractions ) - logging.info( + logger.info( "Document %s: Merged %d extractions from %d passes into " "%d non-overlapping extractions.", doc_id, @@ -527,7 +528,7 @@ def _annotate_documents_sequential_passes( text=document_texts.get(doc_id, doc.text or ""), ) - logging.info("Sequential extraction passes completed.") + logger.info("Sequential extraction passes completed.") def annotate_text( self, diff --git a/langextract/chunking.py b/langextract/chunking.py index f9e914a4..aa2e8081 100644 --- a/langextract/chunking.py +++ b/langextract/chunking.py @@ -24,10 +24,12 @@ import dataclasses import re -from absl import logging +from langextract._logging import get_logger import more_itertools from langextract.core import data + +logger = get_logger(__name__) from langextract.core import exceptions from langextract.core import tokenizer as tokenizer_lib @@ -196,7 +198,7 @@ def get_token_interval_text( f"{token_interval.end_index}." ) return_string = tokenizer_lib.tokens_text(tokenized_text, token_interval) - logging.debug( + logger.debug( "Token util returns string: %s for tokenized_text: %s, token_interval:" " %s", return_string, diff --git a/langextract/core/base_model.py b/langextract/core/base_model.py index eda41836..a873b26e 100644 --- a/langextract/core/base_model.py +++ b/langextract/core/base_model.py @@ -16,167 +16,421 @@ from __future__ import annotations import abc +import asyncio from collections.abc import Iterator, Sequence +import concurrent.futures +import dataclasses import json -from typing import Any, Mapping +from typing import Any, Mapping, Optional, Sequence as TypingSequence import yaml from langextract.core import schema from langextract.core import types -__all__ = ['BaseLanguageModel'] +__all__ = ['BaseLanguageModel', 'LLMProvider', 'GenerateResult', 'Usage'] -class BaseLanguageModel(abc.ABC): - """An abstract inference class for managing LLM inference. +@dataclasses.dataclass +class Usage: + """Token usage information. - Attributes: - _constraint: A `Constraint` object specifying constraints for model output. - """ - - def __init__(self, constraint: types.Constraint | None = None, **kwargs: Any): - """Initializes the BaseLanguageModel with an optional constraint. - - Args: - constraint: Applies constraints when decoding the output. Defaults to no - constraint. - **kwargs: Additional keyword arguments passed to the model. + Attributes: + input_tokens: Number of tokens in the input prompt. + output_tokens: Number of tokens in the generated output. + total_tokens: Total number of tokens used. """ - self._constraint = constraint or types.Constraint() - self._schema: schema.BaseSchema | None = None - self._fence_output_override: bool | None = None - self._extra_kwargs: dict[str, Any] = kwargs.copy() - - @classmethod - def get_schema_class(cls) -> type[Any] | None: - """Return the schema class this provider supports.""" - return None - def apply_schema(self, schema_instance: schema.BaseSchema | None) -> None: - """Apply a schema instance to this provider. + input_tokens: int | None = None + output_tokens: int | None = None + total_tokens: int | None = None - Optional method that providers can override to store the schema instance - for runtime use. The default implementation stores it as _schema. - - Args: - schema_instance: The schema instance to apply, or None to clear. - """ - self._schema = schema_instance - @property - def schema(self) -> schema.BaseSchema | None: - """The current schema instance if one is configured. +@dataclasses.dataclass +class GenerateResult: + """Result of a generate() call. - Returns: - The schema instance or None if no schema is applied. + Attributes: + text: The generated text output. + usage: Optional token usage information. + raw_response: The raw response from the provider API. """ - return self._schema - def set_fence_output(self, fence_output: bool | None) -> None: - """Set explicit fence output preference. + text: str + usage: Usage | None = None + raw_response: Any = None - Args: - fence_output: True to force fences, False to disable, None for auto. - """ - if not hasattr(self, '_fence_output_override'): - self._fence_output_override = None - self._fence_output_override = fence_output - - @property - def requires_fence_output(self) -> bool: - """Whether this model requires fence output for parsing. - - Uses explicit override if set, otherwise computes from schema. - Returns True if no schema or schema doesn't require raw output. - """ - if ( - hasattr(self, '_fence_output_override') - and self._fence_output_override is not None - ): - return self._fence_output_override - schema_obj = self.schema - if schema_obj is None: - return True - return not schema_obj.requires_raw_output +class LLMProvider(abc.ABC): + """Abstract base class for LLM providers. - def merge_kwargs( - self, runtime_kwargs: Mapping[str, Any] | None = None - ) -> dict[str, Any]: - """Merge stored extra kwargs with runtime kwargs. + This interface defines the contract that all LLM providers must implement. + It provides both synchronous and asynchronous generation methods, as well + as context manager support for automatic resource cleanup. - Runtime kwargs take precedence over stored kwargs. + Example usage: + # Synchronous usage + provider = MyProvider(api_key="...") + result = provider.generate("Hello", model="gpt-4") + print(result.text) + provider.close() - Args: - runtime_kwargs: Kwargs provided at inference time, or None. + # As context manager (auto-close) + with MyProvider(api_key="...") as provider: + result = provider.generate("Hello") - Returns: - Merged kwargs dictionary. + # Async usage + async with MyProvider(api_key="...") as provider: + result = await provider.agenerate("Hello") """ - base = getattr(self, '_extra_kwargs', {}) or {} - incoming = dict(runtime_kwargs or {}) - return {**base, **incoming} - - @abc.abstractmethod - def infer( - self, batch_prompts: Sequence[str], **kwargs - ) -> Iterator[Sequence[types.ScoredOutput]]: - """Implements language model inference. - - Args: - batch_prompts: Batch of inputs for inference. Single element list can be - used for a single input. - **kwargs: Additional arguments for inference, like temperature and - max_decode_steps. - - Returns: Batch of Sequence of probable output text outputs, sorted by - descending score. - """ - - def infer_batch( - self, prompts: Sequence[str], batch_size: int = 32 # pylint: disable=unused-argument - ) -> list[list[types.ScoredOutput]]: - """Batch inference with configurable batch size. - - This is a convenience method that collects all results from infer(). - Args: - prompts: List of prompts to process. - batch_size: Batch size (currently unused, for future optimization). - - Returns: - List of lists of ScoredOutput objects. + @property + @abc.abstractmethod + def name(self) -> str: + """Return the provider name identifier. + + Returns: + A short string identifying the provider (e.g., "gemini", "openai", "mock"). + """ + ... + + @property + @abc.abstractmethod + def supported_models(self) -> TypingSequence[str]: + """Return the model ID patterns supported by this provider. + + Returns: + A list of regex patterns matching model IDs that this provider handles. + """ + ... + + @abc.abstractmethod + def generate( + self, prompt: str, model: str | None = None, **kwargs: Any + ) -> GenerateResult: + """Generate text synchronously. + + Args: + prompt: The input prompt to send to the model. + model: Optional model ID to use. If None, uses the provider's default. + **kwargs: Additional provider-specific arguments. + + Returns: + A GenerateResult containing the generated text and metadata. + """ + ... + + @abc.abstractmethod + async def agenerate( + self, prompt: str, model: str | None = None, **kwargs: Any + ) -> GenerateResult: + """Generate text asynchronously. + + Args: + prompt: The input prompt to send to the model. + model: Optional model ID to use. If None, uses the provider's default. + **kwargs: Additional provider-specific arguments. + + Returns: + A GenerateResult containing the generated text and metadata. + """ + ... + + @abc.abstractmethod + def close(self) -> None: + """Clean up any resources held by this provider. + + This method should be called when the provider is no longer needed. + It closes HTTP clients, connection pools, etc. + """ + ... + + def __enter__(self) -> LLMProvider: + """Enter context manager. + + Returns: + Self for use in with statement. + """ + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any | None, + ) -> None: + """Exit context manager, calling close(). + + Args: + exc_type: Exception type if an exception occurred. + exc_val: Exception value if an exception occurred. + exc_tb: Traceback if an exception occurred. + """ + self.close() + + async def __aenter__(self) -> LLMProvider: + """Enter async context manager. + + Returns: + Self for use in async with statement. + """ + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any | None, + ) -> None: + """Exit async context manager, calling close(). + + Args: + exc_type: Exception type if an exception occurred. + exc_val: Exception value if an exception occurred. + exc_tb: Traceback if an exception occurred. + """ + self.close() + + +class BaseLanguageModel(LLMProvider): + """An abstract inference class for managing LLM inference. + + This class implements the LLMProvider interface with default implementations + for backward compatibility. Providers should inherit from this class and + implement infer() at minimum. + + Attributes: + _constraint: A `Constraint` object specifying constraints for model output. """ - results = [] - for output in self.infer(prompts): - results.append(list(output)) - return results - - def parse_output(self, output: str) -> Any: - """Parses model output as JSON or YAML. - Note: This expects raw JSON/YAML without code fences. - Code fence extraction is handled by resolver.py. - - Args: - output: Raw output string from the model. - - Returns: - Parsed Python object (dict or list). - - Raises: - ValueError: If output cannot be parsed as JSON or YAML. - """ - # Check if we have a format_type attribute (providers should set this) - format_type = getattr(self, 'format_type', types.FormatType.JSON) - - try: - if format_type == types.FormatType.JSON: - return json.loads(output) - else: - return yaml.safe_load(output) - except Exception as e: - raise ValueError( - f'Failed to parse output as {format_type.name}: {str(e)}' - ) from e + def __init__(self, constraint: types.Constraint | None = None, **kwargs: Any): + """Initializes the BaseLanguageModel with an optional constraint. + + Args: + constraint: Applies constraints when decoding the output. Defaults to no + constraint. + **kwargs: Additional keyword arguments passed to the model. + """ + self._constraint = constraint or types.Constraint() + self._schema: schema.BaseSchema | None = None + self._fence_output_override: bool | None = None + self._extra_kwargs: dict[str, Any] = kwargs.copy() + + @property + def name(self) -> str: + """Return the provider name identifier. + + Default implementation derives the name from the class name by: + 1. Stripping "LanguageModel" suffix if present + 2. Converting to lowercase + + Returns: + The provider name. + """ + class_name = self.__class__.__name__ + if class_name.endswith("LanguageModel"): + return class_name[:-13].lower() + return class_name.lower() + + @property + def supported_models(self) -> TypingSequence[str]: + """Return the model ID patterns supported by this provider. + + Default implementation returns an empty list. Providers should override + this to return the actual patterns they support. + + Returns: + List of regex patterns. + """ + return [] + + @classmethod + def get_schema_class(cls) -> type[Any] | None: + """Return the schema class this provider supports.""" + return None + + def apply_schema(self, schema_instance: schema.BaseSchema | None) -> None: + """Apply a schema instance to this provider. + + Optional method that providers can override to store the schema instance + for runtime use. The default implementation stores it as _schema. + + Args: + schema_instance: The schema instance to apply, or None to clear. + """ + self._schema = schema_instance + + @property + def schema(self) -> schema.BaseSchema | None: + """The current schema instance if one is configured. + + Returns: + The schema instance or None if no schema is applied. + """ + return self._schema + + def set_fence_output(self, fence_output: bool | None) -> None: + """Set explicit fence output preference. + + Args: + fence_output: True to force fences, False to disable, None for auto. + """ + if not hasattr(self, '_fence_output_override'): + self._fence_output_override = None + self._fence_output_override = fence_output + + @property + def requires_fence_output(self) -> bool: + """Whether this model requires fence output for parsing. + + Uses explicit override if set, otherwise computes from schema. + Returns True if no schema or schema doesn't require raw output. + """ + if ( + hasattr(self, '_fence_output_override') + and self._fence_output_override is not None + ): + return self._fence_output_override + + schema_obj = self.schema + if schema_obj is None: + return True + return not schema_obj.requires_raw_output + + def merge_kwargs( + self, runtime_kwargs: Mapping[str, Any] | None = None + ) -> dict[str, Any]: + """Merge stored extra kwargs with runtime kwargs. + + Runtime kwargs take precedence over stored kwargs. + + Args: + runtime_kwargs: Kwargs provided at inference time, or None. + + Returns: + Merged kwargs dictionary. + """ + base = getattr(self, '_extra_kwargs', {}) or {} + incoming = dict(runtime_kwargs or {}) + return {**base, **incoming} + + @abc.abstractmethod + def infer( + self, batch_prompts: Sequence[str], **kwargs + ) -> Iterator[Sequence[types.ScoredOutput]]: + """Implements language model inference. + + Args: + batch_prompts: Batch of inputs for inference. Single element list can be + used for a single input. + **kwargs: Additional arguments for inference, like temperature and + max_decode_steps. + + Returns: Batch of Sequence of probable output text outputs, sorted by + descending score. + """ + + def infer_batch( + self, prompts: Sequence[str], batch_size: int = 32 + ) -> list[list[types.ScoredOutput]]: + """Batch inference with configurable batch size. + + This is a convenience method that collects all results from infer(). + + Args: + prompts: List of prompts to process. + batch_size: Batch size (currently unused, for future optimization). + + Returns: + List of lists of ScoredOutput objects. + """ + results = [] + for output in self.infer(prompts): + results.append(list(output)) + return results + + def generate( + self, prompt: str, model: str | None = None, **kwargs: Any + ) -> GenerateResult: + """Generate text synchronously (default implementation based on infer()). + + This is a fallback implementation for providers that haven't implemented + the new generate() API yet. It wraps infer() to provide the new interface. + + Args: + prompt: The input prompt. + model: Optional model ID (passed to infer() if applicable). + **kwargs: Additional arguments. + + Returns: + A GenerateResult with the output. + """ + results = list(self.infer([prompt], **kwargs)) + if results and results[0]: + output_text = results[0][0].output + else: + output_text = "" + + return GenerateResult( + text=output_text, + usage=None, + raw_response={"infer_results": results}, + ) + + async def agenerate( + self, prompt: str, model: str | None = None, **kwargs: Any + ) -> GenerateResult: + """Generate text asynchronously (default implementation). + + Default implementation runs the synchronous generate() in a thread pool. + Providers should override this for native async support. + + Args: + prompt: The input prompt. + model: Optional model ID. + **kwargs: Additional arguments. + + Returns: + A GenerateResult with the output. + """ + loop = asyncio.get_event_loop() + with concurrent.futures.ThreadPoolExecutor() as executor: + return await loop.run_in_executor( + executor, + lambda: self.generate(prompt, model, **kwargs), + ) + + def close(self) -> None: + """Clean up resources (default no-op implementation). + + Providers should override this if they hold resources like HTTP clients + that need to be explicitly closed. + """ + pass + + def parse_output(self, output: str) -> Any: + """Parses model output as JSON or YAML. + + Note: This expects raw JSON/YAML without code fences. + Code fence extraction is handled by resolver.py. + + Args: + output: Raw output string from the model. + + Returns: + Parsed Python object (dict or list). + + Raises: + ValueError: If output cannot be parsed as JSON or YAML. + """ + format_type = getattr(self, 'format_type', types.FormatType.JSON) + + try: + if format_type == types.FormatType.JSON: + return json.loads(output) + else: + return yaml.safe_load(output) + except Exception as e: + raise ValueError( + f'Failed to parse output as {format_type.name}: {str(e)}' + ) from e diff --git a/langextract/core/debug_utils.py b/langextract/core/debug_utils.py index 603dfcef..c4012f00 100644 --- a/langextract/core/debug_utils.py +++ b/langextract/core/debug_utils.py @@ -22,14 +22,9 @@ import time from typing import Any, Callable, Mapping -from absl import logging as absl_logging +from langextract._logging import configure, get_logger -_LOG = logging.getLogger("langextract.debug") - -# Add NullHandler to prevent "No handler found" warnings -_langextract_logger = logging.getLogger("langextract") -if not _langextract_logger.handlers: - _langextract_logger.addHandler(logging.NullHandler()) +logger = get_logger("debug") # Sensitive keys to redact _REDACT_KEYS = { @@ -111,7 +106,6 @@ def debug_log_calls(fn: Callable) -> Callable: @functools.wraps(fn) def wrapper(*args, **kwargs): - logger = _LOG if not logger.isEnabledFor(logging.DEBUG): return fn(*args, **kwargs) @@ -149,37 +143,9 @@ def wrapper(*args, **kwargs): def configure_debug_logging() -> None: - """Enable debug logging for the 'langextract' namespace only.""" - logger = logging.getLogger("langextract") - - # Skip if we already added our handler - our_handler_exists = any( - isinstance(h, logging.StreamHandler) - and getattr(h, "langextract_debug", False) - for h in logger.handlers - ) - if our_handler_exists: - return - - # Respect host handlers - only set level if they exist - non_null_handlers = [ - h for h in logger.handlers if not isinstance(h, logging.NullHandler) - ] - - if non_null_handlers: - logger.setLevel(logging.DEBUG) - else: - logger.setLevel(logging.DEBUG) - handler = logging.StreamHandler() - handler.setLevel(logging.DEBUG) - fmt = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - handler.setFormatter(logging.Formatter(fmt)) - handler.langextract_debug = True - logger.addHandler(handler) - logger.propagate = False - - # Best-effort absl configuration - try: - absl_logging.set_verbosity(absl_logging.DEBUG) - except Exception: - pass + """Enable debug logging for the 'langextract' namespace only. + + This function is kept for backward compatibility. + Prefer using langextract.configure(log_level="DEBUG") instead. + """ + configure(log_level="DEBUG") diff --git a/langextract/io.py b/langextract/io.py index 7bfedf36..e20a40b9 100644 --- a/langextract/io.py +++ b/langextract/io.py @@ -27,11 +27,14 @@ import pandas as pd import requests +from langextract._logging import get_logger from langextract import data_lib from langextract import progress from langextract.core import data from langextract.core import exceptions +logger = get_logger(__name__) + DEFAULT_TIMEOUT_SECONDS = 30 @@ -293,7 +296,7 @@ def download_text_from_url( for ct in ['text/', 'application/json', 'application/xml'] ): # Try to proceed anyway, but warn - print(f"Warning: Content-Type '{content_type}' may not be text-based") + logger.warning("Content-Type '%s' may not be text-based", content_type) # Get content length for progress bar total_size = int(response.headers.get('Content-Length', 0)) diff --git a/langextract/plugins.py b/langextract/plugins.py index ca76e149..de6999ae 100644 --- a/langextract/plugins.py +++ b/langextract/plugins.py @@ -23,10 +23,11 @@ import importlib from importlib import metadata -from absl import logging - +from langextract._logging import get_logger from langextract.core import base_model +logger = get_logger(__name__) + __all__ = ["available_providers", "get_provider_class"] # Static mapping for built-in providers (always available) @@ -78,7 +79,7 @@ def _discovered() -> dict[str, str]: discovered.setdefault(ep.name, value) if discovered: - logging.debug( + logger.debug( "Discovered third-party providers: %s", list(discovered.keys()) ) @@ -170,7 +171,7 @@ def _load_class(spec: str) -> type[base_model.BaseLanguageModel]: f" {missing}" ) - logging.warning( + logger.warning( "Provider %s does not inherit from BaseLanguageModel but appears" " compatible", cls, diff --git a/langextract/progress.py b/langextract/progress.py index 37463392..ca96b533 100644 --- a/langextract/progress.py +++ b/langextract/progress.py @@ -12,14 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Progress and visualization utilities for LangExtract.""" +"""Progress and visualization utilities for LangExtract. + +This module provides progress bars and completion messages for terminal display. +Note that these utilities use direct print() with ANSI color codes for visual +appeal in interactive terminals. For production logging (e.g., to files), use +the unified logging system via langextract._logging.get_logger(). + +Configuration: +- progress_enabled: Controls whether progress bars and print_* messages are shown. + Can be set via: + - langextract.configure(progress_enabled=False) + - Environment variable LANGEXTRACT_PROGRESS_ENABLED=0 + - Default: True +""" from __future__ import annotations +import sys from typing import Any import urllib.parse import tqdm +from langextract._logging import get_logger + # ANSI color codes for terminal output BLUE = "\033[94m" GREEN = "\033[92m" @@ -30,6 +46,38 @@ # Google Blue color for progress bars GOOGLE_BLUE = "#4285F4" +logger = get_logger(__name__) + + +def _is_progress_enabled() -> bool: + """Check if progress display is enabled. + + Returns: + True if progress bars and print_* messages should be shown. + """ + try: + from langextract._config import get_global_config + + config = get_global_config() + return config.progress_enabled + except Exception: + return True + + +def _strip_ansi(text: str) -> str: + """Strip ANSI color codes from text. + + Args: + text: Text containing ANSI codes. + + Returns: + Text without ANSI codes. + """ + import re + + ansi_escape = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]") + return ansi_escape.sub("", text) + def create_download_progress_bar( total_size: int, url: str, ncols: int = 100, max_url_length: int = 50 @@ -61,6 +109,8 @@ def create_download_progress_bar( else: url_display = url + disable = not _is_progress_enabled() + return tqdm.tqdm( total=total_size, unit="B", @@ -75,6 +125,7 @@ def create_download_progress_bar( ), colour=GOOGLE_BLUE, ncols=ncols, + disable=disable, ) @@ -86,18 +137,20 @@ def create_extraction_progress_bar( Args: iterable: The iterable to wrap with progress bar. model_info: Optional model information to display (e.g., "gemini-1.5-pro"). - disable: Whether to disable the progress bar. + disable: Whether to disable the progress bar (overrides config). Returns: A configured tqdm progress bar. """ desc = format_extraction_progress(model_info) + effective_disable = disable or not _is_progress_enabled() + return tqdm.tqdm( iterable, desc=desc, bar_format="{desc} [{elapsed}]", - disable=disable, + disable=effective_disable, dynamic_ncols=True, ) @@ -107,21 +160,42 @@ def print_download_complete( ) -> None: """Print a styled download completion message. + This function outputs to stdout with ANSI colors for terminal display. + It also logs the same information (without ANSI codes) at INFO level. + Args: char_count: Number of characters downloaded. word_count: Number of words downloaded. filename: Name of the downloaded file. """ - print( + message = ( f"{GREEN}✓{RESET} Downloaded {BOLD}{char_count:,}{RESET} characters " - f"({BOLD}{word_count:,}{RESET} words) from {BLUE}{filename}{RESET}", - flush=True, + f"({BOLD}{word_count:,}{RESET} words) from {BLUE}{filename}{RESET}" ) + logger.info( + "Downloaded %d characters (%d words) from %s", + char_count, + word_count, + filename, + ) + + if _is_progress_enabled(): + print(message, flush=True, file=sys.stdout) + def print_extraction_complete() -> None: - """Print a generic extraction completion message.""" - print(f"{GREEN}✓{RESET} Extraction processing complete", flush=True) + """Print a generic extraction completion message. + + This function outputs to stdout with ANSI colors for terminal display. + It also logs the same information (without ANSI codes) at INFO level. + """ + message = f"{GREEN}✓{RESET} Extraction processing complete" + + logger.info("Extraction processing complete") + + if _is_progress_enabled(): + print(message, flush=True, file=sys.stdout) def print_extraction_summary( @@ -133,6 +207,9 @@ def print_extraction_summary( ) -> None: """Print a styled extraction summary with optional performance metrics. + This function outputs to stdout with ANSI colors for terminal display. + It also logs the same information (without ANSI codes) at INFO level. + Args: num_extractions: Total number of extractions. unique_classes: Number of unique extraction classes. @@ -140,28 +217,38 @@ def print_extraction_summary( chars_processed: Optional number of characters processed. num_chunks: Optional number of chunks processed. """ - print( + main_message = ( f"{GREEN}✓{RESET} Extracted {BOLD}{num_extractions}{RESET} entities " - f"({BOLD}{unique_classes}{RESET} unique types)", - flush=True, + f"({BOLD}{unique_classes}{RESET} unique types)" ) - if elapsed_time is not None: - metrics = [] + logger.info("Extracted %d entities (%d unique types)", num_extractions, unique_classes) + + metrics: list[str] = [] + log_metrics: list[str] = [] + if elapsed_time is not None: # Time metrics.append(f"Time: {BOLD}{elapsed_time:.2f}s{RESET}") + log_metrics.append(f"Time: {elapsed_time:.2f}s") # Speed if chars_processed is not None and elapsed_time > 0: speed = chars_processed / elapsed_time metrics.append(f"Speed: {BOLD}{speed:,.0f}{RESET} chars/sec") + log_metrics.append(f"Speed: {speed:,.0f} chars/sec") if num_chunks is not None: metrics.append(f"Chunks: {BOLD}{num_chunks}{RESET}") + log_metrics.append(f"Chunks: {num_chunks}") + + if log_metrics: + logger.info("Performance: %s", ", ".join(log_metrics)) + if _is_progress_enabled(): + print(main_message, flush=True, file=sys.stdout) for metric in metrics: - print(f" {CYAN}•{RESET} {metric}", flush=True) + print(f" {CYAN}•{RESET} {metric}", flush=True, file=sys.stdout) def create_save_progress_bar( @@ -177,12 +264,14 @@ def create_save_progress_bar( A configured tqdm progress bar. """ filename = output_path.split("/")[-1] + effective_disable = disable or not _is_progress_enabled() + return tqdm.tqdm( desc=( f"{BLUE}{BOLD}LangExtract{RESET}: Saving to {GREEN}{filename}{RESET}" ), unit=" docs", - disable=disable, + disable=effective_disable, ) @@ -200,6 +289,8 @@ def create_load_progress_bar( A configured tqdm progress bar. """ filename = file_path.split("/")[-1] + effective_disable = disable or not _is_progress_enabled() + if total_size: return tqdm.tqdm( total=total_size, @@ -208,7 +299,7 @@ def create_load_progress_bar( ), unit="B", unit_scale=True, - disable=disable, + disable=effective_disable, ) else: return tqdm.tqdm( @@ -216,39 +307,53 @@ def create_load_progress_bar( f"{BLUE}{BOLD}LangExtract{RESET}: Loading {GREEN}{filename}{RESET}" ), unit=" docs", - disable=disable, + disable=effective_disable, ) def print_save_complete(num_docs: int, file_path: str) -> None: """Print a save completion message. + This function outputs to stdout with ANSI colors for terminal display. + It also logs the same information (without ANSI codes) at INFO level. + Args: num_docs: Number of documents saved. file_path: Path to the saved file. """ filename = file_path.split("/")[-1] - print( + message = ( f"{GREEN}✓{RESET} Saved {BOLD}{num_docs}{RESET} documents to" - f" {GREEN}{filename}{RESET}", - flush=True, + f" {GREEN}{filename}{RESET}" ) + logger.info("Saved %d documents to %s", num_docs, filename) + + if _is_progress_enabled(): + print(message, flush=True, file=sys.stdout) + def print_load_complete(num_docs: int, file_path: str) -> None: """Print a load completion message. + This function outputs to stdout with ANSI colors for terminal display. + It also logs the same information (without ANSI codes) at INFO level. + Args: num_docs: Number of documents loaded. file_path: Path to the loaded file. """ filename = file_path.split("/")[-1] - print( + message = ( f"{GREEN}✓{RESET} Loaded {BOLD}{num_docs}{RESET} documents from" - f" {GREEN}{filename}{RESET}", - flush=True, + f" {GREEN}{filename}{RESET}" ) + logger.info("Loaded %d documents from %s", num_docs, filename) + + if _is_progress_enabled(): + print(message, flush=True, file=sys.stdout) + def get_model_info(language_model: Any) -> str | None: """Extract model information from a language model instance. @@ -313,13 +418,11 @@ def format_extraction_progress( Returns: Formatted description string. """ - # Base description if model_info: desc = f"{BLUE}{BOLD}LangExtract{RESET}: model={GREEN}{model_info}{RESET}" else: desc = f"{BLUE}{BOLD}LangExtract{RESET}: Processing" - # Add stats if provided if current_chars is not None and processed_chars is not None: current_str = f"{GREEN}{current_chars:,}{RESET}" processed_str = f"{GREEN}{processed_chars:,}{RESET}" @@ -340,14 +443,16 @@ def create_pass_progress_bar( Returns: A configured tqdm progress bar. """ + effective_disable = disable or not _is_progress_enabled() desc = f"{BLUE}{BOLD}LangExtract{RESET}: Extraction passes" + return tqdm.tqdm( total=total_passes, desc=desc, bar_format=( "{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}]" ), - disable=disable, + disable=effective_disable, colour=GOOGLE_BLUE, ncols=100, ) diff --git a/langextract/prompt_validation.py b/langextract/prompt_validation.py index 32543390..2d09c624 100644 --- a/langextract/prompt_validation.py +++ b/langextract/prompt_validation.py @@ -21,9 +21,10 @@ import dataclasses import enum -from absl import logging - +from langextract._logging import get_logger from langextract import resolver + +logger = get_logger(__name__) from langextract.core import data from langextract.core import tokenizer as tokenizer_lib @@ -238,11 +239,11 @@ def handle_alignment_report( for issue in report.issues: if issue.issue_kind is IssueKind.NON_EXACT: - logging.warning( + logger.warning( "Prompt alignment: non-exact match: %s", issue.short_msg() ) else: - logging.warning( + logger.warning( "Prompt alignment: FAILED to align: %s", issue.short_msg() ) diff --git a/langextract/providers/__init__.py b/langextract/providers/__init__.py index e5840d2b..70936b74 100644 --- a/langextract/providers/__init__.py +++ b/langextract/providers/__init__.py @@ -22,141 +22,230 @@ import importlib from importlib import metadata import os - -from absl import logging - +import re + +from langextract._logging import get_logger +from langextract.core.base_model import ( + GenerateResult, + LLMProvider, + Usage, +) from langextract.providers import builtin_registry from langextract.providers import router +from langextract.providers.registry import ( + MockProvider, + ProviderInfo, + ProviderRegistry, +) -registry = router # Backward compat alias +logger = get_logger(__name__) + +registry = router __all__ = [ "gemini", "openai", "ollama", "router", - "registry", # Backward compat + "registry", "schemas", "load_plugins_once", "load_builtins_once", + "GenerateResult", + "LLMProvider", + "Usage", + "MockProvider", + "ProviderInfo", + "ProviderRegistry", ] -# Track provider loading for lazy initialization -_plugins_loaded = False # pylint: disable=invalid-name -_builtins_loaded = False # pylint: disable=invalid-name +_plugins_loaded = False +_builtins_loaded = False def load_builtins_once() -> None: - """Load built-in providers to register their patterns. + """Load built-in providers to register their patterns. + + Idempotent function that ensures provider patterns are available + for model resolution. Uses lazy registration to ensure providers + can be re-registered after registry.clear() even if their modules + are already in sys.modules. + """ + global _builtins_loaded + + if _builtins_loaded: + return + + for config in builtin_registry.BUILTIN_PROVIDERS: + router.register_lazy( + *config["patterns"], + target=config["target"], + priority=config["priority"], + ) - Idempotent function that ensures provider patterns are available - for model resolution. Uses lazy registration to ensure providers - can be re-registered after registry.clear() even if their modules - are already in sys.modules. - """ - global _builtins_loaded # pylint: disable=global-statement + _builtins_loaded = True - if _builtins_loaded: - return - # Register built-ins lazily so they can be re-registered after a registry.clear() - # even if their modules were already imported earlier in the test run. - for config in builtin_registry.BUILTIN_PROVIDERS: - router.register_lazy( - *config["patterns"], - target=config["target"], - priority=config["priority"], - ) +def _parse_entry_point_value(value: str) -> tuple[str, int | None]: + """Parse entry point value for priority suffix. - _builtins_loaded = True + Supports format: "module.path:ClassName" or "module.path:ClassName:priority=N" + Args: + value: The entry point value string. + + Returns: + Tuple of (target_path, priority). priority is None if not specified. + """ + priority_match = re.search(r':priority=(\d+)$', value) + if priority_match: + target_path = value[:priority_match.start()] + priority = int(priority_match.group(1)) + return target_path, priority + return value, None -def load_plugins_once() -> None: - """Load provider plugins from installed packages. - - Discovers and loads langextract provider plugins using entry points. - This function is idempotent - multiple calls have no effect. - """ - global _plugins_loaded # pylint: disable=global-statement - if _plugins_loaded: - return - - if os.environ.get("LANGEXTRACT_DISABLE_PLUGINS", "").lower() in ( - "1", - "true", - "yes", - ): - logging.info("Plugin loading disabled via LANGEXTRACT_DISABLE_PLUGINS") - _plugins_loaded = True - return - - load_builtins_once() - - try: - - eps = metadata.entry_points() - - # Try different APIs based on what's available - if hasattr(eps, "select"): - # Python 3.10+ API - provider_eps = eps.select(group="langextract.providers") - elif hasattr(eps, "get"): - # Python 3.9 API - provider_eps = eps.get("langextract.providers", []) - else: - # Fallback for older versions - provider_eps = [ - ep - for ep in eps - if getattr(ep, "group", None) == "langextract.providers" - ] - - for entry_point in provider_eps: - try: - - provider_class = entry_point.load() - logging.info("Loaded provider plugin: %s", entry_point.name) - - if hasattr(provider_class, "get_model_patterns"): - patterns = provider_class.get_model_patterns() - for pattern in patterns: - router.register( - pattern, - priority=getattr( - provider_class, - "pattern_priority", - 20, # Default plugin priority - ), - )(provider_class) - logging.info( - "Registered %d patterns for %s", len(patterns), entry_point.name - ) - except Exception as e: - logging.warning( - "Failed to load provider plugin %s: %s", entry_point.name, e - ) - except Exception as e: - logging.warning("Error discovering provider plugins: %s", e) +def _load_entry_point_class(entry_point) -> tuple[type, int | None]: + """Load a class from an entry point. - _plugins_loaded = True + Args: + entry_point: The entry point object. + + Returns: + Tuple of (provider_class, priority). priority is None if not specified. + """ + value = getattr(entry_point, 'value', None) + priority = None + + if value: + value, priority = _parse_entry_point_value(value) + + provider_class = entry_point.load() + return provider_class, priority + + +def load_plugins_once() -> None: + """Load provider plugins from installed packages. + + Discovers and loads langextract provider plugins using entry points. + This function is idempotent - multiple calls have no effect. + + Entry point format: + [project.entry-points."langextract.providers"] + my_provider = "my_pkg.provider:MyProvider" + my_provider_high_prio = "my_pkg.provider:MyProvider:priority=100" + + Priority suffix syntax: + - ":priority=N" can be appended to set registration priority + - Higher priority wins when multiple providers match the same pattern + - Default plugin priority is 20 if not specified + """ + global _plugins_loaded + if _plugins_loaded: + return + + if os.environ.get("LANGEXTRACT_DISABLE_PLUGINS", "").lower() in ( + "1", + "true", + "yes", + ): + logger.info("Plugin loading disabled via LANGEXTRACT_DISABLE_PLUGINS") + _plugins_loaded = True + return + + load_builtins_once() + + try: + eps = metadata.entry_points() + + if hasattr(eps, "select"): + provider_eps = eps.select(group="langextract.providers") + elif hasattr(eps, "get"): + provider_eps = eps.get("langextract.providers", []) + else: + provider_eps = [ + ep + for ep in eps + if getattr(ep, "group", None) == "langextract.providers" + ] + + for entry_point in provider_eps: + try: + provider_class, ep_priority = _load_entry_point_class(entry_point) + logger.info("Loaded provider plugin: %s", entry_point.name) + + if hasattr(provider_class, "get_model_patterns"): + patterns = provider_class.get_model_patterns() + + class_priority = getattr( + provider_class, + "pattern_priority", + 20, + ) + + priority = ep_priority if ep_priority is not None else class_priority + + for pattern in patterns: + router.register( + pattern, + priority=priority, + )(provider_class) + logger.info( + "Registered %d patterns for %s with priority %d", + len(patterns), + entry_point.name, + priority, + ) + else: + class_name = provider_class.__name__ + if class_name.endswith("LanguageModel"): + base_name = class_name[:-13].lower() + else: + base_name = class_name.lower() + + class_priority = getattr( + provider_class, + "pattern_priority", + 20, + ) + priority = ep_priority if ep_priority is not None else class_priority + + router.register( + f"^{base_name}", + priority=priority, + )(provider_class) + logger.info( + "Registered provider %s with pattern ^%s and priority %d", + entry_point.name, + base_name, + priority, + ) + except Exception as e: + logger.warning( + "Failed to load provider plugin %s: %s", entry_point.name, e + ) + + except Exception as e: + logger.warning("Error discovering provider plugins: %s", e) + + _plugins_loaded = True def _reset_for_testing() -> None: - """Reset plugin loading state for testing. Should only be used in tests.""" - global _plugins_loaded, _builtins_loaded # pylint: disable=global-statement - _plugins_loaded = False - _builtins_loaded = False + """Reset plugin loading state for testing. Should only be used in tests.""" + global _plugins_loaded, _builtins_loaded + _plugins_loaded = False + _builtins_loaded = False def __getattr__(name: str): - """Lazy loading for submodules.""" - if name == "router": - return importlib.import_module("langextract.providers.router") - elif name == "schemas": - return importlib.import_module("langextract.providers.schemas") - elif name == "_plugins_loaded": - return _plugins_loaded - elif name == "_builtins_loaded": - return _builtins_loaded - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + """Lazy loading for submodules.""" + if name == "router": + return importlib.import_module("langextract.providers.router") + elif name == "schemas": + return importlib.import_module("langextract.providers.schemas") + elif name == "_plugins_loaded": + return _plugins_loaded + elif name == "_builtins_loaded": + return _builtins_loaded + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/langextract/providers/base.py b/langextract/providers/base.py new file mode 100644 index 00000000..2efae6e2 --- /dev/null +++ b/langextract/providers/base.py @@ -0,0 +1,33 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base provider interface for LangExtract. + +This module re-exports the base provider classes from langextract.core.base_model +for convenience. New code should import from langextract.core.base_model directly. +""" + +from __future__ import annotations + +from langextract.core.base_model import ( + GenerateResult, + LLMProvider, + Usage, +) + +__all__ = [ + 'GenerateResult', + 'LLMProvider', + 'Usage', +] diff --git a/langextract/providers/gemini.py b/langextract/providers/gemini.py index a82afe1e..6835735d 100644 --- a/langextract/providers/gemini.py +++ b/langextract/providers/gemini.py @@ -21,14 +21,15 @@ import dataclasses from typing import Any, Final, Iterator, Sequence -from absl import logging - +from langextract._logging import get_logger from langextract.core import base_model from langextract.core import data from langextract.core import exceptions from langextract.core import schema from langextract.core import types as core_types from langextract.providers import gemini_batch + +logger = get_logger(__name__) from langextract.providers import patterns from langextract.providers import router from langextract.providers import schemas @@ -166,7 +167,7 @@ def __init__( ) if self.api_key and self.vertexai: - logging.warning( + logger.warning( 'Both API key and Vertex AI configuration provided. ' 'API key will take precedence for authentication.' ) @@ -294,7 +295,7 @@ def infer( yield [core_types.ScoredOutput(score=1.0, output=text)] return else: - logging.info( + logger.info( 'Gemini batch mode enabled but prompt count (%d) is below the' ' threshold (%d); using real-time API. Submit at least %d prompts' ' to trigger batch mode.', diff --git a/langextract/providers/gemini_batch.py b/langextract/providers/gemini_batch.py index 220a262d..8dfc5283 100644 --- a/langextract/providers/gemini_batch.py +++ b/langextract/providers/gemini_batch.py @@ -38,8 +38,10 @@ from typing import Any, Callable, Protocol import uuid -from absl import logging +from langextract._logging import get_logger from google import genai + +logger = get_logger(__name__) from google.api_core import exceptions as google_exceptions from google.cloud import storage @@ -130,7 +132,7 @@ def from_dict(cls, d: dict | None) -> BatchConfig: unknown = sorted(set(d.keys()) - valid_keys) if unknown: - logging.warning( + logger.warning( "Ignoring unknown batch config keys: %s", ", ".join(unknown) ) cfg = cls(**filtered_dict) @@ -152,8 +154,8 @@ def from_dict(cls, d: dict | None) -> BatchConfig: def _default_job_create_callback(job: Any) -> None: """Default callback to log batch job details.""" - logging.info("Batch job created successfully: %s", job.name) - logging.info("Job State: %s", job.state) + logger.info("Batch job created successfully: %s", job.name) + logger.info("Job State: %s", job.state) # Extract project and job ID for console URL try: # job.name format: projects/{project}/locations/{location}/batchPredictionJobs/{job_id} @@ -162,7 +164,7 @@ def _default_job_create_callback(job: Any) -> None: job_id = parts[-1] location = parts[3] project = parts[1] - logging.info( + logger.info( "Job Console URL:" " https://console.cloud.google.com/vertex-ai/locations/%s/batch-predictions/%s?project=%s", location, @@ -245,13 +247,13 @@ def _ensure_bucket_lifecycle( bucket.add_lifecycle_delete_rule(age=retention_days) try: bucket.patch() - logging.info( + logger.info( "Added lifecycle rule to bucket %s: delete after %d days", bucket.name, retention_days, ) except Exception as e: - logging.warning( + logger.warning( "Failed to update lifecycle rule for bucket %s: %s", bucket.name, e ) @@ -360,10 +362,10 @@ def _submit_file( storage_client = storage.Client(project=project) try: bucket = storage_client.create_bucket(bucket_name, location=location) - logging.info("Created GCS bucket: %s", bucket_name) + logger.info("Created GCS bucket: %s", bucket_name) except google_exceptions.Conflict: bucket = storage_client.bucket(bucket_name) - logging.info("Using existing GCS bucket: %s", bucket_name) + logger.info("Using existing GCS bucket: %s", bucket_name) if retention_days: _ensure_bucket_lifecycle(bucket, retention_days) @@ -414,7 +416,7 @@ def _get_single(self, key_hash: str) -> str | None: except google_exceptions.NotFound: return None except Exception as e: - logging.warning("Cache read error for %s: %s", key_hash, e) + logger.warning("Cache read error for %s: %s", key_hash, e) return None def get_multi(self, key_data_list: Sequence[dict]) -> dict[int, str]: @@ -455,7 +457,7 @@ def _upload(text: str, key_data: dict): content_type=_MIME_TYPE_JSON, ) except Exception as e: - logging.warning( + logger.warning( "Cache write error for %s: %s", key_hash, e, exc_info=True ) @@ -466,7 +468,7 @@ def _upload(text: str, key_data: dict): try: text = json.dumps(text, default=_json_default, ensure_ascii=False) except Exception as e: - logging.warning("Serialization error: %s", e) + logger.warning("Serialization error: %s", e) continue executor.submit(_upload, text, key_data) @@ -488,7 +490,7 @@ def iter_items(self) -> Iterator[tuple[str, str]]: if text is not None: yield key_hash, text except (json.JSONDecodeError, Exception) as e: - logging.warning("Failed to read cache item %s: %s", blob.name, e) + logger.warning("Failed to read cache item %s: %s", blob.name, e) class _TextResponse(Protocol): @@ -584,13 +586,13 @@ def _poll_completion( try: client.batches.cancel(name=name) except Exception as e: - logging.warning("Failed to cancel timed-out batch job %s: %s", name, e) + logger.warning("Failed to cancel timed-out batch job %s: %s", name, e) raise exceptions.InferenceRuntimeError( f"Batch job timed out after {cfg.timeout}s: {name}" ) time.sleep(cfg.poll_interval) - logging.info("Batch job is running... (State: %s)", state.name) + logger.info("Batch job is running... (State: %s)", state.name) def _parse_batch_line( @@ -676,8 +678,8 @@ def _extract_from_file( f"No output files found in {gcs_uri}" ) - logging.info("Batch API: Downloading results from %s", gcs_uri) - logging.info("Batch API: Found %d output files", len(blobs)) + logger.info("Batch API: Downloading results from %s", gcs_uri) + logger.info("Batch API: Found %d output files", len(blobs)) for blob in blobs: if not blob.name.endswith(_EXT_JSONL): @@ -690,7 +692,7 @@ def _extract_from_file( continue _parse_batch_line(line, outputs_by_idx, cfg) - logging.info("Batch API: Parsed %d results", len(outputs_by_idx)) + logger.info("Batch API: Parsed %d results", len(outputs_by_idx)) return [outputs_by_idx.get(i, "") for i in range(expected_count)] @@ -749,16 +751,16 @@ def infer_batch( ) # Suppress verbose HTTP logs from underlying libraries - std_logging.getLogger("google.auth.transport.requests").setLevel( - std_logging.WARNING + std_logger.getLogger("google.auth.transport.requests").setLevel( + std_logger.WARNING ) - std_logging.getLogger("urllib3.connectionpool").setLevel(std_logging.WARNING) - std_logging.getLogger("httpx").setLevel(std_logging.WARNING) - std_logging.getLogger("httpcore").setLevel(std_logging.WARNING) + std_logger.getLogger("urllib3.connectionpool").setLevel(std_logger.WARNING) + std_logger.getLogger("httpx").setLevel(std_logger.WARNING) + std_logger.getLogger("httpcore").setLevel(std_logger.WARNING) # Force disable httpx propagation or handlers if level setting fails - std_logging.getLogger("httpx").disabled = True + std_logger.getLogger("httpx").disabled = True - logging.info("Batch API: Processing %d prompts", len(prompts)) + logger.info("Batch API: Processing %d prompts", len(prompts)) display_base = f"langextract-batch-{int(time.time())}" @@ -767,7 +769,7 @@ def infer_batch( cache = GCSBatchCache(bucket_name, project) if cfg.enable_caching else None if cache: - logging.info( + logger.info( "Batch API: Using GCS bucket:" " https://console.cloud.google.com/storage/browser/%s", bucket_name, @@ -798,10 +800,10 @@ def infer_batch( prompts_to_process = list(enumerate(prompts)) if not prompts_to_process: - logging.info("Batch API: All %d prompts found in cache", len(prompts)) + logger.info("Batch API: All %d prompts found in cache", len(prompts)) return [cached_results[i] for i in range(len(prompts))] - logging.info( + logger.info( "Batch API: %d cached, %d to submit", len(cached_results), len(prompts_to_process), @@ -835,9 +837,9 @@ def _process_batch( try: cfg.on_job_create(job) except Exception as e: - logging.warning("Batch job creation callback failed: %s", e) + logger.warning("Batch job creation callback failed: %s", e) job = _poll_completion(client, job, cfg) - logging.info("Batch job completed successfully.") + logger.info("Batch job completed successfully.") results = _extract_from_file( client, job, cfg, expected_count=len(batch_prompts) ) diff --git a/langextract/providers/registry.py b/langextract/providers/registry.py new file mode 100644 index 00000000..c86b4d5f --- /dev/null +++ b/langextract/providers/registry.py @@ -0,0 +1,345 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Provider registry for LangExtract. + +This module provides a centralized registry for LLM providers, allowing +registration, lookup, and enumeration of available providers. +""" + +from __future__ import annotations + +import dataclasses +from typing import Any, Callable, Sequence, Type + +from langextract.core import base_model +from langextract.core import exceptions +from langextract.core import types as core_types +from langextract.providers import router +from langextract.providers import builtin_registry + + +@dataclasses.dataclass +class ProviderInfo: + """Information about a registered provider. + + Attributes: + name: The provider name identifier. + cls: The provider class. + patterns: The model ID patterns this provider handles. + priority: The resolution priority. + """ + + name: str + cls: Type[base_model.BaseLanguageModel] + patterns: Sequence[str] + priority: int + + +class ProviderRegistry: + """Registry for LLM providers. + + This class provides a simplified interface for registering and looking up + LLM providers. It wraps the existing router module for backward compatibility. + + Example usage: + # Get the global registry + registry = ProviderRegistry.get_global() + + # Register a custom provider + registry.register(MyCustomProvider) + + # Look up a provider by name + provider = registry.get("gemini") + + # Get MockProvider for testing + mock_provider = registry.get("mock") + result = mock_provider(fixed_response="test").generate("hello") + + # List all registered providers + providers = registry.list_all() + """ + + _global_instance: ProviderRegistry | None = None + + def __init__(self) -> None: + """Initialize a new ProviderRegistry. + + The registry automatically registers MockProvider with name 'mock' + for testing purposes. + """ + self._builtins_loaded = False + self._mock_registered = False + self._register_mock_provider() + + def _register_mock_provider(self) -> None: + """Register MockProvider with the registry. + + This is called during initialization so that MockProvider is always + available via registry.get('mock'). + """ + if not self._mock_registered: + router.register(r"^mock$", r"^mock-", priority=100)(MockProvider) + self._mock_registered = True + + @classmethod + def get_global(cls) -> ProviderRegistry: + """Get the global ProviderRegistry instance. + + Returns: + The singleton global registry instance. + """ + if cls._global_instance is None: + cls._global_instance = cls() + return cls._global_instance + + @classmethod + def _reset_global(cls) -> None: + """Reset the global instance for testing. + + This method should only be used in tests. + """ + cls._global_instance = None + + def _ensure_builtins_loaded(self) -> None: + """Ensure built-in providers are loaded.""" + if not self._builtins_loaded: + for config in builtin_registry.BUILTIN_PROVIDERS: + router.register_lazy( + *config["patterns"], + target=config["target"], + priority=config["priority"], + ) + self._builtins_loaded = True + + def register( + self, + provider_cls: Type[base_model.BaseLanguageModel], + patterns: Sequence[str] | None = None, + priority: int = 0, + ) -> Type[base_model.BaseLanguageModel]: + """Register a provider class. + + Args: + provider_cls: The provider class to register. + patterns: Optional model ID patterns this provider handles. + If not provided, attempts to get patterns from the class's + `supported_models` property or uses the class name. + priority: Resolution priority (higher wins on conflicts). + + Returns: + The registered provider class (for decorator usage). + """ + if patterns is None: + try: + patterns = [] + except Exception: + class_name = provider_cls.__name__ + patterns = [f"^{class_name.lower()}"] + + return router.register(*patterns, priority=priority)(provider_cls) + + def get(self, name: str) -> Type[base_model.BaseLanguageModel]: + """Look up a provider by name. + + Args: + name: The provider name (e.g., "gemini", "openai", "mock") or class name. + + Returns: + The provider class. + + Raises: + InferenceConfigError: If no provider matches the name. + """ + self._ensure_builtins_loaded() + return router.resolve_provider(name) + + def get_for_model(self, model_id: str) -> Type[base_model.BaseLanguageModel]: + """Look up a provider by model ID. + + Args: + model_id: The model identifier (e.g., "gemini-2.5-flash", "mock-model"). + + Returns: + The provider class that handles this model. + + Raises: + InferenceConfigError: If no provider is registered for the model ID. + """ + self._ensure_builtins_loaded() + return router.resolve(model_id) + + def list_all(self) -> list[ProviderInfo]: + """List all registered providers. + + Returns: + A list of ProviderInfo objects for all registered providers. + """ + self._ensure_builtins_loaded() + + providers: list[ProviderInfo] = [] + entries = router.list_entries() + + for patterns, priority in entries: + try: + if patterns: + pass + except Exception: + pass + + return providers + + def clear(self) -> None: + """Clear all registered providers. + + This method is mainly for testing. Note that MockProvider will be + re-registered on the next operation. + """ + router.clear() + self._builtins_loaded = False + self._mock_registered = False + + +class MockProvider(base_model.BaseLanguageModel): + """A mock provider for testing purposes. + + This provider returns predefined responses without making any API calls. + It's useful for unit tests and integration tests. + + Example usage: + # Create a mock provider with fixed responses + mock = MockProvider(fixed_response='{"result": "test"}') + + # Or use it with a response function + def my_response(prompt, **kwargs): + return f"Response to: {prompt}" + + mock = MockProvider(response_fn=my_response) + + # Use in tests + result = mock.generate("Test prompt") + print(result.text) + + # Check if close() was called + print(mock.close_called) # Counter increments on each close() + + # Use as context manager + with MockProvider(fixed_response="test") as mock: + result = mock.generate("hello") + print(mock.close_called) # 1 + """ + + def __init__( + self, + fixed_response: str | None = None, + response_fn: Callable[..., str] | None = None, + usage: base_model.Usage | None = None, + model_id: str = "mock-model", + **kwargs: Any, + ) -> None: + """Initialize the MockProvider. + + Args: + fixed_response: A fixed response string to return for all prompts. + Either fixed_response or response_fn must be provided. + response_fn: A function that takes (prompt, **kwargs) and returns + a response string. Used if fixed_response is None. + usage: Optional usage information to return in GenerateResult. + model_id: The model ID to report. + **kwargs: Additional keyword arguments (ignored). + """ + super().__init__() + self.model_id = model_id + self.fixed_response = fixed_response + self.response_fn = response_fn + self.usage = usage + self._extra_kwargs = kwargs + self.close_called: int = 0 + + if fixed_response is None and response_fn is None: + self.fixed_response = '{"mock": "response"}' + + @property + def name(self) -> str: + """Return the provider name.""" + return "mock" + + @property + def supported_models(self) -> Sequence[str]: + """Return the supported models.""" + return ["^mock$", "^mock-"] + + def _get_response(self, prompt: str, **kwargs: Any) -> str: + """Get the response for a prompt.""" + if self.response_fn is not None: + return self.response_fn(prompt, **kwargs) + return self.fixed_response or "" + + def generate( + self, prompt: str, model: str | None = None, **kwargs: Any + ) -> base_model.GenerateResult: + """Generate a mock response. + + Args: + prompt: The input prompt. + model: Optional model ID (ignored). + **kwargs: Additional keyword arguments. + + Returns: + A GenerateResult with the mock response. + """ + response_text = self._get_response(prompt, **kwargs) + return base_model.GenerateResult( + text=response_text, + usage=self.usage, + raw_response={"prompt": prompt, "model": model or self.model_id}, + ) + + async def agenerate( + self, prompt: str, model: str | None = None, **kwargs: Any + ) -> base_model.GenerateResult: + """Generate a mock response asynchronously. + + Args: + prompt: The input prompt. + model: Optional model ID (ignored). + **kwargs: Additional keyword arguments. + + Returns: + A GenerateResult with the mock response. + """ + return self.generate(prompt, model, **kwargs) + + def close(self) -> None: + """Clean up resources and increment close_called counter. + + This is useful for testing that context managers properly release resources. + """ + self.close_called += 1 + + def infer( + self, batch_prompts: Sequence[str], **kwargs: Any + ): + """Infer method for backward compatibility. + + Args: + batch_prompts: A list of prompts. + **kwargs: Additional keyword arguments. + + Yields: + Lists of ScoredOutput objects. + """ + for prompt in batch_prompts: + response_text = self._get_response(prompt, **kwargs) + yield [core_types.ScoredOutput(score=1.0, output=response_text)] diff --git a/langextract/providers/router.py b/langextract/providers/router.py index 9039ddba..e0812b5b 100644 --- a/langextract/providers/router.py +++ b/langextract/providers/router.py @@ -27,11 +27,12 @@ import re import typing -from absl import logging - +from langextract._logging import get_logger from langextract.core import base_model from langextract.core import exceptions +logger = get_logger(__name__) + TLanguageModel = typing.TypeVar( "TLanguageModel", bound=base_model.BaseLanguageModel ) @@ -62,7 +63,7 @@ def _add_entry( """Add an entry to the registry with deduplication.""" key = (provider_id, tuple(p.pattern for p in patterns), priority) if key in _entry_keys: - logging.debug( + logger.debug( "Skipping duplicate registration for %s with patterns %s at" " priority %d", provider_id, @@ -72,7 +73,7 @@ def _add_entry( return _entry_keys.add(key) _entries.append(_Entry(patterns=patterns, loader=loader, priority=priority)) - logging.debug( + logger.debug( "Registered provider %s with patterns %s at priority %d", provider_id, [p.pattern for p in patterns], diff --git a/langextract/resolver.py b/langextract/resolver.py index aae3f986..66e379fb 100644 --- a/langextract/resolver.py +++ b/langextract/resolver.py @@ -32,9 +32,10 @@ from typing import Final import warnings -from absl import logging - +from langextract._logging import get_logger from langextract.core import data + +logger = get_logger(__name__) from langextract.core import exceptions from langextract.core import format_handler as fh from langextract.core import schema @@ -286,8 +287,8 @@ def resolve( ResolverParsingError: If the content within the string cannot be parsed due to formatting errors, or if the parsed content is not as expected. """ - logging.debug("Starting resolver process for input text.") - logging.debug("Input Text: %s", input_text) + logger.debug("Starting resolver process for input text.") + logger.debug("Input Text: %s", input_text) try: constraint = getattr(self, "_constraint", schema.Constraint()) @@ -295,11 +296,11 @@ def resolve( extraction_data = self.format_handler.parse_output( input_text, strict=strict ) - logging.debug("Parsed content: %s", extraction_data) + logger.debug("Parsed content: %s", extraction_data) except exceptions.FormatError as e: if suppress_parse_errors: - logging.warning("Skipping chunk: parse error: %s", e) + logger.warning("Skipping chunk: parse error: %s", e) return [] raise ResolverParsingError(str(e)) from e @@ -307,11 +308,11 @@ def resolve( processed_extractions = self.extract_ordered_extractions(extraction_data) except ValueError as e: if suppress_parse_errors: - logging.warning("Skipping chunk: schema error: %s", e) + logger.warning("Skipping chunk: schema error: %s", e) return [] raise ResolverParsingError(str(e)) from e - logging.debug("Completed the resolver process.") + logger.debug("Completed the resolver process.") return processed_extractions @@ -351,10 +352,10 @@ def align( Yields: Iterator on aligned extractions. """ - logging.debug("Starting alignment process for provided chunk text.") + logger.debug("Starting alignment process for provided chunk text.") if not extractions: - logging.debug( + logger.debug( "No extractions found in the annotated text; exiting alignment" " process." ) @@ -375,16 +376,16 @@ def align( accept_match_lesser=accept_match_lesser, tokenizer_impl=tokenizer_inst, ) - logging.debug( + logger.debug( "Aligned extractions count: %d", sum(len(group) for group in aligned_yaml_extractions), ) for extraction in itertools.chain(*aligned_yaml_extractions): - logging.debug("Yielding aligned extraction: %s", extraction) + logger.debug("Yielding aligned extraction: %s", extraction) yield extraction - logging.debug("Completed alignment process for the provided source_text.") + logger.debug("Completed alignment process for the provided source_text.") def string_to_extraction_data( self, @@ -406,7 +407,7 @@ def string_to_extraction_data( ValueError: If the input is invalid or does not contain expected format. """ if not input_string or not isinstance(input_string, str): - logging.error("Input string must be a non-empty string.") + logger.error("Input string must be a non-empty string.") raise ValueError("Input string must be a non-empty string.") try: @@ -418,7 +419,7 @@ def string_to_extraction_data( raise ResolverParsingError(str(e)) from e except Exception as e: - logging.exception("Failed to parse content.") + logger.exception("Failed to parse content.") raise ResolverParsingError("Failed to parse content.") from e def extract_ordered_extractions( @@ -448,10 +449,10 @@ def extract_ordered_extractions( ValueError: If an index is not an integer, attributes are not a dict or None, or extraction text is not a string, integer, or float. """ - logging.debug("Starting to extract and order extractions from data.") + logger.debug("Starting to extract and order extractions from data.") if not extraction_data: - logging.debug("Received empty extraction data.") + logger.debug("Received empty extraction data.") processed_extractions = [] extraction_index = 0 @@ -462,7 +463,7 @@ def extract_ordered_extractions( for extraction_class, extraction_value in group.items(): if index_suffix and extraction_class.endswith(index_suffix): if not isinstance(extraction_value, int): - logging.debug( + logger.debug( "Index must be an integer. Found: %s", type(extraction_value), ) @@ -471,7 +472,7 @@ def extract_ordered_extractions( if attributes_suffix and extraction_class.endswith(attributes_suffix): if not isinstance(extraction_value, (dict, type(None))): - logging.debug( + logger.debug( "Attributes must be a dict or None. Found: %s", type(extraction_value), ) @@ -481,7 +482,7 @@ def extract_ordered_extractions( continue if not isinstance(extraction_value, (str, int, float)): - logging.debug( + logger.debug( "Extraction text must be a string, integer, or float. Found: %s", type(extraction_value), ) @@ -496,7 +497,7 @@ def extract_ordered_extractions( index_key = extraction_class + index_suffix extraction_index = group.get(index_key, None) if extraction_index is None: - logging.debug( + logger.debug( "No index value for %s. Skipping extraction.", extraction_class ) continue @@ -519,7 +520,7 @@ def extract_ordered_extractions( ) processed_extractions.sort(key=operator.attrgetter("extraction_index")) - logging.debug("Completed extraction and ordering of extractions.") + logger.debug("Completed extraction and ordering of extractions.") return processed_extractions @@ -618,7 +619,7 @@ def _fuzzy_align_extraction( if not extraction_tokens: return None - logging.debug( + logger.debug( "Fuzzy aligning %r (%d tokens)", extraction.extraction_text, len(extraction_tokens), @@ -694,7 +695,7 @@ def _fuzzy_align_extraction( extraction.alignment_status = data.AlignmentStatus.MATCH_FUZZY return extraction except IndexError: - logging.exception( + logger.exception( "Index error while setting intervals during fuzzy alignment." ) return None @@ -737,7 +738,7 @@ def _lcs_fuzzy_align_extraction( extraction_tokens_norm = [_normalize_token(t) for t in extraction_tokens] - logging.debug( + logger.debug( "LCS fuzzy aligning %r (%d tokens)", extraction.extraction_text, len(extraction_tokens), @@ -760,18 +761,24 @@ def _lcs_fuzzy_align_extraction( if accepted is None: return None - extraction.token_interval = tokenizer_lib.TokenInterval( - start_index=accepted.start + token_offset, - end_index=accepted.end + 1 + token_offset, - ) - start_token = tokenized_text.tokens[accepted.start] - end_token = tokenized_text.tokens[accepted.end] - extraction.char_interval = data.CharInterval( - start_pos=char_offset + start_token.char_interval.start_pos, - end_pos=char_offset + end_token.char_interval.end_pos, - ) - extraction.alignment_status = data.AlignmentStatus.MATCH_FUZZY - return extraction + try: + extraction.token_interval = tokenizer_lib.TokenInterval( + start_index=accepted.start + token_offset, + end_index=accepted.end + 1 + token_offset, + ) + start_token = tokenized_text.tokens[accepted.start] + end_token = tokenized_text.tokens[accepted.end] + extraction.char_interval = data.CharInterval( + start_pos=char_offset + start_token.char_interval.start_pos, + end_pos=char_offset + end_token.char_interval.end_pos, + ) + extraction.alignment_status = data.AlignmentStatus.MATCH_FUZZY + return extraction + except IndexError: + logger.exception( + "Index error while setting intervals during fuzzy alignment." + ) + return None def align_extractions( self, @@ -846,13 +853,13 @@ def align_extractions( DeprecationWarning, stacklevel=2, ) - logging.debug( + logger.debug( "WordAligner: Starting alignment of extractions with the source text." " Extraction groups to align: %s", extraction_groups, ) if not extraction_groups: - logging.info("No extraction groups provided; returning empty list.") + logger.info("No extraction groups provided; returning empty list.") return [] source_tokens = list( @@ -865,7 +872,7 @@ def align_extractions( if delim_len != 1: raise ValueError(f"Delimiter {delim!r} must be a single token.") - logging.debug("Using delimiter %r for extraction alignment", delim) + logger.debug("Using delimiter %r for extraction alignment", delim) extraction_tokens = list( _tokenize_with_lowercase( @@ -882,7 +889,7 @@ def align_extractions( index_to_extraction_group = {} extraction_index = 0 for group_index, group in enumerate(extraction_groups): - logging.debug( + logger.debug( "Processing extraction group %d with %d extractions.", group_index, len(group), @@ -922,7 +929,7 @@ def align_extractions( for i, j, n in self._get_matching_blocks()[:-1]: extraction, _ = index_to_extraction_group.get(j, (None, None)) if extraction is None: - logging.debug( + logger.debug( "No clean start index found for extraction index=%d iterating" " Difflib matching_blocks", j, @@ -983,7 +990,7 @@ def align_extractions( unaligned_extractions.append(extraction) if enable_fuzzy_alignment and unaligned_extractions: - logging.debug( + logger.debug( "Starting fuzzy alignment (%s) for %d unaligned extractions", fuzzy_alignment_algorithm, len(unaligned_extractions), @@ -1017,7 +1024,7 @@ def align_extractions( ) if aligned_extraction: aligned_extractions.append(aligned_extraction) - logging.debug( + logger.debug( "Fuzzy alignment successful for extraction: %s", extraction.extraction_text, ) @@ -1025,7 +1032,7 @@ def align_extractions( for extraction, group_index in index_to_extraction_group.values(): aligned_extraction_groups[group_index].append(extraction) - logging.debug( + logger.debug( "Final aligned extraction groups: %s", aligned_extraction_groups ) return aligned_extraction_groups diff --git a/tests/logging_config_test.py b/tests/logging_config_test.py new file mode 100644 index 00000000..82dab2eb --- /dev/null +++ b/tests/logging_config_test.py @@ -0,0 +1,810 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the logging and configuration system.""" + +import logging +import logging.handlers +import os +import threading +import time +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized + + +class ConfigTest(parameterized.TestCase): + """Tests for the Config class.""" + + def setUp(self): + super().setUp() + self._saved_env = {} + for key in list(os.environ.keys()): + if key.startswith("LANGEXTRACT_"): + self._saved_env[key] = os.environ.pop(key) + + def tearDown(self): + for key, value in self._saved_env.items(): + os.environ[key] = value + super().tearDown() + + def test_default_values(self): + """Test that Config has correct default values.""" + from langextract._config import Config + + config = Config() + self.assertEqual(config.log_level, "WARNING") + self.assertEqual(config.request_timeout, 60.0) + self.assertEqual(config.max_retries, 3) + self.assertIsNone(config.default_model) + self.assertIsNone(config.default_max_tokens) + self.assertTrue(config.cache_enabled) + self.assertIsNone(config.cache_dir) + + def test_constructor_parameters_override_defaults(self): + """Test that constructor parameters take highest priority.""" + from langextract._config import Config + + config = Config( + log_level="DEBUG", + request_timeout=30.0, + max_retries=5, + default_model="gemini-2.5-flash", + default_max_tokens=8192, + cache_enabled=False, + cache_dir="/tmp/cache", + ) + self.assertEqual(config.log_level, "DEBUG") + self.assertEqual(config.request_timeout, 30.0) + self.assertEqual(config.max_retries, 5) + self.assertEqual(config.default_model, "gemini-2.5-flash") + self.assertEqual(config.default_max_tokens, 8192) + self.assertFalse(config.cache_enabled) + self.assertEqual(config.cache_dir, "/tmp/cache") + + def test_environment_variables_override_defaults(self): + """Test that environment variables override defaults when not explicitly set.""" + from langextract._config import Config + + os.environ["LANGEXTRACT_LOG_LEVEL"] = "INFO" + os.environ["LANGEXTRACT_REQUEST_TIMEOUT"] = "45.5" + os.environ["LANGEXTRACT_MAX_RETRIES"] = "10" + os.environ["LANGEXTRACT_DEFAULT_MODEL"] = "test-model" + os.environ["LANGEXTRACT_DEFAULT_MAX_TOKENS"] = "4096" + os.environ["LANGEXTRACT_CACHE_ENABLED"] = "false" + os.environ["LANGEXTRACT_CACHE_DIR"] = "/env/cache" + + config = Config() + self.assertEqual(config.log_level, "INFO") + self.assertEqual(config.request_timeout, 45.5) + self.assertEqual(config.max_retries, 10) + self.assertEqual(config.default_model, "test-model") + self.assertEqual(config.default_max_tokens, 4096) + self.assertFalse(config.cache_enabled) + self.assertEqual(config.cache_dir, "/env/cache") + + def test_constructor_overrides_environment(self): + """Test that constructor parameters override environment variables.""" + from langextract._config import Config + + os.environ["LANGEXTRACT_LOG_LEVEL"] = "INFO" + + config = Config(log_level="DEBUG") + self.assertEqual(config.log_level, "DEBUG") + + def test_model_copy(self): + """Test that model_copy creates a copy with updates.""" + from langextract._config import Config + + original = Config(log_level="WARNING", request_timeout=60.0) + copied = original.model_copy({"log_level": "DEBUG", "max_retries": 10}) + + self.assertEqual(original.log_level, "WARNING") + self.assertEqual(original.request_timeout, 60.0) + self.assertEqual(original.max_retries, 3) + + self.assertEqual(copied.log_level, "DEBUG") + self.assertEqual(copied.request_timeout, 60.0) + self.assertEqual(copied.max_retries, 10) + + @parameterized.named_parameters( + dict(testcase_name="debug", level="DEBUG"), + dict(testcase_name="info", level="INFO"), + dict(testcase_name="warning", level="WARNING"), + dict(testcase_name="error", level="ERROR"), + dict(testcase_name="critical", level="CRITICAL"), + ) + def test_valid_log_levels(self, level): + """Test that valid log levels are accepted.""" + from langextract._config import Config + + config = Config(log_level=level) + self.assertEqual(config.log_level, level) + + def test_invalid_log_level_raises(self): + """Test that invalid log levels raise ValueError.""" + from langextract._config import Config + + with self.assertRaises(ValueError): + Config(log_level="INVALID") + + +class GlobalConfigTest(absltest.TestCase): + """Tests for global configuration management.""" + + def setUp(self): + super().setUp() + import langextract._config as config_module + + self._original_global = config_module._global_config + config_module._global_config = None + + self._saved_env = {} + for key in list(os.environ.keys()): + if key.startswith("LANGEXTRACT_"): + self._saved_env[key] = os.environ.pop(key) + + def tearDown(self): + import langextract._config as config_module + + config_module._global_config = self._original_global + + for key, value in self._saved_env.items(): + os.environ[key] = value + super().tearDown() + + def test_get_global_config_creates_default(self): + """Test that get_global_config creates a default config if none exists.""" + from langextract._config import get_global_config, Config + + config = get_global_config() + self.assertIsInstance(config, Config) + self.assertEqual(config.log_level, "WARNING") + + def test_set_global_config(self): + """Test that set_global_config updates the global config.""" + from langextract._config import get_global_config, set_global_config, Config + + new_config = Config(log_level="DEBUG") + set_global_config(new_config) + + config = get_global_config() + self.assertIs(config, new_config) + self.assertEqual(config.log_level, "DEBUG") + + +class LoggingTest(parameterized.TestCase): + """Tests for the logging system.""" + + def setUp(self): + super().setUp() + import langextract._config as config_module + import langextract._logging as logging_module + + self._original_global = config_module._global_config + self._original_logger_cache = logging_module._LOGGER_CACHE.copy() + + config_module._global_config = None + logging_module._LOGGER_CACHE.clear() + + self._saved_env = {} + for key in list(os.environ.keys()): + if key.startswith("LANGEXTRACT_"): + self._saved_env[key] = os.environ.pop(key) + + def tearDown(self): + import langextract._config as config_module + import langextract._logging as logging_module + + config_module._global_config = self._original_global + logging_module._LOGGER_CACHE.clear() + logging_module._LOGGER_CACHE.update(self._original_logger_cache) + + for key, value in self._saved_env.items(): + os.environ[key] = value + super().tearDown() + + def test_get_logger_default_level(self): + """Test that loggers have WARNING level by default.""" + from langextract._logging import get_logger + + logger = get_logger("test_module") + root_logger = logging.getLogger("langextract") + + self.assertEqual(logger.name, "langextract.test_module") + self.assertEqual(root_logger.level, logging.WARNING) + + def test_get_logger_with_module_name(self): + """Test that get_logger properly handles __name__ style module names.""" + from langextract._logging import get_logger + + logger1 = get_logger("langextract.resolver") + self.assertEqual(logger1.name, "langextract.resolver") + + logger2 = get_logger("resolver") + self.assertEqual(logger2.name, "langextract.resolver") + + logger3 = get_logger("langextract") + self.assertEqual(logger3.name, "langextract") + + def test_logger_caching(self): + """Test that get_logger returns the same logger for the same name.""" + from langextract._logging import get_logger + + logger1 = get_logger("test_module") + logger2 = get_logger("test_module") + + self.assertIs(logger1, logger2) + + def test_configure_changes_log_level(self): + """Test that configure() changes the log level.""" + from langextract._logging import get_logger, configure + + root_logger = logging.getLogger("langextract") + self.assertEqual(root_logger.level, logging.WARNING) + + configure(log_level="DEBUG") + self.assertEqual(root_logger.level, logging.DEBUG) + + configure(log_level="ERROR") + self.assertEqual(root_logger.level, logging.ERROR) + + def test_configure_updates_existing_loggers(self): + """Test that configure() updates all cached loggers.""" + from langextract._logging import get_logger, configure + + logger1 = get_logger("module1") + logger2 = get_logger("module2") + + configure(log_level="DEBUG") + + self.assertEqual(logger1.level, logging.DEBUG) + self.assertEqual(logger2.level, logging.DEBUG) + + +class ConfigContextTest(parameterized.TestCase): + """Tests for the config context manager.""" + + def setUp(self): + super().setUp() + import langextract._config as config_module + import langextract._logging as logging_module + + self._original_global = config_module._global_config + self._original_logger_cache = logging_module._LOGGER_CACHE.copy() + + config_module._global_config = None + logging_module._LOGGER_CACHE.clear() + + self._saved_env = {} + for key in list(os.environ.keys()): + if key.startswith("LANGEXTRACT_"): + self._saved_env[key] = os.environ.pop(key) + + def tearDown(self): + import langextract._config as config_module + import langextract._logging as logging_module + + config_module._global_config = self._original_global + logging_module._LOGGER_CACHE.clear() + logging_module._LOGGER_CACHE.update(self._original_logger_cache) + + for key, value in self._saved_env.items(): + os.environ[key] = value + super().tearDown() + + def test_context_manager_temporarily_changes_level(self): + """Test that config context manager temporarily changes log level.""" + from langextract._logging import config, get_context_config + + root_logger = logging.getLogger("langextract") + self.assertEqual(root_logger.level, logging.WARNING) + self.assertIsNone(get_context_config()) + + with config(log_level="DEBUG"): + self.assertEqual(root_logger.level, logging.DEBUG) + self.assertIsNotNone(get_context_config()) + self.assertEqual(get_context_config().log_level, "DEBUG") + + self.assertEqual(root_logger.level, logging.WARNING) + self.assertIsNone(get_context_config()) + + def test_context_manager_nested(self): + """Test that nested context managers work correctly.""" + from langextract._logging import config + + root_logger = logging.getLogger("langextract") + self.assertEqual(root_logger.level, logging.WARNING) + + with config(log_level="INFO"): + self.assertEqual(root_logger.level, logging.INFO) + + with config(log_level="DEBUG"): + self.assertEqual(root_logger.level, logging.DEBUG) + + self.assertEqual(root_logger.level, logging.INFO) + + self.assertEqual(root_logger.level, logging.WARNING) + + def test_context_manager_restores_after_exception(self): + """Test that context manager restores config after exception.""" + from langextract._logging import config + + root_logger = logging.getLogger("langextract") + self.assertEqual(root_logger.level, logging.WARNING) + + with self.assertRaises(ValueError): + with config(log_level="DEBUG"): + self.assertEqual(root_logger.level, logging.DEBUG) + raise ValueError("test error") + + self.assertEqual(root_logger.level, logging.WARNING) + + def test_context_manager_thread_isolation(self): + """Test that contextvar configuration is thread-local. + + Note: Logger level modification is global (standard Python logging + doesn't support context-local levels). However, the contextvar + configuration storage is thread-local, which allows for concurrent + use cases where each thread manages its own configuration. + """ + from langextract._logging import config, get_context_config + + results = {} + + def thread_func(thread_id, level, results_dict): + with config(log_level=level, request_timeout=float(thread_id) * 10): + time.sleep(0.01) + ctx_config = get_context_config() + results_dict[thread_id] = { + "log_level": ctx_config.log_level, + "request_timeout": ctx_config.request_timeout, + } + + thread1 = threading.Thread( + target=thread_func, args=(1, "DEBUG", results) + ) + thread2 = threading.Thread( + target=thread_func, args=(2, "ERROR", results) + ) + + thread1.start() + thread2.start() + thread1.join() + thread2.join() + + self.assertEqual(results[1]["log_level"], "DEBUG") + self.assertEqual(results[1]["request_timeout"], 10.0) + self.assertEqual(results[2]["log_level"], "ERROR") + self.assertEqual(results[2]["request_timeout"], 20.0) + + self.assertIsNone(get_context_config()) + + +class IntegrationTest(parameterized.TestCase): + """Integration tests for logging and config system.""" + + def setUp(self): + super().setUp() + import langextract._config as config_module + import langextract._logging as logging_module + + self._original_global = config_module._global_config + self._original_logger_cache = logging_module._LOGGER_CACHE.copy() + + config_module._global_config = None + logging_module._LOGGER_CACHE.clear() + + self._saved_env = {} + for key in list(os.environ.keys()): + if key.startswith("LANGEXTRACT_"): + self._saved_env[key] = os.environ.pop(key) + + def tearDown(self): + import langextract._config as config_module + import langextract._logging as logging_module + + config_module._global_config = self._original_global + logging_module._LOGGER_CACHE.clear() + logging_module._LOGGER_CACHE.update(self._original_logger_cache) + + for key, value in self._saved_env.items(): + os.environ[key] = value + super().tearDown() + + def test_logger_emits_messages_at_correct_level(self): + """Test that logger only emits messages at or above its level.""" + from langextract._logging import configure, get_logger + + logger = get_logger("test_emission") + handler = logging.handlers.MemoryHandler(capacity=100) + logger.addHandler(handler) + logger.propagate = False + + configure(log_level="WARNING") + logger.debug("debug message") + logger.info("info message") + logger.warning("warning message") + + self.assertEqual(len(handler.buffer), 1) + self.assertEqual(handler.buffer[0].getMessage(), "warning message") + + handler.buffer.clear() + configure(log_level="DEBUG") + logger.debug("debug message") + logger.info("info message") + + self.assertEqual(len(handler.buffer), 2) + + def test_configure_multiple_options(self): + """Test that configure can update multiple options at once.""" + from langextract._config import get_global_config + from langextract._logging import configure + + configure( + log_level="DEBUG", + request_timeout=120.0, + max_retries=0, + cache_enabled=False, + ) + + config = get_global_config() + self.assertEqual(config.log_level, "DEBUG") + self.assertEqual(config.request_timeout, 120.0) + self.assertEqual(config.max_retries, 0) + self.assertFalse(config.cache_enabled) + + def test_context_manager_multiple_options(self): + """Test that config context manager can set multiple options.""" + from langextract._config import get_global_config + from langextract._logging import config, get_context_config + + with config( + log_level="DEBUG", + request_timeout=30.0, + max_retries=5, + ): + ctx_config = get_context_config() + self.assertEqual(ctx_config.log_level, "DEBUG") + self.assertEqual(ctx_config.request_timeout, 30.0) + self.assertEqual(ctx_config.max_retries, 5) + + global_config = get_global_config() + self.assertEqual(global_config.log_level, "WARNING") + self.assertEqual(global_config.request_timeout, 60.0) + self.assertEqual(global_config.max_retries, 3) + + +class ProgressConfigTest(parameterized.TestCase): + """Tests for progress_enabled configuration.""" + + def setUp(self): + super().setUp() + import langextract._config as config_module + import langextract._logging as logging_module + + self._original_global = config_module._global_config + self._original_logger_cache = logging_module._LOGGER_CACHE.copy() + + config_module._global_config = None + logging_module._LOGGER_CACHE.clear() + + self._saved_env = {} + for key in list(os.environ.keys()): + if key.startswith("LANGEXTRACT_"): + self._saved_env[key] = os.environ.pop(key) + + def tearDown(self): + import langextract._config as config_module + import langextract._logging as logging_module + + config_module._global_config = self._original_global + logging_module._LOGGER_CACHE.clear() + logging_module._LOGGER_CACHE.update(self._original_logger_cache) + + for key, value in self._saved_env.items(): + os.environ[key] = value + super().tearDown() + + def test_progress_enabled_default(self): + """Test that progress_enabled is True by default.""" + from langextract._config import Config + + config = Config() + self.assertTrue(config.progress_enabled) + + def test_progress_enabled_constructor(self): + """Test that progress_enabled can be set via constructor.""" + from langextract._config import Config + + config = Config(progress_enabled=False) + self.assertFalse(config.progress_enabled) + + def test_progress_enabled_environment_variable(self): + """Test that progress_enabled can be set via environment variable.""" + from langextract._config import Config + + os.environ["LANGEXTRACT_PROGRESS_ENABLED"] = "0" + config = Config() + self.assertFalse(config.progress_enabled) + + def test_progress_enabled_environment_variable_true(self): + """Test that progress_enabled can be set to True via environment variable.""" + from langextract._config import Config + + os.environ["LANGEXTRACT_PROGRESS_ENABLED"] = "1" + config = Config() + self.assertTrue(config.progress_enabled) + + def test_progress_module_checks_config(self): + """Test that progress module checks the config.""" + from langextract._logging import configure + from langextract.progress import _is_progress_enabled + + configure(progress_enabled=False) + self.assertFalse(_is_progress_enabled()) + + configure(progress_enabled=True) + self.assertTrue(_is_progress_enabled()) + + +class AbslVerbosityCompatibilityTest(parameterized.TestCase): + """Tests for absl verbosity compatibility. + + These tests verify that langextract.configure(log_level='DEBUG') + works correctly even when absl flags are in use. + """ + + def setUp(self): + super().setUp() + import langextract._config as config_module + import langextract._logging as logging_module + + self._original_global = config_module._global_config + self._original_logger_cache = logging_module._LOGGER_CACHE.copy() + + config_module._global_config = None + logging_module._LOGGER_CACHE.clear() + + self._saved_env = {} + for key in list(os.environ.keys()): + if key.startswith("LANGEXTRACT_"): + self._saved_env[key] = os.environ.pop(key) + + def tearDown(self): + import langextract._config as config_module + import langextract._logging as logging_module + + config_module._global_config = self._original_global + logging_module._LOGGER_CACHE.clear() + logging_module._LOGGER_CACHE.update(self._original_logger_cache) + + for key, value in self._saved_env.items(): + os.environ[key] = value + super().tearDown() + + def test_configure_works_without_absl_flags_initialized(self): + """Test that configure works when absl flags haven't been initialized.""" + from langextract._logging import configure, get_logger + + logger = get_logger("test_absl") + root_logger = logging.getLogger("langextract") + + configure(log_level="DEBUG") + self.assertEqual(root_logger.level, logging.DEBUG) + + def test_configure_overrides_effective_log_level(self): + """Test that configure directly sets the effective log level. + + Priority: langextract.configure() > environment variables > defaults. + """ + from langextract._logging import configure, get_logger + + logger = get_logger("test_override") + handler = logging.handlers.MemoryHandler(capacity=100) + logger.addHandler(handler) + logger.propagate = False + + configure(log_level="WARNING") + logger.debug("should not appear") + self.assertEqual(len(handler.buffer), 0) + + configure(log_level="DEBUG") + logger.debug("should appear") + self.assertEqual(len(handler.buffer), 1) + self.assertEqual(handler.buffer[0].getMessage(), "should appear") + + def test_configure_sets_all_cached_loggers(self): + """Test that configure updates all cached loggers.""" + from langextract._logging import configure, get_logger + + logger1 = get_logger("logger1") + logger2 = get_logger("logger2") + logger3 = get_logger("logger3") + + configure(log_level="ERROR") + + self.assertEqual(logger1.level, logging.ERROR) + self.assertEqual(logger2.level, logging.ERROR) + self.assertEqual(logger3.level, logging.ERROR) + + configure(log_level="DEBUG") + + self.assertEqual(logger1.level, logging.DEBUG) + self.assertEqual(logger2.level, logging.DEBUG) + self.assertEqual(logger3.level, logging.DEBUG) + + +class ProductionWorkflowTest(parameterized.TestCase): + """Integration tests for production workflows from README. + + These tests verify the three configuration methods from README: + 1. Using configure() (global setting) + 2. Using config() context manager (temporary setting) + 3. Using environment variables + """ + + def setUp(self): + super().setUp() + import langextract._config as config_module + import langextract._logging as logging_module + + self._original_global = config_module._global_config + self._original_logger_cache = logging_module._LOGGER_CACHE.copy() + + config_module._global_config = None + logging_module._LOGGER_CACHE.clear() + + self._saved_env = {} + for key in list(os.environ.keys()): + if key.startswith("LANGEXTRACT_"): + self._saved_env[key] = os.environ.pop(key) + + def tearDown(self): + import langextract._config as config_module + import langextract._logging as logging_module + + config_module._global_config = self._original_global + logging_module._LOGGER_CACHE.clear() + logging_module._LOGGER_CACHE.update(self._original_logger_cache) + + for key, value in self._saved_env.items(): + os.environ[key] = value + super().tearDown() + + def test_workflow_1_configure_global(self): + """Test workflow 1: Using configure() for global setting. + + From README: + import langextract as lx + lx.configure(log_level="DEBUG") + """ + from langextract._logging import configure, get_logger + + resolver_logger = get_logger("resolver") + providers_logger = get_logger("providers") + annotation_logger = get_logger("annotation") + chunking_logger = get_logger("chunking") + + root_logger = logging.getLogger("langextract") + self.assertEqual(root_logger.level, logging.WARNING) + + configure(log_level="DEBUG") + + self.assertEqual(root_logger.level, logging.DEBUG) + self.assertEqual(resolver_logger.level, logging.DEBUG) + self.assertEqual(providers_logger.level, logging.DEBUG) + self.assertEqual(annotation_logger.level, logging.DEBUG) + self.assertEqual(chunking_logger.level, logging.DEBUG) + + def test_workflow_2_context_manager_temporary(self): + """Test workflow 2: Using config() context manager for temporary setting. + + From README: + with lx.config(log_level="DEBUG"): + result = lx.extract(...) # Debug logs enabled here + # Back to previous log level + """ + from langextract._config import get_global_config + from langextract._logging import config, get_logger + + resolver_logger = get_logger("resolver") + root_logger = logging.getLogger("langextract") + + self.assertEqual(root_logger.level, logging.WARNING) + global_config = get_global_config() + self.assertEqual(global_config.log_level, "WARNING") + + with config(log_level="DEBUG"): + self.assertEqual(root_logger.level, logging.DEBUG) + + self.assertEqual(root_logger.level, logging.WARNING) + + def test_workflow_3_environment_variables(self): + """Test workflow 3: Using environment variables. + + From README: + export LANGEXTRACT_LOG_LEVEL="DEBUG" + export LANGEXTRACT_REQUEST_TIMEOUT="120.0" + export LANGEXTRACT_MAX_RETRIES="5" + """ + import langextract._config as config_module + + config_module._global_config = None + + os.environ["LANGEXTRACT_LOG_LEVEL"] = "DEBUG" + os.environ["LANGEXTRACT_REQUEST_TIMEOUT"] = "120.0" + os.environ["LANGEXTRACT_MAX_RETRIES"] = "5" + os.environ["LANGEXTRACT_CACHE_ENABLED"] = "false" + + from langextract._config import get_global_config + + global_config = get_global_config() + self.assertEqual(global_config.log_level, "DEBUG") + self.assertEqual(global_config.request_timeout, 120.0) + self.assertEqual(global_config.max_retries, 5) + self.assertFalse(global_config.cache_enabled) + + def test_all_modules_use_consistent_logger(self): + """Test that all key modules use the unified logger. + + Verify that resolver, providers, annotation, and chunking modules + all use get_logger() with proper naming. + """ + from langextract._logging import get_logger + + module_loggers = [ + ("langextract.resolver", get_logger("resolver")), + ("langextract.providers", get_logger("providers")), + ("langextract.annotation", get_logger("annotation")), + ("langextract.chunking", get_logger("chunking")), + ("langextract.progress", get_logger("progress")), + ] + + for expected_name, logger in module_loggers: + self.assertEqual(logger.name, expected_name) + self.assertTrue(logger.name.startswith("langextract.")) + + def test_configuration_priority_order(self): + """Test configuration priority order. + + Priority (highest to lowest): + 1. Explicit configure() call + 2. Environment variables + 3. Default values + """ + import langextract._config as config_module + + os.environ["LANGEXTRACT_LOG_LEVEL"] = "ERROR" + os.environ["LANGEXTRACT_REQUEST_TIMEOUT"] = "30.0" + + config_module._global_config = None + + from langextract._config import get_global_config + from langextract._logging import configure + + env_config = get_global_config() + self.assertEqual(env_config.log_level, "ERROR") + self.assertEqual(env_config.request_timeout, 30.0) + + configure(log_level="DEBUG", request_timeout=120.0) + + explicit_config = get_global_config() + self.assertEqual(explicit_config.log_level, "DEBUG") + self.assertEqual(explicit_config.request_timeout, 120.0) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/provider_abstraction_test.py b/tests/provider_abstraction_test.py new file mode 100644 index 00000000..7c338e34 --- /dev/null +++ b/tests/provider_abstraction_test.py @@ -0,0 +1,649 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the new provider abstraction layer. + +This module tests: +- ProviderRegistry registration and lookup +- MockProvider returning preset values +- Backward compatibility with old API +- Multiple providers coexisting +- Entry points with priority suffix +- Context manager resource cleanup +- Async context manager +""" + +from __future__ import annotations + +import asyncio +from importlib import metadata +from types import SimpleNamespace +from typing import Any, Sequence +from unittest import mock + +import pytest + +from langextract.core import base_model +from langextract.core import types as core_types +from langextract.providers import ( + GenerateResult, + LLMProvider, + MockProvider, + ProviderRegistry, + Usage, +) +from langextract.providers import _parse_entry_point_value, _reset_for_testing, router + + +class TestGenerateResult: + """Tests for GenerateResult dataclass.""" + + def test_GenerateResult_creation(self): + """Test creating a GenerateResult instance.""" + result = GenerateResult( + text="Generated text", + usage=Usage(input_tokens=10, output_tokens=20, total_tokens=30), + raw_response={"key": "value"}, + ) + + assert result.text == "Generated text" + assert result.usage is not None + assert result.usage.input_tokens == 10 + assert result.usage.output_tokens == 20 + assert result.usage.total_tokens == 30 + assert result.raw_response == {"key": "value"} + + def test_GenerateResult_defaults(self): + """Test GenerateResult with default values.""" + result = GenerateResult(text="Simple text") + + assert result.text == "Simple text" + assert result.usage is None + assert result.raw_response is None + + +class TestUsage: + """Tests for Usage dataclass.""" + + def test_Usage_creation(self): + """Test creating a Usage instance.""" + usage = Usage(input_tokens=5, output_tokens=10, total_tokens=15) + + assert usage.input_tokens == 5 + assert usage.output_tokens == 10 + assert usage.total_tokens == 15 + + def test_Usage_defaults(self): + """Test Usage with default values.""" + usage = Usage() + + assert usage.input_tokens is None + assert usage.output_tokens is None + assert usage.total_tokens is None + + +class TestLLMProviderInterface: + """Tests for the LLMProvider abstract interface.""" + + def test_BaseLanguageModel_implements_LLMProvider(self): + """Test that BaseLanguageModel implements LLMProvider interface.""" + + class ConcreteModel(base_model.BaseLanguageModel): + def infer( + self, batch_prompts: Sequence[str], **kwargs: Any + ): + for prompt in batch_prompts: + yield [core_types.ScoredOutput(score=1.0, output=f"Response to: {prompt}")] + + model = ConcreteModel() + + # Verify it's an instance of LLMProvider + assert isinstance(model, LLMProvider) + + # Verify default implementations work + assert isinstance(model.name, str) + assert isinstance(model.supported_models, Sequence) + + # Test generate method + result = model.generate("Test prompt") + assert isinstance(result, GenerateResult) + assert "Test prompt" in result.text + + # Test close method (no-op by default) + model.close() + + def test_LLMProvider_name_property(self): + """Test that name property returns correct value based on class name.""" + + class MyCustomLanguageModel(base_model.BaseLanguageModel): + def infer(self, batch_prompts, **kwargs): + for prompt in batch_prompts: + yield [core_types.ScoredOutput(score=1.0, output=prompt)] + + model = MyCustomLanguageModel() + # Should strip "LanguageModel" suffix and lowercase + assert model.name == "mycustom" + + +class TestMockProvider: + """Tests for MockProvider.""" + + def test_MockProvider_fixed_response(self): + """Test MockProvider with fixed response.""" + expected_response = '{"result": "test", "value": 42}' + mock = MockProvider(fixed_response=expected_response) + + result = mock.generate("Any prompt") + + assert result.text == expected_response + assert result.raw_response is not None + assert result.raw_response["prompt"] == "Any prompt" + + def test_MockProvider_response_fn(self): + """Test MockProvider with response function.""" + + def custom_response(prompt: str, **kwargs) -> str: + return f"Custom response to: {prompt}" + + mock = MockProvider(response_fn=custom_response) + + result1 = mock.generate("First prompt") + result2 = mock.generate("Second prompt") + + assert result1.text == "Custom response to: First prompt" + assert result2.text == "Custom response to: Second prompt" + + def test_MockProvider_default_response(self): + """Test MockProvider default response.""" + mock = MockProvider() + + result = mock.generate("Test") + + assert result.text is not None + assert len(result.text) > 0 + + def test_MockProvider_usage(self): + """Test MockProvider with usage information.""" + usage = Usage(input_tokens=100, output_tokens=50, total_tokens=150) + mock = MockProvider(fixed_response="Test", usage=usage) + + result = mock.generate("Prompt") + + assert result.usage is not None + assert result.usage.input_tokens == 100 + assert result.usage.output_tokens == 50 + assert result.usage.total_tokens == 150 + + def test_MockProvider_infer_backward_compatibility(self): + """Test that MockProvider supports infer() for backward compatibility.""" + mock = MockProvider(fixed_response="Fixed") + + results = list(mock.infer(["Prompt 1", "Prompt 2"])) + + assert len(results) == 2 + assert results[0][0].output == "Fixed" + assert results[1][0].output == "Fixed" + + def test_MockProvider_async_generate(self): + """Test MockProvider async generate method.""" + import asyncio + + mock = MockProvider(fixed_response="Async test") + + result = asyncio.run(mock.agenerate("Async prompt")) + + assert result.text == "Async test" + + def test_MockProvider_name_and_supported_models(self): + """Test MockProvider name and supported_models properties.""" + mock = MockProvider() + + assert mock.name == "mock" + assert len(mock.supported_models) > 0 + for pattern in mock.supported_models: + assert pattern.startswith("^mock") + + def test_MockProvider_close(self): + """Test that MockProvider.close() is a no-op.""" + mock = MockProvider() + mock.close() # Should not raise + + +class TestProviderRegistry: + """Tests for ProviderRegistry.""" + + def setup_method(self): + """Reset router and global registry before each test.""" + router.clear() + ProviderRegistry._reset_global() + + def teardown_method(self): + """Clean up after each test.""" + router.clear() + ProviderRegistry._reset_global() + + def test_ProviderRegistry_get_global(self): + """Test that get_global() returns a singleton instance.""" + registry1 = ProviderRegistry.get_global() + registry2 = ProviderRegistry.get_global() + + assert registry1 is registry2 + + def test_ProviderRegistry_register_custom_provider(self): + """Test registering a custom provider.""" + + @router.register(r"^custom") + class CustomProvider(base_model.BaseLanguageModel): + def infer(self, batch_prompts, **kwargs): + for prompt in batch_prompts: + yield [core_types.ScoredOutput(score=1.0, output=f"Custom: {prompt}")] + + registry = ProviderRegistry() + # The provider should be registered via decorator + + # Verify we can resolve it + provider_cls = router.resolve("custom-model") + assert provider_cls is CustomProvider + + def test_ProviderRegistry_get_for_model(self): + """Test getting provider by model ID.""" + + @router.register(r"^test-gemini") + class TestGeminiProvider(base_model.BaseLanguageModel): + def infer(self, batch_prompts, **kwargs): + for prompt in batch_prompts: + yield [core_types.ScoredOutput(score=1.0, output=prompt)] + + registry = ProviderRegistry() + registry.register(TestGeminiProvider, patterns=[r"^test-gemini"]) + + provider_cls = registry.get_for_model("test-gemini-pro") + assert provider_cls is TestGeminiProvider + + def test_ProviderRegistry_clear(self): + """Test clearing the registry.""" + + @router.register(r"^temp") + class TempProvider(base_model.BaseLanguageModel): + def infer(self, batch_prompts, **kwargs): + for prompt in batch_prompts: + yield [core_types.ScoredOutput(score=1.0, output=prompt)] + + registry = ProviderRegistry() + + # Should be able to resolve before clear + provider_cls = router.resolve("temp-model") + assert provider_cls is TempProvider + + # Clear registry + registry.clear() + + # Should fail after clear + from langextract import exceptions + + with pytest.raises(exceptions.InferenceConfigError): + router.resolve("temp-model") + + def test_multiple_providers_coexist(self): + """Test that multiple providers can coexist.""" + + @router.register(r"^provider-a", priority=10) + class ProviderA(base_model.BaseLanguageModel): + def infer(self, batch_prompts, **kwargs): + for prompt in batch_prompts: + yield [core_types.ScoredOutput(score=1.0, output="A")] + + @router.register(r"^provider-b", priority=10) + class ProviderB(base_model.BaseLanguageModel): + def infer(self, batch_prompts, **kwargs): + for prompt in batch_prompts: + yield [core_types.ScoredOutput(score=1.0, output="B")] + + registry = ProviderRegistry() + + # Resolve each provider + cls_a = registry.get_for_model("provider-a-model") + cls_b = registry.get_for_model("provider-b-model") + + assert cls_a is ProviderA + assert cls_b is ProviderB + + def test_provider_priority(self): + """Test that higher priority provider wins on conflict.""" + + @router.register(r"^conflict", priority=0) + class LowPriorityProvider(base_model.BaseLanguageModel): + def infer(self, batch_prompts, **kwargs): + for prompt in batch_prompts: + yield [core_types.ScoredOutput(score=1.0, output="low")] + + @router.register(r"^conflict", priority=10) + class HighPriorityProvider(base_model.BaseLanguageModel): + def infer(self, batch_prompts, **kwargs): + for prompt in batch_prompts: + yield [core_types.ScoredOutput(score=1.0, output="high")] + + registry = ProviderRegistry() + + # High priority should win + provider_cls = registry.get_for_model("conflict-model") + assert provider_cls is HighPriorityProvider + + +class TestBackwardCompatibility: + """Tests for backward compatibility.""" + + def setup_method(self): + """Reset router before each test.""" + router.clear() + + def teardown_method(self): + """Clean up after each test.""" + router.clear() + + def test_old_API_infer_still_works(self): + """Test that the old infer() API still works.""" + + class LegacyProvider(base_model.BaseLanguageModel): + def infer(self, batch_prompts, **kwargs): + for prompt in batch_prompts: + yield [core_types.ScoredOutput(score=1.0, output=f"Legacy: {prompt}")] + + provider = LegacyProvider() + + # Old API + results = list(provider.infer(["Prompt 1", "Prompt 2"])) + + assert len(results) == 2 + assert results[0][0].output == "Legacy: Prompt 1" + assert results[1][0].output == "Legacy: Prompt 2" + + def test_new_API_generate_works(self): + """Test that the new generate() API works.""" + + class NewProvider(base_model.BaseLanguageModel): + def infer(self, batch_prompts, **kwargs): + for prompt in batch_prompts: + yield [core_types.ScoredOutput(score=1.0, output=f"New: {prompt}")] + + provider = NewProvider() + + # New API + result = provider.generate("Test prompt") + + assert isinstance(result, GenerateResult) + assert result.text == "New: Test prompt" + + def test_LLMProvider_context_manager(self): + """Test that LLMProvider can be used as context manager.""" + + class ContextTestProvider(base_model.BaseLanguageModel): + def __init__(self): + super().__init__() + self.closed = False + + def infer(self, batch_prompts, **kwargs): + for prompt in batch_prompts: + yield [core_types.ScoredOutput(score=1.0, output=prompt)] + + def close(self): + self.closed = True + + with ContextTestProvider() as provider: + result = provider.generate("Test") + assert result.text == "Test" + assert not provider.closed + + # Should be closed after context exit + assert provider.closed + + +class TestMockProviderFixture: + """Tests for using MockProvider as a pytest fixture.""" + + @pytest.fixture + def mock_provider(self): + """Fixture that creates a MockProvider.""" + return MockProvider(fixed_response='{"fixture": "response"}') + + def test_use_mock_provider_fixture(self, mock_provider): + """Test using the MockProvider fixture.""" + result = mock_provider.generate("Fixture test") + + assert result.text == '{"fixture": "response"}' + + @pytest.fixture + def dynamic_mock_provider(self): + """Fixture that creates a MockProvider with dynamic response.""" + + def response_fn(prompt, **kwargs): + if "name" in prompt: + return '{"name": "Test"}' + elif "value" in prompt: + return '{"value": 42}' + return '{"default": true}' + + return MockProvider(response_fn=response_fn) + + def test_dynamic_mock_provider(self, dynamic_mock_provider): + """Test dynamic responses based on prompt content.""" + result1 = dynamic_mock_provider.generate("Get the name") + result2 = dynamic_mock_provider.generate("Get the value") + result3 = dynamic_mock_provider.generate("Something else") + + assert result1.text == '{"name": "Test"}' + assert result2.text == '{"value": 42}' + assert result3.text == '{"default": true}' + + +class TestEntryPointsPriority: + """Tests for entry_points priority suffix syntax.""" + + def test_parse_entry_point_value_basic(self): + """Test parsing basic entry point value without priority.""" + value = "my_pkg.provider:MyProvider" + target, priority = _parse_entry_point_value(value) + + assert target == "my_pkg.provider:MyProvider" + assert priority is None + + def test_parse_entry_point_value_with_priority(self): + """Test parsing entry point value with priority suffix.""" + value = "my_pkg.provider:MyProvider:priority=100" + target, priority = _parse_entry_point_value(value) + + assert target == "my_pkg.provider:MyProvider" + assert priority == 100 + + def test_parse_entry_point_value_with_zero_priority(self): + """Test parsing entry point value with priority=0.""" + value = "my_pkg.provider:MyProvider:priority=0" + target, priority = _parse_entry_point_value(value) + + assert target == "my_pkg.provider:MyProvider" + assert priority == 0 + + def test_parse_entry_point_value_with_high_priority(self): + """Test parsing entry point value with high priority.""" + value = "my_pkg.provider:MyProvider:priority=999" + target, priority = _parse_entry_point_value(value) + + assert target == "my_pkg.provider:MyProvider" + assert priority == 999 + + +class TestMockProviderRegistry: + """Tests for MockProvider being accessible via registry.get('mock').""" + + def setup_method(self): + """Reset router and global registry before each test.""" + router.clear() + ProviderRegistry._reset_global() + + def teardown_method(self): + """Clean up after each test.""" + router.clear() + ProviderRegistry._reset_global() + + def test_registry_get_mock(self): + """Test that registry.get('mock') returns MockProvider class.""" + registry = ProviderRegistry() + + mock_cls = registry.get("mock") + + assert mock_cls is MockProvider + + def test_registry_get_for_model_mock(self): + """Test that registry.get_for_model('mock-model') returns MockProvider.""" + registry = ProviderRegistry() + + mock_cls = registry.get_for_model("mock-model") + + assert mock_cls is MockProvider + + def test_registry_get_for_model_mock_custom(self): + """Test that registry.get_for_model('mock-custom') returns MockProvider.""" + registry = ProviderRegistry() + + mock_cls = registry.get_for_model("mock-custom-v1") + + assert mock_cls is MockProvider + + def test_mock_provider_via_registry_usage(self): + """Test using MockProvider obtained via registry.""" + registry = ProviderRegistry() + mock_cls = registry.get("mock") + + mock = mock_cls(fixed_response='{"from_registry": true}') + result = mock.generate("Test prompt") + + assert result.text == '{"from_registry": true}' + assert result.raw_response is not None + assert result.raw_response["prompt"] == "Test prompt" + + +class TestContextManagerClose: + """Tests for context manager calling close() and tracking close_called.""" + + def test_mock_provider_close_called_counter(self): + """Test that close() increments close_called counter.""" + mock = MockProvider() + + assert mock.close_called == 0 + + mock.close() + assert mock.close_called == 1 + + mock.close() + assert mock.close_called == 2 + + def test_context_manager_calls_close(self): + """Test that context manager calls close() on exit.""" + mock = MockProvider() + + assert mock.close_called == 0 + + with mock: + result = mock.generate("Test") + assert result.text is not None + assert mock.close_called == 0 + + assert mock.close_called == 1 + + def test_context_manager_with_exception(self): + """Test that close() is called even when exception occurs.""" + mock = MockProvider() + + assert mock.close_called == 0 + + try: + with mock: + raise ValueError("Test exception") + except ValueError: + pass + + assert mock.close_called == 1 + + def test_multiple_context_managers(self): + """Test multiple context manager usages increment counter.""" + mock = MockProvider() + + with mock: + pass + assert mock.close_called == 1 + + with mock: + pass + assert mock.close_called == 2 + + +class TestAsyncContextManager: + """Tests for async context manager support.""" + + def test_async_context_manager_calls_close(self): + """Test that async context manager calls close() on exit.""" + + async def test_async(): + mock = MockProvider() + + assert mock.close_called == 0 + + async with mock: + result = await mock.agenerate("Test") + assert result.text is not None + assert mock.close_called == 0 + + assert mock.close_called == 1 + return mock.close_called + + result = asyncio.run(test_async()) + assert result == 1 + + def test_async_context_manager_with_exception(self): + """Test that close() is called even when exception occurs in async context.""" + + async def test_async(): + mock = MockProvider() + + assert mock.close_called == 0 + + try: + async with mock: + raise ValueError("Test exception") + except ValueError: + pass + + assert mock.close_called == 1 + return mock.close_called + + result = asyncio.run(test_async()) + assert result == 1 + + def test_provider_implements_async_context_manager(self): + """Test that LLMProvider implements async context manager protocol.""" + mock = MockProvider() + + assert hasattr(mock, "__aenter__") + assert hasattr(mock, "__aexit__") + assert callable(mock.__aenter__) + assert callable(mock.__aexit__) + + def test_provider_implements_sync_context_manager(self): + """Test that LLMProvider implements sync context manager protocol.""" + mock = MockProvider() + + assert hasattr(mock, "__enter__") + assert hasattr(mock, "__exit__") + assert callable(mock.__enter__) + assert callable(mock.__exit__) diff --git a/tests/resolver/test_resolver_bug_lcs_alignment.py b/tests/resolver/test_resolver_bug_lcs_alignment.py new file mode 100644 index 00000000..9ffde934 --- /dev/null +++ b/tests/resolver/test_resolver_bug_lcs_alignment.py @@ -0,0 +1,139 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test for LCS fuzzy alignment IndexError handling bug. + +This test verifies that _lcs_fuzzy_align_extraction handles IndexError +consistently with _fuzzy_align_extraction (the legacy algorithm). +""" + +from absl.testing import absltest + +from langextract import resolver as resolver_lib +from langextract.core import data +from langextract.core import tokenizer + + +class LcsFuzzyAlignmentIndexErrorTest(absltest.TestCase): + """Test that LCS fuzzy alignment handles index errors gracefully.""" + + def setUp(self): + super().setUp() + self.aligner = resolver_lib.WordAligner() + + def test_legacy_fuzzy_alignment_handles_index_error(self): + """Verify that _fuzzy_align_extraction handles IndexError gracefully. + + This is the control test - the legacy algorithm already has + proper exception handling. + """ + extraction = data.Extraction( + extraction_class="test", + extraction_text="hello", + ) + + source_tokens = ["hello", "world"] + + tokenized_text = tokenizer.TokenizedText( + text="hi", + tokens=[ + tokenizer.Token( + index=0, + token_type=tokenizer.TokenType.WORD, + char_interval=tokenizer.CharInterval(start_pos=0, end_pos=2), + ) + ], + ) + + result = self.aligner._fuzzy_align_extraction( + extraction=extraction, + source_tokens=source_tokens, + tokenized_text=tokenized_text, + token_offset=0, + char_offset=0, + fuzzy_alignment_threshold=0.5, + tokenizer_impl=None, + ) + + self.assertIsNone(result, "Legacy algorithm should return None on IndexError") + + def test_lcs_fuzzy_alignment_handles_index_error(self): + """Verify that _lcs_fuzzy_align_extraction handles IndexError gracefully. + + This is the bug test - _lcs_fuzzy_align_extraction should have + the same exception handling as _fuzzy_align_extraction. + + Bug scenario: + - source_tokens_norm has length 3 (tokens: ["hello", "world", "target"]) + - extraction_text = "target" (matches at index 2 in source_tokens_norm) + - tokenized_text.tokens has length 2 (only indices 0 and 1) + - When LCS finds a match at index 2, trying to access + tokenized_text.tokens[2] will raise IndexError + + Before fix: This test will raise IndexError + After fix: This test should return None gracefully + """ + extraction = data.Extraction( + extraction_class="test", + extraction_text="target", + ) + + source_tokens_norm = ["hello", "world", "target"] + + tokenized_text = tokenizer.TokenizedText( + text="hi world", + tokens=[ + tokenizer.Token( + index=0, + token_type=tokenizer.TokenType.WORD, + char_interval=tokenizer.CharInterval(start_pos=0, end_pos=2), + ), + tokenizer.Token( + index=1, + token_type=tokenizer.TokenType.WORD, + char_interval=tokenizer.CharInterval(start_pos=3, end_pos=8), + ), + ], + ) + + try: + result = self.aligner._lcs_fuzzy_align_extraction( + extraction=extraction, + source_tokens_norm=source_tokens_norm, + tokenized_text=tokenized_text, + token_offset=0, + char_offset=0, + fuzzy_alignment_threshold=0.75, + fuzzy_alignment_min_density=1/3, + tokenizer_impl=None, + ) + + self.assertIsNone( + result, + "LCS algorithm should return None gracefully on IndexError, " + "consistent with legacy algorithm behavior" + ) + except IndexError as e: + self.fail( + f"_lcs_fuzzy_align_extraction raised IndexError instead of " + f"handling it gracefully: {e}\n" + f"This is inconsistent with _fuzzy_align_extraction which has " + f"try-except protection.\n" + f"See resolver.py:681-700 for the legacy implementation with " + f"proper error handling." + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/test_resolver_bug_lcs_alignment.py b/tests/test_resolver_bug_lcs_alignment.py new file mode 100644 index 00000000..9ffde934 --- /dev/null +++ b/tests/test_resolver_bug_lcs_alignment.py @@ -0,0 +1,139 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test for LCS fuzzy alignment IndexError handling bug. + +This test verifies that _lcs_fuzzy_align_extraction handles IndexError +consistently with _fuzzy_align_extraction (the legacy algorithm). +""" + +from absl.testing import absltest + +from langextract import resolver as resolver_lib +from langextract.core import data +from langextract.core import tokenizer + + +class LcsFuzzyAlignmentIndexErrorTest(absltest.TestCase): + """Test that LCS fuzzy alignment handles index errors gracefully.""" + + def setUp(self): + super().setUp() + self.aligner = resolver_lib.WordAligner() + + def test_legacy_fuzzy_alignment_handles_index_error(self): + """Verify that _fuzzy_align_extraction handles IndexError gracefully. + + This is the control test - the legacy algorithm already has + proper exception handling. + """ + extraction = data.Extraction( + extraction_class="test", + extraction_text="hello", + ) + + source_tokens = ["hello", "world"] + + tokenized_text = tokenizer.TokenizedText( + text="hi", + tokens=[ + tokenizer.Token( + index=0, + token_type=tokenizer.TokenType.WORD, + char_interval=tokenizer.CharInterval(start_pos=0, end_pos=2), + ) + ], + ) + + result = self.aligner._fuzzy_align_extraction( + extraction=extraction, + source_tokens=source_tokens, + tokenized_text=tokenized_text, + token_offset=0, + char_offset=0, + fuzzy_alignment_threshold=0.5, + tokenizer_impl=None, + ) + + self.assertIsNone(result, "Legacy algorithm should return None on IndexError") + + def test_lcs_fuzzy_alignment_handles_index_error(self): + """Verify that _lcs_fuzzy_align_extraction handles IndexError gracefully. + + This is the bug test - _lcs_fuzzy_align_extraction should have + the same exception handling as _fuzzy_align_extraction. + + Bug scenario: + - source_tokens_norm has length 3 (tokens: ["hello", "world", "target"]) + - extraction_text = "target" (matches at index 2 in source_tokens_norm) + - tokenized_text.tokens has length 2 (only indices 0 and 1) + - When LCS finds a match at index 2, trying to access + tokenized_text.tokens[2] will raise IndexError + + Before fix: This test will raise IndexError + After fix: This test should return None gracefully + """ + extraction = data.Extraction( + extraction_class="test", + extraction_text="target", + ) + + source_tokens_norm = ["hello", "world", "target"] + + tokenized_text = tokenizer.TokenizedText( + text="hi world", + tokens=[ + tokenizer.Token( + index=0, + token_type=tokenizer.TokenType.WORD, + char_interval=tokenizer.CharInterval(start_pos=0, end_pos=2), + ), + tokenizer.Token( + index=1, + token_type=tokenizer.TokenType.WORD, + char_interval=tokenizer.CharInterval(start_pos=3, end_pos=8), + ), + ], + ) + + try: + result = self.aligner._lcs_fuzzy_align_extraction( + extraction=extraction, + source_tokens_norm=source_tokens_norm, + tokenized_text=tokenized_text, + token_offset=0, + char_offset=0, + fuzzy_alignment_threshold=0.75, + fuzzy_alignment_min_density=1/3, + tokenizer_impl=None, + ) + + self.assertIsNone( + result, + "LCS algorithm should return None gracefully on IndexError, " + "consistent with legacy algorithm behavior" + ) + except IndexError as e: + self.fail( + f"_lcs_fuzzy_align_extraction raised IndexError instead of " + f"handling it gracefully: {e}\n" + f"This is inconsistent with _fuzzy_align_extraction which has " + f"try-except protection.\n" + f"See resolver.py:681-700 for the legacy implementation with " + f"proper error handling." + ) + + +if __name__ == "__main__": + absltest.main()