Skip to content

Commit 07cb8c7

Browse files
committed
Refactor & simplify names in configuration
Hopefully this will be easier to explain. The previous names and structure weren't at all intuitive. In the process also simplify the architecture somewhat. It still doesn't feel completely clean but "baby steps".
1 parent 7600795 commit 07cb8c7

11 files changed

Lines changed: 243 additions & 216 deletions

File tree

examples/docstub.toml

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
11
[tool.docstub]
22

3-
# TODO not implemented and used yet
4-
extend_grammar = """
5-
6-
"""
7-
8-
# Import information for type annotations, declared ahead of time.
3+
# Prefixes for external modules to match types in docstrings.
4+
# Docstub can't yet automatically discover where to import types from other
5+
# packages from. Instead, you can provide this information explicitly.
6+
# Any type in a docstring whose prefix matches the name given on the left side,
7+
# will be associated with the given "module" on the right side.
98
#
10-
# Each item maps an annotation name on the left side to a dictionary on the
11-
# right side.
9+
# Examples:
10+
# np = "numpy"
11+
# Will match `np.uint8` and `np.typing.NDarray` and use "import numpy as np".
1212
#
13-
# Import information can be declared with the following fields:
14-
# from : Indicate that the DocType can be imported from this path.
15-
# import : Import this object, defaults to the DocType.
16-
# as : Use this alias for the imported object
17-
# is_builtin : Indicate that this DocType doesn't need to be imported,
18-
# defaults to "false"
19-
[tool.docstub.known_imports]
20-
configparser = {import = "configparser"}
13+
# plt = "matplotlib.pyplot
14+
# Will match `plt.Figure` use `import matplotlib.pyplot as plt`.
15+
[tool.docstub.type_prefixes]
16+
configparser = "configparser"

