Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ Click>=7.0
yt>=4.1.0
traitlets>=5.0.5
pyOpenGL>=3.1.5
traittypes>=0.2.1
matplotlib>=3.0
numpy>=1.18.0
pyglet>=2.0.dev0
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ install_requires =
pyglet>=2.0.dev0
pyyaml>=5.3.1
traitlets>=5.0.5
traittypes>=0.2.1
yt>=4.1.0
python_requires = >=3.9
include_package_data = True
Expand Down
22 changes: 11 additions & 11 deletions yt_idv/cameras/base_camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

import numpy as np
import traitlets
import traittypes
from OpenGL import GL

from yt_idv.traitlets_support import YTPositionTrait, ndarray_ro, ndarray_shape
from yt_idv.traitlets_support import (
ArrayTrait,
YTPositionTrait,
ndarray_ro,
ndarray_shape,
)


class BaseCamera(traitlets.HasTraits):
Expand Down Expand Up @@ -35,9 +39,8 @@ class BaseCamera(traitlets.HasTraits):
# operations won't trigger our observation.
position = YTPositionTrait([0.0, 0.0, 1.0])
focus = YTPositionTrait([0.0, 0.0, 0.0])
up = traittypes.Array(np.array([0.0, 0.0, 1.0])).valid(
ndarray_shape(3), ndarray_ro()
)

up = ArrayTrait(np.array([0.0, 0.0, 1.0])).valid(ndarray_shape(3), ndarray_ro())
scroll_delta = traitlets.Float(0.1)
fov = traitlets.Float(45.0)
near_plane = traitlets.Float(0.001)
Expand All @@ -46,14 +49,11 @@ class BaseCamera(traitlets.HasTraits):
1.0
) # This was 8.0/6.0 for a long time. I don't know why.

projection_matrix = traittypes.Array(np.zeros((4, 4))).valid(
projection_matrix = ArrayTrait(np.zeros((4, 4))).valid(
ndarray_shape(4, 4), ndarray_ro()
)
view_matrix = traittypes.Array(np.zeros((4, 4))).valid(
ndarray_shape(4, 4), ndarray_ro()
)
orientation = traittypes.Array(np.zeros(4)).valid(ndarray_shape(4), ndarray_ro())

view_matrix = ArrayTrait(np.zeros((4, 4))).valid(ndarray_shape(4, 4), ndarray_ro())
orientation = ArrayTrait(np.zeros(4)).valid(ndarray_shape(4), ndarray_ro())
held = traitlets.Bool(False)

@contextlib.contextmanager
Expand Down
8 changes: 4 additions & 4 deletions yt_idv/opengl_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
import matplotlib.pyplot as plt
import numpy as np
import traitlets
import traittypes
from OpenGL import GL

# Set up a mapping from numbers to names
from yt.utilities.math_utils import get_scale_matrix, get_translate_matrix

from yt_idv._cmyt_utilities import validate_cmyt_name
from yt_idv.constants import bbox_vertices
from yt_idv.traitlets_support import ArrayTrait

const_types = (
GL.constant.IntConstant,
Expand Down Expand Up @@ -164,7 +164,7 @@ def validate(self, obj, value):

class Texture(traitlets.HasTraits):
texture_name = traitlets.CInt(-1)
data = traittypes.Array(None, allow_none=True)
data = ArrayTrait(None, allow_none=True)
channels = GLValue("r32f")
min_filter = GLValue("linear")
mag_filter = GLValue("linear")
Expand Down Expand Up @@ -335,7 +335,7 @@ def _set_data(self, change):
class VertexAttribute(traitlets.HasTraits):
name = traitlets.CUnicode("attr")
id = traitlets.CInt(-1)
data = traittypes.Array(None, allow_none=True)
data = ArrayTrait(None, allow_none=True)
each = traitlets.CInt(-1)
opengl_type = traitlets.CInt(GL.GL_FLOAT)
divisor = traitlets.CInt(0)
Expand Down Expand Up @@ -372,7 +372,7 @@ def _set_data(self, change):
class VertexArray(traitlets.HasTraits):
name = traitlets.CUnicode("vertex")
id = traitlets.CInt(-1)
indices = traittypes.Array(None, allow_none=True)
indices = ArrayTrait(None, allow_none=True)
index_id = traitlets.CInt(-1)
attributes = traitlets.List(trait=traitlets.Instance(VertexAttribute))
each = traitlets.CInt(-1)
Expand Down
12 changes: 6 additions & 6 deletions yt_idv/scene_data/curve.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import numpy as np
import traitlets
from traittypes import Array

from yt_idv.opengl_support import VertexArray, VertexAttribute
from yt_idv.scene_data.base_data import SceneData
from yt_idv.traitlets_support import ArrayTrait


class CurveData(SceneData):
Expand All @@ -12,7 +12,7 @@ class CurveData(SceneData):
"""

name = "curve_data"
data = Array()
data = ArrayTrait(allow_none=True)
n_vertices = traitlets.CInt()

@traitlets.default("vertex_array")
Expand Down Expand Up @@ -54,7 +54,7 @@ class CurveCollection(CurveData):
"""Data component for a collection of curves"""

name = "curve_collection"
data = Array()
data = ArrayTrait(allow_none=True)
n_vertices = traitlets.CInt()

def add_curve(self, curve):
Expand All @@ -79,10 +79,10 @@ def add_curve(self, curve):
data = curve[line_indices]
data = np.column_stack([data, np.ones((data.shape[0],))])

if self.data.shape:
self.data = np.concatenate([self.data, data])
else:
if self.data is None:
self.data = data
else:
self.data = np.concatenate([self.data, data])

def add_data(self):
"""
Expand Down
29 changes: 29 additions & 0 deletions yt_idv/tests/test_array_traits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import numpy as np
import traitlets

from yt_idv.traitlets_support import ArrayTrait, ndarray_shape


def test_array_trait():

shp = (4, 3)
x = np.ones(shp)

def extra(obj, value):
# arbitrary to make sure the logic is working
return value * 2.0

class UsefulTestClass(traitlets.HasTraits):

array_no_args = ArrayTrait(allow_none=True)
array_with_default = ArrayTrait(x)
array_with_default_valid = ArrayTrait(x).valid(ndarray_shape(*shp))
two_x = ArrayTrait(x).valid(extra)
two_x_chained = ArrayTrait(x).valid(ndarray_shape(*shp), extra)

utc = UsefulTestClass()
assert utc.array_no_args is None
assert np.all(utc.array_with_default == x)
assert np.all(utc.array_with_default_valid == x)
assert np.all(utc.two_x == 2 * x)
assert np.all(utc.two_x_chained == 2 * x)
36 changes: 36 additions & 0 deletions yt_idv/traitlets_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,39 @@ def validate(self, obj, value):
except FileNotFoundError:
self.error(obj, value)
return value


class ArrayTrait(traitlets.TraitType):

# a replacement for the un-maintained traittypes.Array, loosely
# based off of the implementation at https://github.com/jupyter-widgets/traittypes
info_text = "A numpy array"

def __init__(self, default_value=None, **kwargs):
if default_value is not None:
default_value = np.asarray(default_value)
super().__init__(default_value=default_value, **kwargs)
self.validators = []

def valid(self, *args):
self.validators.extend(args)
return self

def validate(self, obj, value):
if self.allow_none and value is None:
return value

if self.allow_none is False and value is None:
self.error(obj, value)

if not isinstance(value, np.ndarray):
# try to coerce whatever it is
value = np.asarray(value)

for validator in self.validators:
value = validator(obj, value)

if not isinstance(value, np.ndarray):
self.error(obj, value)

return value