From 30e6686d8b516731593c988dba5a80850df0a091 Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 13 May 2026 11:51:08 +0200 Subject: [PATCH 1/2] [Feat] Add layout scripts and minor fixes --- .github/workflows/references.yml | 88 +++ .../using_doctr/custom_models_training.rst | 15 + docs/source/using_doctr/using_models.rst | 60 ++ doctr/datasets/datasets/pytorch.py | 12 +- doctr/datasets/layout.py | 6 +- doctr/models/layout/lw_detr/base.py | 5 +- doctr/models/layout/lw_detr/pytorch.py | 12 +- references/layout/README.md | 104 +++ references/layout/evaluate.py | 200 +++++ references/layout/latency.py | 59 ++ references/layout/train.py | 726 ++++++++++++++++++ references/layout/utils.py | 101 +++ tests/conftest.py | 2 +- tests/pytorch/test_datasets_pt.py | 25 +- 14 files changed, 1389 insertions(+), 26 deletions(-) create mode 100644 references/layout/README.md create mode 100644 references/layout/evaluate.py create mode 100644 references/layout/latency.py create mode 100644 references/layout/train.py create mode 100644 references/layout/utils.py diff --git a/.github/workflows/references.yml b/.github/workflows/references.yml index e4c9503136..9585eaac5a 100644 --- a/.github/workflows/references.yml +++ b/.github/workflows/references.yml @@ -251,3 +251,91 @@ jobs: pip install -e .[viz,html] --upgrade - name: Benchmark latency run: python references/detection/latency.py db_mobilenet_v3_large --it 5 --size 512 + + + train-layout-analysis: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: ["3.10"] + steps: + - uses: actions/checkout@v6 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v5 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('references/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}- + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[viz,html] --upgrade + pip install -r references/requirements.txt + - name: Download and extract toy set + run: | + wget https://github.com/mindee/doctr/releases/download/v0.3.1/toy_detection_set-bbbb4243.zip + sudo apt-get update && sudo apt-get install unzip -y + unzip toy_detection_set-bbbb4243.zip -d det_set + - name: Train for a short epoch + run: python references/layout/train.py lw_detr_s --train_path ./det_set --val_path ./det_set -b 2 --epochs 1 + + evaluate-layout-analysis: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: ["3.10"] + steps: + - uses: actions/checkout@v6 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v5 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[viz,html] --upgrade + pip install -r references/requirements.txt + - name: Evaluate layout analysis + run: python references/layout/evaluate.py lw_detr_s + + latency-layout-analysis: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: ["3.10"] + steps: + - uses: actions/checkout@v6 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v5 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[viz,html] --upgrade + - name: Benchmark latency + run: python references/layout/latency.py lw_detr_s --it 5 --size 512 diff --git a/docs/source/using_doctr/custom_models_training.rst b/docs/source/using_doctr/custom_models_training.rst index c67f6c2d70..9b28df0fbb 100644 --- a/docs/source/using_doctr/custom_models_training.rst +++ b/docs/source/using_doctr/custom_models_training.rst @@ -6,6 +6,7 @@ For details on the training process and the necessary data and data format, refe - `detection `_ - `recognition `_ +- `layout `_ If you’re looking for a lightweight yet efficient tool to annotate small amounts of data, especially tailored for docTR, check out the `docTR Labeling Tool `_. @@ -52,6 +53,20 @@ Load a custom recognition model trained on another vocabulary as the default one predictor = ocr_predictor(det_arch='linknet_resnet18', reco_arch=reco_model, pretrained=True) + +Load a custom layout analysis model trained on another set of classes as the default one: + +.. code:: python3 + + import torch + from doctr.models import layout_predictor, lw_detr_s + from doctr.datasets import VOCABS + + layout_model = lw_detr_s(pretrained=False, class_names=["class_name_1", "class_name_2", ...]) + layout_model.from_pretrained('') + + predictor = layout_predictor(layout_arch=layout_model, pretrained=True) + Load a custom trained KIE detection model: .. code:: python3 diff --git a/docs/source/using_doctr/using_models.rst b/docs/source/using_doctr/using_models.rst index b37434092e..18b3b5dab0 100644 --- a/docs/source/using_doctr/using_models.rst +++ b/docs/source/using_doctr/using_models.rst @@ -174,6 +174,66 @@ Recognition predictors out = model([dummy_img]) +Layout Analysis +--------------- + +The task consists of localizing and classifying visual elements in a given image. +This is a more general task than text detection, as it can be used to detect and classify any type of visual element in a document, such as tables, figures, headers, footers, etc. +Our latest layout models works with rotated and skewed documents! + +Available architectures +^^^^^^^^^^^^^^^^^^^^^^^ + +The following architectures are currently supported: + +* :py:meth:`lw_detr_s ` +* :py:meth:`lw_detr_m ` + +For a comprehensive comparison, we have compiled a detailed benchmark: + ++--------------------------------------------------+-----------------+---------------+------------------+-------------+--------------+--------------------+ +| | | | | | | | ++==================================================+=================+===============+==================+=============+==============+====================+ +| **Architecture** | **Input shape** | **# params** | **mAP@[.5:.95]** | **AP@[.5]** | **AP@[.75]** | **sec/it (B: 1)** | ++--------------------------------------------------+-----------------+---------------+------------------+-------------+--------------+--------------------+ +| lw_detr_s | (1024, 1024, 3) | 15.1 M | | | | 0.5 | ++--------------------------------------------------+-----------------+---------------+------------------+-------------+--------------+--------------------+ +| lw_detr_m | (1024, 1024, 3) | 29.5 M | | | | 0.7 | ++--------------------------------------------------+-----------------+---------------+------------------+-------------+--------------+--------------------+ + + +Explanations about the metrics being used are available in :ref:`metrics`. + +Seconds per iteration (with a batch size of 1) is computed after a warmup phase of 100 tensors, by measuring the average number of processed tensors per second over 1000 samples. Those results were obtained on a `11th Gen Intel(R) Core(TM) i7-11800H @ 2.30GHz`. + + +Layout predictors +^^^^^^^^^^^^^^^^^ + +:py:meth:`layout_predictor ` wraps your layout model to make it easily useable with your favorite deep learning framework seamlessly. + +.. code:: python3 + + import numpy as np + from doctr.models import layout_predictor + model = layout_predictor('lw_detr_s') + dummy_img = (255 * np.random.rand(800, 600, 3)).astype(np.uint8) + out = model([dummy_img]) + +You can pass specific boolean arguments to the predictor: +* `pretrained`: if you want to use a model that has been pretrained on a specific dataset, setting `pretrained=True` this will load the corresponding weights. If `pretrained=False`, which is the default, would otherwise lead to a random initialization and would lead to no/useless results. +* `assume_straight_pages`: if you work with straight documents only, it will fit straight bounding boxes to the text areas. +* `preserve_aspect_ratio`: if you want to preserve the aspect ratio of your documents while resizing before sending them to the model. +* `symmetric_pad`: if you choose to preserve the aspect ratio, it will pad the image symmetrically and not from the bottom-right. + +For instance, this snippet will instantiates a layout predictor able to detect text on rotated documents while preserving the aspect ratio: + +.. code:: python3 + + from doctr.models import layout_predictor + predictor = layout_predictor('lw_detr_s', pretrained=True, assume_straight_pages=False, preserve_aspect_ratio=True) + + End-to-End OCR -------------- diff --git a/doctr/datasets/datasets/pytorch.py b/doctr/datasets/datasets/pytorch.py index a5df0dd7f5..6418906dd7 100644 --- a/doctr/datasets/datasets/pytorch.py +++ b/doctr/datasets/datasets/pytorch.py @@ -48,9 +48,17 @@ def _read_sample(self, index: int) -> tuple[torch.Tensor, Any]: return img, deepcopy(target) @staticmethod - def collate_fn(samples: list[tuple[torch.Tensor, Any]]) -> tuple[torch.Tensor, list[Any]]: + def collate_fn( + samples: list[tuple[torch.Tensor, Any]] | list[tuple[tuple[torch.Tensor, torch.Tensor], Any]], + ) -> tuple[torch.Tensor, list[Any]]: images, targets = zip(*samples) - images = torch.stack(images, dim=0) # type: ignore[assignment] + if isinstance(images[0], tuple): + images, padding_masks = zip(*images) + images = torch.stack(images, dim=0) # type: ignore[assignment] + padding_masks = torch.stack(padding_masks, dim=0) # type: ignore[assignment] + images = (images, padding_masks) + else: + images = torch.stack(images, dim=0) # type: ignore[assignment] return images, list(targets) # type: ignore[return-value] diff --git a/doctr/datasets/layout.py b/doctr/datasets/layout.py index 0e7f5df9de..2d2b7d18cf 100644 --- a/doctr/datasets/layout.py +++ b/doctr/datasets/layout.py @@ -61,17 +61,17 @@ def __init__( raise FileNotFoundError(f"unable to locate {img_path}") polygons = label.get("polygons") - class_names = label.get("class_names") + class_names = label.get("classes") if polygons is None: raise KeyError(f"missing 'polygons' for image: {img_name}") if class_names is None: - raise KeyError(f"missing 'class_names' for image: {img_name}") + raise KeyError(f"missing 'classes' for image: {img_name}") if len(polygons) != len(class_names): raise ValueError( f"number of polygons ({len(polygons)}) does not match " - f"number of class_names ({len(class_names)}) for image: {img_name}" + f"number of classes ({len(class_names)}) for image: {img_name}" ) geoms, polygon_classes = self.format_polygons( diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index e103021f5f..fe237a05dc 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -9,6 +9,7 @@ import numpy as np from doctr.models.core import BaseModel +from doctr.utils import order_points __all__ = ["_LWDETR", "LWDETRPostProcessor"] @@ -57,7 +58,7 @@ def _decode_boxes(self, boxes: np.ndarray) -> tuple[np.ndarray, np.ndarray]: for i in range(len(boxes)): rect = ((float(cx[i]), float(cy[i])), (float(w[i]), float(h[i])), float(np.degrees(angles[i]))) - poly = cv2.boxPoints(rect) + poly = order_points(cv2.boxPoints(rect)) polys.append(poly) return np.asarray(polys, dtype=np.float32), angles @@ -237,7 +238,7 @@ def _quad_to_obb(poly: np.ndarray): continue for cls_id, box in zip(np.asarray(class_ids), np.asarray(boxes)): - poly = box.reshape(4, 2) + poly = order_points(box.reshape(4, 2)) obb = _quad_to_obb(poly) if obb[2] <= 1e-3 or obb[3] <= 1e-3: diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 8c97fc626b..fcaf52c444 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -39,12 +39,8 @@ "Table", "Text", "Title", - "Document Index", - "Code", "Checkbox-Selected", "Checkbox-Unselected", - "Form", - "Key-Value Region", ], "url": None, }, @@ -64,12 +60,8 @@ "Table", "Text", "Title", - "Document Index", - "Code", "Checkbox-Selected", "Checkbox-Unselected", - "Form", - "Key-Value Region", ], "url": None, }, @@ -735,7 +727,7 @@ def _lw_detr( kwargs["class_names"] = kwargs.get("class_names", default_cfgs[arch].get("class_names", [])) _cfg = deepcopy(default_cfgs[arch]) - _cfg["class_names"] = kwargs["class_names"] + _cfg["class_names"] = sorted(kwargs["class_names"]) kwargs.pop("class_names") # Build the feature extractor @@ -758,7 +750,7 @@ def _lw_detr( if pretrained: # The number of class_names is not the same as the number of classes in the pretrained model => # remove the layer weights - _ignore_keys = ignore_keys if _cfg["class_names"] != default_cfgs[arch].get("class_names") else None + _ignore_keys = ignore_keys if _cfg["class_names"] != sorted(default_cfgs[arch].get("class_names", [])) else None model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys) return model diff --git a/references/layout/README.md b/references/layout/README.md new file mode 100644 index 0000000000..dfd080f0bb --- /dev/null +++ b/references/layout/README.md @@ -0,0 +1,104 @@ +# Layout detection + +The sample training script was made to train layout detection model with docTR. + +## Setup + +First, you need to install `doctr` (with pip, for instance) + +```shell +pip install -e . --upgrade +pip install -r references/requirements.txt +``` + +## Usage + +You can start your training in PyTorch: + +```shell +python references/layout/train.py lw_detr_s --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 +``` + +### Multi-GPU support + +We now use the built-in [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) launcher to spawn your DDP workers. `torchrun` will set all the necessary environment variables (`LOCAL_RANK`, `RANK`, etc.) for you. Arguments are the same than the ones from single GPU, except: + +- `--backend`: you can specify another `backend` for `DistributedDataParallel` if the default one is not available on +your operating system. Fastest one is `nccl` according to [PyTorch Documentation](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). + +#### Key `torchrun` parameters + +- `--nproc_per_node=` + Spawn `` processes on the local machine (typically equal to the number of GPUs you want to use). +- `--nnodes=` + (Optional) Total number of nodes in your job. Default is 1. +- `--rdzv_backend`, `--rdzv_endpoint`, `--rdzv_id` + (Optional) Rendezvous settings for multi-node jobs. See the [torchrun docs](https://pytorch.org/docs/stable/elastic/run.html) for details. + +#### GPU selection + +By default all visible GPUs will be used. To limit which GPUs participate, set the `CUDA_VISIBLE_DEVICES` environment variable **before** running `torchrun`. For example, to use only CUDA devices 0 and 2: + +```shell +CUDA_VISIBLE_DEVICES=0,2 \ +torchrun --nproc_per_node=2 references/layout/train.py \ + lw_detr_s \ + --train_path path/to/train \ + --val_path path/to/val \ + --epochs 5 \ + --backend nccl + ``` + +## Data format + +You need to provide both `train_path` and `val_path` arguments to start training. +Each path must lead to folder with 1 subfolder and 1 file: + +```shell +├── images +│ ├── sample_img_01.png +│ ├── sample_img_02.png +│ ├── sample_img_03.png +│ └── ... +└── labels.json +``` + +Each JSON file must be a dictionary, where the keys are the image file names and the value is a dictionary with 4 entries: `img_dimensions` (spatial shape of the image), `img_hash` (SHA256 of the image file), `polygons` (the set of 2D points forming the localization polygon), `classes` (list of class names for each polygon). +The order of the points does not matter inside a polygon. Points are (x, y) absolutes coordinates. + +labels.json + +```shell +{ + "sample_img_01.png" = { + 'img_dimensions': (900, 600), + 'img_hash': "theimagedumpmyhash", + 'polygons': [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], ...], + 'classes': ["class_name_1", "class_name_2", ...] + }, + "sample_img_02.png" = { + 'img_dimensions': (900, 600), + 'img_hash': "thisisahash", + 'polygons': [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], ...], + 'classes': ["class_name_1", "class_name_2", ...] + } + ... +} +``` + +## Slack Logging with tqdm + +To enable Slack logging using `tqdm`, you need to set the following environment variables: + +- `TQDM_SLACK_TOKEN`: the Slack Bot Token +- `TQDM_SLACK_CHANNEL`: you can retrieve it using `Right Click on Channel > Copy > Copy link`. You should get something like `https://xxxxxx.slack.com/archives/yyyyyyyy`. Keep only the `yyyyyyyy` part. + +You can follow this page on [how to create a Slack App](https://api.slack.com/quickstart). + +## Advanced options + +Feel free to inspect the multiple script option to customize your training to your own needs! + +```python +python references/layout/train.py --help +``` diff --git a/references/layout/evaluate.py b/references/layout/evaluate.py new file mode 100644 index 0000000000..7888b9c816 --- /dev/null +++ b/references/layout/evaluate.py @@ -0,0 +1,200 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import multiprocessing as mp +import os +import time + +import torch +from torch.utils.data import DataLoader, SequentialSampler +from torchvision.transforms import Normalize + +if os.getenv("TQDM_SLACK_TOKEN") and os.getenv("TQDM_SLACK_CHANNEL"): + from tqdm.contrib.slack import tqdm +else: + from tqdm.auto import tqdm + +from doctr import transforms as T +from doctr.datasets import LayoutDataset +from doctr.models import layout +from doctr.utils.metrics import ObjectDetectionMetric + + +@torch.inference_mode() +def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): + # Model in eval mode + model.eval() + # Reset val metric + val_metric.reset() + # Validation loop + val_loss, batch_cnt = 0, 0 + for images, targets in tqdm(val_loader): + imgs, padding_masks = images + if torch.cuda.is_available(): + imgs = imgs.cuda() + padding_masks = padding_masks.cuda() + imgs = batch_transforms(imgs) + if amp: + with torch.cuda.amp.autocast(): + out = model(imgs, padding_masks, targets, return_preds=True) + else: + out = model(imgs, padding_masks, targets, return_preds=True) + # Compute metric + loc_preds = out["preds"] + for target, pred in zip(targets, loc_preds): + assert pred["boxes"].shape[0] == pred["scores"].shape[0] + assert pred["boxes"].shape[0] == pred["labels"].shape[0] + val_metric.update( + gt_boxes=target["boxes"], + pred_boxes=pred["boxes"], + gt_labels=target["labels"], + pred_labels=pred["labels"], + pred_scores=pred["scores"], + ) + + val_loss += out["loss"].item() + batch_cnt += 1 + + val_loss /= batch_cnt + metrics = val_metric.summary() + return ( + val_loss, + metrics["mAP@[.5:.95]"], + metrics["AP@[.5]"], + metrics["AP@[.75]"], + ) + + +def main(args): + slack_token = os.getenv("TQDM_SLACK_TOKEN") + slack_channel = os.getenv("TQDM_SLACK_CHANNEL") + pbar = tqdm(disable=False if slack_token and slack_channel else True) + if slack_token and slack_channel: + # Monkey patch tqdm write method to send messages directly to Slack + pbar.write = lambda msg: pbar.sio.client.chat_postMessage( + channel=slack_channel, + text=msg, + ) + pbar.write(str(args)) + + if not isinstance(args.workers, int): + args.workers = min(16, mp.cpu_count()) + + torch.backends.cudnn.benchmark = True + + # Temporary model to recover configuration + tmp_model = layout.__dict__[args.arch]( + pretrained=False, + assume_straight_pages=not args.rotation, + ) + + if isinstance(args.size, int): + input_shape = (args.size, args.size) + else: + input_shape = tmp_model.cfg["input_shape"][-2:] + mean, std = tmp_model.cfg["mean"], tmp_model.cfg["std"] + + st = time.time() + ds = LayoutDataset( + img_folder=os.path.join(args.dataset_path, "images"), + label_path=os.path.join(args.dataset_path, "labels.json"), + use_polygons=args.rotation, + sample_transforms=T.Resize( + input_shape, + preserve_aspect_ratio=args.keep_ratio, + symmetric_pad=args.symmetric_pad, + return_padding_mask=True, + ), + ) + class_names = ds.class_names + + test_loader = DataLoader( + ds, + batch_size=args.batch_size, + drop_last=False, + num_workers=args.workers, + sampler=SequentialSampler(ds), + pin_memory=torch.cuda.is_available(), + collate_fn=ds.collate_fn, + ) + + pbar.write(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in {len(test_loader)} batches)") + + # Load docTR model + model = layout.__dict__[args.arch]( + pretrained=not isinstance(args.resume, str), + assume_straight_pages=not args.rotation, + class_names=class_names, + ).eval() + + batch_transforms = Normalize(mean=mean, std=std) + + # Resume weights + if isinstance(args.resume, str): + pbar.write(f"Resuming {args.resume}") + model.from_pretrained(args.resume) + + # GPU + if isinstance(args.device, int): + if not torch.cuda.is_available(): + raise AssertionError("PyTorch cannot access your GPU. Please investigate!") + if args.device >= torch.cuda.device_count(): + raise ValueError("Invalid device index") + # Silent default switch to GPU if available + elif torch.cuda.is_available(): + args.device = 0 + else: + pbar.write("No accessible GPU, target device set to CPU.") + + if torch.cuda.is_available(): + torch.cuda.set_device(args.device) + model = model.cuda() + + # Metrics + metric = ObjectDetectionMetric( + num_classes=len(class_names), + use_polygons=args.rotation, + ) + + pbar.write("Running evaluation") + val_loss, map5095, ap50, ap75 = evaluate( + model, + test_loader, + batch_transforms, + metric, + amp=args.amp, + ) + pbar.write( + f"Validation loss: {val_loss:.6f} | mAP@[.5:.95]: {map5095:.2%} | AP@[.5]: {ap50:.2%} | AP@[.75]: {ap75:.2%}" + ) + + +def parse_args(): + import argparse + + parser = argparse.ArgumentParser( + description="docTR evaluation script for text detection (PyTorch)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument("arch", type=str, help="text-detection model to evaluate") + parser.add_argument("dataset_path", type=str, help="path to the dataset to evaluate on") + parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for evaluation") + parser.add_argument("--device", default=None, type=int, help="device") + parser.add_argument("--size", type=int, default=None, help="model input size, H = W") + parser.add_argument("--keep_ratio", action="store_true", help="keep the aspect ratio of the input image") + parser.add_argument("--symmetric_pad", action="store_true", help="pad the image symmetrically") + parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") + parser.add_argument("--rotation", dest="rotation", action="store_true", help="inference with rotated bbox") + parser.add_argument("--resume", type=str, default=None, help="Checkpoint to resume") + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/references/layout/latency.py b/references/layout/latency.py new file mode 100644 index 0000000000..14c73720c5 --- /dev/null +++ b/references/layout/latency.py @@ -0,0 +1,59 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +"""Layout detection latency benchmark""" + +import argparse +import time + +import numpy as np +import torch + +from doctr.models import layout + + +@torch.inference_mode() +def main(args): + device = torch.device("cuda:0" if args.gpu else "cpu") + + # Pretrained imagenet model + model = layout.__dict__[args.arch](pretrained=args.pretrained).eval().to(device=device) + + # Input + img_tensor = torch.rand((1, 3, args.size, args.size)).to(device=device) + padding_masks = torch.zeros((1, args.size, args.size), dtype=torch.bool).to(device=device) + + # Warmup + for _ in range(10): + _ = model(input=img_tensor, masks=padding_masks) + + timings = [] + + # Evaluation runs + for _ in range(args.it): + start_ts = time.perf_counter() + _ = model(input=img_tensor, masks=padding_masks) + timings.append(time.perf_counter() - start_ts) + + _timings = np.array(timings) + print(f"{args.arch} ({args.it} runs on ({args.size}, {args.size}) inputs)") + print(f"mean {1000 * _timings.mean():.2f}ms, std {1000 * _timings.std():.2f}ms") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="docTR latency benchmark for layout detection (PyTorch)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("arch", type=str, help="Architecture to use") + parser.add_argument("--size", type=int, default=1024, help="The image input size") + parser.add_argument("--gpu", dest="gpu", help="Should the benchmark be performed on GPU", action="store_true") + parser.add_argument("--it", type=int, default=100, help="Number of iterations to run") + parser.add_argument( + "--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", action="store_true" + ) + args = parser.parse_args() + + main(args) diff --git a/references/layout/train.py b/references/layout/train.py new file mode 100644 index 0000000000..0ef5b4fd4c --- /dev/null +++ b/references/layout/train.py @@ -0,0 +1,726 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import datetime +import hashlib +import logging +import multiprocessing +import os +import time +from pathlib import Path + +import numpy as np +import torch + +# The following import is required for DDP +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler +from torchvision.transforms.v2 import Compose, Normalize, RandomGrayscale, RandomPhotometricDistort + +if os.getenv("TQDM_SLACK_TOKEN") and os.getenv("TQDM_SLACK_CHANNEL"): + from tqdm.contrib.slack import tqdm +else: + from tqdm.auto import tqdm + +from doctr import transforms as T +from doctr.datasets import LayoutDataset +from doctr.models import layout, login_to_hub, push_to_hf_hub +from doctr.utils.metrics import ObjectDetectionMetric +from utils import EarlyStopper, plot_recorder, plot_samples + + +def record_lr( + model: torch.nn.Module, + train_loader: DataLoader, + batch_transforms, + optimizer, + start_lr: float = 1e-7, + end_lr: float = 1, + num_it: int = 100, + amp: bool = False, +): + """Gridsearch the optimal learning rate for the training. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py + """ + if num_it > len(train_loader): + raise ValueError("the value of `num_it` needs to be lower than the number of available batches") + + model = model.train() + # Update param groups & LR + optimizer.defaults["lr"] = start_lr + for pgroup in optimizer.param_groups: + pgroup["lr"] = start_lr + + gamma = (end_lr / start_lr) ** (1 / (num_it - 1)) + scheduler = MultiplicativeLR(optimizer, lambda step: gamma) + + lr_recorder = [start_lr * gamma**idx for idx in range(num_it)] + loss_recorder = [] + + if amp: + scaler = torch.cuda.amp.GradScaler() + + for batch_idx, (images, targets) in enumerate(train_loader): + imgs, padding_masks = images + + if torch.cuda.is_available(): + imgs = imgs.cuda() + padding_masks = padding_masks.cuda() + + imgs = batch_transforms(imgs) + + # Forward, Backward & update + optimizer.zero_grad() + if amp: + with torch.cuda.amp.autocast(): + train_loss = model(imgs, padding_masks, targets)["loss"] + scaler.scale(train_loss).backward() + # Gradient clipping + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + # Update the params + scaler.step(optimizer) + scaler.update() + else: + train_loss = model(imgs, padding_masks, targets)["loss"] + train_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + optimizer.step() + # Update LR + scheduler.step() + + # Record + if not torch.isfinite(train_loss): + if batch_idx == 0: + raise ValueError("loss value is NaN or inf.") + else: + break + loss_recorder.append(train_loss.item()) + # Stop after the number of iterations + if batch_idx + 1 == num_it: + break + + return lr_recorder[: len(loss_recorder)], loss_recorder + + +def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): + if amp: + scaler = torch.cuda.amp.GradScaler() + + model.train() + # Iterate over the batches of the dataset + epoch_train_loss, batch_cnt = 0, 0 + pbar = tqdm(train_loader, dynamic_ncols=True, disable=(rank != 0)) + for images, targets in pbar: + imgs, padding_masks = images + if torch.cuda.is_available(): + imgs = imgs.cuda() + padding_masks = padding_masks.cuda() + imgs = batch_transforms(imgs) + + optimizer.zero_grad() + if amp: + with torch.cuda.amp.autocast(): + train_loss = model(imgs, padding_masks, targets)["loss"] + scaler.scale(train_loss).backward() + # Gradient clipping + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + # Update the params + scaler.step(optimizer) + scaler.update() + else: + train_loss = model(imgs, padding_masks, targets)["loss"] + train_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + optimizer.step() + + scheduler.step() + last_lr = scheduler.get_last_lr()[0] + + pbar.set_description(f"Training loss: {train_loss.item():.6f} | LR: {last_lr:.6f}") + if log: + log(train_loss=train_loss.item(), lr=last_lr) + + epoch_train_loss += train_loss.item() + batch_cnt += 1 + + epoch_train_loss /= batch_cnt + return epoch_train_loss, last_lr + + +@torch.no_grad() +def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=None): + # Model in eval mode + model.eval() + # Reset val metric + val_metric.reset() + # Validation loop + val_loss, batch_cnt = 0, 0 + pbar = tqdm(val_loader, dynamic_ncols=True) + for images, targets in pbar: + imgs, padding_masks = images + if torch.cuda.is_available(): + imgs = imgs.cuda() + padding_masks = padding_masks.cuda() + imgs = batch_transforms(imgs) + if amp: + with torch.cuda.amp.autocast(): + out = model(imgs, padding_masks, targets, return_preds=True) + else: + out = model(imgs, padding_masks, targets, return_preds=True) + # Compute metric + loc_preds = out["preds"] + for target, pred in zip(targets, loc_preds): + assert pred["boxes"].shape[0] == pred["scores"].shape[0] + assert pred["boxes"].shape[0] == pred["labels"].shape[0] + + val_metric.update( + gt_boxes=target["boxes"], + pred_boxes=pred["boxes"], + gt_labels=target["labels"], + pred_labels=pred["labels"], + pred_scores=pred["scores"], + ) + + pbar.set_description(f"Validation loss: {out['loss'].item():.6f}") + if log: + log(val_loss=out["loss"].item()) + + val_loss += out["loss"].item() + batch_cnt += 1 + + val_loss /= batch_cnt + metrics = val_metric.summary() + return ( + val_loss, + metrics["mAP@[.5:.95]"], + metrics["AP@[.5]"], + metrics["AP@[.75]"], + ) + + +def main(args): + # Detect distributed setup + # variable is set by torchrun + world_size = int(os.environ.get("WORLD_SIZE", 1)) + distributed = world_size > 1 + + # GPU setup + if distributed: + rank = int(os.environ.get("LOCAL_RANK", 0)) + dist.init_process_group(backend=args.backend) + device = torch.device("cuda", rank) + torch.cuda.set_device(device) + else: + # single process + rank = 0 + if isinstance(args.device, int): + if not torch.cuda.is_available(): + raise AssertionError("PyTorch cannot access your GPU. Please investigate!") + if args.device >= torch.cuda.device_count(): + raise ValueError("Invalid device index") + device = torch.device("cuda", args.device) + # Silent default switch to GPU if available + elif torch.cuda.is_available(): + device = torch.device("cuda", 0) + else: + logging.warning("No accessible GPU, target device set to CPU.") + device = torch.device("cpu") + + slack_token = os.getenv("TQDM_SLACK_TOKEN") + slack_channel = os.getenv("TQDM_SLACK_CHANNEL") + + pbar = tqdm(disable=False if (slack_token and slack_channel) and (rank == 0) else True) + if slack_token and slack_channel: + # Monkey patch tqdm write method to send messages directly to Slack + pbar.write = lambda msg: pbar.sio.client.chat_postMessage(channel=slack_channel, text=msg) + pbar.write(str(args)) + + if rank == 0 and args.push_to_hub: + login_to_hub() + + if not isinstance(args.workers, int): + args.workers = min(16, multiprocessing.cpu_count()) + + torch.backends.cudnn.benchmark = True + + # Temporary model to recover configuration + tmp_model = layout.__dict__[args.arch]( + pretrained=False, + assume_straight_pages=not args.rotation, + ) + + mean, std = tmp_model.cfg["mean"], tmp_model.cfg["std"] + + # placeholder for class names + cls_container = [None] + if rank == 0: + # validation dataset related code + st = time.time() + val_set = LayoutDataset( + img_folder=os.path.join(args.val_path, "images"), + label_path=os.path.join(args.val_path, "labels.json"), + sample_transforms=T.SampleCompose( + ( + # Important to return padding masks for layout models + [ + T.Resize( + (args.input_size, args.input_size), + preserve_aspect_ratio=True, + symmetric_pad=True, + return_padding_mask=True, + ) + ] + if not args.rotation or args.eval_straight + else [] + ) + + ( + [ + T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad + T.RandomApply(T.RandomRotate(90, expand=True), 0.5), + T.Resize( + (args.input_size, args.input_size), + preserve_aspect_ratio=True, + symmetric_pad=True, + return_padding_mask=True, + ), + ] + if args.rotation and not args.eval_straight + else [] + ) + ), + use_polygons=args.rotation and not args.eval_straight, + ) + val_loader = DataLoader( + val_set, + batch_size=args.batch_size, + drop_last=False, + num_workers=args.workers, + sampler=SequentialSampler(val_set), + pin_memory=torch.cuda.is_available(), + collate_fn=val_set.collate_fn, + ) + pbar.write( + f"Validation set loaded in {time.time() - st:.4f}s ({len(val_set)} samples in {len(val_loader)} batches)" + ) + with open(os.path.join(args.val_path, "labels.json"), "rb") as f: + val_hash = hashlib.sha256(f.read()).hexdigest() + + cls_container[0] = val_set.class_names + if distributed: + # broadcast class names to all ranks + dist.broadcast_object_list(cls_container, src=0) + # unpack class names on all ranks + class_names = cls_container[0] + + batch_transforms = Normalize(mean=mean, std=std) + + # Load docTR model + model = layout.__dict__[args.arch]( + pretrained=args.pretrained, + assume_straight_pages=not args.rotation, + class_names=class_names, + ) + + # Resume weights + if isinstance(args.resume, str): + pbar.write(f"Resuming {args.resume}") + model.from_pretrained(args.resume) + + if rank == 0: + # Metrics + val_metric = ObjectDetectionMetric( + num_classes=len(class_names), + use_polygons=args.rotation and not args.eval_straight, + ) + + if rank == 0 and args.test_only: + pbar.write("Running evaluation") + val_loss, map5095, ap50, ap75 = evaluate( + model, + val_loader, + batch_transforms, + val_metric, + amp=args.amp, + ) + pbar.write( + f"Validation loss: {val_loss:.6f} | " + f"mAP@[.5:.95]: {map5095:.2%} | " + f"AP@[.5]: {ap50:.2%} | " + f"AP@[.75]: {ap75:.2%}" + ) + return + + st = time.time() + # Augmentations + # Image augmentations + img_transforms = T.OneOf([ + Compose([ + T.RandomApply(T.ColorInversion(), 0.3), + T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.2), + ]), + Compose([ + T.RandomApply(T.RandomShadow(), 0.3), + T.RandomApply(T.GaussianNoise(), 0.1), + T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3), + RandomGrayscale(p=0.15), + ]), + RandomPhotometricDistort(p=0.3), + lambda x: x, # Identity no transformation + ]) + # Image + target augmentations + sample_transforms = T.SampleCompose( + ( + [ + T.RandomHorizontalFlip(0.15), + T.OneOf([ + T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25), + T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), + ]), + T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), + ] + if not args.rotation + else [ + T.RandomHorizontalFlip(0.15), + T.OneOf([ + T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25), + T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), + ]), + # Rotation augmentation + T.Resize(args.input_size, preserve_aspect_ratio=True), + T.RandomApply(T.RandomRotate(90, expand=True), 0.5), + # Important to return padding masks for layout models + T.Resize( + (args.input_size, args.input_size), + preserve_aspect_ratio=True, + symmetric_pad=True, + return_padding_mask=True, + ), + ] + ) + ) + + # Load both train and val data generators + train_set = LayoutDataset( + img_folder=os.path.join(args.train_path, "images"), + label_path=os.path.join(args.train_path, "labels.json"), + img_transforms=img_transforms, + sample_transforms=sample_transforms, + use_polygons=args.rotation, + ) + + if distributed: + sampler = DistributedSampler(train_set, rank=rank, shuffle=False, drop_last=True) + else: + sampler = RandomSampler(train_set) + + train_loader = DataLoader( + train_set, + batch_size=args.batch_size, + drop_last=True, + num_workers=args.workers, + sampler=sampler, + pin_memory=torch.cuda.is_available(), + collate_fn=train_set.collate_fn, + ) + + # Sanity class names check between train and val sets + if set(class_names) != set(train_set.class_names): + raise ValueError( + "Class names mismatch between train and val sets. " + f"Train classes: {train_set.class_names}, Val classes: {class_names}" + ) + + if rank == 0: + pbar.write( + f"Train set loaded in {time.time() - st:.4f}s ({len(train_set)} samples in {len(train_loader)} batches)" + ) + + with open(os.path.join(args.train_path, "labels.json"), "rb") as f: + train_hash = hashlib.sha256(f.read()).hexdigest() + + if rank == 0 and args.show_samples: + x, target = next(iter(train_loader)) + plot_samples(x, target) + return + + # Backbone freezing + if args.freeze_backbone: + for p in model.feat_extractor.parameters(): + p.requires_grad = False + + if torch.cuda.is_available(): + torch.cuda.set_device(device) + model = model.to(device) + + if distributed: + # construct DDP model + model = DDP(model, device_ids=[rank]) + + # Optimizer + if args.optim == "adam": + optimizer = torch.optim.Adam( + [p for p in model.parameters() if p.requires_grad], + args.lr, + betas=(0.95, 0.999), + eps=1e-6, + weight_decay=args.weight_decay, + ) + + elif args.optim == "adamw": + optimizer = torch.optim.AdamW( + [p for p in model.parameters() if p.requires_grad], + args.lr, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=args.weight_decay or 1e-4, + ) + + # LR Finder + if rank == 0 and args.find_lr: + lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp) + plot_recorder(lrs, losses) + return + + # Scheduler + if args.sched == "cosine": + scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) + elif args.sched == "onecycle": + scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) + elif args.sched == "poly": + scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader)) + + # Training monitoring + current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name + + if rank == 0: + config = { + "learning_rate": args.lr, + "epochs": args.epochs, + "weight_decay": args.weight_decay, + "batch_size": args.batch_size, + "architecture": args.arch, + "input_size": args.input_size, + "optimizer": args.optim, + "framework": "pytorch", + "scheduler": args.sched, + "train_hash": train_hash, + "val_hash": val_hash, + "pretrained": args.pretrained, + "rotation": args.rotation, + "amp": args.amp, + } + + global global_step + global_step = 0 # Shared global step counter + + # W&B + if args.wb: + import wandb + + run = wandb.init(name=exp_name, project="layout-detection", config=config) + + def wandb_log_at_step(train_loss=None, val_loss=None, lr=None): + wandb.log({ + **({"train_loss_step": train_loss} if train_loss is not None else {}), + **({"val_loss_step": val_loss} if val_loss is not None else {}), + **({"step_lr": lr} if lr is not None else {}), + }) + + # ClearML + if args.clearml: + from clearml import Logger, Task + + task = Task.init(project_name="docTR/layout-detection", task_name=exp_name, reuse_last_task_id=False) + task.upload_artifact("config", config) + + def clearml_log_at_step(train_loss=None, val_loss=None, lr=None): + logger = Logger.current_logger() + + if train_loss is not None: + logger.report_scalar( + title="Training Step Loss", + series="train_loss_step", + iteration=global_step, + value=train_loss, + ) + if val_loss is not None: + logger.report_scalar( + title="Validation Step Loss", + series="val_loss_step", + iteration=global_step, + value=val_loss, + ) + if lr is not None: + logger.report_scalar( + title="Step Learning Rate", + series="step_lr", + iteration=global_step, + value=lr, + ) + + # Unified logger + def log_at_step(train_loss=None, val_loss=None, lr=None): + global global_step + if args.wb: + wandb_log_at_step(train_loss, val_loss, lr) + if args.clearml: + clearml_log_at_step(train_loss, val_loss, lr) + global_step += 1 # Increment the shared global step counter + + # Create loss queue + min_loss = np.inf + if args.early_stop: + early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta) + + # Training loop + for epoch in range(args.epochs): + train_loss, actual_lr = fit_one_epoch( + model, + train_loader, + batch_transforms, + optimizer, + scheduler, + amp=args.amp, + log=log_at_step, + rank=rank, + ) + + if rank == 0: + pbar.write(f"Epoch {epoch + 1}/{args.epochs} - Training loss: {train_loss:.6f} | LR: {actual_lr:.6f}") + + # Validation loop at the end of each epoch + val_loss, map5095, ap50, ap75 = evaluate( + model, + val_loader, + batch_transforms, + val_metric, + amp=args.amp, + log=log_at_step, + ) + params = model.module if hasattr(model, "module") else model + if val_loss < min_loss: + pbar.write(f"Validation loss decreased {min_loss:.6f} --> {val_loss:.6f}: saving state...") + torch.save(params.state_dict(), Path(args.output_dir) / f"{exp_name}.pt") + min_loss = val_loss + if args.save_interval_epoch: + pbar.write(f"Saving state at epoch: {epoch + 1}") + torch.save(params.state_dict(), Path(args.output_dir) / f"{exp_name}_epoch{epoch + 1}.pt") + log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " + if any(val is None for val in (map5095, ap50, ap75)): + log_msg += "(Undefined metric value, caused by empty GTs or predictions)" + else: + log_msg += f"| mAP@[.5:.95]: {map5095:.2%} | AP@[.5]: {ap50:.2%} | AP@[.75]: {ap75:.2%}" + pbar.write(log_msg) + # W&B + if args.wb: + wandb.log({ + "train_loss": train_loss, + "val_loss": val_loss, + "learning_rate": actual_lr, + "mAP@[.5:.95]": map5095, + "AP@[.5]": ap50, + "AP@[.75]": ap75, + }) + + # ClearML + if args.clearml: + from clearml import Logger + + logger = Logger.current_logger() + logger.report_scalar(title="Training Loss", series="train_loss", value=train_loss, iteration=epoch) + logger.report_scalar(title="Validation Loss", series="val_loss", value=val_loss, iteration=epoch) + logger.report_scalar(title="Learning Rate", series="lr", value=actual_lr, iteration=epoch) + logger.report_scalar(title="mAP@[.5:.95]", series="mAP@[.5:.95]", value=map5095, iteration=epoch) + logger.report_scalar(title="AP@[.5]", series="AP@[.5]", value=ap50, iteration=epoch) + logger.report_scalar(title="AP@[.75]", series="AP@[.75]", value=ap75, iteration=epoch) + + if args.early_stop and early_stopper.early_stop(val_loss): + pbar.write("Training halted early due to reaching patience limit.") + break + + if rank == 0: + if args.wb: + run.finish() + + if args.push_to_hub: + push_to_hf_hub(model, exp_name, task="layout", run_config=args) + + +def parse_args(): + import argparse + + parser = argparse.ArgumentParser( + description="DocTR training script for layout detection (PyTorch)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # DDP related args + parser.add_argument("--backend", default="nccl", type=str, help="Backend to use for torch.distributed") + parser.add_argument( + "--device", + default=None, + type=int, + help="Specify gpu device for single-gpu training. In destributed setting, this parameter is ignored", + ) + parser.add_argument("arch", type=str, help="text-detection model to train") + parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") + parser.add_argument("--train_path", type=str, required=True, help="path to training data folder") + parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder") + parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") + parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") + parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training") + parser.add_argument( + "--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch" + ) + parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W") + parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam or AdamW)") + parser.add_argument("--wd", "--weight-decay", default=0, type=float, help="weight decay", dest="weight_decay") + parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") + parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") + parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop") + parser.add_argument( + "--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning" + ) + parser.add_argument( + "--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples" + ) + parser.add_argument("--wb", dest="wb", action="store_true", help="Log to Weights & Biases") + parser.add_argument("--clearml", dest="clearml", action="store_true", help="Log to ClearML") + parser.add_argument("--push-to-hub", dest="push_to_hub", action="store_true", help="Push to Huggingface Hub") + parser.add_argument( + "--pretrained", + dest="pretrained", + action="store_true", + help="Load pretrained parameters before starting the training", + ) + parser.add_argument("--rotation", dest="rotation", action="store_true", help="train with rotated documents") + parser.add_argument( + "--eval-straight", + action="store_true", + help="metrics evaluation with straight boxes instead of polygons to save time + memory", + ) + parser.add_argument("--optim", type=str, default="adam", choices=["adam", "adamw"], help="optimizer to use") + parser.add_argument( + "--sched", type=str, default="poly", choices=["cosine", "onecycle", "poly"], help="scheduler to use" + ) + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR") + parser.add_argument("--early-stop", action="store_true", help="Enable early stopping") + parser.add_argument("--early-stop-epochs", type=int, default=5, help="Patience for early stopping") + parser.add_argument("--early-stop-delta", type=float, default=0.01, help="Minimum Delta for early stopping") + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/references/layout/utils.py b/references/layout/utils.py new file mode 100644 index 0000000000..218d5548ea --- /dev/null +++ b/references/layout/utils.py @@ -0,0 +1,101 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +import cv2 +import matplotlib.pyplot as plt +import numpy as np + + +def plot_samples(images, targets: list[dict[str, np.ndarray]]) -> None: + # Unnormalize image + nb_samples = min(len(images), 4) + _, axes = plt.subplots(2, nb_samples, figsize=(20, 5)) + for idx in range(nb_samples): + img = (255 * images[idx].numpy()).round().clip(0, 255).astype(np.uint8) + if img.shape[0] == 3 and img.shape[2] != 3: + img = img.transpose(1, 2, 0) + + target = np.zeros(img.shape[:2], np.uint8) + tgts = targets[idx].copy() + for boxes in tgts.values(): + boxes[:, [0, 2]] = boxes[:, [0, 2]] * img.shape[1] + boxes[:, [1, 3]] = boxes[:, [1, 3]] * img.shape[0] + boxes[:, :4] = boxes[:, :4].round().astype(int) + + for box in boxes: + if boxes.ndim == 3: + cv2.fillPoly(target, [np.intp(box)], 1) + else: + target[int(box[1]) : int(box[3]) + 1, int(box[0]) : int(box[2]) + 1] = 1 + if nb_samples > 1: + axes[0][idx].imshow(img) + axes[1][idx].imshow(target.astype(bool)) + else: + axes[0].imshow(img) + axes[1].imshow(target.astype(bool)) + + # Disable axis + for ax in axes.ravel(): + ax.axis("off") + plt.show() + + +def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> None: + """Display the results of the LR grid search. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py + + Args: + lr_recorder: list of LR values + loss_recorder: list of loss values + beta (float, optional): smoothing factor + **kwargs: keyword arguments from `matplotlib.pyplot.show` + """ + if len(lr_recorder) != len(loss_recorder) or len(lr_recorder) == 0: + raise AssertionError("Both `lr_recorder` and `loss_recorder` should have the same length") + + # Exp moving average of loss + smoothed_losses = [] + avg_loss = 0.0 + for idx, loss in enumerate(loss_recorder): + avg_loss = beta * avg_loss + (1 - beta) * loss + smoothed_losses.append(avg_loss / (1 - beta ** (idx + 1))) + + # Properly rescale Y-axis + data_slice = slice( + min(len(loss_recorder) // 10, 10), + # -min(len(loss_recorder) // 20, 5) if len(loss_recorder) >= 20 else len(loss_recorder) + len(loss_recorder), + ) + vals = np.array(smoothed_losses[data_slice]) + min_idx = vals.argmin() + max_val = vals.max() if min_idx is None else vals[: min_idx + 1].max() # type: ignore[misc] + delta = max_val - vals[min_idx] + + plt.plot(lr_recorder[data_slice], smoothed_losses[data_slice]) + plt.xscale("log") + plt.xlabel("Learning Rate") + plt.ylabel("Training loss") + plt.ylim(vals[min_idx] - 0.1 * delta, max_val + 0.2 * delta) + plt.grid(True, linestyle="--", axis="x") + plt.show(**kwargs) + + +class EarlyStopper: + def __init__(self, patience: int = 5, min_delta: float = 0.01): + self.patience = patience + self.min_delta = min_delta + self.counter = 0 + self.min_validation_loss = float("inf") + + def early_stop(self, validation_loss: float) -> bool: + if validation_loss < self.min_validation_loss: + self.min_validation_loss = validation_loss + self.counter = 0 + elif validation_loss > (self.min_validation_loss + self.min_delta): + self.counter += 1 + if self.counter >= self.patience: + return True + return False diff --git a/tests/conftest.py b/tests/conftest.py index 7132c41e63..f12b604491 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -147,7 +147,7 @@ def mock_layout_label(tmpdir_factory): [[3, 2], [3, 3], [4, 1], [4, 3]], [[30, 20], [30, 30], [40, 10], [40, 30]], ], - "class_names": [ + "classes": [ "Table", "Header", "Footer", diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py index a8008b5bc6..7c1e97d808 100644 --- a/tests/pytorch/test_datasets_pt.py +++ b/tests/pytorch/test_datasets_pt.py @@ -188,15 +188,20 @@ def test_layout_dataset(mock_image_folder, mock_layout_label, use_polygons): ds = datasets.LayoutDataset( img_folder=mock_image_folder, label_path=mock_layout_label, - img_transforms=Resize(input_size), + img_transforms=Resize(input_size, return_padding_mask=True), use_polygons=use_polygons, ) assert len(ds) == 5 - img, target_dict = ds[0] + inputs, target_dict = ds[0] + assert isinstance(inputs, tuple) and len(inputs) == 2 + img, padding_mask = inputs assert isinstance(img, torch.Tensor) assert img.dtype == torch.float32 assert img.shape[-2:] == input_size + assert isinstance(padding_mask, torch.Tensor) + assert padding_mask.dtype == torch.bool + assert padding_mask.shape == input_size assert isinstance(target_dict, dict) expected_classes = {"Table", "Header", "Footer", "Text"} assert set(target_dict.keys()) == expected_classes @@ -213,8 +218,12 @@ def test_layout_dataset(mock_image_folder, mock_layout_label, use_polygons): assert ds.class_names == sorted(expected_classes) loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) images, targets = next(iter(loader)) - assert isinstance(images, torch.Tensor) - assert images.shape == (2, 3, *input_size) + assert isinstance(images, tuple) and len(images) == 2 + img, padding_mask = images + assert isinstance(img, torch.Tensor) + assert img.shape == (2, 3, *input_size) + assert isinstance(padding_mask, torch.Tensor) + assert padding_mask.shape == (2, *input_size) assert isinstance(targets, list) assert all(isinstance(target, dict) for target in targets) for target in targets: @@ -243,14 +252,14 @@ def test_layout_dataset(mock_image_folder, mock_layout_label, use_polygons): test_cases = [ ( - {"class_names": ["Text"]}, + {"classes": ["Text"]}, KeyError, "missing 'polygons'", ), ( {"polygons": [[[0, 0], [1, 0], [1, 1], [0, 1]]]}, KeyError, - "missing 'class_names'", + "missing 'classes'", ), ( { @@ -258,7 +267,7 @@ def test_layout_dataset(mock_image_folder, mock_layout_label, use_polygons): [[0, 0], [1, 0], [1, 1], [0, 1]], [[0, 0], [1, 0], [1, 1], [0, 1]], ], - "class_names": ["Text"], + "classes": ["Text"], }, ValueError, "number of polygons", @@ -266,7 +275,7 @@ def test_layout_dataset(mock_image_folder, mock_layout_label, use_polygons): ( { "polygons": [[[0, 0], [1, 0], [1, 1]]], # only 3 points - "class_names": ["Text"], + "classes": ["Text"], }, ValueError, "polygons are expected to have shape", From 320b4ae6e1b01b0ecf75bb7b00d2b6212ab8f3a3 Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 13 May 2026 11:55:27 +0200 Subject: [PATCH 2/2] typing and mypy --- doctr/datasets/datasets/pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doctr/datasets/datasets/pytorch.py b/doctr/datasets/datasets/pytorch.py index 6418906dd7..439555327b 100644 --- a/doctr/datasets/datasets/pytorch.py +++ b/doctr/datasets/datasets/pytorch.py @@ -50,7 +50,7 @@ def _read_sample(self, index: int) -> tuple[torch.Tensor, Any]: @staticmethod def collate_fn( samples: list[tuple[torch.Tensor, Any]] | list[tuple[tuple[torch.Tensor, torch.Tensor], Any]], - ) -> tuple[torch.Tensor, list[Any]]: + ) -> tuple[torch.Tensor | tuple[torch.Tensor, torch.Tensor], list[Any]]: images, targets = zip(*samples) if isinstance(images[0], tuple): images, padding_masks = zip(*images) @@ -60,7 +60,7 @@ def collate_fn( else: images = torch.stack(images, dim=0) # type: ignore[assignment] - return images, list(targets) # type: ignore[return-value] + return images, list(targets) class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101