Skip to content

Commit ab036ba

Browse files
committed
Deal with shape only natlang arrays
and add doctests and documentation
1 parent 7b4ea91 commit ab036ba

2 files changed

Lines changed: 60 additions & 10 deletions

File tree

src/docstub/_doctype.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Parsing of doctypes."""
1+
"""Parsing & transformation of doctypes into Python-compatible syntax."""
22

33
import enum
44
import keyword
@@ -34,10 +34,18 @@ def flatten_recursive(iterable):
3434
Parameters
3535
----------
3636
iterable : Iterable[Iterable or str]
37+
An iterable containing nested iterables or strings. Only strings are
38+
supported as "leafs" for now.
3739
3840
Yields
3941
------
4042
item : str
43+
44+
Examples
45+
--------
46+
>>> nested = ["only", ["strings", ("and", "iterables"), "are", ["allowed"]]]
47+
>>> list(flatten_recursive(nested))
48+
['only', 'strings', 'and', 'iterables', 'are', 'allowed']
4149
"""
4250
for item in iterable:
4351
if isinstance(item, str):
@@ -59,6 +67,12 @@ def insert_between(iterable, *, sep):
5967
Returns
6068
-------
6169
out : list[Any]
70+
71+
Examples
72+
--------
73+
>>> code = ["a", "b", "c", ]
74+
>>> list(insert_between(code, sep=" | "))
75+
['a', ' | ', 'b', ' | ', 'c']
6276
"""
6377
out = []
6478
for item in iterable:
@@ -68,6 +82,8 @@ def insert_between(iterable, *, sep):
6882

6983

7084
class TermKind(enum.StrEnum):
85+
"""Encodes the different kinds of :class:`Term`."""
86+
7187
# docstub: off
7288
NAME = enum.auto()
7389
LITERAL = enum.auto()
@@ -83,6 +99,17 @@ class Term(str):
8399
kind : TermKind
84100
pos : tuple of (int, int) or None
85101
__slots__ : Final[tuple[str, ...]]
102+
103+
Examples
104+
--------
105+
>>> ''.join(
106+
... [
107+
... Term("int", kind="name"),
108+
... Term(" | ", kind="syntax"),
109+
... Term("float", kind="name")
110+
... ]
111+
... )
112+
'int | float'
86113
"""
87114

88115
__slots__ = ("kind", "pos")
@@ -216,6 +243,16 @@ class BlacklistedQualname(DocstubError):
216243

217244
@lark.visitors.v_args(tree=True)
218245
class DoctypeTransformer(lark.visitors.Transformer):
246+
"""Transform parsed doctypes into Python-compatible syntax.
247+
248+
Examples
249+
--------
250+
>>> tree = _lark.parse("int or tuple of (int, ...)")
251+
>>> transformer = DoctypeTransformer()
252+
>>> str(transformer.transform(tree=tree))
253+
'int | tuple[int, ...]'
254+
"""
255+
219256
def __init__(self, *, reporter=None):
220257
"""
221258
Parameters
@@ -310,6 +347,7 @@ def subscription(self, tree):
310347
-------
311348
out : Expr
312349
"""
350+
assert len(tree.children) > 1
313351
return self._format_subscription(tree.children, rule="subscription")
314352

315353
def param_spec(self, tree):
@@ -341,6 +379,7 @@ def callable(self, tree):
341379
-------
342380
out : Expr
343381
"""
382+
assert len(tree.children) > 1
344383
return self._format_subscription(tree.children, rule="callable")
345384

346385
def literal(self, tree):
@@ -353,6 +392,7 @@ def literal(self, tree):
353392
-------
354393
out : Expr
355394
"""
395+
assert len(tree.children) > 1
356396
out = self._format_subscription(tree.children, rule="literal")
357397
return out
358398

@@ -372,6 +412,7 @@ def natlang_literal(self, tree):
372412
]
373413
out = self._format_subscription(items, rule="natlang_literal")
374414

415+
assert len(tree.children) >= 1
375416
if len(tree.children) == 1:
376417
details = ("Consider using `%s` to improve readability", "".join(out))
377418
self.reporter.warn(
@@ -409,6 +450,7 @@ def natlang_container(self, tree):
409450
-------
410451
out : Expr
411452
"""
453+
assert len(tree.children) >= 1
412454
return self._format_subscription(tree.children, rule="natlang_container")
413455

414456
def natlang_array(self, tree):
@@ -490,7 +532,8 @@ def extra_info(self, tree):
490532
return lark.Discard
491533

492534
def _format_subscription(self, sequence, *, rule):
493-
"""
535+
"""Format a `name[...]` style expression.
536+
494537
Parameters
495538
----------
496539
sequence : Sequence[str]
@@ -502,17 +545,20 @@ def _format_subscription(self, sequence, *, rule):
502545
"""
503546
sep = Term(", ", kind=TermKind.SYNTAX)
504547
container, *content = sequence
505-
content = insert_between(content, sep=sep)
506-
assert content
507-
expr = Expr(
508-
rule=rule,
509-
children=[
548+
assert container
549+
550+
if content:
551+
content = insert_between(content, sep=sep)
552+
children = [
510553
container,
511554
Term("[", kind=TermKind.SYNTAX),
512555
*content,
513556
Term("]", kind=TermKind.SYNTAX),
514-
],
515-
)
557+
]
558+
else:
559+
children = [container]
560+
561+
expr = Expr(rule=rule, children=children)
516562
return expr
517563

518564

@@ -534,6 +580,8 @@ def parse_doctype(doctype, *, reporter=None):
534580
lark.exceptions.VisitError
535581
Raised when the transformation is interrupted by an exception.
536582
See :cls:`lark.exceptions.VisitError`.
583+
BlacklistedQualname
584+
Raised when a qualname is a forbidden keyword.
537585
538586
Examples
539587
--------

tests/test_doctype.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ def test_rst_role(self, doctype, expected):
270270
[
271271
("{name} of shape {shape} and dtype {dtype}", "{name}[{dtype}]"),
272272
("{name} of dtype {dtype} and shape {shape}", "{name}[{dtype}]"),
273-
("{name} of {dtype}", "{name}[{dtype}]"),
274273
],
275274
)
276275
@pytest.mark.parametrize("name", ["array", "ndarray", "array-like", "array_like"])
@@ -283,17 +282,20 @@ def test_natlang_array(self, fmt, expected_fmt, name, dtype, shape):
283282
expected = expected_fmt.format(name=name, dtype=dtype, shape=shape)
284283
expr = parse_doctype(doctype)
285284
assert expr.as_code() == expected
285+
assert "natlang_array" in [e.rule for e in expr.sub_expressions]
286286
# fmt: on
287287

288288
@pytest.mark.parametrize(
289289
("doctype", "expected"),
290290
[
291291
("ndarray of dtype (int or float)", "ndarray[int | float]"),
292+
("ndarray of shape (M, N)", "ndarray"),
292293
],
293294
)
294295
def test_natlang_array_specific(self, doctype, expected):
295296
expr = parse_doctype(doctype)
296297
assert expr.as_code() == expected
298+
assert "natlang_array" in [e.rule for e in expr.sub_expressions]
297299

298300
@pytest.mark.parametrize("shape", ["(-1, 3)", "(1.0, 2)", "-3D", "-2-D"])
299301
def test_natlang_array_invalid_shape(self, shape):

0 commit comments

Comments
 (0)