Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions .github/workflows/references.yml
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,91 @@ jobs:
pip install -e .[viz,html] --upgrade
- name: Benchmark latency
run: python references/detection/latency.py db_mobilenet_v3_large --it 5 --size 512


train-layout-analysis:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: ["3.10"]
steps:
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python }}
architecture: x64
- name: Cache python modules
uses: actions/cache@v5
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('references/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[viz,html] --upgrade
pip install -r references/requirements.txt
- name: Download and extract toy set
run: |
wget https://github.com/mindee/doctr/releases/download/v0.3.1/toy_detection_set-bbbb4243.zip
sudo apt-get update && sudo apt-get install unzip -y
unzip toy_detection_set-bbbb4243.zip -d det_set
- name: Train for a short epoch
run: python references/layout/train.py lw_detr_s --train_path ./det_set --val_path ./det_set -b 2 --epochs 1

evaluate-layout-analysis:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: ["3.10"]
steps:
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python }}
architecture: x64
- name: Cache python modules
uses: actions/cache@v5
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[viz,html] --upgrade
pip install -r references/requirements.txt
- name: Evaluate layout analysis
run: python references/layout/evaluate.py lw_detr_s

latency-layout-analysis:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: ["3.10"]
steps:
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python }}
architecture: x64
- name: Cache python modules
uses: actions/cache@v5
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[viz,html] --upgrade
- name: Benchmark latency
run: python references/layout/latency.py lw_detr_s --it 5 --size 512
15 changes: 15 additions & 0 deletions docs/source/using_doctr/custom_models_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ For details on the training process and the necessary data and data format, refe

- `detection <https://github.com/mindee/doctr/tree/main/references/detection#readme>`_
- `recognition <https://github.com/mindee/doctr/tree/main/references/recognition#readme>`_
- `layout <https://github.com/mindee/doctr/tree/main/references/layout#readme>`_

If you’re looking for a lightweight yet efficient tool to annotate small amounts of data, especially tailored for docTR,
check out the `docTR Labeling Tool <https://github.com/text2knowledge/docTR-Labeler>`_.
Expand Down Expand Up @@ -52,6 +53,20 @@ Load a custom recognition model trained on another vocabulary as the default one

predictor = ocr_predictor(det_arch='linknet_resnet18', reco_arch=reco_model, pretrained=True)


Load a custom layout analysis model trained on another set of classes as the default one:

.. code:: python3

import torch
from doctr.models import layout_predictor, lw_detr_s
from doctr.datasets import VOCABS

layout_model = lw_detr_s(pretrained=False, class_names=["class_name_1", "class_name_2", ...])
layout_model.from_pretrained('<path_to_pt>')

predictor = layout_predictor(layout_arch=layout_model, pretrained=True)

Load a custom trained KIE detection model:

.. code:: python3
Expand Down
60 changes: 60 additions & 0 deletions docs/source/using_doctr/using_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,66 @@ Recognition predictors
out = model([dummy_img])


Layout Analysis
---------------

The task consists of localizing and classifying visual elements in a given image.
This is a more general task than text detection, as it can be used to detect and classify any type of visual element in a document, such as tables, figures, headers, footers, etc.
Our latest layout models works with rotated and skewed documents!

Available architectures
^^^^^^^^^^^^^^^^^^^^^^^

The following architectures are currently supported:

* :py:meth:`lw_detr_s <doctr.models.layout.lw_detr_s>`
* :py:meth:`lw_detr_m <doctr.models.layout.lw_detr_m>`

For a comprehensive comparison, we have compiled a detailed benchmark:

+--------------------------------------------------+-----------------+---------------+------------------+-------------+--------------+--------------------+
| | | | | | | |
+==================================================+=================+===============+==================+=============+==============+====================+
| **Architecture** | **Input shape** | **# params** | **mAP@[.5:.95]** | **AP@[.5]** | **AP@[.75]** | **sec/it (B: 1)** |
+--------------------------------------------------+-----------------+---------------+------------------+-------------+--------------+--------------------+
| lw_detr_s | (1024, 1024, 3) | 15.1 M | | | | 0.5 |
+--------------------------------------------------+-----------------+---------------+------------------+-------------+--------------+--------------------+
| lw_detr_m | (1024, 1024, 3) | 29.5 M | | | | 0.7 |
+--------------------------------------------------+-----------------+---------------+------------------+-------------+--------------+--------------------+


Explanations about the metrics being used are available in :ref:`metrics`.

Seconds per iteration (with a batch size of 1) is computed after a warmup phase of 100 tensors, by measuring the average number of processed tensors per second over 1000 samples. Those results were obtained on a `11th Gen Intel(R) Core(TM) i7-11800H @ 2.30GHz`.


Layout predictors
^^^^^^^^^^^^^^^^^

:py:meth:`layout_predictor <doctr.models.layout.layout_predictor>` wraps your layout model to make it easily useable with your favorite deep learning framework seamlessly.

.. code:: python3

import numpy as np
from doctr.models import layout_predictor
model = layout_predictor('lw_detr_s')
dummy_img = (255 * np.random.rand(800, 600, 3)).astype(np.uint8)
out = model([dummy_img])

