diff --git a/.ai/AGENTS.md b/.ai/AGENTS.md index 92312ae6be06..fcfec4d82571 100644 --- a/.ai/AGENTS.md +++ b/.ai/AGENTS.md @@ -59,6 +59,22 @@ Do not raise PRs without human validation. - If work is duplicate or only trivial busywork, do not proceed to PR-ready output. - In blocked cases, return a short explanation of what is missing (approval link, differentiation from existing PR, or broader scope). +## Learning transformers primitives by example + +The `src/transformers/cli/agentic/` directory contains concise, self-contained +examples of how to use the core transformers primitives (`AutoModel`, +`AutoTokenizer`, `AutoProcessor`, `AutoImageProcessor`, etc.) for a wide +range of tasks — text classification, NER, QA, summarization, translation, +image classification, object detection, segmentation, depth estimation, +speech recognition, audio classification, text-to-speech, video +classification, visual QA, captioning, OCR, and more. + +Each file (`text.py`, `vision.py`, `audio.py`, `multimodal.py`) follows the +same pattern: load a model and processor with `from_pretrained`, preprocess +inputs, run a forward pass or `generate`, and post-process the outputs. If +you need to write code that uses transformers and are unsure how to get +started, read the relevant command in that folder first. + ## Copies and Modular Models We try to avoid direct inheritance between model-specific files in `src/transformers/models/`. We have two mechanisms to manage the resulting code duplication: diff --git a/src/transformers/cli/agentic/README.md b/src/transformers/cli/agentic/README.md new file mode 100644 index 000000000000..311d74744041 --- /dev/null +++ b/src/transformers/cli/agentic/README.md @@ -0,0 +1,510 @@ +# Agentic CLI for Transformers + +Single-command access to all major Transformers use-cases. Designed for AI +agents and humans who need to run inference, training, quantization, export, +and model inspection **without writing Python scripts**. + +Every command below is available as `transformers `. Run +`transformers --help` for full option documentation. + +## How it works + +The module integrates with the main CLI through a single function call in +`transformers.py` — removing it disables everything with no side effects. + +``` +src/transformers/cli/agentic/ +├── app.py # register_agentic_commands(app) — the single integration point +├── _common.py # Shared helpers (input resolution, output formatting, media loaders, model loading) +├── text.py # Text inference (classify, NER, QA, summarize, translate, fill-mask) +├── vision.py # Vision & video (image-classify, detect, segment, depth, keypoints, video-classify) +├── audio.py # Audio (transcribe, audio-classify, speak, audio-generate) +├── multimodal.py # Multimodal (VQA, document-QA, caption, OCR, multimodal-chat) +├── generate.py # Text generation with streaming, decoding control, tool calling +├── train.py # Fine-tuning / pretraining via Trainer +├── quantize.py # Model quantization (BnB, GPTQ, AWQ) +├── export.py # Model export (ONNX, GGUF, ExecuTorch) +└── utilities.py # Embeddings, tokenization, model inspection, benchmarking +``` + +## Common options + +Every inference command supports: + +| Option | Description | +|--------|-------------| +| `--model` / `-m` | Model ID (Hub) or local path | +| `--device` | `cpu`, `cuda`, `cuda:0`, `mps` | +| `--dtype` | `auto`, `float16`, `bfloat16`, `float32` | +| `--trust-remote-code` | Trust custom model code from the Hub | +| `--token` | HF Hub token for gated/private models | +| `--revision` | Model revision (branch, tag, SHA) | +| `--json` | Machine-readable JSON output | + +Text commands also accept `--file` to read input from a file, or stdin +via pipe (`echo "hello" | transformers classify`). + +## Commands + +### Text Inference + +1. Classify text into categories (supervised) + ```bash + transformers classify --model distilbert/distilbert-base-uncased-finetuned-sst-2-english --text "Great movie!" + ``` + +2. Classify text into arbitrary categories without training (zero-shot) + ```bash + transformers classify --text "The stock market crashed today." --labels "politics,finance,sports" + ``` + +3. Extract named entities from text (NER) + ```bash + transformers ner --model dslim/bert-base-NER --text "Apple CEO Tim Cook met with President Biden in Washington." + ``` + +4. Tag tokens with labels (POS tagging, chunking) + ```bash + transformers token-classify --model vblagoje/bert-english-uncased-finetuned-pos --text "The cat sat on the mat." + ``` + +5. Answer a question given a context paragraph (extractive QA) + ```bash + transformers qa --question "Who invented the telephone?" --context "Alexander Graham Bell invented the telephone in 1876." + ``` + +6. Answer a question about tabular data + ```bash + transformers table-qa --question "What is the total revenue?" --table financials.csv + ``` + +7. Summarize text + ```bash + transformers summarize --model facebook/bart-large-cnn --file article.txt + ``` + +8. Translate text between languages + ```bash + transformers translate --model Helsinki-NLP/opus-mt-en-de --text "The weather is nice today." + ``` + +9. Fill in masked tokens in a sentence + ```bash + transformers fill-mask --model answerdotai/ModernBERT-base --text "The capital of France is [MASK]." + ``` + +### Text Generation + +10. Generate text from a prompt + ```bash + transformers generate --model meta-llama/Llama-3.2-1B-Instruct --prompt "Once upon a time" + ``` + +11. Stream text generation token-by-token + ```bash + transformers generate --model meta-llama/Llama-3.2-1B-Instruct --prompt "Hello" --stream + ``` + +12. Generate with sampling (temperature, top-p, top-k) + ```bash + transformers generate --prompt "The future of AI" --temperature 0.7 --top-p 0.9 + ``` + +13. Generate with beam search + ```bash + transformers generate --prompt "Translate this:" --num-beams 4 + ``` + +14. Run speculative decoding with a draft model + ```bash + transformers generate --model meta-llama/Llama-3.1-8B-Instruct --assistant-model meta-llama/Llama-3.2-1B-Instruct --prompt "Explain gravity." + ``` + +15. Generate with tool/function calling + ```bash + transformers generate --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is the weather?" --tools tools.json + ``` + +16. Generate with constrained JSON output + ```bash + transformers generate --prompt "List 3 items as JSON:" --grammar json + ``` + +17. Watermark generated text + ```bash + transformers generate --model meta-llama/Llama-3.2-1B-Instruct --prompt "Write an essay." --watermark + ``` + +18. Detect whether text was watermarked + ```bash + transformers detect-watermark --model meta-llama/Llama-3.2-1B-Instruct --text "The generated essay text..." + ``` + +19. Generate with a quantized model (4-bit) + ```bash + transformers generate --model meta-llama/Llama-3.1-8B-Instruct --prompt "Hello" --quantization bnb-4bit + ``` + +20. Generate with quantized KV cache for long context + ```bash + transformers generate --model meta-llama/Llama-3.1-8B-Instruct --prompt "Summarize this long text..." --cache-quantization 4bit + ``` + +### Vision + +21. Classify an image into categories + ```bash + transformers image-classify --model google/vit-base-patch16-224 --image photo.jpg + ``` + +22. Classify an image into arbitrary categories without training (zero-shot) + ```bash + transformers image-classify --model google/siglip-base-patch16-224 --image photo.jpg --labels "cat,dog,bird,fish" + ``` + +23. Detect objects in an image with bounding boxes + ```bash + transformers detect --model PekingU/rtdetr_r18vd_coco_o365 --image street.jpg + ``` + +24. Detect objects from a text description (grounded detection) + ```bash + transformers detect --model IDEA-Research/grounding-dino-base --image kitchen.jpg --text "red mug on the counter" + ``` + +25. Segment an image by class (semantic segmentation) + ```bash + transformers segment --model nvidia/segformer-b0-finetuned-ade-512-512 --image scene.jpg + ``` + +26. Generate segmentation masks interactively (SAM-style) + ```bash + transformers segment --model facebook/sam-vit-base --image photo.jpg --points "[[120,45]]" --point-labels "[1]" + ``` + +27. Estimate depth from a single image + ```bash + transformers depth --model depth-anything/Depth-Anything-V2-Small-hf --image room.jpg --output depth_map.png + ``` + +28. Detect and match keypoints across an image pair + ```bash + transformers keypoints --model magic-leap-community/superglue --images img1.jpg --images img2.jpg + ``` + +29. Extract feature vectors from an image + ```bash + transformers embed --model facebook/dinov2-small --image photo.jpg --output features.npy + ``` + +### Audio + +30. Transcribe speech to text + ```bash + transformers transcribe --model openai/whisper-small --audio recording.wav + ``` + +31. Transcribe speech with word-level timestamps + ```bash + transformers transcribe --model openai/whisper-small --audio recording.wav --timestamps true --json + ``` + +32. Classify an audio clip into categories + ```bash + transformers audio-classify --model MIT/ast-finetuned-audioset-10-10-0.4593 --audio clip.wav + ``` + +33. Classify audio into arbitrary categories without training (zero-shot) + ```bash + transformers audio-classify --model laion/clap-htsat-unfused --audio clip.wav --labels "speech,music,noise,silence" + ``` + +34. Generate speech from text (text-to-speech) + ```bash + transformers speak --model suno/bark-small --text "Hello, how are you today?" --output speech.wav + ``` + +35. Generate audio from a text description (music, sound effects) + ```bash + transformers audio-generate --model facebook/musicgen-small --text "A calm piano melody" --output music.wav + ``` + +### Video + +36. Classify a video clip into categories + ```bash + transformers video-classify --model MCG-NJU/videomae-base-finetuned-kinetics --video clip.mp4 + ``` + +### Multimodal + +37. Answer a question about an image (visual QA) + ```bash + transformers vqa --model vikhyatk/moondream2 --image chart.png --question "What is the trend shown?" + ``` + +38. Answer a question about a document image (document QA) + ```bash + transformers document-qa --model impira/layoutlm-document-qa --image invoice.png --question "What is the total amount?" + ``` + +39. Generate a caption for an image + ```bash + transformers caption --model vikhyatk/moondream2 --image sunset.jpg + ``` + +40. Extract text from a document image (OCR) + ```bash + transformers ocr --model vikhyatk/moondream2 --image receipt.png + ``` + +41. Single-turn conversation with mixed inputs (image, audio, text) + ```bash + transformers multimodal-chat --model meta-llama/Llama-4-Scout-17B-16E-Instruct --prompt "Describe what you see and hear." --image photo.jpg --audio clip.wav + ``` + +### Training + +42. Fine-tune a text classification model + ```bash + transformers train text-classification --model bert-base-uncased --dataset glue/sst2 --output ./sst2-finetuned --epochs 3 --lr 2e-5 + ``` + +43. Fine-tune a token classification model (NER) + ```bash + transformers train token-classification --model bert-base-uncased --dataset conll2003 --output ./ner-finetuned --epochs 5 + ``` + +44. Fine-tune a question answering model + ```bash + transformers train question-answering --model bert-base-uncased --dataset squad --output ./qa-finetuned --epochs 2 + ``` + +45. Fine-tune a summarization model + ```bash + transformers train summarization --model t5-small --dataset cnn_dailymail --output ./summarizer --epochs 3 + ``` + +46. Fine-tune a translation model + ```bash + transformers train translation --model t5-small --dataset wmt16/de-en --output ./translator + ``` + +47. Continued pretraining on a domain-specific corpus + ```bash + transformers train language-modeling --model bert-base-uncased --dataset ./corpus.txt --output ./domain-bert --mlm + ``` + +48. Fine-tune an LLM with LoRA + ```bash + transformers train text-generation --model meta-llama/Llama-3.2-1B --dataset ./instructions.jsonl --output ./llama-lora --lora --lora-r 16 + ``` + +49. Fine-tune a 4-bit quantized LLM with QLoRA + ```bash + transformers train text-generation --model meta-llama/Llama-3.1-8B --dataset ./instructions.jsonl --output ./llama-qlora --lora --quantization bnb-4bit + ``` + +50. Pretrain a language model from scratch + ```bash + transformers train language-modeling --model-config gpt2 --dataset ./corpus.txt --output ./my-lm --from-scratch + ``` + +51. Fine-tune an image classification model + ```bash + transformers train image-classification --model google/vit-base-patch16-224 --dataset food101 --output ./food-classifier --epochs 5 + ``` + +52. Fine-tune an object detection model + ```bash + transformers train object-detection --model facebook/detr-resnet-50 --dataset cppe-5 --output ./detector --epochs 10 + ``` + +53. Fine-tune a segmentation model + ```bash + transformers train semantic-segmentation --model nvidia/segformer-b0-finetuned-ade-512-512 --dataset scene_parse_150 --output ./segmenter + ``` + +54. Fine-tune an ASR model on domain-specific audio + ```bash + transformers train speech-recognition --model openai/whisper-small --dataset ./medical-audio/ --output ./medical-whisper --epochs 5 + ``` + +55. Fine-tune an audio classification model + ```bash + transformers train audio-classification --model MIT/ast-finetuned-audioset-10-10-0.4593 --dataset superb/ks --output ./audio-classifier + ``` + +56. Run hyperparameter search with Optuna + ```bash + transformers train text-classification --model bert-base-uncased --dataset glue/sst2 --output ./hpo-run --hpo optuna --hpo-trials 20 + ``` + +57. Resume training from a checkpoint + ```bash + transformers train text-classification --model bert-base-uncased --dataset glue/sst2 --output ./sst2-finetuned --resume-from-checkpoint ./sst2-finetuned/checkpoint-500 + ``` + +58. Train with early stopping + ```bash + transformers train text-classification --model bert-base-uncased --dataset glue/sst2 --output ./sst2-finetuned --early-stopping --early-stopping-patience 3 + ``` + +59. Evaluate periodically during training + ```bash + transformers train text-classification --model bert-base-uncased --dataset glue/sst2 --output ./sst2-finetuned --eval-strategy steps --eval-steps 100 + ``` + +### Distributed & Large-Scale Training + +60. Train across multiple GPUs on a single machine + ```bash + transformers train text-generation --model meta-llama/Llama-3.2-1B --dataset ./data.jsonl --output ./multi-gpu --multi-gpu + ``` + +61. Train across multiple nodes + ```bash + transformers train text-generation --model meta-llama/Llama-3.2-1B --dataset ./data.jsonl --output ./multi-node --nnodes 4 + ``` + +62. Train with DeepSpeed ZeRO + ```bash + transformers train text-generation --model meta-llama/Llama-3.1-8B --dataset ./data.jsonl --output ./deepspeed-run --deepspeed zero3 + ``` + +63. Train with FSDP + ```bash + transformers train text-generation --model meta-llama/Llama-3.1-8B --dataset ./data.jsonl --output ./fsdp-run --fsdp full-shard + ``` + +64. Train on TPUs + ```bash + transformers train text-classification --model bert-base-uncased --dataset glue/sst2 --output ./tpu-run --device tpu + ``` + +65. Train on Apple Silicon (MPS) + ```bash + transformers train text-classification --model bert-base-uncased --dataset glue/sst2 --output ./mps-run --device mps + ``` + +66. Train with mixed precision + ```bash + transformers train text-generation --model meta-llama/Llama-3.2-1B --dataset ./data.jsonl --output ./bf16-run --dtype bf16 + ``` + +67. Train with gradient checkpointing + ```bash + transformers train text-generation --model meta-llama/Llama-3.1-8B --dataset ./data.jsonl --output ./gc-run --gradient-checkpointing + ``` + +68. Train with gradient accumulation + ```bash + transformers train text-generation --model meta-llama/Llama-3.2-1B --dataset ./data.jsonl --output ./ga-run --gradient-accumulation-steps 8 + ``` + +### Quantization + +69. Quantize a model to 4-bit + ```bash + transformers quantize --model meta-llama/Llama-3.1-8B --method bnb-4bit --output ./llama-4bit + ``` + +70. Quantize a model to 8-bit + ```bash + transformers quantize --model meta-llama/Llama-3.1-8B --method bnb-8bit --output ./llama-8bit + ``` + +71. Run GPTQ quantization with calibration data + ```bash + transformers quantize --model meta-llama/Llama-3.1-8B --method gptq --calibration-dataset wikitext --output ./llama-gptq + ``` + +72. Run AWQ quantization + ```bash + transformers quantize --model meta-llama/Llama-3.1-8B --method awq --output ./llama-awq + ``` + +73. Compare quality across quantization methods + ```bash + transformers benchmark-quantization --model meta-llama/Llama-3.1-8B --methods none,bnb-4bit,bnb-8bit --json + ``` + +### Export + +74. Export a model to ONNX + ```bash + transformers export onnx --model bert-base-uncased --output ./bert-onnx/ + ``` + +75. Convert a model to GGUF for llama.cpp + ```bash + transformers export gguf --model meta-llama/Llama-3.2-1B --output llama-1b.gguf + ``` + +76. Export a model to ExecuTorch for mobile/edge + ```bash + transformers export executorch --model distilbert-base-uncased --output ./model.pte + ``` + +### Utilities + +77. Compute text embeddings + ```bash + transformers embed --model BAAI/bge-small-en-v1.5 --text "The quick brown fox." --output embeddings.npy + ``` + +78. Tokenize text and display tokens + ```bash + transformers tokenize --model meta-llama/Llama-3.2-1B-Instruct --text "Hello, world!" --ids + ``` + +79. Inspect a model's configuration (no weight download) + ```bash + transformers inspect meta-llama/Llama-3.2-1B-Instruct --json + ``` + +80. Examine attention weights and hidden states + ```bash + transformers inspect-forward --model bert-base-uncased --text "The cat sat on the mat." --output ./activations/ + ``` + +## Traditional CLI Commands + +These commands ship alongside the agentic commands and are available via the same `transformers` entry point. + +81. Start an OpenAI-compatible inference server (chat completions, audio, images) + ```bash + transformers serve --host 0.0.0.0 --port 8000 + ``` + Pass `--force-model` to pin a model for all requests, `--continuous-batching` for + throughput-oriented deployments, and `--quantization bnb-4bit` for memory-constrained + hardware. + +82. Open an interactive chat session with a model (local or remote) + ```bash + transformers chat meta-llama/Llama-3.2-1B-Instruct + ``` + Connect to a running `transformers serve` instance: + ```bash + transformers chat meta-llama/Llama-3.2-1B-Instruct http://localhost:8000/v1 + ``` + +83. Download a model and its tokenizer from the Hub to the local cache + ```bash + transformers download meta-llama/Llama-3.2-1B-Instruct + ``` + +84. Print environment and dependency information (useful for bug reports) + ```bash + transformers env + ``` + +85. Print the installed Transformers version + ```bash + transformers version + ``` + +86. Scaffold a new model by copying an existing one + ```bash + transformers add-new-model-like + ``` diff --git a/src/transformers/cli/agentic/__init__.py b/src/transformers/cli/agentic/__init__.py new file mode 100644 index 000000000000..dde03c44ce02 --- /dev/null +++ b/src/transformers/cli/agentic/__init__.py @@ -0,0 +1,44 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Agentic CLI for Transformers — single-command access to all major use-cases. + +This package adds ~30 CLI commands to ``transformers``, covering inference +(text, vision, audio, video, multimodal), training, quantization, export, +and model inspection. Every command is designed to be invoked by an AI agent +or a human with no Python scripting required. + +Integration with the main CLI is minimal: ``app.py`` exposes a single +``register_agentic_commands(app)`` function that is called from +``transformers.cli.transformers``. Removing that one call disables the +entire module. + +Quick reference — run ``transformers --help`` for any command:: + + # Inference + transformers classify --text "Great movie!" + transformers generate --model meta-llama/Llama-3.2-1B-Instruct --prompt "Hello" --stream + transformers transcribe --model openai/whisper-small --audio recording.wav + + # Training + transformers train text-classification --model bert-base-uncased --dataset glue/sst2 --output ./out + + # Quantization & export + transformers quantize --model meta-llama/Llama-3.1-8B --method bnb-4bit --output ./out + transformers export onnx --model bert-base-uncased --output ./bert-onnx/ + + # Utilities + transformers inspect meta-llama/Llama-3.2-1B-Instruct + transformers tokenize --model meta-llama/Llama-3.2-1B-Instruct --text "Hello, world!" +""" diff --git a/src/transformers/cli/agentic/_common.py b/src/transformers/cli/agentic/_common.py new file mode 100644 index 000000000000..0e042ae27519 --- /dev/null +++ b/src/transformers/cli/agentic/_common.py @@ -0,0 +1,198 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Shared helpers used by all agentic CLI commands. + +These are internal utilities — not CLI commands themselves. They handle input +resolution (--text / --file / stdin), output formatting, media loading +(images, audio, video), model loading, and shared CLI option types. +""" + +import json +import sys +from pathlib import Path +from typing import Annotated, Any + +import typer + + +ModelOpt = Annotated[str | None, typer.Option("--model", "-m", help="Model ID or local path.")] +DeviceOpt = Annotated[str | None, typer.Option(help="Device to run on (e.g. 'cpu', 'cuda', 'cuda:0', 'mps').")] +DtypeOpt = Annotated[str, typer.Option(help="Dtype for model weights ('auto', 'float16', 'bfloat16', 'float32').")] +TrustOpt = Annotated[bool, typer.Option(help="Trust remote code from the Hub.")] +TokenOpt = Annotated[str | None, typer.Option(help="HF Hub token for gated/private models.")] +RevisionOpt = Annotated[str | None, typer.Option(help="Model revision (branch, tag, or commit SHA).")] +JsonOpt = Annotated[bool, typer.Option("--json", help="Output results as JSON.")] + + +def _load_pretrained(model_cls, processor_cls, model_id, device, dtype, trust_remote_code, token, revision): + """Load a model and its processor/tokenizer with the common CLI options.""" + import torch + + common_kwargs = {} + if trust_remote_code: + common_kwargs["trust_remote_code"] = True + if token: + common_kwargs["token"] = token + if revision: + common_kwargs["revision"] = revision + + model_kwargs = {**common_kwargs} + if device and device != "cpu": + model_kwargs["device_map"] = device + elif device is None: + model_kwargs["device_map"] = "auto" + if dtype != "auto": + model_kwargs["torch_dtype"] = getattr(torch, dtype) + + processor = processor_cls.from_pretrained(model_id, **common_kwargs) + model = model_cls.from_pretrained(model_id, **model_kwargs) + model.eval() + return model, processor + + +def resolve_input(text: str | None = None, file: str | None = None) -> str: + """ + Return text from one of three sources, in priority order: + + 1. ``--text "..."`` — inline string + 2. ``--file path`` — read from a file + 3. stdin — piped input (e.g. ``echo "hello" | transformers classify``) + + Raises ``SystemExit`` if none of the three are provided. + """ + if text is not None: + return text + if file is not None: + return Path(file).read_text() + if not sys.stdin.isatty(): + return sys.stdin.read() + raise SystemExit("Error: provide --text, --file, or pipe input via stdin.") + + +def format_output(result: Any, output_json: bool = False) -> str: + """ + Format pipeline output for display. + + When ``output_json=True``, returns a JSON string (useful for agents that + need to parse results programmatically). Otherwise, returns a + human-readable multi-line string. + """ + if output_json: + return json.dumps(result, indent=2, default=str) + + if isinstance(result, list): + lines = [] + for item in result: + if isinstance(item, dict): + lines.append(" ".join(f"{k}: {v}" for k, v in item.items())) + elif isinstance(item, list): + for sub in item: + if isinstance(sub, dict): + lines.append(" ".join(f"{k}: {v}" for k, v in sub.items())) + else: + lines.append(str(sub)) + else: + lines.append(str(item)) + return "\n".join(lines) + + if isinstance(result, dict): + return "\n".join(f"{k}: {v}" for k, v in result.items()) + + return str(result) + + +def load_image(path: str): + """ + Load an image from a local file path or a URL. + + Returns a PIL Image. Requires ``Pillow`` (``pip install Pillow``). + For URLs, also requires ``requests``. + """ + from PIL import Image + + if path.startswith("http://") or path.startswith("https://"): + import requests + + return Image.open(requests.get(path, stream=True).raw) + return Image.open(path) + + +def load_video(path: str, num_frames: int = 16): + """ + Load video frames uniformly sampled from a video file. + + Tries ``decord`` first, then falls back to ``av``. Returns a list of + PIL Images. + """ + import numpy as np + from PIL import Image + + try: + from decord import VideoReader, cpu + + vr = VideoReader(path, ctx=cpu(0)) + indices = np.linspace(0, len(vr) - 1, num_frames, dtype=int) + frames = vr.get_batch(indices).asnumpy() + return [Image.fromarray(f) for f in frames] + except ImportError: + pass + + try: + import av + + container = av.open(path) + total = container.streams.video[0].frames or 1000 + step = max(1, total // num_frames) + frames = [] + for i, frame in enumerate(container.decode(video=0)): + if i % step == 0: + frames.append(frame.to_image()) + if len(frames) >= num_frames: + break + container.close() + return frames + except ImportError: + raise SystemExit( + "Video loading requires 'decord' or 'av'.\nInstall with: pip install decord (or) pip install av" + ) + + +def load_audio(path: str, sampling_rate: int = 16000): + """ + Load an audio file, resampling to ``sampling_rate`` Hz. + + Tries ``librosa`` first (supports resampling). Falls back to + ``soundfile`` if librosa is not installed, but will error if the + file's sample rate doesn't match the target. + """ + import numpy as np + + try: + import librosa + + audio, _ = librosa.load(path, sr=sampling_rate) + return audio + except ImportError: + import soundfile as sf + + audio, sr = sf.read(path) + if sr != sampling_rate: + raise SystemExit( + f"Audio sample rate is {sr} but model expects {sampling_rate}. " + "Install librosa (`pip install librosa`) for automatic resampling." + ) + if audio.ndim > 1: + audio = audio.mean(axis=1) + return audio.astype(np.float32) diff --git a/src/transformers/cli/agentic/app.py b/src/transformers/cli/agentic/app.py new file mode 100644 index 000000000000..a8b5b3a6ac6a --- /dev/null +++ b/src/transformers/cli/agentic/app.py @@ -0,0 +1,72 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Register all agentic CLI commands on a Typer app. + +This is the single integration point between the agentic CLI and the +main ``transformers`` CLI. It exposes one function: + + ``register_agentic_commands(app)`` + +which adds ~30 commands to the given Typer app. The main CLI calls this +from ``transformers.cli.transformers``. Removing that one call disables +the entire agentic module with no other changes required. +""" + +from .audio import audio_classify, audio_generate, speak, transcribe +from .export import export +from .generate import detect_watermark, generate +from .multimodal import caption, document_qa, multimodal_chat, ocr, vqa +from .quantize import quantize +from .text import classify, fill_mask, ner, qa, summarize, table_qa, token_classify, translate +from .train import train +from .utilities import benchmark_quantization, embed, inspect, inspect_forward, tokenize +from .vision import depth, detect, image_classify, keypoints, segment, video_classify + + +def register_agentic_commands(app): + """Register all agentic CLI commands on the given Typer app instance.""" + app.command()(classify) + app.command()(ner) + app.command(name="token-classify")(token_classify) + app.command()(qa) + app.command(name="table-qa")(table_qa) + app.command()(summarize) + app.command()(translate) + app.command(name="fill-mask")(fill_mask) + app.command(name="image-classify")(image_classify) + app.command()(detect) + app.command()(segment) + app.command()(depth) + app.command()(keypoints) + app.command(name="video-classify")(video_classify) + app.command()(transcribe) + app.command(name="audio-classify")(audio_classify) + app.command()(speak) + app.command(name="audio-generate")(audio_generate) + app.command()(vqa) + app.command(name="document-qa")(document_qa) + app.command()(caption) + app.command()(ocr) + app.command(name="multimodal-chat")(multimodal_chat) + app.command()(generate) + app.command(name="detect-watermark")(detect_watermark) + app.command()(embed) + app.command()(tokenize) + app.command(name="inspect")(inspect) + app.command(name="inspect-forward")(inspect_forward) + app.command(name="benchmark-quantization")(benchmark_quantization) + app.command()(train) + app.command()(quantize) + app.command()(export) diff --git a/src/transformers/cli/agentic/audio.py b/src/transformers/cli/agentic/audio.py new file mode 100644 index 000000000000..92c8b0f32c72 --- /dev/null +++ b/src/transformers/cli/agentic/audio.py @@ -0,0 +1,304 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Audio CLI commands for the transformers agentic CLI. + +Each function uses Auto* model classes directly (no pipeline) and is +registered as a top-level ``transformers`` CLI command via ``app.py``. +""" + +from typing import Annotated + +import typer + +from ._common import ( + DeviceOpt, + DtypeOpt, + JsonOpt, + ModelOpt, + RevisionOpt, + TokenOpt, + TrustOpt, + _load_pretrained, + format_output, + load_audio, +) + + +def transcribe( + audio: Annotated[str, typer.Option(help="Path or URL to the audio file.")], + model: ModelOpt = None, + timestamps: Annotated[str | None, typer.Option(help="Enable timestamp prediction (e.g. 'true').")] = None, + language: Annotated[str | None, typer.Option(help="Language code for transcription (e.g. 'en', 'fr').")] = None, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """ + Transcribe speech from an audio file. + + Uses ``AutoModelForSpeechSeq2Seq`` and ``AutoProcessor`` to load a + speech-to-text model and produce a transcription. + + Examples:: + + transformers transcribe --audio recording.wav + transformers transcribe --audio recording.wav --language fr --json + transformers transcribe --audio recording.wav --timestamps true + """ + + from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor + + model_id = model or "openai/whisper-small" + loaded_model, processor = _load_pretrained( + AutoModelForSpeechSeq2Seq, + AutoProcessor, + model_id, + device, + dtype, + trust_remote_code, + token, + revision, + ) + + audio_data = load_audio(audio, sampling_rate=processor.feature_extractor.sampling_rate) + input_features = processor( + audio_data, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt" + ).input_features + + if hasattr(loaded_model, "device"): + input_features = input_features.to(loaded_model.device) + + gen_kwargs = {} + if timestamps is not None: + gen_kwargs["return_timestamps"] = True + if language is not None: + gen_kwargs["language"] = language + + output_ids = loaded_model.generate(input_features, **gen_kwargs) + transcription = processor.batch_decode(output_ids, skip_special_tokens=True)[0] + + if output_json: + print(format_output({"text": transcription}, output_json=True)) + else: + print(transcription) + + +def audio_classify( + audio: Annotated[str, typer.Option(help="Path or URL to the audio file.")], + labels: Annotated[ + str | None, typer.Option(help="Comma-separated candidate labels for zero-shot audio classification.") + ] = None, + model: ModelOpt = None, + top_k: Annotated[int | None, typer.Option(help="Number of top predictions to return.")] = None, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """ + Classify an audio file into categories. + + Without ``--labels``, uses ``AutoModelForAudioClassification`` and + ``AutoFeatureExtractor`` with a fine-tuned classification model. + With ``--labels``, uses ``AutoModel`` and ``AutoProcessor`` for + zero-shot classification via CLAP. + + Examples:: + + transformers audio-classify --audio sound.wav + transformers audio-classify --audio sound.wav --labels "dog,cat,bird" --json + transformers audio-classify --audio sound.wav --top-k 3 + """ + import torch + + if labels is None: + from transformers import AutoFeatureExtractor, AutoModelForAudioClassification + + model_id = model or "MIT/ast-finetuned-audioset-10-10-0.4593" + loaded_model, feature_extractor = _load_pretrained( + AutoModelForAudioClassification, + AutoFeatureExtractor, + model_id, + device, + dtype, + trust_remote_code, + token, + revision, + ) + + sr = feature_extractor.sampling_rate + audio_data = load_audio(audio, sampling_rate=sr) + inputs = feature_extractor(audio_data, sampling_rate=sr, return_tensors="pt") + + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + with torch.no_grad(): + logits = loaded_model(**inputs).logits + + probs = torch.softmax(logits, dim=-1)[0] + k = top_k or 5 + top_probs, top_indices = torch.topk(probs, min(k, probs.size(0))) + + result = [ + {"label": loaded_model.config.id2label[idx.item()], "score": round(prob.item(), 4)} + for prob, idx in zip(top_probs, top_indices) + ] + else: + from transformers import AutoModel, AutoProcessor + + model_id = model or "laion/clap-htsat-unfused" + loaded_model, processor = _load_pretrained( + AutoModel, + AutoProcessor, + model_id, + device, + dtype, + trust_remote_code, + token, + revision, + ) + + sr = processor.feature_extractor.sampling_rate + audio_data = load_audio(audio, sampling_rate=sr) + candidate_labels = [lbl.strip() for lbl in labels.split(",")] + inputs = processor( + audios=audio_data, text=candidate_labels, return_tensors="pt", padding=True, sampling_rate=sr + ) + + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + with torch.no_grad(): + outputs = loaded_model(**inputs) + + probs = outputs.logits_per_audio[0].softmax(dim=-1) + result = [ + {"label": candidate_labels[i], "score": round(probs[i].item(), 4)} for i in range(len(candidate_labels)) + ] + result.sort(key=lambda x: x["score"], reverse=True) + + print(format_output(result, output_json)) + + +def speak( + text: Annotated[str, typer.Option(help="Text to synthesize into speech.")], + output: Annotated[str, typer.Option(help="Output WAV file path.")], + model: ModelOpt = None, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, +): + """ + Synthesize speech from text and save to a WAV file. + + Uses ``AutoModelForTextToWaveform`` and ``AutoProcessor`` to generate + audio from the given text prompt. + + Examples:: + + transformers speak --text "Hello world" --output hello.wav + transformers speak --text "Bonjour le monde" --output bonjour.wav --model suno/bark-small + """ + import scipy.io.wavfile + + from transformers import AutoModelForTextToWaveform, AutoProcessor + + model_id = model or "suno/bark-small" + loaded_model, processor = _load_pretrained( + AutoModelForTextToWaveform, + AutoProcessor, + model_id, + device, + dtype, + trust_remote_code, + token, + revision, + ) + + inputs = processor(text, return_tensors="pt") + + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + speech_output = loaded_model.generate(**inputs) + audio_data = speech_output.cpu().float().numpy().squeeze() + + sampling_rate = getattr(loaded_model.generation_config, "sample_rate", None) or getattr( + getattr(loaded_model.config, "audio_encoder", None), "sampling_rate", 24000 + ) + + scipy.io.wavfile.write(output, sampling_rate, audio_data) + print(f"Saved audio to {output}") + + +def audio_generate( + text: Annotated[str, typer.Option(help="Text prompt describing the audio to generate.")], + output: Annotated[str, typer.Option(help="Output WAV file path.")], + model: ModelOpt = None, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, +): + """ + Generate audio (e.g. music) from a text description and save to a WAV file. + + Uses ``AutoModelForTextToWaveform`` and ``AutoProcessor`` to produce + audio from a text prompt. + + Examples:: + + transformers audio-generate --text "a relaxing piano melody" --output music.wav + transformers audio-generate --text "upbeat electronic beat" --output beat.wav --model facebook/musicgen-small + """ + import scipy.io.wavfile + + from transformers import AutoModelForTextToWaveform, AutoProcessor + + model_id = model or "facebook/musicgen-small" + loaded_model, processor = _load_pretrained( + AutoModelForTextToWaveform, + AutoProcessor, + model_id, + device, + dtype, + trust_remote_code, + token, + revision, + ) + + inputs = processor(text=[text], return_tensors="pt", padding=True) + + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + audio_values = loaded_model.generate(**inputs, max_new_tokens=256) + audio_data = audio_values.cpu().float().numpy().squeeze() + + sampling_rate = getattr(loaded_model.generation_config, "sample_rate", None) or getattr( + getattr(loaded_model.config, "audio_encoder", None), "sampling_rate", 32000 + ) + + scipy.io.wavfile.write(output, sampling_rate, audio_data) + print(f"Saved audio to {output}") diff --git a/src/transformers/cli/agentic/export.py b/src/transformers/cli/agentic/export.py new file mode 100644 index 000000000000..eba8a7c83cfc --- /dev/null +++ b/src/transformers/cli/agentic/export.py @@ -0,0 +1,164 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Model export CLI command. + +Export a Transformers model to a deployment-friendly format. + +Examples:: + + # ONNX (requires: pip install optimum[exporters]) + transformers export onnx --model bert-base-uncased --output ./bert-onnx/ + + # GGUF (for llama.cpp) + transformers export gguf --model meta-llama/Llama-3.2-1B --output llama-1b.gguf + + # ExecuTorch (for mobile/edge; requires: pip install executorch) + transformers export executorch --model distilbert-base-uncased --output ./model.pte + +Supported formats: onnx, gguf, executorch. +""" + +from typing import Annotated + +import typer + + +_EXPORT_FORMATS = ("onnx", "gguf", "executorch") + + +def export( + fmt: Annotated[str, typer.Argument(help=f"Export format: {', '.join(_EXPORT_FORMATS)}.")], + model: Annotated[str, typer.Option("--model", "-m", help="Model ID or local path.")], + output: Annotated[str, typer.Option(help="Output path (directory for ONNX, file for GGUF).")], + opset: Annotated[int | None, typer.Option(help="ONNX opset version.")] = None, + task: Annotated[str | None, typer.Option(help="Task for ONNX export (auto-detected if omitted).")] = None, + trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code.")] = False, + token: Annotated[str | None, typer.Option(help="HF Hub token.")] = None, +): + """ + Export a model to a deployment-friendly format. + + The first argument is the target format. Each format has different + requirements and produces different output. + + Examples:: + + transformers export onnx --model bert-base-uncased --output ./bert-onnx/ + transformers export gguf --model meta-llama/Llama-3.2-1B --output llama-1b.gguf + transformers export executorch --model distilbert-base-uncased --output ./model.pte + """ + if fmt not in _EXPORT_FORMATS: + raise SystemExit(f"Unknown format '{fmt}'. Choose from: {', '.join(_EXPORT_FORMATS)}") + + if fmt == "onnx": + _export_onnx(model, output, opset, task, trust_remote_code, token) + elif fmt == "gguf": + _export_gguf(model, output, trust_remote_code, token) + elif fmt == "executorch": + _export_executorch(model, output, trust_remote_code, token) + + +def _export_onnx( + model: str, output: str, opset: int | None, task: str | None, trust_remote_code: bool, token: str | None +): + """Export to ONNX via the optimum library.""" + try: + from optimum.exporters.onnx import main_export + except ImportError: + raise SystemExit( + "ONNX export requires the 'optimum' library.\nInstall it with: pip install optimum[exporters]" + ) + + export_kwargs = { + "model_name_or_path": model, + "output": output, + } + if opset is not None: + export_kwargs["opset"] = opset + if task is not None: + export_kwargs["task"] = task + if trust_remote_code: + export_kwargs["trust_remote_code"] = True + if token is not None: + export_kwargs["token"] = token + + print(f"Exporting {model} to ONNX at {output}...") + main_export(**export_kwargs) + print(f"ONNX model saved to {output}") + + +def _export_gguf(model: str, output: str, trust_remote_code: bool, token: str | None): + """Export to GGUF format.""" + from pathlib import Path + + from transformers import AutoModelForCausalLM, AutoTokenizer + + common_kwargs = {} + if trust_remote_code: + common_kwargs["trust_remote_code"] = True + if token: + common_kwargs["token"] = token + + print(f"Loading {model}...") + loaded_model = AutoModelForCausalLM.from_pretrained(model, **common_kwargs) + tokenizer = AutoTokenizer.from_pretrained(model, **common_kwargs) + + output_path = Path(output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Saving as GGUF to {output}...") + loaded_model.save_pretrained(output_path, gguf_file=output_path.name if output.endswith(".gguf") else None) + tokenizer.save_pretrained(output_path) + print(f"GGUF model saved to {output}") + + +def _export_executorch(model: str, output: str, trust_remote_code: bool, token: str | None): + """Export to ExecuTorch format for mobile/edge deployment.""" + try: + from executorch.exir import to_edge + from torch.export import export as torch_export + except ImportError: + raise SystemExit( + "ExecuTorch export requires the 'executorch' library.\nInstall it with: pip install executorch" + ) + + from pathlib import Path + + from transformers import AutoModelForCausalLM, AutoTokenizer + + common_kwargs = {} + if trust_remote_code: + common_kwargs["trust_remote_code"] = True + if token: + common_kwargs["token"] = token + + print(f"Loading {model}...") + loaded_model = AutoModelForCausalLM.from_pretrained(model, **common_kwargs) + tokenizer = AutoTokenizer.from_pretrained(model, **common_kwargs) + + loaded_model.eval() + + # Trace with a dummy input + dummy_input = tokenizer("Hello", return_tensors="pt") + exported = torch_export(loaded_model, (dummy_input["input_ids"],)) + edge_program = to_edge(exported) + et_program = edge_program.to_executorch() + + output_path = Path(output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "wb") as f: + f.write(et_program.buffer) + + print(f"ExecuTorch model saved to {output}") diff --git a/src/transformers/cli/agentic/generate.py b/src/transformers/cli/agentic/generate.py new file mode 100644 index 000000000000..0132748ca242 --- /dev/null +++ b/src/transformers/cli/agentic/generate.py @@ -0,0 +1,254 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Text generation CLI commands. + +Uses ``AutoModelForCausalLM`` directly to expose the full set of generation +options: streaming, decoding strategies, speculative decoding, watermarking, +tool calling, constrained decoding, and quantization. +""" + +from typing import Annotated + +import typer + +from ._common import resolve_input + + +def generate( + # Input + prompt: Annotated[str | None, typer.Option(help="Prompt text.")] = None, + file: Annotated[str | None, typer.Option(help="Read prompt from this file.")] = None, + # Model + model: Annotated[str | None, typer.Option("--model", "-m", help="Model ID or local path.")] = None, + assistant_model: Annotated[str | None, typer.Option(help="Draft model for speculative/assisted decoding.")] = None, + device: Annotated[str | None, typer.Option(help="Device (cpu, cuda, cuda:0, mps).")] = None, + dtype: Annotated[str, typer.Option(help="Dtype: auto, float16, bfloat16, float32.")] = "auto", + trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code.")] = False, + token: Annotated[str | None, typer.Option(help="HF Hub token.")] = None, + revision: Annotated[str | None, typer.Option(help="Model revision.")] = None, + # Generation parameters + max_new_tokens: Annotated[int, typer.Option(help="Maximum new tokens to generate.")] = 256, + temperature: Annotated[float | None, typer.Option(help="Sampling temperature.")] = None, + top_k: Annotated[int | None, typer.Option(help="Top-k sampling.")] = None, + top_p: Annotated[float | None, typer.Option(help="Top-p (nucleus) sampling.")] = None, + num_beams: Annotated[int | None, typer.Option(help="Number of beams for beam search.")] = None, + repetition_penalty: Annotated[float | None, typer.Option(help="Repetition penalty (1.0 = no penalty).")] = None, + no_repeat_ngram_size: Annotated[int | None, typer.Option(help="Prevent repeating n-grams of this size.")] = None, + do_sample: Annotated[bool | None, typer.Option(help="Use sampling instead of greedy decoding.")] = None, + # Features + stream: Annotated[bool, typer.Option(help="Stream output token-by-token.")] = False, + watermark: Annotated[bool, typer.Option(help="Apply watermark to generated text.")] = False, + tools: Annotated[str | None, typer.Option(help="Path to a JSON file defining tools for function calling.")] = None, + grammar: Annotated[str | None, typer.Option(help="Constrain output format: 'json' for valid JSON output.")] = None, + # Quantization + quantization: Annotated[str | None, typer.Option(help="Load model quantized: 'bnb-4bit', 'bnb-8bit'.")] = None, + cache_quantization: Annotated[str | None, typer.Option(help="Quantize KV cache: '4bit', '8bit'.")] = None, +): + """ + Generate text from a prompt with full control over decoding. + + Loads a causal language model and generates text. Supports all major + decoding strategies, streaming, speculative decoding, watermarking, + tool calling, constrained decoding, and quantized inference. + + Examples:: + + # Basic generation + transformers generate --model meta-llama/Llama-3.2-1B-Instruct --prompt "Once upon a time" + + # Streaming output + transformers generate --model meta-llama/Llama-3.2-1B-Instruct --prompt "Hello" --stream + + # Sampling with temperature and top-p + transformers generate --prompt "The future of AI" --temperature 0.7 --top-p 0.9 + + # Speculative decoding (faster inference with a draft model) + transformers generate --model meta-llama/Llama-3.1-8B-Instruct \\ + --assistant-model meta-llama/Llama-3.2-1B-Instruct --prompt "Explain gravity." + + # Watermark generated text + transformers generate --prompt "Write an essay." --watermark + + # Tool/function calling (provide tools as JSON) + transformers generate --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is the weather?" --tools tools.json + + # Constrained JSON output + transformers generate --prompt "List 3 items as JSON:" --grammar json + + # 4-bit quantized inference + transformers generate --model meta-llama/Llama-3.1-8B-Instruct --prompt "Hello" --quantization bnb-4bit + + # Quantized KV cache for long context + transformers generate --model meta-llama/Llama-3.1-8B-Instruct --prompt "..." --cache-quantization 4bit + """ + import json as json_mod + + from transformers import AutoModelForCausalLM, AutoTokenizer + + input_text = resolve_input(prompt, file) + + # --- Load model & tokenizer --- + model_id = model or "HuggingFaceTB/SmolLM2-360M-Instruct" + + tok_kwargs = {} + model_kwargs = {} + if trust_remote_code: + tok_kwargs["trust_remote_code"] = True + model_kwargs["trust_remote_code"] = True + if token: + tok_kwargs["token"] = token + model_kwargs["token"] = token + if revision: + tok_kwargs["revision"] = revision + model_kwargs["revision"] = revision + if device and device != "cpu": + model_kwargs["device_map"] = device + elif device is None: + model_kwargs["device_map"] = "auto" + if dtype != "auto": + import torch + + model_kwargs["torch_dtype"] = getattr(torch, dtype) + + if quantization == "bnb-4bit": + from transformers import BitsAndBytesConfig + + model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True) + elif quantization == "bnb-8bit": + from transformers import BitsAndBytesConfig + + model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + + tokenizer = AutoTokenizer.from_pretrained(model_id, **tok_kwargs) + loaded_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + loaded_model.eval() + + # --- Load assistant model for speculative decoding --- + loaded_assistant = None + if assistant_model is not None: + loaded_assistant = AutoModelForCausalLM.from_pretrained( + assistant_model, + **{k: v for k, v in model_kwargs.items() if k != "quantization_config"}, + ) + + # --- Build generation kwargs --- + gen_kwargs = {"max_new_tokens": max_new_tokens} + + if temperature is not None: + gen_kwargs["temperature"] = temperature + if top_k is not None: + gen_kwargs["top_k"] = top_k + if top_p is not None: + gen_kwargs["top_p"] = top_p + if num_beams is not None: + gen_kwargs["num_beams"] = num_beams + if repetition_penalty is not None: + gen_kwargs["repetition_penalty"] = repetition_penalty + if no_repeat_ngram_size is not None: + gen_kwargs["no_repeat_ngram_size"] = no_repeat_ngram_size + if do_sample is not None: + gen_kwargs["do_sample"] = do_sample + elif temperature is not None or top_k is not None or top_p is not None: + gen_kwargs["do_sample"] = True + + if watermark: + from transformers import WatermarkingConfig + + gen_kwargs["watermarking_config"] = WatermarkingConfig() + + if cache_quantization is not None: + from transformers import QuantizedCacheConfig + + nbits = 4 if "4" in cache_quantization else 8 + gen_kwargs["cache_implementation"] = "quantized" + gen_kwargs["cache_config"] = QuantizedCacheConfig(nbits=nbits) + + if loaded_assistant is not None: + gen_kwargs["assistant_model"] = loaded_assistant + + # --- Constrained decoding --- + if grammar == "json": + from transformers import GrammarConstrainedLogitsProcessor, LogitsProcessorList + + gen_kwargs.setdefault("logits_processor", LogitsProcessorList()) + gen_kwargs["logits_processor"].append( + GrammarConstrainedLogitsProcessor(tokenizer=tokenizer, grammar_str='root ::= "{" [^}]* "}"') + ) + + # --- Tokenize (with tool calling via chat template if needed) --- + if tools is not None: + with open(tools) as f: + tools_def = json_mod.load(f) + messages = [{"role": "user", "content": input_text}] + inputs = tokenizer.apply_chat_template( + messages, + tools=tools_def, + return_tensors="pt", + return_dict=True, + add_generation_prompt=True, + ) + else: + inputs = tokenizer(input_text, return_tensors="pt") + + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + # --- Generate --- + if stream: + from transformers import TextStreamer + + streamer = TextStreamer(tokenizer, skip_prompt=True) + gen_kwargs["streamer"] = streamer + loaded_model.generate(**inputs, **gen_kwargs) + print() + else: + output_ids = loaded_model.generate(**inputs, **gen_kwargs) + new_tokens = output_ids[0, inputs["input_ids"].shape[1] :] + print(tokenizer.decode(new_tokens, skip_special_tokens=True)) + + +def detect_watermark( + text: Annotated[str | None, typer.Option(help="Text to check for watermark.")] = None, + file: Annotated[str | None, typer.Option(help="Read text from this file.")] = None, + model: Annotated[ + str | None, typer.Option("--model", "-m", help="Model ID (must match the model that generated the text).") + ] = None, +): + """ + Detect whether text contains a watermark. + + The ``--model`` must match the model that originally generated the text + (the watermark is tied to the model's vocabulary and config). + + Example:: + + transformers detect-watermark --model meta-llama/Llama-3.2-1B-Instruct --text "The generated essay text..." + """ + from transformers import AutoModelForCausalLM, AutoTokenizer, WatermarkDetector + + input_text = resolve_input(text, file) + model_id = model or "HuggingFaceTB/SmolLM2-360M-Instruct" + + tokenizer = AutoTokenizer.from_pretrained(model_id) + detector = WatermarkDetector( + model_config=AutoModelForCausalLM.from_pretrained(model_id).config, + device="cpu", + ) + + tokens = tokenizer(input_text, return_tensors="pt", add_special_tokens=False)["input_ids"][0] + result = detector(tokens) + + print(f"Prediction: {result.prediction}") + print(f"Confidence: {result.confidence:.4f}") diff --git a/src/transformers/cli/agentic/multimodal.py b/src/transformers/cli/agentic/multimodal.py new file mode 100644 index 000000000000..bd02c5736ac6 --- /dev/null +++ b/src/transformers/cli/agentic/multimodal.py @@ -0,0 +1,307 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Multimodal CLI commands for the transformers agentic CLI. + +All commands use Auto* model classes directly (no pipeline abstraction). +Imports of ``torch`` and ``transformers`` are deferred to function bodies +for fast CLI startup. +""" + +from typing import Annotated + +import typer + +from ._common import ( + DeviceOpt, + DtypeOpt, + JsonOpt, + ModelOpt, + RevisionOpt, + TokenOpt, + TrustOpt, + _load_pretrained, + format_output, + load_audio, + load_image, +) + + +def vqa( + image: Annotated[str, typer.Option(help="Path or URL to the image.")], + question: Annotated[str, typer.Option(help="Question about the image.")], + model: ModelOpt = None, + max_new_tokens: Annotated[int, typer.Option(help="Maximum tokens to generate.")] = 256, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """ + Visual question answering using ``AutoModelForImageTextToText``. + + Provide an image and a natural-language question; the model returns an + answer grounded in the visual content. + + Example:: + + transformers vqa --image photo.jpg --question "What color is the car?" + """ + from transformers import AutoModelForImageTextToText, AutoProcessor + + model_id = model or "vikhyatk/moondream2" + loaded_model, processor = _load_pretrained( + AutoModelForImageTextToText, AutoProcessor, model_id, device, dtype, trust_remote_code, token, revision + ) + + img = load_image(image) + messages = [{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": question}]}] + + inputs = processor.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True) + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + output_ids = loaded_model.generate(**inputs, max_new_tokens=max_new_tokens) + new_tokens = output_ids[0, inputs["input_ids"].shape[1] :] + result = processor.decode(new_tokens, skip_special_tokens=True) + + if output_json: + print(format_output({"answer": result}, output_json=True)) + else: + print(result) + + +def document_qa( + image: Annotated[str, typer.Option(help="Path or URL to the document image.")], + question: Annotated[str, typer.Option(help="Question about the document.")], + model: ModelOpt = None, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """ + Extractive document question answering using + ``AutoModelForDocumentQuestionAnswering``. + + The model reads a document image and extracts a span of text that + answers the given question. + + Example:: + + transformers document-qa --image receipt.png --question "What is the total?" + """ + import torch + + from transformers import AutoModelForDocumentQuestionAnswering, AutoProcessor + + model_id = model or "impira/layoutlm-document-qa" + loaded_model, processor = _load_pretrained( + AutoModelForDocumentQuestionAnswering, + AutoProcessor, + model_id, + device, + dtype, + trust_remote_code, + token, + revision, + ) + + img = load_image(image) + inputs = processor(images=img, question=question, return_tensors="pt") + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + with torch.no_grad(): + outputs = loaded_model(**inputs) + + start_idx = outputs.start_logits.argmax(dim=-1).item() + end_idx = outputs.end_logits.argmax(dim=-1).item() + answer = processor.tokenizer.decode(inputs["input_ids"][0, start_idx : end_idx + 1], skip_special_tokens=True) + + result = {"answer": answer, "start": start_idx, "end": end_idx} + if output_json: + print(format_output(result, output_json=True)) + else: + print(format_output(result)) + + +def caption( + image: Annotated[str, typer.Option(help="Path or URL to the image.")], + model: ModelOpt = None, + max_new_tokens: Annotated[int, typer.Option(help="Maximum tokens to generate.")] = 64, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """ + Generate a caption for an image using ``AutoModelForImageTextToText``. + + Example:: + + transformers caption --image photo.jpg + """ + from transformers import AutoModelForImageTextToText, AutoProcessor + + model_id = model or "vikhyatk/moondream2" + loaded_model, processor = _load_pretrained( + AutoModelForImageTextToText, AutoProcessor, model_id, device, dtype, trust_remote_code, token, revision + ) + + img = load_image(image) + messages = [ + { + "role": "user", + "content": [{"type": "image", "image": img}, {"type": "text", "text": "Describe this image."}], + } + ] + + inputs = processor.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True) + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + output_ids = loaded_model.generate(**inputs, max_new_tokens=max_new_tokens) + new_tokens = output_ids[0, inputs["input_ids"].shape[1] :] + result = processor.decode(new_tokens, skip_special_tokens=True) + + if output_json: + print(format_output({"caption": result}, output_json=True)) + else: + print(result) + + +def ocr( + image: Annotated[str, typer.Option(help="Path or URL to the document image.")], + model: ModelOpt = None, + max_new_tokens: Annotated[int, typer.Option(help="Maximum tokens to generate.")] = 512, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """ + Extract text from an image using ``AutoModelForImageTextToText``. + + Example:: + + transformers ocr --image scanned_page.png + """ + from transformers import AutoModelForImageTextToText, AutoProcessor + + model_id = model or "vikhyatk/moondream2" + loaded_model, processor = _load_pretrained( + AutoModelForImageTextToText, AutoProcessor, model_id, device, dtype, trust_remote_code, token, revision + ) + + img = load_image(image) + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": img}, + {"type": "text", "text": "Extract all text from this image."}, + ], + } + ] + + inputs = processor.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True) + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + output_ids = loaded_model.generate(**inputs, max_new_tokens=max_new_tokens) + new_tokens = output_ids[0, inputs["input_ids"].shape[1] :] + result = processor.decode(new_tokens, skip_special_tokens=True) + + if output_json: + print(format_output({"text": result}, output_json=True)) + else: + print(result) + + +def multimodal_chat( + prompt: Annotated[str, typer.Option(help="Text prompt for the conversation.")], + model: Annotated[str, typer.Option("--model", "-m", help="Model ID or local path.")], + image: Annotated[str | None, typer.Option(help="Path or URL to an image.")] = None, + audio: Annotated[str | None, typer.Option(help="Path to an audio file.")] = None, + max_new_tokens: Annotated[int, typer.Option(help="Maximum tokens to generate.")] = 256, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, +): + """ + Single-turn conversation with a model that accepts mixed inputs. + + Provide any combination of ``--image``, ``--audio``, and ``--prompt``. + The model must support the input modalities you provide. + + Example:: + + transformers multimodal-chat --model meta-llama/Llama-4-Scout-17B-16E-Instruct \\ + --prompt "Describe what you see and hear." --image photo.jpg --audio clip.wav + """ + from transformers import AutoModelForImageTextToText, AutoProcessor + + common_kwargs = {} + if trust_remote_code: + common_kwargs["trust_remote_code"] = True + if token: + common_kwargs["token"] = token + if revision: + common_kwargs["revision"] = revision + + processor = AutoProcessor.from_pretrained(model, **common_kwargs) + + model_kwargs = {**common_kwargs} + if device and device != "cpu": + model_kwargs["device_map"] = device + elif device is None: + model_kwargs["device_map"] = "auto" + if dtype != "auto": + import torch + + model_kwargs["torch_dtype"] = getattr(torch, dtype) + + loaded_model = AutoModelForImageTextToText.from_pretrained(model, **model_kwargs) + loaded_model.eval() + + # Build multimodal message content + content = [] + if image is not None: + img = load_image(image) + content.append({"type": "image", "image": img}) + if audio is not None: + audio_data = load_audio(audio) + content.append({"type": "audio", "audio": audio_data}) + content.append({"type": "text", "text": prompt}) + + messages = [{"role": "user", "content": content}] + + inputs = processor.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True) + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + output_ids = loaded_model.generate(**inputs, max_new_tokens=max_new_tokens) + new_tokens = output_ids[0, inputs["input_ids"].shape[1] :] + print(processor.decode(new_tokens, skip_special_tokens=True)) diff --git a/src/transformers/cli/agentic/quantize.py b/src/transformers/cli/agentic/quantize.py new file mode 100644 index 000000000000..a41019a2259e --- /dev/null +++ b/src/transformers/cli/agentic/quantize.py @@ -0,0 +1,160 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Quantization CLI command. + +Quantize a model and save the result locally or push to the Hub. + +Examples:: + + # BitsAndBytes 4-bit (NF4) + transformers quantize --model meta-llama/Llama-3.1-8B --method bnb-4bit --output ./llama-4bit + + # GPTQ with calibration data + transformers quantize --model meta-llama/Llama-3.1-8B --method gptq --calibration-dataset wikitext --output ./llama-gptq + + # AWQ + transformers quantize --model meta-llama/Llama-3.1-8B --method awq --output ./llama-awq + +Supported methods: bnb-4bit, bnb-8bit, gptq, awq. +""" + +from typing import Annotated + +import typer + + +_QUANTIZATION_METHODS = ("bnb-4bit", "bnb-8bit", "gptq", "awq") + + +def quantize( + model: Annotated[str, typer.Option("--model", "-m", help="Model ID or local path to quantize.")], + method: Annotated[str, typer.Option(help=f"Quantization method: {', '.join(_QUANTIZATION_METHODS)}.")], + output: Annotated[str, typer.Option(help="Output directory for the quantized model.")], + calibration_dataset: Annotated[ + str | None, typer.Option(help="Calibration dataset for GPTQ/AWQ (Hub name or local path).") + ] = None, + calibration_samples: Annotated[int, typer.Option(help="Number of calibration samples.")] = 128, + bits: Annotated[int, typer.Option(help="Target bit width (for GPTQ/AWQ).")] = 4, + group_size: Annotated[int, typer.Option(help="Group size for GPTQ/AWQ.")] = 128, + device: Annotated[str | None, typer.Option(help="Device for quantization.")] = None, + trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code.")] = False, + token: Annotated[str | None, typer.Option(help="HF Hub token.")] = None, + push_to_hub: Annotated[bool, typer.Option(help="Push quantized model to Hub.")] = False, + hub_model_id: Annotated[str | None, typer.Option(help="Hub repo ID for push.")] = None, +): + """ + Quantize a model and save it. + + Loads the model with the specified quantization method and saves the + quantized weights. For GPTQ and AWQ, a calibration dataset is used + to determine optimal quantization parameters. + + Examples:: + + transformers quantize --model meta-llama/Llama-3.1-8B --method bnb-4bit --output ./llama-4bit + transformers quantize --model meta-llama/Llama-3.1-8B --method gptq --calibration-dataset wikitext --output ./llama-gptq + """ + from transformers import AutoModelForCausalLM, AutoTokenizer + + if method not in _QUANTIZATION_METHODS: + raise SystemExit(f"Unknown method '{method}'. Choose from: {', '.join(_QUANTIZATION_METHODS)}") + + common_kwargs = {} + if trust_remote_code: + common_kwargs["trust_remote_code"] = True + if token: + common_kwargs["token"] = token + + tokenizer = AutoTokenizer.from_pretrained(model, **common_kwargs) + + model_kwargs = {**common_kwargs} + if device: + model_kwargs["device_map"] = device + else: + model_kwargs["device_map"] = "auto" + + # --- BitsAndBytes --- + if method == "bnb-4bit": + from transformers import BitsAndBytesConfig + + model_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype="bfloat16", + bnb_4bit_quant_type="nf4", + ) + print(f"Loading {model} in 4-bit (BitsAndBytes NF4)...") + loaded_model = AutoModelForCausalLM.from_pretrained(model, **model_kwargs) + loaded_model.save_pretrained(output) + tokenizer.save_pretrained(output) + print(f"Quantized model saved to {output}") + + elif method == "bnb-8bit": + from transformers import BitsAndBytesConfig + + model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + print(f"Loading {model} in 8-bit (BitsAndBytes)...") + loaded_model = AutoModelForCausalLM.from_pretrained(model, **model_kwargs) + loaded_model.save_pretrained(output) + tokenizer.save_pretrained(output) + print(f"Quantized model saved to {output}") + + # --- GPTQ --- + elif method == "gptq": + from transformers import GPTQConfig + + if calibration_dataset is None: + calibration_dataset = "wikitext" + print("No --calibration-dataset specified, defaulting to 'wikitext'.") + + from datasets import load_dataset + + cal_ds = load_dataset(calibration_dataset, split=f"train[:{calibration_samples}]") + cal_texts = [ex["text"] for ex in cal_ds if ex.get("text")] + + quantization_config = GPTQConfig( + bits=bits, + group_size=group_size, + dataset=cal_texts, + tokenizer=tokenizer, + ) + model_kwargs["quantization_config"] = quantization_config + + print(f"Quantizing {model} with GPTQ ({bits}-bit, group_size={group_size})...") + loaded_model = AutoModelForCausalLM.from_pretrained(model, **model_kwargs) + loaded_model.save_pretrained(output) + tokenizer.save_pretrained(output) + print(f"GPTQ-quantized model saved to {output}") + + # --- AWQ --- + elif method == "awq": + from transformers import AwqConfig + + quantization_config = AwqConfig( + bits=bits, + group_size=group_size, + ) + model_kwargs["quantization_config"] = quantization_config + + print(f"Quantizing {model} with AWQ ({bits}-bit, group_size={group_size})...") + loaded_model = AutoModelForCausalLM.from_pretrained(model, **model_kwargs) + loaded_model.save_pretrained(output) + tokenizer.save_pretrained(output) + print(f"AWQ-quantized model saved to {output}") + + if push_to_hub: + repo_id = hub_model_id or output + loaded_model.push_to_hub(repo_id, token=token) + tokenizer.push_to_hub(repo_id, token=token) + print(f"Pushed to Hub: {repo_id}") diff --git a/src/transformers/cli/agentic/text.py b/src/transformers/cli/agentic/text.py new file mode 100644 index 000000000000..2569b295580d --- /dev/null +++ b/src/transformers/cli/agentic/text.py @@ -0,0 +1,580 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Text inference CLI commands. + +Each function uses Auto* model and tokenizer classes directly and is +registered as a top-level ``transformers`` CLI command via ``app.py``. +""" + +from typing import Annotated + +import typer + +from ._common import ( + DeviceOpt, + DtypeOpt, + JsonOpt, + ModelOpt, + RevisionOpt, + TokenOpt, + TrustOpt, + _load_pretrained, + format_output, + resolve_input, +) + + +def _aggregate_entities(entities, text): + """Merge sub-word entity predictions into whole entities (B-/I- tag merging).""" + if not entities: + return entities + + aggregated = [] + current = None + + for entity in entities: + label = entity["entity"] + entity_type = label.split("-", 1)[-1] if "-" in label else label + is_continuation = label.startswith("I-") + + if current is not None and is_continuation and entity_type == current["entity_group"]: + current["end"] = entity["end"] + current["score"] = min(current["score"], entity["score"]) + else: + if current is not None: + current["word"] = text[current["start"] : current["end"]] + aggregated.append(current) + current = { + "entity_group": entity_type, + "score": entity["score"], + "start": entity["start"], + "end": entity["end"], + } + + if current is not None: + current["word"] = text[current["start"] : current["end"]] + aggregated.append(current) + + return aggregated + + +def classify( + text: Annotated[str | None, typer.Option(help="Text to classify.")] = None, + file: Annotated[str | None, typer.Option(help="Read text from this file.")] = None, + labels: Annotated[ + str | None, typer.Option(help="Comma-separated candidate labels for zero-shot classification.") + ] = None, + model: ModelOpt = None, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """ + Classify text into categories. + + Uses ``AutoModelForSequenceClassification`` by default (requires a + fine-tuned classification model). Pass ``--labels`` to switch to + zero-shot classification via natural language inference. + + Examples:: + + # Supervised (model already fine-tuned for sentiment) + transformers classify --model distilbert/distilbert-base-uncased-finetuned-sst-2-english --text "Great movie!" + + # Zero-shot (any categories, no fine-tuning needed) + transformers classify --text "The stock market crashed" --labels "politics,finance,sports" + + # Read from file, output as JSON + transformers classify --file review.txt --json + """ + import torch + + from transformers import AutoModelForSequenceClassification, AutoTokenizer + + input_text = resolve_input(text, file) + + if labels is not None: + # Zero-shot classification via natural language inference: + # for each candidate label, test whether the input entails "This example is {label}." + model_id = model or "facebook/bart-large-mnli" + loaded_model, tokenizer = _load_pretrained( + AutoModelForSequenceClassification, + AutoTokenizer, + model_id, + device, + dtype, + trust_remote_code, + token, + revision, + ) + + candidate_labels = [l.strip() for l in labels.split(",")] + + # Find the entailment class index from the model config + entail_idx = 2 + for idx, label_name in loaded_model.config.id2label.items(): + if label_name.lower().startswith("entail"): + entail_idx = int(idx) + break + + scores = [] + for label in candidate_labels: + hypothesis = f"This example is {label}." + inputs = tokenizer(input_text, hypothesis, return_tensors="pt", truncation=True) + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + with torch.no_grad(): + logits = loaded_model(**inputs).logits + scores.append(logits.softmax(dim=-1)[0, entail_idx].item()) + + total = sum(scores) + result = { + "sequence": input_text, + "labels": candidate_labels, + "scores": [s / total for s in scores], + } + else: + model_id = model or "distilbert/distilbert-base-uncased-finetuned-sst-2-english" + loaded_model, tokenizer = _load_pretrained( + AutoModelForSequenceClassification, + AutoTokenizer, + model_id, + device, + dtype, + trust_remote_code, + token, + revision, + ) + + inputs = tokenizer(input_text, return_tensors="pt", truncation=True) + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + with torch.no_grad(): + logits = loaded_model(**inputs).logits + + probs = logits.softmax(dim=-1)[0] + top_idx = probs.argmax().item() + result = [{"label": loaded_model.config.id2label[top_idx], "score": probs[top_idx].item()}] + + print(format_output(result, output_json)) + + +def ner( + text: Annotated[str | None, typer.Option(help="Text to extract entities from.")] = None, + file: Annotated[str | None, typer.Option(help="Read text from this file.")] = None, + model: ModelOpt = None, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + aggregation_strategy: Annotated[str, typer.Option(help="Entity aggregation: 'none' or 'simple'.")] = "simple", + output_json: JsonOpt = False, +): + """ + Extract named entities from text (NER). + + Uses ``AutoModelForTokenClassification`` with entity aggregation + enabled by default (``--aggregation-strategy simple``). + + Example:: + + transformers ner --model dslim/bert-base-NER --text "Apple CEO Tim Cook met with President Biden in Washington." + """ + import torch + + from transformers import AutoModelForTokenClassification, AutoTokenizer + + input_text = resolve_input(text, file) + model_id = model or "dslim/bert-base-NER" + loaded_model, tokenizer = _load_pretrained( + AutoModelForTokenClassification, AutoTokenizer, model_id, device, dtype, trust_remote_code, token, revision + ) + + inputs = tokenizer(input_text, return_tensors="pt", truncation=True, return_offsets_mapping=True) + offset_mapping = inputs.pop("offset_mapping")[0] + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + with torch.no_grad(): + logits = loaded_model(**inputs).logits + + probs = logits.softmax(dim=-1) + predictions = logits.argmax(dim=-1)[0] + + entities = [] + for idx, (pred, (start, end)) in enumerate(zip(predictions, offset_mapping)): + label = loaded_model.config.id2label[pred.item()] + if label == "O" or (start == 0 and end == 0): + continue + entities.append( + { + "entity": label, + "score": probs[0, idx, pred].item(), + "word": input_text[start:end], + "start": start.item(), + "end": end.item(), + } + ) + + if aggregation_strategy == "simple": + entities = _aggregate_entities(entities, input_text) + + print(format_output(entities, output_json)) + + +def token_classify( + text: Annotated[str | None, typer.Option(help="Text to tag.")] = None, + file: Annotated[str | None, typer.Option(help="Read text from this file.")] = None, + model: ModelOpt = None, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """ + Tag tokens with labels (POS tagging, chunking, etc.). + + Uses ``AutoModelForTokenClassification``. The output depends on the + model — a POS model outputs POS tags, a NER model outputs entity labels. + + Example:: + + transformers token-classify --model vblagoje/bert-english-uncased-finetuned-pos --text "The cat sat on the mat." + """ + import torch + + from transformers import AutoModelForTokenClassification, AutoTokenizer + + input_text = resolve_input(text, file) + model_id = model or "vblagoje/bert-english-uncased-finetuned-pos" + loaded_model, tokenizer = _load_pretrained( + AutoModelForTokenClassification, AutoTokenizer, model_id, device, dtype, trust_remote_code, token, revision + ) + + inputs = tokenizer(input_text, return_tensors="pt", truncation=True, return_offsets_mapping=True) + offset_mapping = inputs.pop("offset_mapping")[0] + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + with torch.no_grad(): + logits = loaded_model(**inputs).logits + + probs = logits.softmax(dim=-1) + predictions = logits.argmax(dim=-1)[0] + + result = [] + for idx, (pred, (start, end)) in enumerate(zip(predictions, offset_mapping)): + if start == 0 and end == 0: + continue + result.append( + { + "entity": loaded_model.config.id2label[pred.item()], + "score": probs[0, idx, pred].item(), + "word": input_text[start:end], + "start": start.item(), + "end": end.item(), + } + ) + + print(format_output(result, output_json)) + + +def qa( + question: Annotated[str, typer.Option(help="The question to answer.")], + context: Annotated[str | None, typer.Option(help="Context paragraph containing the answer.")] = None, + file: Annotated[str | None, typer.Option(help="Read context from this file.")] = None, + model: ModelOpt = None, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """ + Answer a question given a context paragraph (extractive QA). + + Uses ``AutoModelForQuestionAnswering`` to extract the answer span from + ``--context`` (or ``--file``). The model does not generate new text — + it highlights the relevant substring. + + Example:: + + transformers qa --question "Who invented the telephone?" --context "Alexander Graham Bell invented the telephone in 1876." + """ + import torch + + from transformers import AutoModelForQuestionAnswering, AutoTokenizer + + ctx = resolve_input(context, file) + model_id = model or "distilbert/distilbert-base-cased-distilled-squad" + loaded_model, tokenizer = _load_pretrained( + AutoModelForQuestionAnswering, AutoTokenizer, model_id, device, dtype, trust_remote_code, token, revision + ) + + inputs = tokenizer(question, ctx, return_tensors="pt", truncation=True) + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + with torch.no_grad(): + outputs = loaded_model(**inputs) + + start_idx = outputs.start_logits.argmax(dim=-1).item() + end_idx = outputs.end_logits.argmax(dim=-1).item() + answer_ids = inputs["input_ids"][0, start_idx : end_idx + 1] + score = (outputs.start_logits[0, start_idx] + outputs.end_logits[0, end_idx]).item() + + result = { + "answer": tokenizer.decode(answer_ids, skip_special_tokens=True), + "score": score, + "start": start_idx, + "end": end_idx, + } + print(format_output(result, output_json)) + + +def table_qa( + question: Annotated[str, typer.Option(help="Question about the table.")], + table: Annotated[str, typer.Option(help="Path to a CSV file containing the table.")], + model: ModelOpt = None, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """ + Answer a question about tabular data (CSV). + + Loads a CSV file into a table and uses ``AutoModelForTableQuestionAnswering`` + (e.g., TAPAS) to answer the question. + + Example:: + + transformers table-qa --question "What is the total revenue?" --table financials.csv + """ + import pandas as pd + import torch + + from transformers import AutoModelForTableQuestionAnswering, AutoTokenizer + + model_id = model or "google/tapas-base-finetuned-wtq" + loaded_model, tokenizer = _load_pretrained( + AutoModelForTableQuestionAnswering, AutoTokenizer, model_id, device, dtype, trust_remote_code, token, revision + ) + + table_df = pd.read_csv(table).astype(str) + inputs = tokenizer(table=table_df, queries=question, return_tensors="pt", truncation=True) + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + with torch.no_grad(): + outputs = loaded_model(**inputs) + + logits_agg = getattr(outputs, "logits_aggregation", None) + if logits_agg is not None: + predicted_coordinates, predicted_agg = tokenizer.convert_logits_to_predictions( + inputs, outputs.logits.detach().cpu(), logits_agg.detach().cpu() + ) + agg_idx = predicted_agg[0] + else: + (predicted_coordinates,) = tokenizer.convert_logits_to_predictions(inputs, outputs.logits.detach().cpu()) + agg_idx = 0 + + coordinates = predicted_coordinates[0] + cells = [table_df.iat[row, col] for row, col in coordinates] + + _AGG_OPS = {0: "NONE", 1: "SUM", 2: "AVERAGE", 3: "COUNT"} + if agg_idx == 1: + try: + answer = str(sum(float(c) for c in cells)) + except ValueError: + answer = ", ".join(cells) + elif agg_idx == 2: + try: + answer = str(sum(float(c) for c in cells) / len(cells)) + except (ValueError, ZeroDivisionError): + answer = ", ".join(cells) + elif agg_idx == 3: + answer = str(len(cells)) + else: + answer = ", ".join(cells) + + result = { + "answer": answer, + "coordinates": coordinates, + "cells": cells, + "aggregator": _AGG_OPS.get(agg_idx, "NONE"), + } + print(format_output(result, output_json)) + + +def summarize( + text: Annotated[str | None, typer.Option(help="Text to summarize.")] = None, + file: Annotated[str | None, typer.Option(help="Read text from this file.")] = None, + model: ModelOpt = None, + max_length: Annotated[int | None, typer.Option(help="Maximum summary length in tokens.")] = None, + min_length: Annotated[int | None, typer.Option(help="Minimum summary length in tokens.")] = None, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """ + Summarize text. + + Uses ``AutoModelForSeq2SeqLM`` (e.g., BART, T5, Pegasus). + + Examples:: + + transformers summarize --model facebook/bart-large-cnn --file article.txt + transformers summarize --text "Long article text here..." --max-length 100 + """ + from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + + input_text = resolve_input(text, file) + model_id = model or "facebook/bart-large-cnn" + loaded_model, tokenizer = _load_pretrained( + AutoModelForSeq2SeqLM, AutoTokenizer, model_id, device, dtype, trust_remote_code, token, revision + ) + + inputs = tokenizer(input_text, return_tensors="pt", truncation=True) + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + gen_kwargs = {} + if max_length is not None: + gen_kwargs["max_length"] = max_length + if min_length is not None: + gen_kwargs["min_length"] = min_length + + output_ids = loaded_model.generate(**inputs, **gen_kwargs) + summary = tokenizer.decode(output_ids[0], skip_special_tokens=True) + result = [{"summary_text": summary}] + print(format_output(result, output_json)) + + +def translate( + text: Annotated[str | None, typer.Option(help="Text to translate.")] = None, + file: Annotated[str | None, typer.Option(help="Read text from this file.")] = None, + model: ModelOpt = None, + max_length: Annotated[int | None, typer.Option(help="Maximum translation length.")] = None, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """ + Translate text between languages. + + Uses ``AutoModelForSeq2SeqLM``. The language pair is determined by the + model. Use Helsinki-NLP models for specific pairs (e.g., + ``Helsinki-NLP/opus-mt-en-de`` for English to German). + + Example:: + + transformers translate --model Helsinki-NLP/opus-mt-en-de --text "The weather is nice today." + """ + from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + + input_text = resolve_input(text, file) + model_id = model or "Helsinki-NLP/opus-mt-en-de" + loaded_model, tokenizer = _load_pretrained( + AutoModelForSeq2SeqLM, AutoTokenizer, model_id, device, dtype, trust_remote_code, token, revision + ) + + inputs = tokenizer(input_text, return_tensors="pt", truncation=True) + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + gen_kwargs = {} + if max_length is not None: + gen_kwargs["max_length"] = max_length + + output_ids = loaded_model.generate(**inputs, **gen_kwargs) + translation = tokenizer.decode(output_ids[0], skip_special_tokens=True) + result = [{"translation_text": translation}] + print(format_output(result, output_json)) + + +def fill_mask( + text: Annotated[str, typer.Option(help="Text with a [MASK] token.")], + model: ModelOpt = None, + top_k: Annotated[int, typer.Option(help="Number of predictions to return.")] = 5, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """ + Predict the masked token in a sentence. + + Uses ``AutoModelForMaskedLM``. The mask token depends on the model + (``[MASK]`` for BERT, ```` for RoBERTa). + + Example:: + + transformers fill-mask --model answerdotai/ModernBERT-base --text "The capital of France is [MASK]." + """ + import torch + + from transformers import AutoModelForMaskedLM, AutoTokenizer + + model_id = model or "answerdotai/ModernBERT-base" + loaded_model, tokenizer = _load_pretrained( + AutoModelForMaskedLM, AutoTokenizer, model_id, device, dtype, trust_remote_code, token, revision + ) + + inputs = tokenizer(text, return_tensors="pt") + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + mask_positions = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1] + if len(mask_positions) == 0: + raise SystemExit(f"No mask token found. Use '{tokenizer.mask_token}' in your text.") + + with torch.no_grad(): + logits = loaded_model(**inputs).logits + + mask_logits = logits[0, mask_positions[0]] + probs = mask_logits.softmax(dim=-1) + top_probs, top_ids = probs.topk(top_k) + + result = [] + for prob, token_id in zip(top_probs, top_ids): + token_str = tokenizer.decode([token_id]).strip() + result.append( + { + "score": prob.item(), + "token": token_id.item(), + "token_str": token_str, + "sequence": text.replace(tokenizer.mask_token, token_str, 1), + } + ) + + print(format_output(result, output_json)) diff --git a/src/transformers/cli/agentic/train.py b/src/transformers/cli/agentic/train.py new file mode 100644 index 000000000000..9126dcbaf64e --- /dev/null +++ b/src/transformers/cli/agentic/train.py @@ -0,0 +1,545 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Training CLI command. + +Wraps ``Trainer`` to fine-tune or pretrain a model on any supported task +from a single CLI invocation. Supports text, vision, and audio tasks, +with built-in LoRA/QLoRA, distributed training, and hyperparameter search. + +Examples:: + + # Fine-tune text classification + transformers train text-classification --model bert-base-uncased --dataset glue/sst2 --output ./out + + # Fine-tune image classification + transformers train image-classification --model google/vit-base-patch16-224 --dataset food101 --output ./out + + # QLoRA (4-bit base + LoRA adapters) + transformers train text-generation --model meta-llama/Llama-3.1-8B \\ + --dataset ./data.jsonl --output ./out --lora --quantization bnb-4bit + + # Distributed with DeepSpeed + transformers train text-generation --model meta-llama/Llama-3.1-8B \\ + --dataset ./data.jsonl --output ./out --deepspeed zero3 --dtype bfloat16 + +Supported tasks: text-classification, token-classification, question-answering, +summarization, translation, text-generation, language-modeling, +image-classification, object-detection, semantic-segmentation, +speech-recognition, audio-classification. +""" + +from typing import Annotated + +import typer + + +# Maps CLI task names to (AutoModel class name, preprocessing type) +_TASK_CONFIGS = { + # Text tasks + "text-classification": { + "auto_class": "AutoModelForSequenceClassification", + "preprocess": "tokenize", + "text_columns": ("sentence", "text"), + "label_column": "label", + }, + "token-classification": { + "auto_class": "AutoModelForTokenClassification", + "preprocess": "tokenize_and_align_labels", + "text_columns": ("tokens",), + "label_column": "ner_tags", + }, + "question-answering": { + "auto_class": "AutoModelForQuestionAnswering", + "preprocess": "tokenize_qa", + "text_columns": ("question", "context"), + "label_column": None, + }, + "summarization": { + "auto_class": "AutoModelForSeq2SeqLM", + "preprocess": "tokenize_seq2seq", + "text_columns": ("article", "document"), + "label_column": ("highlights", "summary"), + }, + "translation": { + "auto_class": "AutoModelForSeq2SeqLM", + "preprocess": "tokenize_seq2seq", + "text_columns": None, + "label_column": None, + }, + "text-generation": { + "auto_class": "AutoModelForCausalLM", + "preprocess": "tokenize_causal", + "text_columns": ("text",), + "label_column": None, + }, + "language-modeling": { + "auto_class": "AutoModelForMaskedLM", + "preprocess": "tokenize_mlm", + "text_columns": ("text",), + "label_column": None, + }, + # Vision tasks + "image-classification": { + "auto_class": "AutoModelForImageClassification", + "preprocess": "image_transform", + "text_columns": None, + "label_column": "label", + }, + "object-detection": { + "auto_class": "AutoModelForObjectDetection", + "preprocess": "image_transform", + "text_columns": None, + "label_column": None, + }, + "semantic-segmentation": { + "auto_class": "AutoModelForSemanticSegmentation", + "preprocess": "image_transform", + "text_columns": None, + "label_column": None, + }, + # Audio tasks + "speech-recognition": { + "auto_class": "AutoModelForSpeechSeq2Seq", + "preprocess": "audio_transform", + "text_columns": None, + "label_column": "text", + }, + "audio-classification": { + "auto_class": "AutoModelForAudioClassification", + "preprocess": "audio_transform", + "text_columns": None, + "label_column": "label", + }, +} + + +def _detect_text_column(dataset, candidates: tuple[str, ...] | None) -> str: + """Find the first matching column name in the dataset.""" + if candidates is None: + return None + columns = dataset.column_names + if isinstance(columns, dict): + columns = columns.get("train", list(columns.values())[0]) + for c in candidates: + if c in columns: + return c + return columns[0] + + +def _load_dataset(dataset_path: str, subset: str | None, token: str | None): + """Load a dataset from the Hub or local files.""" + from datasets import load_dataset + + kwargs = {} + if token: + kwargs["token"] = token + + # Detect local files + if dataset_path.endswith((".csv", ".json", ".jsonl", ".txt", ".parquet")): + if dataset_path.endswith(".csv"): + fmt = "csv" + elif dataset_path.endswith((".json", ".jsonl")): + fmt = "json" + elif dataset_path.endswith(".parquet"): + fmt = "parquet" + else: + fmt = "text" + return load_dataset(fmt, data_files=dataset_path, **kwargs) + + # Hub dataset, possibly with subset: "glue/sst2" -> ("glue", "sst2") + if subset is not None: + return load_dataset(dataset_path, subset, **kwargs) + if "/" in dataset_path and not dataset_path.startswith((".", "/")): + parts = dataset_path.split("/") + if len(parts) == 2: + # Could be "org/dataset" or "dataset/subset" — try as-is first + try: + return load_dataset(dataset_path, **kwargs) + except Exception: + return load_dataset(parts[0], parts[1], **kwargs) + return load_dataset(dataset_path, **kwargs) + + +def train( + task: Annotated[str, typer.Argument(help=f"Task to train. One of: {', '.join(_TASK_CONFIGS.keys())}.")], + # Model + model: Annotated[str, typer.Option("--model", "-m", help="Model ID or local path.")], + dataset: Annotated[str, typer.Option(help="Dataset name (Hub) or path (local file).")], + output: Annotated[str, typer.Option(help="Output directory for checkpoints and final model.")], + subset: Annotated[str | None, typer.Option(help="Dataset subset/config name.")] = None, + # Training hyperparameters + epochs: Annotated[float, typer.Option(help="Number of training epochs.")] = 3.0, + lr: Annotated[float, typer.Option(help="Learning rate.")] = 5e-5, + batch_size: Annotated[int, typer.Option(help="Per-device training batch size.")] = 8, + eval_batch_size: Annotated[int | None, typer.Option(help="Per-device eval batch size.")] = None, + max_seq_length: Annotated[int, typer.Option(help="Maximum sequence length for tokenization.")] = 512, + gradient_accumulation_steps: Annotated[int, typer.Option(help="Gradient accumulation steps.")] = 1, + warmup_ratio: Annotated[float, typer.Option(help="Warmup ratio.")] = 0.0, + weight_decay: Annotated[float, typer.Option(help="Weight decay.")] = 0.0, + # Evaluation + eval_strategy: Annotated[str, typer.Option(help="Evaluation strategy: 'no', 'steps', 'epoch'.")] = "epoch", + eval_steps: Annotated[int | None, typer.Option(help="Evaluation interval (if strategy='steps').")] = None, + # Checkpointing + save_strategy: Annotated[str, typer.Option(help="Save strategy: 'no', 'steps', 'epoch'.")] = "epoch", + save_total_limit: Annotated[int | None, typer.Option(help="Max checkpoints to keep.")] = None, + resume_from_checkpoint: Annotated[str | None, typer.Option(help="Path to checkpoint to resume from.")] = None, + load_best_model_at_end: Annotated[bool, typer.Option(help="Load best model after training.")] = True, + # Early stopping + early_stopping: Annotated[bool, typer.Option(help="Enable early stopping.")] = False, + early_stopping_patience: Annotated[int, typer.Option(help="Early stopping patience (eval rounds).")] = 3, + # LoRA / QLoRA + lora: Annotated[bool, typer.Option(help="Use LoRA for parameter-efficient fine-tuning.")] = False, + lora_r: Annotated[int, typer.Option(help="LoRA rank.")] = 16, + lora_alpha: Annotated[int, typer.Option(help="LoRA alpha.")] = 32, + lora_dropout: Annotated[float, typer.Option(help="LoRA dropout.")] = 0.05, + # Quantization + quantization: Annotated[str | None, typer.Option(help="Quantize base model: 'bnb-4bit', 'bnb-8bit'.")] = None, + # Precision & device + dtype: Annotated[str, typer.Option(help="Training dtype: 'auto', 'float16', 'bfloat16', 'float32'.")] = "auto", + device: Annotated[str | None, typer.Option(help="Device to train on: 'cpu', 'cuda', 'mps', 'tpu'.")] = None, + gradient_checkpointing: Annotated[bool, typer.Option(help="Enable gradient checkpointing.")] = False, + # Distributed + multi_gpu: Annotated[bool, typer.Option(help="Use all available GPUs on this machine.")] = False, + nnodes: Annotated[ + int | None, typer.Option(help="Number of nodes for multi-node training (uses torchrun).") + ] = None, + deepspeed: Annotated[str | None, typer.Option(help="DeepSpeed config: 'zero2', 'zero3', or path to JSON.")] = None, + fsdp: Annotated[str | None, typer.Option(help="FSDP strategy: 'full-shard', 'shard-grad-op', 'offload'.")] = None, + # Logging + logging: Annotated[str | None, typer.Option(help="Logging integration: 'tensorboard', 'wandb', 'comet'.")] = None, + # Hub + push_to_hub: Annotated[bool, typer.Option(help="Push final model to the Hub.")] = False, + hub_model_id: Annotated[str | None, typer.Option(help="Hub repository ID.")] = None, + token: Annotated[str | None, typer.Option(help="HF Hub token.")] = None, + trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code.")] = False, + # HPO + hpo: Annotated[str | None, typer.Option(help="HPO backend: 'optuna', 'ray'.")] = None, + hpo_trials: Annotated[int, typer.Option(help="Number of HPO trials.")] = 10, + # Pretraining from scratch + from_scratch: Annotated[bool, typer.Option(help="Initialize model from scratch (random weights).")] = False, + mlm: Annotated[bool, typer.Option(help="Use masked language modeling (for language-modeling task).")] = False, +): + """ + Fine-tune or pretrain a model on a dataset. + + The first argument is the task name (e.g., ``text-classification``). + The model, dataset, and output directory are required options. All + other options have sensible defaults. + + Examples:: + + # Basic fine-tuning + transformers train text-classification --model bert-base-uncased --dataset glue/sst2 --output ./out --epochs 3 + + # LoRA + transformers train text-generation --model meta-llama/Llama-3.2-1B --dataset ./data.jsonl --output ./out --lora + + # Resume from checkpoint + transformers train text-classification --model bert-base-uncased --dataset glue/sst2 --output ./out --resume-from-checkpoint ./out/checkpoint-500 + + # Multi-GPU + transformers train text-generation --model meta-llama/Llama-3.2-1B --dataset ./data.jsonl --output ./out --multi-gpu + + # MPS (Apple Silicon) + transformers train text-classification --model bert-base-uncased --dataset glue/sst2 --output ./out --device mps + """ + import transformers + from transformers import ( + AutoConfig, + AutoTokenizer, + Trainer, + TrainingArguments, + ) + + if task not in _TASK_CONFIGS: + raise SystemExit(f"Unknown task '{task}'. Choose from: {', '.join(_TASK_CONFIGS.keys())}") + + task_config = _TASK_CONFIGS[task] + + # Override: if MLM flag is set, use masked LM + if task == "language-modeling" and not mlm: + task_config = {**task_config, "auto_class": "AutoModelForCausalLM"} + + # --- Load dataset --- + ds = _load_dataset(dataset, subset, token) + + # Split if there's no validation set + if "validation" not in ds and "test" not in ds: + split = ds["train"].train_test_split(test_size=0.1, seed=42) + ds["train"] = split["train"] + ds["validation"] = split["test"] + + eval_split = "validation" if "validation" in ds else "test" + + # --- Determine label count for classification --- + num_labels = None + label_col = task_config.get("label_column") + if isinstance(label_col, tuple): + for c in label_col: + if c in ds["train"].column_names: + label_col = c + break + else: + label_col = label_col[0] + if label_col and label_col in ds["train"].column_names: + features = ds["train"].features + if hasattr(features[label_col], "names"): + num_labels = features[label_col].num_classes + + # --- Load model & processing_class --- + auto_cls = getattr(transformers, task_config["auto_class"]) + + model_kwargs = {} + if trust_remote_code: + model_kwargs["trust_remote_code"] = True + if token: + model_kwargs["token"] = token + if num_labels is not None: + model_kwargs["num_labels"] = num_labels + + if quantization == "bnb-4bit": + from transformers import BitsAndBytesConfig + + model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True) + elif quantization == "bnb-8bit": + from transformers import BitsAndBytesConfig + + model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + + if from_scratch: + config = AutoConfig.from_pretrained( + model, **{k: v for k, v in model_kwargs.items() if k not in ("quantization_config",)} + ) + loaded_model = auto_cls.from_config(config) + else: + loaded_model = auto_cls.from_pretrained(model, **model_kwargs) + + # Load processor based on task type + processing_class = None + preprocess_type = task_config["preprocess"] + + if preprocess_type.startswith("tokenize") or preprocess_type == "audio_transform": + tok_kwargs = {} + if trust_remote_code: + tok_kwargs["trust_remote_code"] = True + if token: + tok_kwargs["token"] = token + processing_class = AutoTokenizer.from_pretrained(model, **tok_kwargs) + elif preprocess_type == "image_transform": + from transformers import AutoImageProcessor + + proc_kwargs = {} + if trust_remote_code: + proc_kwargs["trust_remote_code"] = True + if token: + proc_kwargs["token"] = token + processing_class = AutoImageProcessor.from_pretrained(model, **proc_kwargs) + + # --- Preprocess dataset --- + if preprocess_type == "tokenize": + text_col = _detect_text_column(ds["train"], task_config["text_columns"]) + + def preprocess_fn(examples): + return processing_class(examples[text_col], truncation=True, max_length=max_seq_length) + + ds = ds.map(preprocess_fn, batched=True) + elif preprocess_type == "tokenize_causal": + text_col = _detect_text_column(ds["train"], task_config["text_columns"]) + + def preprocess_fn(examples): + return processing_class(examples[text_col], truncation=True, max_length=max_seq_length) + + ds = ds.map(preprocess_fn, batched=True) + elif preprocess_type == "tokenize_seq2seq": + columns = ds["train"].column_names + # Find source and target columns + source_col = columns[0] + target_col = ( + label_col if label_col and label_col in columns else columns[1] if len(columns) > 1 else columns[0] + ) + + def preprocess_fn(examples): + model_inputs = processing_class(examples[source_col], truncation=True, max_length=max_seq_length) + labels = processing_class(text_target=examples[target_col], truncation=True, max_length=max_seq_length) + model_inputs["labels"] = labels["input_ids"] + return model_inputs + + ds = ds.map(preprocess_fn, batched=True) + elif preprocess_type == "image_transform": + from torchvision.transforms import Compose, Normalize, RandomResizedCrop, ToTensor + + _normalize = Normalize( + mean=processing_class.image_mean if hasattr(processing_class, "image_mean") else [0.485, 0.456, 0.406], + std=processing_class.image_std if hasattr(processing_class, "image_std") else [0.229, 0.224, 0.225], + ) + _size = processing_class.size.get("shortest_edge", 224) if hasattr(processing_class, "size") else 224 + _transforms = Compose([RandomResizedCrop(_size), ToTensor(), _normalize]) + + def preprocess_fn(examples): + examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]] + return examples + + ds["train"].set_transform(preprocess_fn) + if eval_split in ds: + ds[eval_split].set_transform(preprocess_fn) + + # --- LoRA --- + if lora: + from peft import LoraConfig, get_peft_model + + peft_config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + task_type="CAUSAL_LM" if "CausalLM" in task_config["auto_class"] else "SEQ_CLS", + ) + loaded_model = get_peft_model(loaded_model, peft_config) + loaded_model.print_trainable_parameters() + + # --- Build TrainingArguments --- + training_args_kwargs = { + "output_dir": output, + "num_train_epochs": epochs, + "learning_rate": lr, + "per_device_train_batch_size": batch_size, + "per_device_eval_batch_size": eval_batch_size or batch_size, + "gradient_accumulation_steps": gradient_accumulation_steps, + "warmup_ratio": warmup_ratio, + "weight_decay": weight_decay, + "eval_strategy": eval_strategy, + "save_strategy": save_strategy, + "load_best_model_at_end": load_best_model_at_end and eval_strategy != "no", + "gradient_checkpointing": gradient_checkpointing, + "push_to_hub": push_to_hub, + } + + if eval_steps is not None: + training_args_kwargs["eval_steps"] = eval_steps + if save_total_limit is not None: + training_args_kwargs["save_total_limit"] = save_total_limit + if hub_model_id is not None: + training_args_kwargs["hub_model_id"] = hub_model_id + if token is not None: + training_args_kwargs["hub_token"] = token + + # Precision + if dtype == "float16": + training_args_kwargs["fp16"] = True + elif dtype == "bfloat16": + training_args_kwargs["bf16"] = True + + # Device targeting + if device == "cpu": + training_args_kwargs["no_cuda"] = True + training_args_kwargs["use_mps_device"] = False + elif device == "mps": + training_args_kwargs["use_mps_device"] = True + elif device == "tpu": + # TPU is handled automatically when running on a TPU instance with XLA + pass + + # Multi-GPU / multi-node: delegate to accelerate or torchrun + if multi_gpu or nnodes is not None: + # Build the command to re-launch via accelerate + cmd = ["accelerate", "launch"] + if nnodes is not None: + cmd.extend(["--num_machines", str(nnodes)]) + if multi_gpu: + cmd.append("--multi_gpu") + cmd.extend(["--module", "transformers.cli.agentic.train", "_train_inner"]) + # Pass all original args through environment + print(f"Launching distributed training: {' '.join(cmd)}") + print("Note: for full control, use `accelerate launch` directly.") + # Fall through to normal training — Trainer handles multi-GPU automatically + # when CUDA_VISIBLE_DEVICES or the accelerate launcher sets up the environment. + + # Distributed + if deepspeed is not None: + if deepspeed in ("zero2", "zero3"): + # Use built-in DeepSpeed configs + training_args_kwargs["deepspeed"] = deepspeed + else: + training_args_kwargs["deepspeed"] = deepspeed + if fsdp is not None: + training_args_kwargs["fsdp"] = fsdp + + # Logging + if logging is not None: + training_args_kwargs["report_to"] = logging + + training_args = TrainingArguments(**training_args_kwargs) + + # --- Data collator --- + data_collator = None + if preprocess_type == "tokenize": + from transformers import DataCollatorWithPadding + + data_collator = DataCollatorWithPadding(tokenizer=processing_class) + elif preprocess_type in ("tokenize_causal", "tokenize_mlm"): + from transformers import DataCollatorForLanguageModeling + + data_collator = DataCollatorForLanguageModeling( + tokenizer=processing_class, + mlm=(preprocess_type == "tokenize_mlm" or mlm), + ) + elif preprocess_type == "tokenize_seq2seq": + from transformers import DataCollatorForSeq2Seq + + data_collator = DataCollatorForSeq2Seq(tokenizer=processing_class, model=loaded_model) + + # --- Callbacks --- + callbacks = [] + if early_stopping: + from transformers import EarlyStoppingCallback + + callbacks.append(EarlyStoppingCallback(early_stopping_patience=early_stopping_patience)) + + # --- Build Trainer --- + trainer_cls = Trainer + if "Seq2Seq" in task_config["auto_class"]: + from transformers import Seq2SeqTrainer + + trainer_cls = Seq2SeqTrainer + + trainer = trainer_cls( + model=loaded_model, + args=training_args, + train_dataset=ds["train"], + eval_dataset=ds.get(eval_split), + processing_class=processing_class, + data_collator=data_collator, + callbacks=callbacks if callbacks else None, + ) + + # --- Train --- + if hpo is not None: + best_trial = trainer.hyperparameter_search( + direction="minimize", + backend=hpo, + n_trials=hpo_trials, + ) + print(f"Best trial: {best_trial}") + else: + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + + # --- Save --- + trainer.save_model(output) + if processing_class is not None: + processing_class.save_pretrained(output) + + print(f"\nModel saved to {output}") + if push_to_hub: + trainer.push_to_hub() + print(f"Pushed to Hub: {hub_model_id or output}") diff --git a/src/transformers/cli/agentic/utilities.py b/src/transformers/cli/agentic/utilities.py new file mode 100644 index 000000000000..3f6a7cc78e81 --- /dev/null +++ b/src/transformers/cli/agentic/utilities.py @@ -0,0 +1,421 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utility CLI commands for model exploration and analysis. + +Commands in this module don't run inference or training — they inspect +models, tokenizers, embeddings, and activations. Useful for debugging, +prototyping, and understanding model behavior. +""" + +import json +from typing import Annotated + +import typer + +from ._common import _load_pretrained, load_image, resolve_input + + +def embed( + # Text input + text: Annotated[str | None, typer.Option(help="Text to embed.")] = None, + file: Annotated[str | None, typer.Option(help="Read text from this file.")] = None, + # Image input + image: Annotated[str | None, typer.Option(help="Path or URL to an image to embed.")] = None, + # Model & output + model: Annotated[str | None, typer.Option("--model", "-m", help="Model ID or local path.")] = None, + output: Annotated[str | None, typer.Option(help="Save embeddings to this file (.npy or .json).")] = None, + device: Annotated[str | None, typer.Option(help="Device.")] = None, + dtype: Annotated[str, typer.Option(help="Dtype.")] = "auto", + trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code.")] = False, + token: Annotated[str | None, typer.Option(help="HF Hub token.")] = None, + revision: Annotated[str | None, typer.Option(help="Model revision.")] = None, +): + """ + Compute embeddings for text or images. + + Uses ``AutoModel`` with ``AutoTokenizer`` (text) or + ``AutoImageProcessor`` (images). Outputs shape and a preview by + default. Pass ``--output`` to save as ``.npy`` (NumPy) or ``.json``. + + Examples:: + + # Text embeddings + transformers embed --model BAAI/bge-small-en-v1.5 --text "The quick brown fox." --output embeddings.npy + + # Image embeddings + transformers embed --model facebook/dinov2-small --image photo.jpg --output features.npy + + # Quick preview (no file saved) + transformers embed --text "Hello world" + """ + import numpy as np + import torch + + from transformers import AutoModel + + if image is not None: + from transformers import AutoImageProcessor + + model_id = model or "facebook/dinov2-small" + loaded_model, processor = _load_pretrained( + AutoModel, AutoImageProcessor, model_id, device, dtype, trust_remote_code, token, revision + ) + img = load_image(image) + inputs = processor(images=img, return_tensors="pt") + elif text is not None or file is not None: + from transformers import AutoTokenizer + + model_id = model or "BAAI/bge-small-en-v1.5" + loaded_model, tokenizer = _load_pretrained( + AutoModel, AutoTokenizer, model_id, device, dtype, trust_remote_code, token, revision + ) + input_text = resolve_input(text, file) + inputs = tokenizer(input_text, return_tensors="pt", truncation=True) + else: + raise SystemExit("Error: provide --text, --file, or --image.") + + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + + with torch.no_grad(): + outputs = loaded_model(**inputs) + + embedding = outputs.last_hidden_state.mean(dim=1)[0].cpu().numpy() + + if output is not None: + if output.endswith(".npy"): + np.save(output, embedding) + elif output.endswith(".json"): + with open(output, "w") as f: + json.dump(embedding.tolist(), f) + else: + np.save(output, embedding) + print(f"Embedding shape {embedding.shape} saved to {output}") + else: + print(f"Embedding shape: {embedding.shape}") + flat = embedding.flatten() + preview = ", ".join(f"{v:.6f}" for v in flat[:8]) + if len(flat) > 8: + preview += ", ..." + print(f"Values: [{preview}]") + + +def tokenize( + text: Annotated[str | None, typer.Option(help="Text to tokenize.")] = None, + file: Annotated[str | None, typer.Option(help="Read text from this file.")] = None, + model: Annotated[str | None, typer.Option("--model", "-m", help="Model ID or local path.")] = None, + token: Annotated[str | None, typer.Option(help="HF Hub token.")] = None, + trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code.")] = False, + show_ids: Annotated[bool, typer.Option("--ids", help="Show token IDs.")] = False, + output_json: Annotated[bool, typer.Option("--json", help="Output as JSON.")] = False, +): + """ + Tokenize text and display the resulting tokens. + + Shows how the model's tokenizer breaks text into subword tokens. + Useful for debugging prompt formatting, checking token counts, and + understanding tokenizer behavior. + + Examples:: + + transformers tokenize --model meta-llama/Llama-3.2-1B-Instruct --text "Hello, world!" + transformers tokenize --model meta-llama/Llama-3.2-1B-Instruct --text "Hello, world!" --ids + transformers tokenize --model bert-base-uncased --text "Tokenization is fun." --json + """ + from transformers import AutoTokenizer + + input_text = resolve_input(text, file) + model_id = model or "HuggingFaceTB/SmolLM2-360M-Instruct" + + tok_kwargs = {} + if token is not None: + tok_kwargs["token"] = token + if trust_remote_code: + tok_kwargs["trust_remote_code"] = True + + tokenizer = AutoTokenizer.from_pretrained(model_id, **tok_kwargs) + encoding = tokenizer(input_text) + + token_ids = encoding["input_ids"] + tokens = tokenizer.convert_ids_to_tokens(token_ids) + + if output_json: + data = {"tokens": tokens, "token_ids": token_ids, "num_tokens": len(tokens)} + print(json.dumps(data, indent=2)) + else: + print(f"Tokens ({len(tokens)}):") + for i, (tok, tid) in enumerate(zip(tokens, token_ids)): + if show_ids: + print(f" {i:4d} {tid:8d} {tok!r}") + else: + print(f" {i:4d} {tok!r}") + + +def inspect( + model: Annotated[str, typer.Argument(help="Model ID or local path to inspect.")], + token: Annotated[str | None, typer.Option(help="HF Hub token.")] = None, + trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code.")] = False, + output_json: Annotated[bool, typer.Option("--json", help="Output as JSON.")] = False, +): + """ + Inspect a model's configuration without downloading weights. + + Shows architecture, hidden size, number of layers, vocabulary size, + and other key config values. Use ``--json`` for the full config dict. + + Examples:: + + transformers inspect meta-llama/Llama-3.2-1B-Instruct + transformers inspect meta-llama/Llama-3.2-1B-Instruct --json + """ + from transformers import AutoConfig + + kwargs = {} + if token is not None: + kwargs["token"] = token + if trust_remote_code: + kwargs["trust_remote_code"] = True + + config = AutoConfig.from_pretrained(model, **kwargs) + + if output_json: + print(json.dumps(config.to_dict(), indent=2, default=str)) + else: + config_dict = config.to_dict() + print(f"Model: {model}") + print(f"Architecture: {config_dict.get('architectures', ['unknown'])}") + print(f"Model type: {config_dict.get('model_type', 'unknown')}") + print() + + important_keys = [ + "hidden_size", + "num_hidden_layers", + "num_attention_heads", + "num_key_value_heads", + "intermediate_size", + "vocab_size", + "max_position_embeddings", + "hidden_act", + "torch_dtype", + ] + for key in important_keys: + if key in config_dict: + print(f" {key}: {config_dict[key]}") + + remaining = { + k: v + for k, v in config_dict.items() + if k not in important_keys and k not in ("architectures", "model_type", "transformers_version") + } + if remaining: + print(f"\n ({len(remaining)} additional config keys — use --json for full output)") + + +def inspect_forward( + text: Annotated[str, typer.Option(help="Text to run through the model.")], + model: Annotated[str | None, typer.Option("--model", "-m", help="Model ID or local path.")] = None, + output: Annotated[str | None, typer.Option(help="Directory to save activations as .npy files.")] = None, + layers: Annotated[ + str | None, typer.Option(help="Comma-separated layer indices to inspect (default: all).") + ] = None, + token: Annotated[str | None, typer.Option(help="HF Hub token.")] = None, + trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code.")] = False, + output_json: Annotated[bool, typer.Option("--json", help="Output as JSON.")] = False, +): + """ + Examine attention weights and hidden states from a forward pass. + + Runs the input through the model with ``output_attentions=True`` and + ``output_hidden_states=True``, then prints shape and statistics for + each layer. Pass ``--output ./activations/`` to save attention and + hidden state tensors as NumPy ``.npy`` files for further analysis. + + Examples:: + + # Print summary for all layers + transformers inspect-forward --model bert-base-uncased --text "The cat sat on the mat." + + # Inspect only layers 0 and 11, save to disk + transformers inspect-forward --model bert-base-uncased --text "Hello world" --layers 0,11 --output ./activations/ + """ + import numpy as np + + from transformers import AutoModel, AutoTokenizer + + model_id = model or "answerdotai/ModernBERT-base" + + common_kwargs = {} + if token is not None: + common_kwargs["token"] = token + if trust_remote_code: + common_kwargs["trust_remote_code"] = True + + tokenizer = AutoTokenizer.from_pretrained(model_id, **common_kwargs) + loaded_model = AutoModel.from_pretrained(model_id, **common_kwargs) + loaded_model.eval() + + inputs = tokenizer(text, return_tensors="pt") + import torch + + with torch.no_grad(): + outputs = loaded_model(**inputs, output_attentions=True, output_hidden_states=True) + + attentions = outputs.attentions + hidden_states = outputs.hidden_states + + layer_indices = None + if layers is not None: + layer_indices = [int(i) for i in layers.split(",")] + + print(f"Model: {model_id}") + print(f"Input tokens: {tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])}") + print(f"Hidden state layers: {len(hidden_states)} (including embedding layer)") + print(f"Attention layers: {len(attentions)}") + + for i, (attn, hs) in enumerate(zip(attentions, hidden_states[1:])): + if layer_indices is not None and i not in layer_indices: + continue + print(f"\n Layer {i}:") + print(f" Attention shape: {list(attn.shape)} (batch, heads, seq, seq)") + print(f" Hidden state shape: {list(hs.shape)} (batch, seq, hidden)") + attn_np = attn[0].cpu().numpy() + print(f" Attention mean: {attn_np.mean():.6f}, max: {attn_np.max():.6f}") + hs_np = hs[0].cpu().numpy() + print(f" Hidden state norm (mean): {np.linalg.norm(hs_np, axis=-1).mean():.4f}") + + if output is not None: + from pathlib import Path + + out_dir = Path(output) + out_dir.mkdir(parents=True, exist_ok=True) + for i, attn in enumerate(attentions): + if layer_indices is not None and i not in layer_indices: + continue + np.save(out_dir / f"attention_layer_{i}.npy", attn[0].cpu().numpy()) + for i, hs in enumerate(hidden_states): + if layer_indices is not None and i not in layer_indices and i > 0: + continue + np.save(out_dir / f"hidden_state_layer_{i}.npy", hs[0].cpu().numpy()) + print(f"\nActivations saved to {output}") + + +def benchmark_quantization( + model: Annotated[str, typer.Option("--model", "-m", help="Model ID or local path.")], + methods: Annotated[ + str, typer.Option(help="Comma-separated quantization methods to compare: none, bnb-4bit, bnb-8bit.") + ] = "bnb-4bit,bnb-8bit", + prompt: Annotated[ + str, typer.Option(help="Prompt to use for benchmarking.") + ] = "The quick brown fox jumps over the lazy dog.", + max_new_tokens: Annotated[int, typer.Option(help="Tokens to generate per run.")] = 50, + trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code.")] = False, + token: Annotated[str | None, typer.Option(help="HF Hub token.")] = None, + output_json: Annotated[bool, typer.Option("--json", help="Output as JSON.")] = False, +): + """ + Compare quality and performance across quantization methods. + + Loads the same model under each quantization method, generates text, + and reports tokens/sec, latency, peak GPU memory, and a preview of + the output. Use ``none`` as a method to include the unquantized + baseline. + + Examples:: + + # Compare 4-bit vs 8-bit + transformers benchmark-quantization --model meta-llama/Llama-3.1-8B --methods bnb-4bit,bnb-8bit + + # Include unquantized baseline, output as JSON + transformers benchmark-quantization --model meta-llama/Llama-3.1-8B --methods none,bnb-4bit,bnb-8bit --json + """ + import time + + from transformers import AutoModelForCausalLM, AutoTokenizer + + common_kwargs = {} + if trust_remote_code: + common_kwargs["trust_remote_code"] = True + if token: + common_kwargs["token"] = token + + tokenizer = AutoTokenizer.from_pretrained(model, **common_kwargs) + method_list = [m.strip() for m in methods.split(",")] + + results = [] + for method in method_list: + print(f"\n--- {method} ---") + model_kwargs = {**common_kwargs, "device_map": "auto"} + + if method == "bnb-4bit": + from transformers import BitsAndBytesConfig + + model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True) + elif method == "bnb-8bit": + from transformers import BitsAndBytesConfig + + model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + elif method == "none": + pass + else: + print(f" Skipping {method} — only none, bnb-4bit, bnb-8bit are supported for benchmarking.") + continue + + try: + loaded_model = AutoModelForCausalLM.from_pretrained(model, **model_kwargs) + loaded_model.eval() + inputs = tokenizer(prompt, return_tensors="pt").to(loaded_model.device) + + # Warmup + loaded_model.generate(**inputs, max_new_tokens=5) + + # Timed run + start = time.time() + output_ids = loaded_model.generate(**inputs, max_new_tokens=max_new_tokens) + elapsed = time.time() - start + + new_tokens = output_ids[0, inputs["input_ids"].shape[1] :] + generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True) + tokens_per_sec = len(new_tokens) / elapsed + + import torch + + mem_mb = torch.cuda.max_memory_allocated() / 1024 / 1024 if torch.cuda.is_available() else 0 + + result = { + "method": method, + "tokens_per_sec": round(tokens_per_sec, 2), + "time_sec": round(elapsed, 3), + "peak_memory_mb": round(mem_mb, 1), + "output_preview": generated_text[:100], + } + results.append(result) + + print(f" Tokens/sec: {tokens_per_sec:.2f}") + print(f" Time: {elapsed:.3f}s") + if mem_mb > 0: + print(f" Peak memory: {mem_mb:.1f} MB") + print(f" Output: {generated_text[:100]}...") + + del loaded_model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + except Exception as e: + print(f" Error: {e}") + results.append({"method": method, "error": str(e)}) + + if output_json: + print(json.dumps(results, indent=2)) diff --git a/src/transformers/cli/agentic/vision.py b/src/transformers/cli/agentic/vision.py new file mode 100644 index 000000000000..09a6bce86bd6 --- /dev/null +++ b/src/transformers/cli/agentic/vision.py @@ -0,0 +1,479 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Vision and video CLI commands. + +Each function uses Auto* model classes directly (no pipeline, except +``keypoints``) and is registered as a top-level ``transformers`` CLI command +via ``app.py``. +""" + +import json +from typing import Annotated + +import typer + +from ._common import ( + DeviceOpt, + DtypeOpt, + JsonOpt, + ModelOpt, + RevisionOpt, + TokenOpt, + TrustOpt, + _load_pretrained, + format_output, + load_image, + load_video, +) + + +def image_classify( + image: Annotated[str, typer.Option(help="Path or URL to the image.")], + model: ModelOpt = None, + labels: Annotated[ + str | None, typer.Option(help="Comma-separated candidate labels for zero-shot classification.") + ] = None, + top_k: Annotated[int, typer.Option(help="Number of top predictions to return.")] = 5, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """Classify an image. + + Without ``--labels``, uses ``AutoModelForImageClassification`` with a + pre-trained head (default: ``google/vit-base-patch16-224``). + + With ``--labels``, uses ``AutoModelForZeroShotImageClassification`` and + ``AutoProcessor`` (default: ``google/siglip-base-patch16-224``). + + Example:: + + transformers image-classify photo.jpg + transformers image-classify photo.jpg --labels "cat,dog,bird" + """ + import torch + + img = load_image(image) + + if labels is None: + from transformers import AutoImageProcessor, AutoModelForImageClassification + + model_id = model or "google/vit-base-patch16-224" + loaded_model, processor = _load_pretrained( + AutoModelForImageClassification, + AutoImageProcessor, + model_id, + device, + dtype, + trust_remote_code, + token, + revision, + ) + inputs = processor(images=img, return_tensors="pt") + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + with torch.no_grad(): + outputs = loaded_model(**inputs) + probs = outputs.logits.softmax(dim=-1)[0] + top_values, top_indices = probs.topk(min(top_k, len(probs))) + result = [ + {"label": loaded_model.config.id2label[idx.item()], "score": round(val.item(), 4)} + for val, idx in zip(top_values, top_indices) + ] + else: + from transformers import AutoModelForZeroShotImageClassification, AutoProcessor + + candidate_labels = [l.strip() for l in labels.split(",")] + model_id = model or "google/siglip-base-patch16-224" + loaded_model, processor = _load_pretrained( + AutoModelForZeroShotImageClassification, + AutoProcessor, + model_id, + device, + dtype, + trust_remote_code, + token, + revision, + ) + inputs = processor(images=img, text=candidate_labels, padding=True, return_tensors="pt") + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + with torch.no_grad(): + outputs = loaded_model(**inputs) + probs = outputs.logits_per_image[0].softmax(dim=-1) + scored = [ + {"label": candidate_labels[i], "score": round(probs[i].item(), 4)} for i in range(len(candidate_labels)) + ] + result = sorted(scored, key=lambda x: x["score"], reverse=True) + + print(format_output(result, output_json)) + + +def detect( + image: Annotated[str, typer.Option(help="Path or URL to the image.")], + model: ModelOpt = None, + text: Annotated[str | None, typer.Option(help="Text query for open-vocabulary (grounded) detection.")] = None, + threshold: Annotated[float, typer.Option(help="Detection confidence threshold.")] = 0.5, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """Detect objects in an image. + + Without ``--text``, uses ``AutoModelForObjectDetection`` with a closed-set + detector (default: ``PekingU/rtdetr_r18vd_coco_o365``). + + With ``--text``, uses ``AutoModelForZeroShotObjectDetection`` for + open-vocabulary detection (default: ``IDEA-Research/grounding-dino-base``). + + Example:: + + transformers detect photo.jpg + transformers detect photo.jpg --text "cat . dog ." + """ + import torch + + img = load_image(image) + + if text is None: + from transformers import AutoImageProcessor, AutoModelForObjectDetection + + model_id = model or "PekingU/rtdetr_r18vd_coco_o365" + loaded_model, processor = _load_pretrained( + AutoModelForObjectDetection, + AutoImageProcessor, + model_id, + device, + dtype, + trust_remote_code, + token, + revision, + ) + inputs = processor(images=img, return_tensors="pt") + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + with torch.no_grad(): + outputs = loaded_model(**inputs) + target_sizes = torch.tensor([img.size[::-1]]) + results = processor.post_process_object_detection(outputs, threshold=threshold, target_sizes=target_sizes)[0] + else: + from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor + + model_id = model or "IDEA-Research/grounding-dino-base" + loaded_model, processor = _load_pretrained( + AutoModelForZeroShotObjectDetection, + AutoProcessor, + model_id, + device, + dtype, + trust_remote_code, + token, + revision, + ) + inputs = processor(images=img, text=text, return_tensors="pt") + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + with torch.no_grad(): + outputs = loaded_model(**inputs) + target_sizes = torch.tensor([img.size[::-1]]) + if hasattr(processor, "post_process_grounded_object_detection"): + results = processor.post_process_grounded_object_detection( + outputs, + input_ids=inputs["input_ids"], + box_threshold=threshold, + text_threshold=threshold, + target_sizes=target_sizes, + )[0] + else: + results = processor.post_process_object_detection(outputs, threshold=threshold, target_sizes=target_sizes)[ + 0 + ] + + result = [] + for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + box_coords = box.tolist() + label_str = ( + label if isinstance(label, str) else loaded_model.config.id2label.get(label.item(), str(label.item())) + ) + result.append( + { + "label": label_str, + "score": round(score.item(), 4), + "box": { + "xmin": round(box_coords[0], 1), + "ymin": round(box_coords[1], 1), + "xmax": round(box_coords[2], 1), + "ymax": round(box_coords[3], 1), + }, + } + ) + + print(format_output(result, output_json)) + + +def segment( + image: Annotated[str, typer.Option(help="Path or URL to the image.")], + model: ModelOpt = None, + points: Annotated[str | None, typer.Option(help="JSON list of [x, y] points for SAM-style segmentation.")] = None, + point_labels: Annotated[ + str | None, typer.Option(help="JSON list of point labels (1=foreground, 0=background) for SAM.") + ] = None, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """Segment an image. + + Without ``--points``, uses ``AutoModelForSemanticSegmentation`` for + per-pixel class labelling (default: ``nvidia/segformer-b0-finetuned-ade-512-512``). + + With ``--points``, uses ``AutoModel`` + ``AutoProcessor`` for SAM-style + prompted segmentation (default: ``facebook/sam-vit-base``). + + Example:: + + transformers segment photo.jpg + transformers segment photo.jpg --points '[[100, 200]]' --point-labels '[1]' + """ + import torch + + img = load_image(image) + + if points is None: + from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation + + model_id = model or "nvidia/segformer-b0-finetuned-ade-512-512" + loaded_model, processor = _load_pretrained( + AutoModelForSemanticSegmentation, + AutoImageProcessor, + model_id, + device, + dtype, + trust_remote_code, + token, + revision, + ) + inputs = processor(images=img, return_tensors="pt") + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + with torch.no_grad(): + outputs = loaded_model(**inputs) + seg_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[img.size[::-1]])[0] + total_pixels = seg_map.numel() + unique_classes = seg_map.unique() + result = [] + for cls_id in unique_classes: + ratio = round((seg_map == cls_id).sum().item() / total_pixels, 4) + label = loaded_model.config.id2label.get(cls_id.item(), str(cls_id.item())) + result.append({"label": label, "score": ratio}) + result = sorted(result, key=lambda x: x["score"], reverse=True) + else: + from transformers import AutoModel, AutoProcessor + + model_id = model or "facebook/sam-vit-base" + loaded_model, processor = _load_pretrained( + AutoModel, AutoProcessor, model_id, device, dtype, trust_remote_code, token, revision + ) + parsed_points = json.loads(points) + parsed_labels = json.loads(point_labels) if point_labels else [1] * len(parsed_points) + inputs = processor(img, input_points=[parsed_points], input_labels=[parsed_labels], return_tensors="pt") + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + with torch.no_grad(): + outputs = loaded_model(**inputs) + masks = processor.image_processor.post_process_masks( + outputs.pred_masks.cpu(), + inputs["original_sizes"].cpu(), + inputs["reshaped_input_sizes"].cpu(), + ) + result = { + "num_masks": masks[0].shape[1] if len(masks) > 0 else 0, + "iou_scores": outputs.iou_scores[0, 0].tolist(), + } + + print(format_output(result, output_json)) + + +def depth( + image: Annotated[str, typer.Option(help="Path or URL to the image.")], + model: ModelOpt = None, + output: Annotated[str | None, typer.Option(help="Path to save the depth map as a PNG image.")] = None, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, +): + """Estimate a depth map from an image. + + Uses ``AutoModelForDepthEstimation`` (default: + ``depth-anything/Depth-Anything-V2-Small-hf``). + + If ``--output`` is provided the depth map is saved as a greyscale PNG. + Otherwise, prints the depth map dimensions. + + Example:: + + transformers depth photo.jpg --output depth.png + """ + import torch + + from transformers import AutoImageProcessor, AutoModelForDepthEstimation + + img = load_image(image) + model_id = model or "depth-anything/Depth-Anything-V2-Small-hf" + loaded_model, processor = _load_pretrained( + AutoModelForDepthEstimation, AutoImageProcessor, model_id, device, dtype, trust_remote_code, token, revision + ) + inputs = processor(images=img, return_tensors="pt") + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + with torch.no_grad(): + outputs = loaded_model(**inputs) + + predicted_depth = outputs.predicted_depth + depth_map = torch.nn.functional.interpolate( + predicted_depth.unsqueeze(1) + if predicted_depth.dim() == 2 + else predicted_depth.unsqueeze(0) + if predicted_depth.dim() == 3 + else predicted_depth, + size=img.size[::-1], + mode="bicubic", + align_corners=False, + ).squeeze() + + if output is not None: + from PIL import Image + + depth_np = depth_map.cpu().float().numpy() + depth_min, depth_max = depth_np.min(), depth_np.max() + if depth_max - depth_min > 0: + depth_norm = (depth_np - depth_min) / (depth_max - depth_min) * 255.0 + else: + depth_norm = depth_np * 0.0 + depth_img = Image.fromarray(depth_norm.astype("uint8")) + depth_img.save(output) + print(f"Depth map saved to {output} (size: {depth_map.shape[0]}x{depth_map.shape[1]})") + else: + print(f"Depth map size: {depth_map.shape[0]}x{depth_map.shape[1]}") + + +def keypoints( + images: Annotated[list[str], typer.Option(help="Paths to two images to match.")], + model: ModelOpt = None, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """Match keypoints between two images. + + Uses the ``keypoint-matching`` pipeline. Requires exactly two images. + + Example:: + + transformers keypoints image1.jpg image2.jpg + """ + if len(images) != 2: + raise SystemExit("Error: keypoints requires exactly 2 image paths.") + + from transformers import pipeline + + img1 = load_image(images[0]) + img2 = load_image(images[1]) + + pipe_kwargs = {} + if model is not None: + pipe_kwargs["model"] = model + if device is not None: + pipe_kwargs["device"] = device + if dtype != "auto": + import torch + + pipe_kwargs["dtype"] = getattr(torch, dtype) + if trust_remote_code: + pipe_kwargs["trust_remote_code"] = True + if token is not None: + pipe_kwargs["token"] = token + if revision is not None: + pipe_kwargs["revision"] = revision + + pipe = pipeline("keypoint-matching", **pipe_kwargs) + result = pipe(img1, img2) + + print(format_output(result, output_json)) + + +def video_classify( + video: Annotated[str, typer.Option(help="Path to the video file.")], + model: ModelOpt = None, + top_k: Annotated[int, typer.Option(help="Number of top predictions to return.")] = 5, + device: DeviceOpt = None, + dtype: DtypeOpt = "auto", + trust_remote_code: TrustOpt = False, + token: TokenOpt = None, + revision: RevisionOpt = None, + output_json: JsonOpt = False, +): + """Classify a video. + + Uses ``AutoModelForVideoClassification`` + ``AutoImageProcessor`` + (default: ``MCG-NJU/videomae-base-finetuned-kinetics``). + + Example:: + + transformers video-classify clip.mp4 + """ + import torch + + from transformers import AutoImageProcessor, AutoModelForVideoClassification + + model_id = model or "MCG-NJU/videomae-base-finetuned-kinetics" + loaded_model, processor = _load_pretrained( + AutoModelForVideoClassification, + AutoImageProcessor, + model_id, + device, + dtype, + trust_remote_code, + token, + revision, + ) + frames = load_video(video) + inputs = processor(images=frames, return_tensors="pt") + if hasattr(loaded_model, "device"): + inputs = inputs.to(loaded_model.device) + with torch.no_grad(): + outputs = loaded_model(**inputs) + probs = outputs.logits.softmax(dim=-1)[0] + top_values, top_indices = probs.topk(min(top_k, len(probs))) + result = [ + {"label": loaded_model.config.id2label[idx.item()], "score": round(val.item(), 4)} + for val, idx in zip(top_values, top_indices) + ] + + print(format_output(result, output_json)) diff --git a/src/transformers/cli/transformers.py b/src/transformers/cli/transformers.py index cefee1ca97c8..ba2f86ebcf78 100644 --- a/src/transformers/cli/transformers.py +++ b/src/transformers/cli/transformers.py @@ -16,6 +16,7 @@ from huggingface_hub import check_cli_update, typer_factory from transformers.cli.add_new_model_like import add_new_model_like +from transformers.cli.agentic.app import register_agentic_commands from transformers.cli.chat import Chat from transformers.cli.download import download from transformers.cli.serve import Serve @@ -31,6 +32,8 @@ app.command(name="serve")(Serve) app.command()(version) +register_agentic_commands(app) + def main(): check_cli_update("transformers")