Skip to content
12 changes: 4 additions & 8 deletions Doc/docstrings_to_rst.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@

import numpy as np

from MDANSE.Framework.Configurators.IConfigurator import IConfigurator
from MDANSE.Framework.Jobs.IJob import IJob
from MDANSE.Framework.Converters.Converter import Converter
from MDANSE.Framework.QVectors.IQVectors import IQVectors
from MDANSE.MolecularDynamics.UnitCell import UnitCell


configurators = sorted(IConfigurator.indirect_subclasses())
converters = sorted(Converter.indirect_subclasses())
jobs = sorted(IJob.indirect_subclasses())
generators = sorted(set(IQVectors.indirect_subclasses()) - {"IQVectors", "LatticeQVectors"})
configurators = sorted(IConfigurator.available_names())
converters = sorted(Converter.available_names())
jobs = sorted(IJob.available_names())

job_inputs = {}
converter_inputs = {}
Expand Down Expand Up @@ -62,7 +58,7 @@ def make_configurator_doc(conf: str, parent: str) -> str:
job_inputs[conf] = make_configurator_doc(conf, "analysis")
result += f"- {iname}: :ref:`configurator-analysis-{conf}` default={defval}\n"
job_page.append(result)


for job in converters:
result = ""
Expand Down
283 changes: 121 additions & 162 deletions MDANSE/Src/MDANSE/Core/SubclassFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from __future__ import annotations

import difflib
from typing import TypeVar
from collections.abc import Callable, Iterable, Sequence
from typing import Any, ClassVar, Generic, ParamSpec, TypeVar, get_args

from more_itertools import always_iterable

Self = TypeVar("Self", bound="SubclassFactory")
# The Self TypeVar is a typing hint indicating that
# a method of a class A will be returning an object
# of type A as well. Since we don't know for which class
Expand All @@ -27,177 +29,134 @@
# NOTE: the later versions of Python (3.11) define Self
# as a type explicitly, but for now we have to define it
# ourselves.
Self = TypeVar("Self", bound="RegisterFactory")
P = ParamSpec("P")


def single_search(parent_class: type, name: str, case_sensitive: bool = False):
"""Finds a subclass of a parent class in the
by searching the _registered_subclasses dictionary.
class RegisterFactory:
"""
Factory requiring manual registration to data.

Arguments:
parent_class (type) -- a class with SubclassFactory metaclass
name (str) -- name of the child class to be found
Attributes
----------
registry : dict[str, Callable]
Dictionary of keys to names.

Returns:
A class (type) or None
"""
for skey in parent_class._registered_subclasses:
if case_sensitive:
lhand = skey
rhand = name
else:
lhand = str(skey).lower()
rhand = name.lower()
if lhand == rhand:
return parent_class._registered_subclasses[skey]

return None


def recursive_search(parent_class: type, name: str):
"""Recursively searches _registered_subclasses dictionaries,
allowing the parent class to find a subclass of a subclass as
well as direct subclasses.

Arguments:
parent_class (type) -- a class with SubclassFactory metaclass
name (str) -- name of the child class to be found

Returns:
A class (type) or None
"""
return_type = single_search(parent_class, name)
if return_type is not None:
return return_type
else:
for child in parent_class._registered_subclasses:
return_type = recursive_search(
parent_class._registered_subclasses[child], name
)
if return_type is not None:
return return_type


def recursive_keys(parent_class: type) -> list:
"""Returns a list of class names of all the subclasses
of a class created with SubclassFactory metaclass.
This includes subclasses of subclasses.

Arguments:
parent_class (type) -- a class with SubclassFactory metaclass

Returns:
A list of class names (str)
"""
try:
results = parent_class.subclasses()
except Exception:
return []
else:
for child in parent_class.subclasses():
results += recursive_keys(parent_class._registered_subclasses[child])
return results


def recursive_dict(parent_class: type) -> dict:
"""Returns a dictionary of {str: type}
of classes derived from the parent_class. The class name (str)
is the key, and the class itself is a value.
This way all the subclasses of a class built with SubclassFactory
can be found, even if they are not _directly_ derived from
the parent class.

Arguments:
parent_class (type) -- a class with SubclassFactory metaclass

Returns:
A dictionary {str: type} of class_name:class pairs
"""
try:
results = {
ckey: parent_class._registered_subclasses[ckey]
for ckey in parent_class.subclasses()
}
except Exception:
return {}
else:
for child in parent_class.subclasses():
newdict = recursive_dict(parent_class._registered_subclasses[child])
results = {**results, **newdict}
return results


class SubclassFactory(type):
"""A metaclass which gives a class the ability to keep track of
its subclasses, and to work as a factory.
See Also
--------
RegisterFactory.register : Registration mechanism.
"""

def __init__(cls, name, base, dct, **kwargs):
super().__init__(name, base, dct)
# Add the registry attribute to the each new child class.
cls._registered_subclasses = {}

@classmethod
def __init_subclass__(cls, **kwargs):
regkey = cls.__name__
super().__init_subclass__(**kwargs)
cls._registered_subclasses[regkey] = cls

# Assign the nested classmethod to the "__init_subclass__" attribute
# of each child class.
# It isn't needed in the terminal children too.
# May be there is a way to avoid adding these needless attributes
# (registry, __init_subclass__) to there. I don't think about it yet.
cls.__init_subclass__ = __init_subclass__

def create(cls, name: str, *args, **kwargs) -> Self:
"""Finds the class called 'name' in the _registered_subclasses
dictionary of the parent class, and returns an instance
of that class with the *args, **kwargs passed to the constructor.