You can pass specific boolean arguments to the predictor:
* `pretrained`: if you want to use a model that has been pretrained on a specific dataset, setting `pretrained=True` this will load the corresponding weights. If `pretrained=False`, which is the default, would otherwise lead to a random initialization and would lead to no/useless results.
* `assume_straight_pages`: if you work with straight documents only, it will fit straight bounding boxes to the text areas.
* `preserve_aspect_ratio`: if you want to preserve the aspect ratio of your documents while resizing before sending them to the model.
* `symmetric_pad`: if you choose to preserve the aspect ratio, it will pad the image symmetrically and not from the bottom-right.

For instance, this snippet will instantiates a layout predictor able to detect text on rotated documents while preserving the aspect ratio:

.. code:: python3

from doctr.models import layout_predictor
predictor = layout_predictor('lw_detr_s', pretrained=True, assume_straight_pages=False, preserve_aspect_ratio=True)


End-to-End OCR
--------------

Expand Down
14 changes: 11 additions & 3 deletions doctr/datasets/datasets/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,19 @@ def _read_sample(self, index: int) -> tuple[torch.Tensor, Any]:
return img, deepcopy(target)

@staticmethod
def collate_fn(samples: list[tuple[torch.Tensor, Any]]) -> tuple[torch.Tensor, list[Any]]:
def collate_fn(
samples: list[tuple[torch.Tensor, Any]] | list[tuple[tuple[torch.Tensor, torch.Tensor], Any]],
) -> tuple[torch.Tensor | tuple[torch.Tensor, torch.Tensor], list[Any]]:
images, targets = zip(*samples)
images = torch.stack(images, dim=0) # type: ignore[assignment]
if isinstance(images[0], tuple):
images, padding_masks = zip(*images)
images = torch.stack(images, dim=0) # type: ignore[assignment]
padding_masks = torch.stack(padding_masks, dim=0) # type: ignore[assignment]
images = (images, padding_masks)
else:
images = torch.stack(images, dim=0) # type: ignore[assignment]

return images, list(targets) # type: ignore[return-value]
return images, list(targets)


class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
Expand Down
6 changes: 3 additions & 3 deletions doctr/datasets/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@ def __init__(
raise FileNotFoundError(f"unable to locate {img_path}")

polygons = label.get("polygons")
class_names = label.get("class_names")
class_names = label.get("classes")

if polygons is None:
raise KeyError(f"missing 'polygons' for image: {img_name}")
if class_names is None:
raise KeyError(f"missing 'class_names' for image: {img_name}")
raise KeyError(f"missing 'classes' for image: {img_name}")

if len(polygons) != len(class_names):
raise ValueError(
f"number of polygons ({len(polygons)}) does not match "
f"number of class_names ({len(class_names)}) for image: {img_name}"
f"number of classes ({len(class_names)}) for image: {img_name}"
)

geoms, polygon_classes = self.format_polygons(
Expand Down
5 changes: 3 additions & 2 deletions doctr/models/layout/lw_detr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np

from doctr.models.core import BaseModel
from doctr.utils import order_points

__all__ = ["_LWDETR", "LWDETRPostProcessor"]

Expand Down Expand Up @@ -57,7 +58,7 @@ def _decode_boxes(self, boxes: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
for i in range(len(boxes)):
rect = ((float(cx[i]), float(cy[i])), (float(w[i]), float(h[i])), float(np.degrees(angles[i])))

poly = cv2.boxPoints(rect)
poly = order_points(cv2.boxPoints(rect))
polys.append(poly)

return np.asarray(polys, dtype=np.float32), angles
Expand Down Expand Up @@ -237,7 +238,7 @@ def _quad_to_obb(poly: np.ndarray):
continue

for cls_id, box in zip(np.asarray(class_ids), np.asarray(boxes)):
poly = box.reshape(4, 2)
poly = order_points(box.reshape(4, 2))
obb = _quad_to_obb(poly)

if obb[2] <= 1e-3 or obb[3] <= 1e-3:
Expand Down
12 changes: 2 additions & 10 deletions doctr/models/layout/lw_detr/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,8 @@
"Table",
"Text",
"Title",
"Document Index",
"Code",
"Checkbox-Selected",
"Checkbox-Unselected",
"Form",
"Key-Value Region",
],
"url": None,
},
Expand All @@ -64,12 +60,8 @@
"Table",
"Text",
"Title",
"Document Index",
"Code",
"Checkbox-Selected",
"Checkbox-Unselected",
"Form",
"Key-Value Region",
],
"url": None,
},
Expand Down Expand Up @@ -735,7 +727,7 @@ def _lw_detr(
kwargs["class_names"] = kwargs.get("class_names", default_cfgs[arch].get("class_names", []))

_cfg = deepcopy(default_cfgs[arch])
_cfg["class_names"] = kwargs["class_names"]
_cfg["class_names"] = sorted(kwargs["class_names"])
kwargs.pop("class_names")

# Build the feature extractor
Expand All @@ -758,7 +750,7 @@ def _lw_detr(
if pretrained:
# The number of class_names is not the same as the number of classes in the pretrained model =>
# remove the layer weights
_ignore_keys = ignore_keys if _cfg["class_names"] != default_cfgs[arch].get("class_names") else None
_ignore_keys = ignore_keys if _cfg["class_names"] != sorted(default_cfgs[arch].get("class_names", [])) else None
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)

return model
Expand Down
Loading
Loading