diff --git a/docs/requirements.txt b/docs/requirements.txt index e233fb9a..d13c98df 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -2,7 +2,6 @@ Click>=7.0 yt>=4.1.0 traitlets>=5.0.5 pyOpenGL>=3.1.5 -traittypes>=0.2.1 matplotlib>=3.0 numpy>=1.18.0 pyglet>=2.0.dev0 diff --git a/setup.cfg b/setup.cfg index edc4eeed..458c29e5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,7 +34,6 @@ install_requires = pyglet>=2.0.dev0 pyyaml>=5.3.1 traitlets>=5.0.5 - traittypes>=0.2.1 yt>=4.1.0 python_requires = >=3.9 include_package_data = True diff --git a/yt_idv/cameras/base_camera.py b/yt_idv/cameras/base_camera.py index 2f087605..9c9d7bc5 100644 --- a/yt_idv/cameras/base_camera.py +++ b/yt_idv/cameras/base_camera.py @@ -2,10 +2,14 @@ import numpy as np import traitlets -import traittypes from OpenGL import GL -from yt_idv.traitlets_support import YTPositionTrait, ndarray_ro, ndarray_shape +from yt_idv.traitlets_support import ( + ArrayTrait, + YTPositionTrait, + ndarray_ro, + ndarray_shape, +) class BaseCamera(traitlets.HasTraits): @@ -35,9 +39,8 @@ class BaseCamera(traitlets.HasTraits): # operations won't trigger our observation. position = YTPositionTrait([0.0, 0.0, 1.0]) focus = YTPositionTrait([0.0, 0.0, 0.0]) - up = traittypes.Array(np.array([0.0, 0.0, 1.0])).valid( - ndarray_shape(3), ndarray_ro() - ) + + up = ArrayTrait(np.array([0.0, 0.0, 1.0])).valid(ndarray_shape(3), ndarray_ro()) scroll_delta = traitlets.Float(0.1) fov = traitlets.Float(45.0) near_plane = traitlets.Float(0.001) @@ -46,14 +49,11 @@ class BaseCamera(traitlets.HasTraits): 1.0 ) # This was 8.0/6.0 for a long time. I don't know why. - projection_matrix = traittypes.Array(np.zeros((4, 4))).valid( + projection_matrix = ArrayTrait(np.zeros((4, 4))).valid( ndarray_shape(4, 4), ndarray_ro() ) - view_matrix = traittypes.Array(np.zeros((4, 4))).valid( - ndarray_shape(4, 4), ndarray_ro() - ) - orientation = traittypes.Array(np.zeros(4)).valid(ndarray_shape(4), ndarray_ro()) - + view_matrix = ArrayTrait(np.zeros((4, 4))).valid(ndarray_shape(4, 4), ndarray_ro()) + orientation = ArrayTrait(np.zeros(4)).valid(ndarray_shape(4), ndarray_ro()) held = traitlets.Bool(False) @contextlib.contextmanager diff --git a/yt_idv/opengl_support.py b/yt_idv/opengl_support.py index f9cb400b..0f60d3f3 100644 --- a/yt_idv/opengl_support.py +++ b/yt_idv/opengl_support.py @@ -18,7 +18,6 @@ import matplotlib.pyplot as plt import numpy as np import traitlets -import traittypes from OpenGL import GL # Set up a mapping from numbers to names @@ -26,6 +25,7 @@ from yt_idv._cmyt_utilities import validate_cmyt_name from yt_idv.constants import bbox_vertices +from yt_idv.traitlets_support import ArrayTrait const_types = ( GL.constant.IntConstant, @@ -164,7 +164,7 @@ def validate(self, obj, value): class Texture(traitlets.HasTraits): texture_name = traitlets.CInt(-1) - data = traittypes.Array(None, allow_none=True) + data = ArrayTrait(None, allow_none=True) channels = GLValue("r32f") min_filter = GLValue("linear") mag_filter = GLValue("linear") @@ -335,7 +335,7 @@ def _set_data(self, change): class VertexAttribute(traitlets.HasTraits): name = traitlets.CUnicode("attr") id = traitlets.CInt(-1) - data = traittypes.Array(None, allow_none=True) + data = ArrayTrait(None, allow_none=True) each = traitlets.CInt(-1) opengl_type = traitlets.CInt(GL.GL_FLOAT) divisor = traitlets.CInt(0) @@ -372,7 +372,7 @@ def _set_data(self, change): class VertexArray(traitlets.HasTraits): name = traitlets.CUnicode("vertex") id = traitlets.CInt(-1) - indices = traittypes.Array(None, allow_none=True) + indices = ArrayTrait(None, allow_none=True) index_id = traitlets.CInt(-1) attributes = traitlets.List(trait=traitlets.Instance(VertexAttribute)) each = traitlets.CInt(-1) diff --git a/yt_idv/scene_data/curve.py b/yt_idv/scene_data/curve.py index 035481d4..9a17be66 100644 --- a/yt_idv/scene_data/curve.py +++ b/yt_idv/scene_data/curve.py @@ -1,9 +1,9 @@ import numpy as np import traitlets -from traittypes import Array from yt_idv.opengl_support import VertexArray, VertexAttribute from yt_idv.scene_data.base_data import SceneData +from yt_idv.traitlets_support import ArrayTrait class CurveData(SceneData): @@ -12,7 +12,7 @@ class CurveData(SceneData): """ name = "curve_data" - data = Array() + data = ArrayTrait(allow_none=True) n_vertices = traitlets.CInt() @traitlets.default("vertex_array") @@ -54,7 +54,7 @@ class CurveCollection(CurveData): """Data component for a collection of curves""" name = "curve_collection" - data = Array() + data = ArrayTrait(allow_none=True) n_vertices = traitlets.CInt() def add_curve(self, curve): @@ -79,10 +79,10 @@ def add_curve(self, curve): data = curve[line_indices] data = np.column_stack([data, np.ones((data.shape[0],))]) - if self.data.shape: - self.data = np.concatenate([self.data, data]) - else: + if self.data is None: self.data = data + else: + self.data = np.concatenate([self.data, data]) def add_data(self): """ diff --git a/yt_idv/tests/test_array_traits.py b/yt_idv/tests/test_array_traits.py new file mode 100644 index 00000000..af430d97 --- /dev/null +++ b/yt_idv/tests/test_array_traits.py @@ -0,0 +1,29 @@ +import numpy as np +import traitlets + +from yt_idv.traitlets_support import ArrayTrait, ndarray_shape + + +def test_array_trait(): + + shp = (4, 3) + x = np.ones(shp) + + def extra(obj, value): + # arbitrary to make sure the logic is working + return value * 2.0 + + class UsefulTestClass(traitlets.HasTraits): + + array_no_args = ArrayTrait(allow_none=True) + array_with_default = ArrayTrait(x) + array_with_default_valid = ArrayTrait(x).valid(ndarray_shape(*shp)) + two_x = ArrayTrait(x).valid(extra) + two_x_chained = ArrayTrait(x).valid(ndarray_shape(*shp), extra) + + utc = UsefulTestClass() + assert utc.array_no_args is None + assert np.all(utc.array_with_default == x) + assert np.all(utc.array_with_default_valid == x) + assert np.all(utc.two_x == 2 * x) + assert np.all(utc.two_x_chained == 2 * x) diff --git a/yt_idv/traitlets_support.py b/yt_idv/traitlets_support.py index 327cfe28..8a56479c 100644 --- a/yt_idv/traitlets_support.py +++ b/yt_idv/traitlets_support.py @@ -52,3 +52,39 @@ def validate(self, obj, value): except FileNotFoundError: self.error(obj, value) return value + + +class ArrayTrait(traitlets.TraitType): + + # a replacement for the un-maintained traittypes.Array, loosely + # based off of the implementation at https://github.com/jupyter-widgets/traittypes + info_text = "A numpy array" + + def __init__(self, default_value=None, **kwargs): + if default_value is not None: + default_value = np.asarray(default_value) + super().__init__(default_value=default_value, **kwargs) + self.validators = [] + + def valid(self, *args): + self.validators.extend(args) + return self + + def validate(self, obj, value): + if self.allow_none and value is None: + return value + + if self.allow_none is False and value is None: + self.error(obj, value) + + if not isinstance(value, np.ndarray): + # try to coerce whatever it is + value = np.asarray(value) + + for validator in self.validators: + value = validator(obj, value) + + if not isinstance(value, np.ndarray): + self.error(obj, value) + + return value