Skip to content

Commit b7c8521

Browse files
committed
Add Radial and initial Grid
1 parent ae909a3 commit b7c8521

15 files changed

Lines changed: 996 additions & 68 deletions

File tree

MDANSE/Src/MDANSE/Core/SubclassFactory.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
# You should have received a copy of the GNU General Public License
1414
# along with this program. If not, see <https://www.gnu.org/licenses/>.
1515
#
16-
17-
from typing import TypeVar
1816
import difflib
19-
17+
from abc import ABC
18+
from collections.abc import Callable, Sequence
19+
from typing import Dict, Generic, TypeVar, Union
2020

2121
Self = TypeVar("Self", bound="SubclassFactory")
22+
T = TypeVar("T")
23+
2224
# The Self TypeVar is a typing hint indicating that
2325
# a method of a class A will be returning an object
2426
# of type A as well. Since we don't know for which class
@@ -201,3 +203,75 @@ def indirect_subclass_dictionary(cls):
201203
dict(str:type) -- a dictionary of the subclasses of this class
202204
"""
203205
return recursive_dict(cls)
206+
207+
class RegisterFactory(ABC, Generic[T]):
208+
"""Alternative factory which uses explicit registration."""
209+
210+
_registered_subclasses: Dict[str, type[T]] = {}
211+
212+
@classmethod
213+
def get(cls, name: str) -> type[T]:
214+
"""Get the class from the registry.
215+
216+
Parameters
217+
----------
218+
name : str
219+
Name of class to retrieve.
220+
221+
Returns
222+
-------
223+
type[T]
224+
Class ready to instantiate.
225+
226+
Raises
227+
------
228+
KeyError
229+
Name not found in registry.
230+
"""
231+
if name not in cls._registered_subclasses:
232+
raise KeyError(f"Cannot instantiate class {name}, "
233+
f"not in registered subclasses ({','.join(cls._registered_subclasses)})")
234+
return cls._registered_subclasses[name]
235+
236+
@classmethod
237+
def create(cls, name: str, *args, **kwargs) -> T:
238+
"""Instantiate child class from registry.
239+
240+
Parameters
241+
----------
242+
name : str
243+
Name to instantiate.
244+
"""
245+
return cls.get(name)(*args, **kwargs)
246+
247+
@classmethod
248+
@property
249+
def subclasses(cls):
250+
return list(cls._registered_subclasses.keys())
251+
252+
@classmethod
253+
def register(cls, names: Union[str, Sequence[str]]):
254+
"""Register class as factory member.
255+
256+
Parameters
257+
----------
258+
names : str | Sequence[str]
259+
Internal name(s) to use.
260+
261+
Examples
262+
--------
263+
.. code-block:: python
264+
@ClassFactory.register("elem")
265+
class MyClass(): ...
266+
267+
"""
268+
def class_wrapper(wrapped_class: type) -> Callable[..., T]:
269+
270+
if isinstance(names, str):
271+
cls._registered_subclasses[names] = wrapped_class
272+
else:
273+
for name in names:
274+
cls._registered_subclasses[name] = wrapped_class
275+
return wrapped_class
276+
277+
return class_wrapper
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from typing import Optional
2+
3+
import numpy as np
4+
from MDANSE.Framework.NewQVectors.QVector import (QVecGen,
5+
QVecGeneratorProtocol,
6+
QVectorData,
7+
QVectorGenerator)
8+
from MDANSE.MolecularDynamics.UnitCell import UnitCell
9+
from numpy.typing import ArrayLike
10+
11+
@QVectorGenerator.register("ListQVectors")
12+
class ListQVectors(QVectorGenerator):
13+
"""Return Q vectors from a provided list of vectors.
14+
15+
Parameters
16+
----------
17+
qvectors : ArrayLike
18+
Sequence of vectors to return.
19+
hkl : bool
20+
Whether vectors are provided in reciprocal lattice units.
21+
lattice : Optional[UnitCell]
22+
Lattice to generate within.
23+
24+
Raises
25+
------
26+
ValueError
27+
If hkl and no lattice provided.
28+
29+
Examples
30+
--------
31+
>>> qvec = ListQVectors([[1, 2, 3], [2, 0, 5]])
32+
>>> for i in qvec:
33+
... print(i.q)
34+
[1. 2. 3.]
35+
[2. 0. 5.]
36+
"""
37+
38+
def __init__(self,
39+
qvectors: ArrayLike,
40+
*,
41+
lattice: Optional[UnitCell] = None,
42+
**kwargs):
43+
44+
super().__init__(lattice=lattice, **kwargs)
45+
self.qvectors = np.array(qvectors)
46+
47+
if self.qvectors.shape != (len(self.qvectors), 3):
48+
raise ValueError(f"`qvectors` ({qvectors.shape}) must be an (N, 3) sequence.")
49+
50+
51+
def generate(self, lattice: Optional[UnitCell] = None) -> QVecGen:
52+
lattice = lattice if lattice is not None else self.lattice
53+
constructor = self.qvec_gen
54+
55+
while self._ind < len(self):
56+
new_ind = yield constructor(self.qvectors[self._ind], lattice)
57+
58+
self._ind += 1
59+
if new_ind is not None:
60+
self.reset(new_ind)
61+
62+
def __len__(self) -> int:
63+
return len(self.qvectors)
64+
65+
def reset(self, value: int = 0):
66+
super().reset(value)
67+
68+
69+
@QVectorGenerator.register("GeneratorQVectors")
70+
class GeneratorQVectors(QVectorGenerator):
71+
"""Return Q Vectors as generated from a generator function.
72+
73+
Parameters
74+
----------
75+
qvectors : QVecGen | Generator[ArrayLike, int, None]
76+
Generator which returns QVectorData.
77+
lattice : Optional[UnitCell]
78+
Lattice in which to generate Q Vectors.
79+
80+
Examples
81+
--------
82+
FIXME: Add docs.
83+
84+
"""
85+
def __init__(self,
86+
qvectors: QVecGeneratorProtocol,
87+
*args,
88+
lattice: Optional[UnitCell] = None,
89+
returns_qvec_data: bool = True,
90+
**kwargs):
91+
super().__init__(lattice=lattice, **kwargs)
92+
93+
self.generator = qvectors
94+
self.args = args
95+
self.returns_qvec_data = returns_qvec_data
96+
97+
def generate(self, lattice: Optional[UnitCell] = None) -> QVecGen:
98+
lattice = lattice if lattice is not None else self.lattice
99+
constructor = self.qvec_gen
100+
101+
for qvec in self.generator(*self.args, lattice):
102+
yield qvec if self.returns_qvec_data else constructor(qvec, lattice)
103+
104+
def reset(self, val: None = None):
105+
raise NotImplementedError("Cannot reset this generator")

0 commit comments

Comments
 (0)