Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions unyt/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,25 @@ def wrapper(self, ufunc, method, *inputs, **kwargs):
LARGE_INPUT = {4: 16777217, 8: 9007199254740993}


def _update_array_dtype_inplace(array, *, dtype) -> None:
# change the dtypes in-place, this does not change the
# underlying memory buffer
dt = np.dtype(dtype)
assert dt.kind == "f"
if NUMPY_VERSION >= Version("2.5.0dev0"):
array._set_dtype(dt)
else:
array.dtype = dt


def _update_array_values_inplace(array, *, dtype) -> None:
new_values = array.astype(dtype)
_update_array_dtype_inplace(array, dtype=dtype)
# actually fill in the new float values now that our
# dtype is correct
np.copyto(array, new_values)


class unyt_array(np.ndarray):
"""
An ndarray subclass that attaches a symbolic unit object to the array data.
Expand Down Expand Up @@ -795,22 +814,15 @@ def convert_to_units(self, units, equivalence=None, **kwargs):
f"Input dtype ({self.dtype}) has a smaller itemsize than the "
"smallest floating point representation possible."
)
new_dtype = "f" + str(dsize)
large = LARGE_INPUT.get(dsize, 0)
if large and np.any(np.abs(values) > large):
warnings.warn(
f"Overflow encountered while converting to units '{new_units}'",
RuntimeWarning,
stacklevel=2,
)
float_values = values.astype(new_dtype)
# change the dtypes in-place, this does not change the
# underlying memory buffer
values.dtype = new_dtype
self.dtype = new_dtype
# actually fill in the new float values now that our
# dtype is correct
np.copyto(values, float_values)
_update_array_values_inplace(values, dtype=f"f{dsize}")
_update_array_dtype_inplace(self, dtype=values.dtype)
values *= conv_factor

if offset:
Expand Down Expand Up @@ -1884,10 +1896,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
else:
out = out[0]
if out.dtype.kind in ("u", "i"):
new_dtype = "f" + str(out.dtype.itemsize)
float_values = out.astype(new_dtype)
out.dtype = new_dtype
np.copyto(out, float_values)
_update_array_values_inplace(out, dtype=f"f{out.dtype.itemsize}")
out_func = out.view(np.ndarray)
if len(inputs) == 1:
# Unary ufuncs
Expand Down