examples/example_pkg/_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def func_use_from_elsewhere(a1, a2, a3, a4):
5858
----------
5959
a1 : example_pkg.CustomException
6060
a2 : ExampleClass
61-
a3 : example_pkg.CustomException.NestedClass
61+
a3 : example_pkg._basic.ExampleClass.NestedClass
6262
a4 : ExampleClass.NestedClass
6363
6464
Returns

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ testpaths = [
105105
[tool.coverage]
106106
run.source = ["docstub"]
107107

108-
[tool.docstub.known_imports]
109-
cst = {import = "libcst", as="cst"}
110-
lark = {import = "lark"}
111-
numpydoc = {import = "numpydoc"}
108+
[tool.docstub.type_prefixes]
109+
cst = "libcst"
110+
lark = "lark"
111+
numpydoc = "numpydoc"
112112

113113

114114
[tool.mypy]

src/docstub/_analysis.py

Lines changed: 62 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -410,69 +410,63 @@ def _collect_type_annotation(self, stack):
410410
self.known_imports[qualname] = known_import
411411

412412

413-
class TypesDatabase:
414-
"""A static database of collected types usable as an annotation.
413+
class TypeMatcher:
414+
"""Match strings to collected type information.
415415
416416
Attributes
417417
----------
418-
current_source : Path | None
419-
source_pkgs : list[Path]
420-
known_imports : dict[str, KnownImport]
421-
stats : dict[str, Any]
418+
types : dict[str, KnownImport]
419+
prefixes : dict[str, KnownImport]
420+
aliases : dict[str, str]
421+
successful_queries : int
422+
unknown_qualnames : list
423+
current_module : Path | None
422424
423425
Examples
424426
--------
425-
>>> from docstub._analysis import TypesDatabase, common_known_imports
426-
>>> db = TypesDatabase(known_imports=common_known_imports())
427-
>>> db.query("Any")
427+
>>> from docstub._analysis import TypeMatcher, common_known_imports
428+
>>> db = TypeMatcher()
429+
>>> db.match("Any")
428430
('Any', <KnownImport 'from typing import Any'>)
429431
"""
430432

431433
def __init__(
432434
self,
433435
*,
434-
source_pkgs=None,
435-
known_imports=None,
436+
types=None,
437+
prefixes=None,
438+
aliases=None,
436439
):
437440
"""
438441
Parameters
439442
----------
440-
source_pkgs : list[Path], optional
441-
known_imports : dict[str, KnownImport], optional
442-
If not provided, defaults to imports returned by
443-
:func:`common_known_imports`.
443+
types : dict[str, KnownImport]
444+
prefixes : dict[str, KnownImport]
445+
aliases : dict[str, str]
444446
"""
445-
if source_pkgs is None:
446-
source_pkgs = []
447-
if known_imports is None:
448-
known_imports = common_known_imports()
449-
450-
self.current_source = None
451-
self.source_pkgs = source_pkgs
447+
self.types = types or common_known_imports()
448+
self.prefixes = prefixes or {}
449+
self.aliases = aliases or {}
450+
self.successful_queries = 0
451+
self.unknown_qualnames = []
452452

453-
self.known_imports = known_imports
453+
self.current_module = None
454454

455-
self.stats = {
456-
"successful_queries": 0,
457-
"unknown_doctypes": [],
458-
}
459-
460-
def query(self, search_name):
455+
def match(self, search_name):
461456
"""Search for a known annotation name.
462457
463458
Parameters
464459
----------
465460
search_name : str
461+
current_module : Path, optional
466462
467463
Returns
468464
-------
469-
annotation_name : str | None
470-
If it was found, the name of the annotation that matches the `known_import`.
471-
known_import : KnownImport | None
472-
If it was found, import information matching the `annotation_name`.
465+
type_name : str | None
466+
type_origin : KnownImport | None
473467
"""
474-
annotation_name = None
475-
known_import = None
468+
type_name = None
469+
type_origin = None
476470

477471
if search_name.startswith("~."):
478472
# Sphinx like matching with abbreviated name
@@ -481,63 +475,64 @@ def query(self, search_name):
481475
regex = re.compile(pattern + "$")
482476
# Might be slow, but works for now
483477
matches = {
484-
key: value
485-
for key, value in self.known_imports.items()
486-
if regex.match(key)
478+
key: value for key, value in self.types.items() if regex.match(key)
487479
}
488480
if len(matches) > 1:
489481
shortest_key = sorted(matches.keys(), key=lambda x: len(x))[0]
490-
known_import = matches[shortest_key]
491-
annotation_name = shortest_key
482+
type_origin = matches[shortest_key]
483+
type_name = shortest_key
492484
logger.warning(
493485
"%r in %s matches multiple types %r, using %r",
494486
search_name,
495-
self.current_source,
487+
self.current_module or "<file not known>",
496488
matches.keys(),
497489
shortest_key,
498490
)
499491
elif len(matches) == 1:
500-
annotation_name, known_import = matches.popitem()
492+
type_name, type_origin = matches.popitem()
501493
else:
502494
search_name = search_name[2:]
503495
logger.debug(
504-
"couldn't match %r in %s", search_name, self.current_source
496+
"couldn't match %r in %s",
497+
search_name,
498+
self.current_module or "<file not known>",
505499
)
506500

507-
if known_import is None and self.current_source:
501+
# Replace alias
502+
search_name = self.aliases.get(search_name, search_name)
503+
504+
if type_origin is None and self.current_module:
508505
# Try scope of current module
509-
module_name = module_name_from_path(self.current_source)
506+
module_name = module_name_from_path(self.current_module)
510507
try_qualname = f"{module_name}.{search_name}"
511-
known_import = self.known_imports.get(try_qualname)
512-
if known_import:
513-
annotation_name = search_name
508+
type_origin = self.types.get(try_qualname)
509+
if type_origin:
510+
type_name = search_name
511+
512+
if type_origin is None and search_name in self.types:
513+
type_name = search_name
514+
type_origin = self.types[search_name]
514515

515-
if known_import is None:
516+
if type_origin is None:
516517
# Try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a')
517518
for partial_qualname in reversed(accumulate_qualname(search_name)):
518-
known_import = self.known_imports.get(partial_qualname)
519-
if known_import:
520-
annotation_name = search_name
519+
type_origin = self.prefixes.get(partial_qualname)
520+
if type_origin:
521+
type_name = search_name
521522
break
522523

523524
if (
524-
known_import is not None
525-
and annotation_name is not None
526-
and annotation_name != known_import.target
527-
and not annotation_name.startswith(known_import.target)
525+
type_origin is not None
526+
and type_name is not None
527+
and type_name != type_origin.target
528+
and not type_name.startswith(type_origin.target)
528529
):
529530
# Ensure that the annotation matches the import target
530-
annotation_name = annotation_name[
531-
annotation_name.find(known_import.target) :
532-
]
531+
type_name = type_name[type_name.find(type_origin.target) :]
533532

534-
if annotation_name is not None:
535-
self.stats["successful_queries"] += 1
533+
if type_name is not None:
534+
self.successful_queries += 1
536535
else:
537-
self.stats["unknown_doctypes"].append(search_name)
536+
self.unknown_qualnames.append(search_name)
538537

539-
return annotation_name, known_import
540-
541-
def __repr__(self) -> str:
542-
repr = f"{type(self).__name__}({self.source_pkgs})"
543-
return repr
538+
return type_name, type_origin

src/docstub/_cli.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ._analysis import (
1111
KnownImport,
1212
TypeCollector,
13-
TypesDatabase,
13+
TypeMatcher,
1414
common_known_imports,
1515
)
1616
from ._cache import FileCache
@@ -76,19 +76,18 @@ def _setup_logging(*, verbose):
7676
)
7777

7878

79-
def _build_import_map(config, root_path):
80-
"""Build a map of known imports.
79+
def _collect_types(root_path):
80+
"""Collect types.
8181
8282
Parameters
8383
----------
84-
config : ~.Config
8584
root_path : Path
8685
8786
Returns
8887
-------
89-
imports : dict[str, ~.KnownImport]
88+
types : dict[str, ~.KnownImport]
9089
"""
91-
known_imports = common_known_imports()
90+
types = common_known_imports()
9291

9392
collect_cached_types = FileCache(
9493
func=TypeCollector.collect,
@@ -99,12 +98,10 @@ def _build_import_map(config, root_path):
9998
if root_path.is_dir():
10099
for source_path in walk_python_package(root_path):
101100
logger.info("collecting types in %s", source_path)
102-
known_imports_in_source = collect_cached_types(source_path)
103-
known_imports.update(known_imports_in_source)
104-
105-
known_imports.update(KnownImport.many_from_config(config.known_imports))
101+
types_in_source = collect_cached_types(source_path)
102+
types.update(types_in_source)
106103

107-
return known_imports
104+
return types
108105

109106

110107
@contextmanager
@@ -195,15 +192,26 @@ def main(root_path, out_dir, config_path, group_errors, allow_errors, verbose):
195192
)
196193

197194
config = _load_configuration(config_path)
198-
known_imports = _build_import_map(config, root_path)
195+
196+
types = common_known_imports()
197+
types |= _collect_types(root_path)
198+
types |= {
199+
type_name: KnownImport(import_path=module, import_name=type_name)
200+
for type_name, module in config.types.items()
201+
}
202+
203+
prefixes = {
204+
prefix: (
205+
KnownImport(import_name=module, import_alias=prefix)
206+
if module != prefix
207+
else KnownImport(import_name=prefix)
208+
)
209+
for prefix, module in config.type_prefixes.items()
210+
}
199211

200212
reporter = GroupedErrorReporter() if group_errors else ErrorReporter()
201-
types_db = TypesDatabase(
202-
source_pkgs=[root_path.parent.resolve()], known_imports=known_imports
203-
)
204-
stub_transformer = Py2StubTransformer(
205-
types_db=types_db, replace_doctypes=config.replace_doctypes, reporter=reporter
206-
)
213+
matcher = TypeMatcher(types=types, prefixes=prefixes, aliases=config.type_aliases)
214+
stub_transformer = Py2StubTransformer(matcher=matcher, reporter=reporter)
207215

208216
if not out_dir:
209217
if root_path.is_file():
@@ -246,22 +254,22 @@ def main(root_path, out_dir, config_path, group_errors, allow_errors, verbose):
246254
reporter.print_grouped()
247255

248256
# Report basic statistics
249-
successful_queries = types_db.stats["successful_queries"]
257+
successful_queries = matcher.successful_queries
250258
click.secho(f"{successful_queries} matched annotations", fg="green")
251259

252260
syntax_error_count = stub_transformer.transformer.stats["syntax_errors"]
253261
if syntax_error_count:
254262
click.secho(f"{syntax_error_count} syntax errors", fg="red")
255263

256-
unknown_doctypes = types_db.stats["unknown_doctypes"]
257-
if unknown_doctypes:
258-
click.secho(f"{len(unknown_doctypes)} unknown doctypes", fg="red")
259-
counter = Counter(unknown_doctypes)
264+
unknown_qualnames = matcher.unknown_qualnames
265+
if unknown_qualnames:
266+
click.secho(f"{len(unknown_qualnames)} unknown type names", fg="red")
267+
counter = Counter(unknown_qualnames)
260268
sorted_item_counts = sorted(counter.items(), key=lambda x: x[1], reverse=True)
261269
for item, count in sorted_item_counts:
262270
click.echo(f" {item} (x{count})")
263271

264-
total_errors = len(unknown_doctypes) + syntax_error_count
272+
total_errors = len(unknown_qualnames) + syntax_error_count
265273
total_msg = f"{total_errors} total errors"
266274
if allow_errors:
267275
total_msg = f"{total_msg} (allowed {allow_errors})"

0 commit comments

Comments
 (0)