diff --git a/unyt/array.py b/unyt/array.py index fb5c0fde..6317732a 100644 --- a/unyt/array.py +++ b/unyt/array.py @@ -505,6 +505,25 @@ def wrapper(self, ufunc, method, *inputs, **kwargs): LARGE_INPUT = {4: 16777217, 8: 9007199254740993} +def _update_array_dtype_inplace(array, *, dtype) -> None: + # change the dtypes in-place, this does not change the + # underlying memory buffer + dt = np.dtype(dtype) + assert dt.kind == "f" + if NUMPY_VERSION >= Version("2.5.0dev0"): + array._set_dtype(dt) + else: + array.dtype = dt + + +def _update_array_values_inplace(array, *, dtype) -> None: + new_values = array.astype(dtype) + _update_array_dtype_inplace(array, dtype=dtype) + # actually fill in the new float values now that our + # dtype is correct + np.copyto(array, new_values) + + class unyt_array(np.ndarray): """ An ndarray subclass that attaches a symbolic unit object to the array data. @@ -795,7 +814,6 @@ def convert_to_units(self, units, equivalence=None, **kwargs): f"Input dtype ({self.dtype}) has a smaller itemsize than the " "smallest floating point representation possible." ) - new_dtype = "f" + str(dsize) large = LARGE_INPUT.get(dsize, 0) if large and np.any(np.abs(values) > large): warnings.warn( @@ -803,14 +821,8 @@ def convert_to_units(self, units, equivalence=None, **kwargs): RuntimeWarning, stacklevel=2, ) - float_values = values.astype(new_dtype) - # change the dtypes in-place, this does not change the - # underlying memory buffer - values.dtype = new_dtype - self.dtype = new_dtype - # actually fill in the new float values now that our - # dtype is correct - np.copyto(values, float_values) + _update_array_values_inplace(values, dtype=f"f{dsize}") + _update_array_dtype_inplace(self, dtype=values.dtype) values *= conv_factor if offset: @@ -1884,10 +1896,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): else: out = out[0] if out.dtype.kind in ("u", "i"): - new_dtype = "f" + str(out.dtype.itemsize) - float_values = out.astype(new_dtype) - out.dtype = new_dtype - np.copyto(out, float_values) + _update_array_values_inplace(out, dtype=f"f{out.dtype.itemsize}") out_func = out.view(np.ndarray) if len(inputs) == 1: # Unary ufuncs