Arguments:
name (str) -- Name of the subclass to be created
*args -- arguments for the subclass constructor
**kwargs -- keyword arguments for the subclass constructor

Returns:
Self-type object - an instance of the requested class.
@classmethod
def instance(cls, key: str) -> Callable[P, type[Self]]:
"""
Return a callable instance to construct given class.
"""
return cls.registry[key]

@classmethod
def create(cls, key: str, *args: P.args, **kwargs: P.kwargs) -> type[Self]:
"""
try:
specific_class = cls._registered_subclasses[name]
except KeyError:
specific_class = recursive_search(cls, name)
if specific_class is None:
subclasses = [i.lower() for i in cls.indirect_subclasses()]
closest = difflib.get_close_matches(name.lower(), subclasses)
err_str = f"Could not find {name} in {cls.__name__}."
if len(closest) > 0:
err_str += f" Did you mean: {closest[0]}?"
raise ValueError(err_str)
return specific_class(*args, **kwargs)

def subclasses(cls):
"""Returns a list of class names that are derived
from this class.

Returns:
list(str) -- a list of the subclasses of this class
Return an instance of given class.
"""
return list(cls._registered_subclasses.keys())
return cls.instance(key)(*args, **kwargs)

def indirect_subclasses(cls):
"""Returns an extended list of class names that are derived
from this class, including subclasses of subclasses
@classmethod
def available_names(cls) -> Sequence[str]:
"""
Known names supported by factory.

Returns
-------
~collections.abc.Sequence[str]
Available keys to load.
"""
return cls.registry.keys()

@classmethod
def raw_dict(cls) -> dict[str, type[Self]]:
"""
Get raw name dictionary.

Returns
-------
~collections.abc.Sequence[str]
Available keys to load.

Notes
-----
Only available on cases where registry is UCDict.
"""
if not hasattr(cls.registry, "raw_dict"):
raise TypeError("No raw names available for class")
return cls.registry.raw_dict

Returns:
list(str) -- a list of the subclasses of this class
@classmethod
def raw_names(cls) -> Sequence[str]:
"""
return recursive_keys(cls)
Get raw names of classes.

def indirect_subclass_dictionary(cls):
"""Returns a {name(str): class(type)} dictionary of classes derived
from this class, including subclasses of subclasses.
Returns
-------
~collections.abc.Sequence[str]
Available keys to load.

Returns:
dict(str:type) -- a dictionary of the subclasses of this class
Notes
-----
Only available on cases where registry is UCDict.
"""
return recursive_dict(cls)
if not hasattr(cls.registry, "raw_mapping"):
raise TypeError("No raw names available for class")
return cls.registry.raw_mapping.values()

@classmethod
def available_classes(cls) -> set[type[Self]]:
"""
Known classes supported by factory.

Returns
-------
set[type[Self]]
Available classes to load.
"""
return set(cls.registry.values())

@classmethod
def registered(cls):
return cls.registry.copy()

@classmethod
def register(cls, names: str | Iterable[str]) -> Callable[P, type[Self]]:
"""
A class level decorator for registering classes.

The names of the modules with which the class is registered
should be the parameter passed to the decorator.

Parameters
----------
names : str
The names of the modules with are registered

Example
-------
To register the ``SphericalQVectors`` class with ``IQVectors``:

.. code-block:: python

class IQVectors(RegisterFactory[IQVectors]): ...

@IQVectors.register('SphericalQVectors')
class SphericalQVectors(Observable): ...
"""

def class_wrapper(wrapped_class: type) -> Callable[P, type[Self]]:
for name in always_iterable(names):
if name in cls.registry:
raise KeyError(
f"{name!r} already in registry. Over-riding registry is forbidden."
)
cls.registry[name] = wrapped_class
return wrapped_class

return class_wrapper
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
from ase.io.formats import all_formats

from MDANSE import PLATFORM
from MDANSE.Framework.Configurators.IConfigurator import IConfigurator
from MDANSE.Framework.Configurators.InputFileConfigurator import InputFileConfigurator
from MDANSE.MLogging import LOG


@IConfigurator.register("AseInputFileConfigurator")
class AseInputFileConfigurator(InputFileConfigurator):
"""Sets an input file for the ASE-based converters."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from MDANSE.Framework.Configurators.IConfigurator import IConfigurator


@IConfigurator.register("AtomMappingConfigurator")
class AtomMappingConfigurator(IConfigurator):
"""The atom mapping configurator for trajectory converters.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from MDANSE.Framework.Configurators.IConfigurator import IConfigurator


@IConfigurator.register("AtomSelectionConfigurator")
class AtomSelectionConfigurator(IConfigurator):
"""Selects atoms in trajectory based on the input string.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def reset_setting(self) -> None:
self.selector.reset()


@IConfigurator.register("AtomTransmutationConfigurator")
class AtomTransmutationConfigurator(IConfigurator):
"""Assigns different chemical elements to selected atoms.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
#
from __future__ import annotations

from MDANSE.Framework.Configurators.IConfigurator import IConfigurator
from MDANSE.Framework.Configurators.MoleculeSelectionConfigurator import (
MoleculeSelectionConfigurator,
)


@IConfigurator.register("AxisSelectionConfigurator")
class AxisSelectionConfigurator(MoleculeSelectionConfigurator):
"""Defines a local axis in a molecule.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from MDANSE.Framework.Configurators.IConfigurator import IConfigurator


@IConfigurator.register("BooleanConfigurator")
class BooleanConfigurator(IConfigurator):
"""Sets a value to a logical True or False.

Expand Down
Loading