From d25abafa5f46914094c63ab2684264f4049fb489 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Thu, 12 Feb 2026 17:57:01 +0000 Subject: [PATCH 01/54] Added jpk-qi-data loading functionality. Two possible channels: by trigger point or by contact point --- AFMReader/general_loader.py | 4 +- AFMReader/jpk_qi.py | 116 ++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 AFMReader/jpk_qi.py diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index be27937..b8ca35d 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -4,7 +4,7 @@ import numpy.typing as npt -from AFMReader import asd, gwy, h5_jpk, ibw, jpk, spm, stp, top, topostats +from AFMReader import asd, gwy, h5_jpk, ibw, jpk, spm, stp, top, topostats, jpk_qi from AFMReader.logging import logger logger.enable(__package__) @@ -66,6 +66,8 @@ def load(self) -> tuple[npt.NDArray | str, float | None]: # noqa: C901 image, pixel_to_nanometre_scaling_factor = spm.load_spm(self.filepath, self.channel) elif self.suffix == ".h5-jpk": image, pixel_to_nanometre_scaling_factor, _ = h5_jpk.load_h5jpk(self.filepath, self.channel) + elif self.suffix == ".jpk-qi-data": + image, pixel_to_nanometre_scaling_factor = jpk_qi.load_jpk_qi(self.filepath, self.channel) elif self.suffix == ".stp": image, pixel_to_nanometre_scaling_factor = stp.load_stp(self.filepath) elif self.suffix == ".top": diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py new file mode 100644 index 0000000..d59a8f3 --- /dev/null +++ b/AFMReader/jpk_qi.py @@ -0,0 +1,116 @@ +import afmformats +from pathlib import Path +import numpy as np +from AFMReader.logging import logger + + + +def load_jpk_qi( + file_path: Path | str, + channel: str, + flip_image: bool | None = True +) -> tuple[np.ndarray, float]: + + # Load the file path passed to the function + curves_data = _fetch_qi_data(file_path) + + if curves_data is None or len(curves_data) == 0: + print("No data found in the file.") + return + + image, px2nm = _process_jpk_qi_data(curves_data, channel) + + if flip_image: + image = np.flipud(image) + + return image, px2nm + + + +def _fetch_qi_data(file_path: Path | str): + return afmformats.load_data(file_path) + + +def _process_jpk_qi_data(curves_data: list, channel: str) -> tuple[np.ndarray, float]: + """ + Process the curves data from a JPK QI file to extract the image and pixel to nanometre scaling factor. + + """ + # Calculate pixel to nanometre scaling factor + metadata = curves_data[0].metadata + shape_x = metadata.get("grid shape x", 0) + shape_y = metadata.get("grid shape y", 0) + + size_x = metadata.get("grid size x", 0) + size_y = metadata.get("grid size y", 0) + + pixel_to_nm_scaling_factor_x = size_x / shape_x if shape_x > 0 else 1.0 + pixel_to_nm_scaling_factor_y = size_y / shape_y if shape_y > 0 else 1.0 + avg_pixel_to_nm_scaling_factor = (pixel_to_nm_scaling_factor_x + pixel_to_nm_scaling_factor_y) / 2 + + if channel == "Height (Trigger)": + z_heights = _find_trigger_points(curves_data) + elif channel == "Height (Contact)": + z_heights = [_find_contact_point(curve) for curve in curves_data] + else: + raise ValueError(f"Channel '{channel}' not recognized. Only 'Height (Trigger)' and 'Height (Contact)' are currently supported.") + + print(f"Grid shape: {shape_x} x {shape_y}") + print(f"Z heights: {z_heights}") + image = np.array(z_heights).reshape(shape_y, shape_x) + return image, avg_pixel_to_nm_scaling_factor + +def _find_trigger_points(curves_data): + max_points = _max_points_buffer(curves_data) + n_curves = len(curves_data) + + logger.info(f"Allocating arrays: {n_curves} curves x {max_points} max points.") + all_segments = np.full((n_curves, max_points), -1, dtype=np.int8) + all_heights = np.full((n_curves, max_points), np.nan, dtype=np.float32) + + for i, curve in enumerate(curves_data): + segment = curve["segment"] + height = curve["height (measured)"] + length = len(segment) + + all_segments[i, :length] = segment + all_heights[i, :length] = height + + logger.info("Data stacked. Calculating trigger points...") + is_approach = (all_segments == 0) + + transition_indices = np.sum(is_approach, axis=1) - 1 + + transition_indices = np.maximum(transition_indices, 0) + + row_indices = np.arange(n_curves) + trigger_values = all_heights[row_indices, transition_indices] + + return trigger_values + +def _find_contact_point(curve): + # find contact point in vertical deflection by peak in first derivative + derivative_vert_deflection = np.diff(curve["force"]) + # Doesn't look like this line is needed: peak_derivative_value = np.max(derivative_vert_deflection) + peak_derivative_index = np.argmax(derivative_vert_deflection) + + # find corresponding height value + corresponding_height_at_peak = curve["height (measured)"][peak_derivative_index] + + return corresponding_height_at_peak + +def _find_trigger_point(curve): + segments = curve["segment"] + approach_indices = np.where(segments == 0)[0] + turn_index = approach_indices[-1] + trigger_point = curve["height (measured)"][turn_index] + return trigger_point + +def _max_points_buffer(curves_data, samples=20, points_buffer=1.2): + + step = len(curves_data) // samples + max_points = np.max(len(curves_data[i]["segment"]) for i in range(0, len(curves_data), step)) + return max_points * points_buffer + + + From 9d28dec779ef0afa70d95cc60bc03d3da547da1e Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 16 Feb 2026 16:44:40 +0000 Subject: [PATCH 02/54] Adding get available channels function to each file format as well as a function to general loader --- AFMReader/asd.py | 19 +++++++++++++++++++ AFMReader/general_loader.py | 25 ++++++++++++++++++++++++- AFMReader/gwy.py | 11 +++++++++++ AFMReader/h5_jpk.py | 4 ++++ AFMReader/ibw.py | 11 +++++++++++ AFMReader/jpk.py | 24 ++++++++++++++++++++++++ AFMReader/spm.py | 13 +++++++++++++ 7 files changed, 106 insertions(+), 1 deletion(-) diff --git a/AFMReader/asd.py b/AFMReader/asd.py index 64727f5..fd42733 100644 --- a/AFMReader/asd.py +++ b/AFMReader/asd.py @@ -287,6 +287,25 @@ def load_asd(file_path: str | Path, channel: str): logger.info(f"[{filename}] : Extracted image.") return frames, pixel_to_nanometre_scaling_factor, header_dict +def get_asd_channels(file_path: Path): + with Path.open(file_path, "rb", encoding=None) as open_file: # pylint: disable=unspecified-encoding + file_version = read_file_version(open_file) + + if file_version == 0: + header_dict = read_header_file_version_0(open_file) + + elif file_version == 1: + header_dict = read_header_file_version_1(open_file) + + elif file_version == 2: + header_dict = read_header_file_version_2(open_file) + else: + raise ValueError( + f"File version {file_version} unknown. Please add support if you know how to decode this file version." + ) + channel_list = [header_dict["channel1"], header_dict["channel2"]] + return channel_list + def read_file_version(open_file: BinaryIO) -> int: """ diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index b8ca35d..61ff0c9 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -60,7 +60,7 @@ def load(self) -> tuple[npt.NDArray | str, float | None]: # noqa: C901 image, pixel_to_nanometre_scaling_factor = gwy.load_gwy(self.filepath, self.channel) elif self.suffix == ".ibw": image, pixel_to_nanometre_scaling_factor = ibw.load_ibw(self.filepath, self.channel) - elif self.suffix == ".jpk": + elif self.suffix in [".jpk", ".jpk-qi-image"]: image, pixel_to_nanometre_scaling_factor = jpk.load_jpk(self.filepath, self.channel) elif self.suffix == ".spm": image, pixel_to_nanometre_scaling_factor = spm.load_spm(self.filepath, self.channel) @@ -93,4 +93,27 @@ def load(self) -> tuple[npt.NDArray | str, float | None]: # noqa: C901 logger.error(f"{e}") return (e, None) # cheeky return of an image, px2nm-like tuple object to propagate error message to Napari + def get_available_channels(self): + if self.suffix == ".asd": + available_channels = asd.get_asd_channels(self.filepath) + elif self.suffix == ".gwy": + available_channels = gwy.get_gwy_channels(self.filepath) + elif self.suffix == ".ibw": + available_channels = ibw.get_ibw_channels(self.filepath) + elif self.suffix in [".jpk", ".jpk-qi-image"]: + available_channels = jpk.get_jpk_channels(self.filepath) + elif self.suffix == ".spm": + available_channels = spm.get_spm_channels(self.filepath) + elif self.suffix == ".h5-jpk": + available_channels = h5_jpk.get_h5jpk_channels(self.filepath) + elif self.suffix == ".jpk-qi-data": + # Implement this + available_channels = None + elif self.suffix in [".stp", ".top"]: + available_channels = stp.load_stp(self.filepath) + elif self.suffix == ".topostats": + available_channels = ["image", "image_original"] + else: + raise ValueError(f"File type '{self.suffix}' is not currently handled by AFMReader.") + return available_channels # scope for a "check what channels are available" function similar to above. diff --git a/AFMReader/gwy.py b/AFMReader/gwy.py index 834cd25..c52cd3f 100644 --- a/AFMReader/gwy.py +++ b/AFMReader/gwy.py @@ -9,6 +9,17 @@ from AFMReader.io import read_char, read_double, read_null_terminated_string, read_uint32 +def get_gwy_channels(file_path): + image_data_dict: dict[Any, Any] = {} + with Path.open(file_path, "rb") as open_file: # pylint: disable=unspecified-encoding + # Read header + header = open_file.read(4) + logger.debug(f"Gwy file header: {header.decode}") + + gwy_read_object(open_file, data_dict=image_data_dict) + channel_ids = gwy_get_channels(gwy_file_structure=image_data_dict) + + return list(channel_ids) def load_gwy(file_path: Path | str, channel: str) -> tuple[np.ndarray[Any, np.float64], float]: """ diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index 13c834b..8bb9027 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -267,6 +267,10 @@ def generate_timestamps(num_frames: int, line_rate: float, image_size: int) -> d # Compose a dictionary of timestamsps return {f"frame {i}": timestamp for i, timestamp in enumerate(timestamps)} +def get_h5jpk_channels(file_path: Path | str): + with h5py.File(file_path, "r") as f: + available_channels = list(_available_channels(f)) + return available_channels def load_h5jpk( file_path: Path | str, channel: str, flip_image: bool = True diff --git a/AFMReader/ibw.py b/AFMReader/ibw.py index f084403..51cd08a 100644 --- a/AFMReader/ibw.py +++ b/AFMReader/ibw.py @@ -38,6 +38,17 @@ def _ibw_pixel_to_nm_scaling(scan: dict) -> float: float(notes["FastScanSize"]) / scan["wave"]["wData"].shape[1] * 1e9, # as in m )[0] +def get_ibw_channels(file_path: Path | str): + filename = file_path.stem + scan = binarywave.load(file_path) + logger.info(f"[{filename}] : Loaded image from : {file_path}") + labels = [] + for label_list in scan["wave"]["labels"]: + for label in label_list: + if label: + labels.append(label.decode()) + return labels + def load_ibw(file_path: Path | str, channel: str) -> tuple[np.ndarray, float]: """ diff --git a/AFMReader/jpk.py b/AFMReader/jpk.py index 4a13a61..dc55d71 100644 --- a/AFMReader/jpk.py +++ b/AFMReader/jpk.py @@ -170,6 +170,30 @@ def _get_z_scaling(tif: tifffile.tifffile, channel_idx: int, jpk_tags: dict[str, raise ValueError(f"Scaling type {scaling_type} is not 'NullScaling' or 'LinearScaling'") return scaling, offset +def get_jpk_channels( + file_path: Path | str, config_path: Path | str | None = None +) -> list[str]: + + file_path = Path(file_path) + filename = file_path.stem + jpk_tags = _load_jpk_tags(config_path) + try: + tif = tifffile.TiffFile(file_path) + except FileNotFoundError: + logger.error(f"[{filename}] File not found : {file_path}") + raise + # Obtain channel list for all channels in file + channel_list = {} + for i, page in enumerate(tif.pages[1:]): # [0] is thumbnail + available_channel = page.tags[jpk_tags["channel_name"]].value # keys are hexadecimal values + if page.tags[jpk_tags["trace_retrace"]].value == 0: # whether img is trace or retrace + tr_rt = "trace" + else: + tr_rt = "retrace" + channel_list[f"{available_channel}_{tr_rt}"] = i + 1 + return channel_list + + def load_jpk( file_path: Path | str, channel: str, config_path: Path | str | None = None, flip_image: bool | None = True diff --git a/AFMReader/spm.py b/AFMReader/spm.py index 0a803e6..8fd7ff3 100644 --- a/AFMReader/spm.py +++ b/AFMReader/spm.py @@ -110,3 +110,16 @@ def load_spm(file_path: Path | str, channel: str) -> tuple: raise e return (image, spm_pixel_to_nm_scaling(filename, channel_data)) + +def get_spm_channels(file_path): + labels = [] + filename = file_path.stem + try: + scan = pySPM.Bruker(file_path) + except FileNotFoundError: + logger.error(f"[{filename}] File not found : {file_path}") + raise + for channel_option in [layer[b"@2:Image Data"][0] for layer in scan.layers]: + channel_name = channel_option.decode("latin1").split('"')[1] + labels.append(channel_name) + return labels From 4bd6bb3da2d256a2b28fe2810cb495336666ae4b Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 23 Feb 2026 11:01:20 +0000 Subject: [PATCH 03/54] Making load jpk qi data function use zipfile instead of afmformats --- AFMReader/jpk.py | 31 +++++--- AFMReader/jpk_qi.py | 169 +++++++++++++++++++++++++++++++++++++++----- 2 files changed, 172 insertions(+), 28 deletions(-) diff --git a/AFMReader/jpk.py b/AFMReader/jpk.py index dc55d71..62f2792 100644 --- a/AFMReader/jpk.py +++ b/AFMReader/jpk.py @@ -2,6 +2,7 @@ from importlib import resources from pathlib import Path +from io import BytesIO import numpy as np import tifffile @@ -170,15 +171,12 @@ def _get_z_scaling(tif: tifffile.tifffile, channel_idx: int, jpk_tags: dict[str, raise ValueError(f"Scaling type {scaling_type} is not 'NullScaling' or 'LinearScaling'") return scaling, offset -def get_jpk_channels( - file_path: Path | str, config_path: Path | str | None = None -) -> list[str]: - - file_path = Path(file_path) - filename = file_path.stem +def _get_jpk_channels( + file: Path | BytesIO, filename: str, file_path: Path | str, config_path: Path | str | None = None +): jpk_tags = _load_jpk_tags(config_path) try: - tif = tifffile.TiffFile(file_path) + tif = tifffile.TiffFile(file) except FileNotFoundError: logger.error(f"[{filename}] File not found : {file_path}") raise @@ -193,6 +191,13 @@ def get_jpk_channels( channel_list[f"{available_channel}_{tr_rt}"] = i + 1 return channel_list +def get_jpk_channels( + file_path: Path | str, config_path: Path | str | None = None +) -> list[str]: + + file_path = Path(file_path) + filename = file_path.stem + return _get_jpk_channels(file_path, filename, file_path, config_path) def load_jpk( @@ -237,11 +242,15 @@ def load_jpk( logger.info(f"Loading image from : {file_path}") file_path = Path(file_path) filename = file_path.stem + image, px2nm = _load_jpk(file=file_path, filename=filename, channel=channel, file_suffix=file_path.suffix, config_path=config_path, flip_image=flip_image) + return (image, px2nm) + +def _load_jpk(file: Path | BytesIO, filename: str, channel: str, file_suffix: str, config_path: Path | str | None = None, flip_image: bool = True): jpk_tags = _load_jpk_tags(config_path) try: - tif = tifffile.TiffFile(file_path) + tif = tifffile.TiffFile(file) except FileNotFoundError: - logger.error(f"[{filename}] File not found : {file_path}") + logger.error(f"[{filename}] File not found : {file}") raise # Obtain channel list for all channels in file channel_list = {} @@ -255,8 +264,8 @@ def load_jpk( try: channel_idx = channel_list[channel] except KeyError as e: - logger.error(f"'{channel}' not in {file_path.suffix} channel list: {channel_list}") - raise ValueError(f"'{channel}' not in {file_path.suffix} channel list: {channel_list}") from e + logger.error(f"'{channel}' not in {file_suffix} channel list: {channel_list}") + raise ValueError(f"'{channel}' not in {file_suffix} channel list: {channel_list}") from e # Get image and if applicable, scale it channel_page = tif.pages[channel_idx] diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index d59a8f3..5522b7d 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -1,34 +1,169 @@ -import afmformats from pathlib import Path import numpy as np from AFMReader.logging import logger +from AFMReader import jpk +import zipfile +import io +import javaproperties +import h5py +ADDITIONAL_CHANNELS = ["contact_point", "manual_trigger_point"] + +def _get_channel_scaling(props, channel_index): + """ + Parses the JPK properties dictionary to find the cumulative multiplier + and offset for a specific channel index (e.g., '1' for vDeflection). + """ + prefix = f"lcd-info.{channel_index}." + + current_slot = props.get(f"{prefix}conversion-set.conversions.default") + + if not current_slot: + mult = float(props[f"{prefix}encoder.scaling.multiplier"]) + off = float(props[f"{prefix}encoder.scaling.offset"]) + return mult, off + + cumulative_multiplier = 1.0 + cumulative_offset = 0.0 + + while current_slot: + slot_prefix = f"{prefix}conversion-set.conversion.{current_slot}." + + if f"{slot_prefix}scaling.multiplier" in props: + m = float(props[f"{slot_prefix}scaling.multiplier"]) + c = float(props[f"{slot_prefix}scaling.offset"]) + + cumulative_offset = (cumulative_multiplier * c) + cumulative_offset + cumulative_multiplier *= m + + current_slot = props.get(f"{slot_prefix}base-calibration-slot") + + if current_slot == props.get(f"{prefix}conversion-set.conversions.base"): + break + else: + break + + enc_m = float(props.get(f"{prefix}encoder.scaling.multiplier", "1.0")) + enc_c = float(props.get(f"{prefix}encoder.scaling.offset", "0.0")) + + final_multiplier = cumulative_multiplier * enc_m + final_offset = (cumulative_multiplier * enc_c) + cumulative_offset + + return final_multiplier, final_offset + def load_jpk_qi( file_path: Path | str, channel: str, + config_path: Path | str | None = None, flip_image: bool | None = True ) -> tuple[np.ndarray, float]: # Load the file path passed to the function - curves_data = _fetch_qi_data(file_path) - - if curves_data is None or len(curves_data) == 0: - print("No data found in the file.") - return + file_path = Path(file_path) + with zipfile.ZipFile(file_path, "r") as qi_archive: + if channel not in ADDITIONAL_CHANNELS: + path_to_image = None + for file_name in qi_archive.namelist(): + if file_name.endswith(".jpk-qi-image"): + path_to_image = file_name + if path_to_image not in qi_archive.namelist(): + raise FileNotFoundError(f"{path_to_image} not found in JPK archive") + + tif_bytes = qi_archive.read(path_to_image) + + virtual_file = io.BytesIO(tif_bytes) + image, px2nm = jpk._load_jpk(virtual_file, path_to_image, channel=channel, file_suffix=".jpk-qi-data", config_path=config_path, flip_image=False) + + else: + with qi_archive.open("header.properties") as archive_meta_file: + props = javaproperties.load(archive_meta_file) + size_x, size_y, shape_x, shape_y = None, None, None, None + for key, value in props.items(): + if key.endswith(".ulength"): + size_x = float(value) + if key.endswith(".vlength"): + size_y = float(value) + if key.endswith(".ilength"): + shape_x = int(value) + if key.endswith(".jlength"): + shape_y = int(value) + + if None in [size_x, size_y, shape_x, shape_y]: + logger.error(f"Incomplete dimension data in {file_path}") + + image = np.zeros((shape_y, shape_x), dtype=np.float32) + + pixel_to_nm_scaling_factor_x = size_x / shape_x * 1e9 if shape_x > 0 else 1.0 + pixel_to_nm_scaling_factor_y = size_y / shape_y * 1e9 if shape_y > 0 else 1.0 + px2nm = (pixel_to_nm_scaling_factor_x + pixel_to_nm_scaling_factor_y) / 2 + + segment_channels = [] + with qi_archive.open("shared-data/header.properties") as shared_data_file: + shared_meta = javaproperties.load(shared_data_file) + channel_i = 0 + while f"lcd-info.{channel_i}.channel.name" in shared_meta: + channel_dict = {} + channel_dict["name"] = shared_meta[f"lcd-info.{channel_i}.channel.name"] + multiplier, offset = _get_channel_scaling(shared_meta, channel_i) + channel_dict["offset"] = offset + channel_dict["multiplier"] = multiplier + segment_channels.append(channel_dict) + channel_i += 1 + + if len(segment_channels) == 0: + logger.error("Could not find channels for segments") + + with h5py.File(file_path.parent / f"{file_path.stem}.h5-jpk", "a") as h5file: + vlen_type = h5py.vlen_dtype(np.float32) + num_of_curves = shape_x * shape_y + master_dataset = h5file.create_dataset("measured_height", shape=(num_of_curves,), dtype=vlen_type) + for y in range(shape_y): + for x in range(shape_x): + curve_num = shape_x * y + x + print(f"Curve {curve_num}") + # with qi_archive.open(f"index/{i}/header.properties") as curve_meta_file: + # curve_meta = javaproperties.load(curve_meta_file) + for direction in range(0, 2): + # with qi_archive.open(f"index/{i}/segments/{direction}/segment-header.properties") as segment_meta_file: + # segments_meta = javaproperties.load(segment_meta_file) + curve_data = {} + for segment_channel in segment_channels: + try: + with qi_archive.open(f"index/{curve_num}/segments/{direction}/channels/{segment_channel['name']}.dat") as segment_data: + dtype_str = '>i4' + raw_bytes = segment_data.read() + raw_array = np.frombuffer(raw_bytes, dtype=dtype_str) + metres_array = (raw_array * segment_channel["multiplier"]) + segment_channel["offset"] + master_dataset[curve_num] = metres_array + curve_data[segment_channel['name']] = metres_array + except KeyError: + break + if channel == "contact_point": + if direction == 0: + image[y, x] = _find_contact_point(curve_data) + elif channel == "manual_trigger_point": + if direction == 0: + image[y, x] = _find_trigger_point(curve_data) + + + + # Need to include flip image as _load_jpk flip image is set to false + if flip_image: + image = np.flipud(image) - image, px2nm = _process_jpk_qi_data(curves_data, channel) + return image, px2nm - if flip_image: - image = np.flipud(image) +def get_jpk_qi_channels(file_path: Path | str): + jpk._get_jpk_channels() - return image, px2nm def _fetch_qi_data(file_path: Path | str): - return afmformats.load_data(file_path) + qi_data = zipfile.ZipFile(file_path, "r") + return qi_data def _process_jpk_qi_data(curves_data: list, channel: str) -> tuple[np.ndarray, float]: @@ -90,20 +225,20 @@ def _find_trigger_points(curves_data): def _find_contact_point(curve): # find contact point in vertical deflection by peak in first derivative - derivative_vert_deflection = np.diff(curve["force"]) + vdef = curve["vDeflection"] + if len(vdef) < 2: + return np.nan + derivative_vert_deflection = np.diff(vdef) # Doesn't look like this line is needed: peak_derivative_value = np.max(derivative_vert_deflection) peak_derivative_index = np.argmax(derivative_vert_deflection) # find corresponding height value - corresponding_height_at_peak = curve["height (measured)"][peak_derivative_index] + corresponding_height_at_peak = curve["measuredHeight"][peak_derivative_index] return corresponding_height_at_peak def _find_trigger_point(curve): - segments = curve["segment"] - approach_indices = np.where(segments == 0)[0] - turn_index = approach_indices[-1] - trigger_point = curve["height (measured)"][turn_index] + trigger_point = curve["measuredHeight"][-1] return trigger_point def _max_points_buffer(curves_data, samples=20, points_buffer=1.2): From 28f7633608d271d6a4b91ca4a3e02faf2eaa4a69 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 23 Feb 2026 11:42:44 +0000 Subject: [PATCH 04/54] Fixing scaling --- AFMReader/jpk_qi.py | 80 +++++++-------------------------------------- 1 file changed, 12 insertions(+), 68 deletions(-) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 5522b7d..2230e34 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -146,8 +146,8 @@ def load_jpk_qi( elif channel == "manual_trigger_point": if direction == 0: image[y, x] = _find_trigger_point(curve_data) - - + if channel in ["manual_trigger_point", "contact_point"]: + image = image * 1e9 # Need to include flip image as _load_jpk flip image is set to false if flip_image: @@ -156,72 +156,16 @@ def load_jpk_qi( return image, px2nm def get_jpk_qi_channels(file_path: Path | str): - jpk._get_jpk_channels() - - - - -def _fetch_qi_data(file_path: Path | str): - qi_data = zipfile.ZipFile(file_path, "r") - return qi_data - - -def _process_jpk_qi_data(curves_data: list, channel: str) -> tuple[np.ndarray, float]: - """ - Process the curves data from a JPK QI file to extract the image and pixel to nanometre scaling factor. - - """ - # Calculate pixel to nanometre scaling factor - metadata = curves_data[0].metadata - shape_x = metadata.get("grid shape x", 0) - shape_y = metadata.get("grid shape y", 0) - - size_x = metadata.get("grid size x", 0) - size_y = metadata.get("grid size y", 0) - - pixel_to_nm_scaling_factor_x = size_x / shape_x if shape_x > 0 else 1.0 - pixel_to_nm_scaling_factor_y = size_y / shape_y if shape_y > 0 else 1.0 - avg_pixel_to_nm_scaling_factor = (pixel_to_nm_scaling_factor_x + pixel_to_nm_scaling_factor_y) / 2 - - if channel == "Height (Trigger)": - z_heights = _find_trigger_points(curves_data) - elif channel == "Height (Contact)": - z_heights = [_find_contact_point(curve) for curve in curves_data] - else: - raise ValueError(f"Channel '{channel}' not recognized. Only 'Height (Trigger)' and 'Height (Contact)' are currently supported.") - - print(f"Grid shape: {shape_x} x {shape_y}") - print(f"Z heights: {z_heights}") - image = np.array(z_heights).reshape(shape_y, shape_x) - return image, avg_pixel_to_nm_scaling_factor - -def _find_trigger_points(curves_data): - max_points = _max_points_buffer(curves_data) - n_curves = len(curves_data) - - logger.info(f"Allocating arrays: {n_curves} curves x {max_points} max points.") - all_segments = np.full((n_curves, max_points), -1, dtype=np.int8) - all_heights = np.full((n_curves, max_points), np.nan, dtype=np.float32) - - for i, curve in enumerate(curves_data): - segment = curve["segment"] - height = curve["height (measured)"] - length = len(segment) - - all_segments[i, :length] = segment - all_heights[i, :length] = height - - logger.info("Data stacked. Calculating trigger points...") - is_approach = (all_segments == 0) - - transition_indices = np.sum(is_approach, axis=1) - 1 - - transition_indices = np.maximum(transition_indices, 0) - - row_indices = np.arange(n_curves) - trigger_values = all_heights[row_indices, transition_indices] - - return trigger_values + file_path = Path(file_path) + channels = [] + with zipfile.ZipFile(file_path, "r") as qi_archive: + for file_name in qi_archive.namelist(): + if file_name.endswith(".jpk-qi-image"): + path_to_image = file_name + with qi_archive.open(path_to_image, "r") as image_file: + channels += jpk._get_jpk_channels(file=image_file, filename=file_path.stem, file_path=file_path / Path(path_to_image)) + channels += ADDITIONAL_CHANNELS + return channels def _find_contact_point(curve): # find contact point in vertical deflection by peak in first derivative From a67b11f943061c61338e54dbfd7c1fad55a67ab3 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 23 Feb 2026 12:31:17 +0000 Subject: [PATCH 05/54] Updating general_loader and dependencies --- AFMReader/general_loader.py | 2 +- pyproject.toml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index 61ff0c9..6dfe002 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -108,7 +108,7 @@ def get_available_channels(self): available_channels = h5_jpk.get_h5jpk_channels(self.filepath) elif self.suffix == ".jpk-qi-data": # Implement this - available_channels = None + available_channels = jpk_qi.get_jpk_qi_channels(self.filepath) elif self.suffix in [".stp", ".top"]: available_channels = stp.load_stp(self.filepath) elif self.suffix == ".topostats": diff --git a/pyproject.toml b/pyproject.toml index f4cb025..cc0051e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "pySPM", "tifffile", "ruamel.yaml", + "javaproperties" ] [project.optional-dependencies] From 37e1ae2e9c29593d21f25341d19adab38b584a99 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 23 Feb 2026 16:37:40 +0000 Subject: [PATCH 06/54] Adjusting h5_jpk so it works for different shaped images --- AFMReader/h5_jpk.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index 8bb9027..f5be7b3 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -327,13 +327,14 @@ def load_h5jpk( images = (images * scaling) + offset # Select and reshape a flattened frame - image_size = measurement_group.attrs["position-pattern.grid.ilength"] # number of pixels + size_x = measurement_group.attrs["position-pattern.grid.ilength"] + size_y = measurement_group.attrs.get("position-pattern.grid.jlength", size_x) # number of pixels # Reshape each column vector (height, width) to get (num_frames, height, width) num_frames = images.shape[1] - image_stack = np.empty((num_frames, image_size, image_size), dtype=images.dtype) + image_stack = np.empty((num_frames, size_y, size_x), dtype=images.dtype) for i in range(num_frames): - frame = images[:, i].reshape((image_size, image_size)) + frame = images[:, i].reshape((size_y, size_x)) # Flip images if flip_image: @@ -346,7 +347,7 @@ def load_h5jpk( # Generate a dictionary of timestamps line_rate = _get_line_rate(measurement_group) - timestamps = generate_timestamps(num_frames, line_rate, image_size) + timestamps = generate_timestamps(num_frames, line_rate, size_y) logger.info(f"[{file_path.stem}] : Extracted {num_frames} frames from channel '{channel}'") return (image_stack, _jpk_pixel_to_nm_scaling_h5(measurement_group), timestamps) From 16a827ca698984b7a829cdb3f29a1d6b0590e481 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 23 Feb 2026 16:38:44 +0000 Subject: [PATCH 07/54] Making the jpk-qi-data processing save the h5 jpk file in the correct format --- AFMReader/jpk_qi.py | 86 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 81 insertions(+), 5 deletions(-) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 2230e34..70d2f49 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -57,7 +57,8 @@ def load_jpk_qi( file_path: Path | str, channel: str, config_path: Path | str | None = None, - flip_image: bool | None = True + flip_image: bool | None = True, + save_as_h5: bool | None = False ) -> tuple[np.ndarray, float]: # Load the file path passed to the function @@ -118,14 +119,27 @@ def load_jpk_qi( with h5py.File(file_path.parent / f"{file_path.stem}.h5-jpk", "a") as h5file: vlen_type = h5py.vlen_dtype(np.float32) num_of_curves = shape_x * shape_y - master_dataset = h5file.create_dataset("measured_height", shape=(num_of_curves,), dtype=vlen_type) + + if save_as_h5: + qi_group = h5file.require_group("QI_Data") + + curve_datasets = {} + for direction in range(2): + dir_group = qi_group.require_group(f"Segment_{direction}") + for seg_chan in segment_channels: + ds_name = seg_chan["name"] + if ds_name not in dir_group: + curve_datasets[f"{direction}_{ds_name}"] = dir_group.create_dataset( + ds_name, shape=(num_of_curves,), dtype=vlen_type + ) + else: + curve_datasets[f"{direction}_{ds_name}"] = dir_group[ds_name] for y in range(shape_y): for x in range(shape_x): curve_num = shape_x * y + x - print(f"Curve {curve_num}") # with qi_archive.open(f"index/{i}/header.properties") as curve_meta_file: # curve_meta = javaproperties.load(curve_meta_file) - for direction in range(0, 2): + for direction in range(2): # with qi_archive.open(f"index/{i}/segments/{direction}/segment-header.properties") as segment_meta_file: # segments_meta = javaproperties.load(segment_meta_file) curve_data = {} @@ -136,8 +150,9 @@ def load_jpk_qi( raw_bytes = segment_data.read() raw_array = np.frombuffer(raw_bytes, dtype=dtype_str) metres_array = (raw_array * segment_channel["multiplier"]) + segment_channel["offset"] - master_dataset[curve_num] = metres_array curve_data[segment_channel['name']] = metres_array + if save_as_h5: + curve_datasets[f"{direction}_{segment_channel['name']}"][curve_num] = metres_array except KeyError: break if channel == "contact_point": @@ -149,12 +164,73 @@ def load_jpk_qi( if channel in ["manual_trigger_point", "contact_point"]: image = image * 1e9 + if save_as_h5: + with h5py.File(file_path.parent / f"{file_path.stem}.h5-jpk", "a") as h5file: + meas_grp = h5file.require_group("Measurement_000") + meas_grp.attrs["position-pattern.grid.ulength"] = size_x + meas_grp.attrs["position-pattern.grid.ilength"] = shape_x + meas_grp.attrs["position-pattern.grid.vlength"] = size_y + meas_grp.attrs["position-pattern.grid.jlength"] = shape_y + meas_grp.attrs["timing-settings.scanRate"] = 1.0 # Dummy value to satisfy reader + + h5_channels = [channel] + for file_name in qi_archive.namelist(): + if file_name.endswith(".jpk-qi-image"): + path_to_image = file_name + with qi_archive.open(path_to_image, "r") as image_file: + h5_channels += jpk._get_jpk_channels(file=image_file, filename=file_path.stem, file_path=file_path / Path(path_to_image)) + for i, h5_channel in enumerate(h5_channels): + chan_grp = meas_grp.require_group(f"Channel_{_make_num_min_characters(i)}") + chan_grp.attrs["channel.name"] = h5_channel.encode("utf-8") + chan_grp.attrs["retrace"] = "false".encode("utf-8") + chan_grp.attrs["net-encoder.scaling.multiplier"] = 1.0 + chan_grp.attrs["net-encoder.scaling.offset"] = 0.0 + + # Format name and reshape image (flattened frame stack) + dataset_name = h5_channel.split("_")[0].capitalize() + frame_stack = image.flatten().reshape(-1, 1) + + if dataset_name in chan_grp: + del chan_grp[dataset_name] + chan_grp.create_dataset(dataset_name, data=frame_stack) + # Need to include flip image as _load_jpk flip image is set to false if flip_image: image = np.flipud(image) return image, px2nm +def load_fdcurves_from_h5(file_path: Path | str): + file_path = Path(file_path) + + with h5py.File(file_path, "r") as h5file: + meas_grp = h5file["Measurement_000"] + shape_x = meas_grp.attrs["position-pattern.grid.ilength"] + shape_y = meas_grp.attrs["position-pattern.grid.jlength"] + size_x = meas_grp.attrs["position-pattern.grid.ulength"] + size_y = meas_grp.attrs["position-pattern.grid.vlength"] + + image = np.zeros((shape_y, shape_x), dtype=np.float32) + segment_0_group = h5file["QI_Data"]["Segment_0"] + channel_datasets = {name: segment_0_group[name] for name in segment_0_group.keys()} + for y in range(shape_y): + for x in range(shape_x): + curve_num = (shape_x * y) + x + curve_dict = {} + for chan_name, dataset in channel_datasets.items(): + curve_dict[chan_name] = dataset[curve_num] + image[y, x] = _find_contact_point(curve=curve_dict) + + return image + + +def _make_num_min_characters(num : int, min_chars: int = 3): + string_num = str(num) + if len(string_num) >= min_chars: + return string_num + string_num = "0" * (min_chars - len(string_num)) + string_num + return string_num + def get_jpk_qi_channels(file_path: Path | str): file_path = Path(file_path) channels = [] From c780de5c5d151a356c9a260d9be6a28e59148256 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Wed, 25 Feb 2026 10:54:17 +0000 Subject: [PATCH 08/54] Add the ability to save the metadata to h5 from jpk qi. Additionally make the save an adjustable option --- AFMReader/general_loader.py | 2 - AFMReader/jpk_qi.py | 97 +++++++++++++++++++++++++++++++------ 2 files changed, 83 insertions(+), 16 deletions(-) diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index 6dfe002..36b5ecc 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -107,7 +107,6 @@ def get_available_channels(self): elif self.suffix == ".h5-jpk": available_channels = h5_jpk.get_h5jpk_channels(self.filepath) elif self.suffix == ".jpk-qi-data": - # Implement this available_channels = jpk_qi.get_jpk_qi_channels(self.filepath) elif self.suffix in [".stp", ".top"]: available_channels = stp.load_stp(self.filepath) @@ -116,4 +115,3 @@ def get_available_channels(self): else: raise ValueError(f"File type '{self.suffix}' is not currently handled by AFMReader.") return available_channels - # scope for a "check what channels are available" function similar to above. diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 70d2f49..d053f47 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -1,14 +1,18 @@ from pathlib import Path -import numpy as np -from AFMReader.logging import logger -from AFMReader import jpk -import zipfile +from contextlib import nullcontext import io +import zipfile + +import numpy as np import javaproperties import h5py +from AFMReader.logging import logger +from AFMReader import jpk -ADDITIONAL_CHANNELS = ["contact_point", "manual_trigger_point"] + +ADDITIONAL_CHANNELS = ["contactPoint_trace", "manualTriggerPoint_trace"] +ADDITIONAL_CHANNELS_IN_M = ["contactPoint_trace", "manualTriggerPoint_trace"] def _get_channel_scaling(props, channel_index): """ @@ -78,8 +82,18 @@ def load_jpk_qi( image, px2nm = jpk._load_jpk(virtual_file, path_to_image, channel=channel, file_suffix=".jpk-qi-data", config_path=config_path, flip_image=False) else: + if save_as_h5: + top_level_meta = {} + changing_curve_keys = set() + changing_segment_keys = set() + all_curve_keys = set() + all_segment_keys = set() + with qi_archive.open("header.properties") as archive_meta_file: props = javaproperties.load(archive_meta_file) + if save_as_h5: + for key, value in props.items(): + top_level_meta[f"shared-data.{key}"] = value size_x, size_y, shape_x, shape_y = None, None, None, None for key, value in props.items(): if key.endswith(".ulength"): @@ -91,6 +105,7 @@ def load_jpk_qi( if key.endswith(".jlength"): shape_y = int(value) + if None in [size_x, size_y, shape_x, shape_y]: logger.error(f"Incomplete dimension data in {file_path}") @@ -104,6 +119,9 @@ def load_jpk_qi( with qi_archive.open("shared-data/header.properties") as shared_data_file: shared_meta = javaproperties.load(shared_data_file) channel_i = 0 + if save_as_h5: + for key, value in shared_meta.items(): + top_level_meta[f"shared-data.{key}"] = value while f"lcd-info.{channel_i}.channel.name" in shared_meta: channel_dict = {} channel_dict["name"] = shared_meta[f"lcd-info.{channel_i}.channel.name"] @@ -116,12 +134,17 @@ def load_jpk_qi( if len(segment_channels) == 0: logger.error("Could not find channels for segments") - with h5py.File(file_path.parent / f"{file_path.stem}.h5-jpk", "a") as h5file: + h5_context = h5py.File(file_path.parent / f"{file_path.stem}.h5-jpk", "a") if save_as_h5 else nullcontext() + with h5_context as h5file: vlen_type = h5py.vlen_dtype(np.float32) num_of_curves = shape_x * shape_y if save_as_h5: + curve_meta = [{} for _ in range(num_of_curves)] + segment_meta = [{} for _ in range(num_of_curves * 2)] qi_group = h5file.require_group("QI_Data") + global_meta_group = qi_group.require_group("Global_Metadata") + curves_meta_group = qi_group.require_group("Curve_Metadata") curve_datasets = {} for direction in range(2): @@ -134,14 +157,30 @@ def load_jpk_qi( ) else: curve_datasets[f"{direction}_{ds_name}"] = dir_group[ds_name] + for y in range(shape_y): for x in range(shape_x): curve_num = shape_x * y + x - # with qi_archive.open(f"index/{i}/header.properties") as curve_meta_file: - # curve_meta = javaproperties.load(curve_meta_file) + if save_as_h5: + with qi_archive.open(f"index/{curve_num}/header.properties") as curve_meta_file: + curve_meta_raw = javaproperties.load(curve_meta_file) + for key, value in curve_meta_raw.items(): + key = ".".join(key.split(".")[1:]) + curve_meta[curve_num][key] = value + all_curve_keys.add(key) + if curve_num != 0 and (key not in curve_meta[0] or curve_meta[0][key] != value): + changing_curve_keys.add(key) + for direction in range(2): - # with qi_archive.open(f"index/{i}/segments/{direction}/segment-header.properties") as segment_meta_file: - # segments_meta = javaproperties.load(segment_meta_file) + if save_as_h5: + with qi_archive.open(f"index/{curve_num}/segments/{direction}/segment-header.properties") as segment_meta_file: + segment_meta_raw = javaproperties.load(segment_meta_file) + for key, value in segment_meta_raw.items(): + key = ".".join(key.split(".")[1:]) + segment_meta[curve_num * 2 + direction][key] = value + all_segment_keys.add(key) + if curve_num != 0 and (key not in segment_meta[0] or segment_meta[0][key] != value): + changing_segment_keys.add(key) curve_data = {} for segment_channel in segment_channels: try: @@ -155,13 +194,37 @@ def load_jpk_qi( curve_datasets[f"{direction}_{segment_channel['name']}"][curve_num] = metres_array except KeyError: break - if channel == "contact_point": + if channel == "contactPoint_trace": if direction == 0: image[y, x] = _find_contact_point(curve_data) - elif channel == "manual_trigger_point": + elif channel == "manualTriggerPoint_trace": if direction == 0: image[y, x] = _find_trigger_point(curve_data) - if channel in ["manual_trigger_point", "contact_point"]: + + if save_as_h5: + # Move all the duplicated metadata to the top level metadata dict + for key in all_curve_keys - changing_curve_keys: + top_level_meta[f"curve.{key}"] = curve_meta[0][key] + for curve_metadata in curve_meta: + curve_metadata.pop(key) + for key in all_segment_keys - changing_segment_keys: + top_level_meta[f"segment.{key}"] = segment_meta[0][key] + for segment_metadata in segment_meta: + segment_metadata.pop(key) + for key, value in top_level_meta.items(): + global_meta_group.attrs[key] = str(value).encode('utf-8') + for i, curve_metadata in enumerate(curve_meta): + curve_meta_group = curves_meta_group.require_group(f"{i}") + for key, value in curve_metadata.items(): + curve_meta_group.attrs[key] = str(value).encode('utf-8') + + for d in range(2): + segment_meta_group = curve_meta_group.require_group(f"{d}") + for key, value in segment_meta[i*2+d].items(): + segment_meta_group.attrs[key] = str(value).encode('utf-8') + + + if channel in ADDITIONAL_CHANNELS_IN_M: image = image * 1e9 if save_as_h5: @@ -194,6 +257,7 @@ def load_jpk_qi( del chan_grp[dataset_name] chan_grp.create_dataset(dataset_name, data=frame_stack) + # Need to include flip image as _load_jpk flip image is set to false if flip_image: image = np.flipud(image) @@ -210,6 +274,10 @@ def load_fdcurves_from_h5(file_path: Path | str): size_x = meas_grp.attrs["position-pattern.grid.ulength"] size_y = meas_grp.attrs["position-pattern.grid.vlength"] + pixel_to_nm_scaling_factor_x = size_x / shape_x * 1e9 if shape_x > 0 else 1.0 + pixel_to_nm_scaling_factor_y = size_y / shape_y * 1e9 if shape_y > 0 else 1.0 + px2nm = (pixel_to_nm_scaling_factor_x + pixel_to_nm_scaling_factor_y) / 2 + image = np.zeros((shape_y, shape_x), dtype=np.float32) segment_0_group = h5file["QI_Data"]["Segment_0"] channel_datasets = {name: segment_0_group[name] for name in segment_0_group.keys()} @@ -220,8 +288,9 @@ def load_fdcurves_from_h5(file_path: Path | str): for chan_name, dataset in channel_datasets.items(): curve_dict[chan_name] = dataset[curve_num] image[y, x] = _find_contact_point(curve=curve_dict) + image = image * 1e9 - return image + return image, px2nm def _make_num_min_characters(num : int, min_chars: int = 3): From a764ca662eaa6b5298f2c21a95aadfd977e580a5 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Fri, 27 Feb 2026 17:51:29 +0000 Subject: [PATCH 09/54] Adding ability to load force curves from h5 file --- AFMReader/general_loader.py | 11 +++++++- AFMReader/h5_jpk.py | 37 +++++++++++++++++++++++---- AFMReader/jpk_qi.py | 50 ++++++++++++++++++------------------- 3 files changed, 66 insertions(+), 32 deletions(-) diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index 36b5ecc..ba21271 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -46,6 +46,7 @@ def load(self) -> tuple[npt.NDArray | str, float | None]: # noqa: C901 ------- tuple The image data (stack if ''.asd'' or ''.h5-jpk'') and the pixel to nanometre scaling ratio. + If curve data is found, also return the curve data (a large dict of all the curves) Raises ------ @@ -65,7 +66,14 @@ def load(self) -> tuple[npt.NDArray | str, float | None]: # noqa: C901 elif self.suffix == ".spm": image, pixel_to_nanometre_scaling_factor = spm.load_spm(self.filepath, self.channel) elif self.suffix == ".h5-jpk": - image, pixel_to_nanometre_scaling_factor, _ = h5_jpk.load_h5jpk(self.filepath, self.channel) + h5_returned = h5_jpk.load_h5jpk(self.filepath, self.channel) + if len(h5_returned) == 3: + image, pixel_to_nanometre_scaling_factor, _ = h5_returned + elif len(h5_returned) == 4: + image, pixel_to_nanometre_scaling_factor, curve_data, _ = h5_returned + return image, pixel_to_nanometre_scaling_factor, curve_data + else: + logger.error(f"Loading h5-jpk file returned unexpected number of items: {len(h5_returned)}") elif self.suffix == ".jpk-qi-data": image, pixel_to_nanometre_scaling_factor = jpk_qi.load_jpk_qi(self.filepath, self.channel) elif self.suffix == ".stp": @@ -91,6 +99,7 @@ def load(self) -> tuple[npt.NDArray | str, float | None]: # noqa: C901 except ValueError as e: logger.error(f"{e}") + raise e return (e, None) # cheeky return of an image, px2nm-like tuple object to propagate error message to Napari def get_available_channels(self): diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index f5be7b3..6bf3cdf 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -321,20 +321,21 @@ def load_h5jpk( channel_group, measurement_group, dataset_name = _get_channel_info(f, channel) + # Load images and scaling factors from channel dataset images = channel_group[dataset_name][:] scaling, offset = _get_z_scaling_h5(channel_group) images = (images * scaling) + offset # Select and reshape a flattened frame - size_x = measurement_group.attrs["position-pattern.grid.ilength"] - size_y = measurement_group.attrs.get("position-pattern.grid.jlength", size_x) # number of pixels + shape_x = measurement_group.attrs["position-pattern.grid.ilength"] + shape_y = measurement_group.attrs.get("position-pattern.grid.jlength", shape_x) # number of pixels # Reshape each column vector (height, width) to get (num_frames, height, width) num_frames = images.shape[1] - image_stack = np.empty((num_frames, size_y, size_x), dtype=images.dtype) + image_stack = np.empty((num_frames, shape_y, shape_x), dtype=images.dtype) for i in range(num_frames): - frame = images[:, i].reshape((size_y, size_x)) + frame = images[:, i].reshape((shape_y, shape_x)) # Flip images if flip_image: @@ -347,7 +348,33 @@ def load_h5jpk( # Generate a dictionary of timestamps line_rate = _get_line_rate(measurement_group) - timestamps = generate_timestamps(num_frames, line_rate, size_y) + timestamps = generate_timestamps(num_frames, line_rate, shape_y) logger.info(f"[{file_path.stem}] : Extracted {num_frames} frames from channel '{channel}'") + + if "QI_Curve_Data" in f: + logger.info(f"[{file_path.stem}] : Found Force Curves QI data in file.") + qi_data_group = f["QI_Curve_Data"] + all_curve_data = [] + for y in range(shape_y): + row = [] + for x in range(shape_x): + curve_num = shape_x * y + x + curve_data = {} + for direction in ["Segment_0", "Segment_1"]: + if direction not in qi_data_group: + continue + direction_group = qi_data_group[direction] + for channel, channel_group in direction_group.items(): + if channel == "error": + continue + if channel not in curve_data: + curve_data[channel] = {} + print(f"Curve num: {curve_num} channel {channel} direction {direction}") + + curve_data[channel][direction] = channel_group[curve_num] + row.append(curve_data) + all_curve_data.append(row) + return (image_stack, _jpk_pixel_to_nm_scaling_h5(measurement_group), all_curve_data, timestamps) + return (image_stack, _jpk_pixel_to_nm_scaling_h5(measurement_group), timestamps) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index d053f47..0f241ef 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -11,8 +11,8 @@ from AFMReader import jpk -ADDITIONAL_CHANNELS = ["contactPoint_trace", "manualTriggerPoint_trace"] -ADDITIONAL_CHANNELS_IN_M = ["contactPoint_trace", "manualTriggerPoint_trace"] +ADDITIONAL_CHANNELS = ["contactPoint", "manualTriggerPoint"] +ADDITIONAL_CHANNELS_IN_M = ["contactPoint", "manualTriggerPoint"] def _get_channel_scaling(props, channel_index): """ @@ -79,6 +79,7 @@ def load_jpk_qi( tif_bytes = qi_archive.read(path_to_image) virtual_file = io.BytesIO(tif_bytes) + logger.info(f"Looking for channel {channel} in ") image, px2nm = jpk._load_jpk(virtual_file, path_to_image, channel=channel, file_suffix=".jpk-qi-data", config_path=config_path, flip_image=False) else: @@ -142,7 +143,7 @@ def load_jpk_qi( if save_as_h5: curve_meta = [{} for _ in range(num_of_curves)] segment_meta = [{} for _ in range(num_of_curves * 2)] - qi_group = h5file.require_group("QI_Data") + qi_group = h5file.require_group("QI_Curve_Data") global_meta_group = qi_group.require_group("Global_Metadata") curves_meta_group = qi_group.require_group("Curve_Metadata") @@ -194,10 +195,10 @@ def load_jpk_qi( curve_datasets[f"{direction}_{segment_channel['name']}"][curve_num] = metres_array except KeyError: break - if channel == "contactPoint_trace": + if channel == "contactPoint": if direction == 0: image[y, x] = _find_contact_point(curve_data) - elif channel == "manualTriggerPoint_trace": + elif channel == "manualTriggerPoint": if direction == 0: image[y, x] = _find_trigger_point(curve_data) @@ -223,12 +224,13 @@ def load_jpk_qi( for key, value in segment_meta[i*2+d].items(): segment_meta_group.attrs[key] = str(value).encode('utf-8') - + # Convert to nanometers if in meters if channel in ADDITIONAL_CHANNELS_IN_M: image = image * 1e9 if save_as_h5: with h5py.File(file_path.parent / f"{file_path.stem}.h5-jpk", "a") as h5file: + # Save data required for reading the h5 file as a normal image file meas_grp = h5file.require_group("Measurement_000") meas_grp.attrs["position-pattern.grid.ulength"] = size_x meas_grp.attrs["position-pattern.grid.ilength"] = shape_x @@ -240,9 +242,12 @@ def load_jpk_qi( for file_name in qi_archive.namelist(): if file_name.endswith(".jpk-qi-image"): path_to_image = file_name + # Add the channels which exist in the jpk-qi-image file with qi_archive.open(path_to_image, "r") as image_file: h5_channels += jpk._get_jpk_channels(file=image_file, filename=file_path.stem, file_path=file_path / Path(path_to_image)) for i, h5_channel in enumerate(h5_channels): + # For each available channel, save the required data to the h5 file + # TODO make sure this metadata is accurate for the channels coming from the .jpk-qi-image file chan_grp = meas_grp.require_group(f"Channel_{_make_num_min_characters(i)}") chan_grp.attrs["channel.name"] = h5_channel.encode("utf-8") chan_grp.attrs["retrace"] = "false".encode("utf-8") @@ -271,26 +276,19 @@ def load_fdcurves_from_h5(file_path: Path | str): meas_grp = h5file["Measurement_000"] shape_x = meas_grp.attrs["position-pattern.grid.ilength"] shape_y = meas_grp.attrs["position-pattern.grid.jlength"] - size_x = meas_grp.attrs["position-pattern.grid.ulength"] - size_y = meas_grp.attrs["position-pattern.grid.vlength"] - - pixel_to_nm_scaling_factor_x = size_x / shape_x * 1e9 if shape_x > 0 else 1.0 - pixel_to_nm_scaling_factor_y = size_y / shape_y * 1e9 if shape_y > 0 else 1.0 - px2nm = (pixel_to_nm_scaling_factor_x + pixel_to_nm_scaling_factor_y) / 2 - - image = np.zeros((shape_y, shape_x), dtype=np.float32) - segment_0_group = h5file["QI_Data"]["Segment_0"] - channel_datasets = {name: segment_0_group[name] for name in segment_0_group.keys()} - for y in range(shape_y): - for x in range(shape_x): - curve_num = (shape_x * y) + x - curve_dict = {} - for chan_name, dataset in channel_datasets.items(): - curve_dict[chan_name] = dataset[curve_num] - image[y, x] = _find_contact_point(curve=curve_dict) - image = image * 1e9 - - return image, px2nm + num_of_curves = shape_x * shape_y + qi_data_group = h5file["QI_Curve_Data"] + all_curve_data = [] + for i in range(num_of_curves): + curve_data = {} + for direction, direction_group in qi_data_group.items(): + for channel, channel_group in direction_group.items(): + if channel not in curve_data: + curve_data[channel] = {} + curve_data[channel][direction] = channel_group[str(i)] + all_curve_data.append(curve_data) + + return all_curve_data def _make_num_min_characters(num : int, min_chars: int = 3): From 405a95dc043e4034f42b760b7f32760dddc8ff3e Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Fri, 27 Feb 2026 18:16:13 +0000 Subject: [PATCH 10/54] Improving speed of loading qi curve data by loading all the curves at once, then extracting separately --- AFMReader/h5_jpk.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index 6bf3cdf..425ce29 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -355,24 +355,27 @@ def load_h5jpk( if "QI_Curve_Data" in f: logger.info(f"[{file_path.stem}] : Found Force Curves QI data in file.") qi_data_group = f["QI_Curve_Data"] + loaded_channels_data = {} + for direction in ["Segment_0", "Segment_1"]: + if direction in qi_data_group: + loaded_channels_data[direction] = {} + for channel, channel_group in qi_data_group[direction].items(): + if channel != "error": + loaded_channels_data[direction][channel] = channel_group[:] + all_curve_data = [] for y in range(shape_y): row = [] for x in range(shape_x): curve_num = shape_x * y + x curve_data = {} - for direction in ["Segment_0", "Segment_1"]: - if direction not in qi_data_group: - continue - direction_group = qi_data_group[direction] - for channel, channel_group in direction_group.items(): - if channel == "error": - continue + + for direction, channels in loaded_channels_data.items(): + for channel, data_array in channels.items(): if channel not in curve_data: curve_data[channel] = {} - print(f"Curve num: {curve_num} channel {channel} direction {direction}") + curve_data[channel][direction] = data_array[curve_num] - curve_data[channel][direction] = channel_group[curve_num] row.append(curve_data) all_curve_data.append(row) return (image_stack, _jpk_pixel_to_nm_scaling_h5(measurement_group), all_curve_data, timestamps) From 31985675ac727fdbd447572f3f887437b031cb7e Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Fri, 27 Feb 2026 20:04:20 +0000 Subject: [PATCH 11/54] Making loading jpk-qi-data return all the curves --- AFMReader/general_loader.py | 9 ++++++++- AFMReader/jpk_qi.py | 28 ++++++++++++++++++++-------- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index ba21271..f3d89eb 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -75,7 +75,14 @@ def load(self) -> tuple[npt.NDArray | str, float | None]: # noqa: C901 else: logger.error(f"Loading h5-jpk file returned unexpected number of items: {len(h5_returned)}") elif self.suffix == ".jpk-qi-data": - image, pixel_to_nanometre_scaling_factor = jpk_qi.load_jpk_qi(self.filepath, self.channel) + jpk_qi_returned = jpk_qi.load_jpk_qi(self.filepath, self.channel) + if len(jpk_qi_returned) == 2: + image, pixel_to_nanometre_scaling_factor = jpk_qi_returned + elif len(jpk_qi_returned) == 3: + image, pixel_to_nanometre_scaling_factor, curve_data = jpk_qi_returned + return image, pixel_to_nanometre_scaling_factor, curve_data + else: + logger.error(f"Loading h5-jpk file returned unexpected number of items: {len(jpk_qi_returned)}") elif self.suffix == ".stp": image, pixel_to_nanometre_scaling_factor = stp.load_stp(self.filepath) elif self.suffix == ".top": diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 0f241ef..8fc6d51 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -67,6 +67,7 @@ def load_jpk_qi( # Load the file path passed to the function file_path = Path(file_path) + all_curve_data = None with zipfile.ZipFile(file_path, "r") as qi_archive: if channel not in ADDITIONAL_CHANNELS: path_to_image = None @@ -139,6 +140,7 @@ def load_jpk_qi( with h5_context as h5file: vlen_type = h5py.vlen_dtype(np.float32) num_of_curves = shape_x * shape_y + all_curve_data = [] if save_as_h5: curve_meta = [{} for _ in range(num_of_curves)] @@ -160,8 +162,10 @@ def load_jpk_qi( curve_datasets[f"{direction}_{ds_name}"] = dir_group[ds_name] for y in range(shape_y): + row = [] for x in range(shape_x): curve_num = shape_x * y + x + curve_data = {} if save_as_h5: with qi_archive.open(f"index/{curve_num}/header.properties") as curve_meta_file: curve_meta_raw = javaproperties.load(curve_meta_file) @@ -182,25 +186,31 @@ def load_jpk_qi( all_segment_keys.add(key) if curve_num != 0 and (key not in segment_meta[0] or segment_meta[0][key] != value): changing_segment_keys.add(key) - curve_data = {} + segment_dict = {} for segment_channel in segment_channels: try: - with qi_archive.open(f"index/{curve_num}/segments/{direction}/channels/{segment_channel['name']}.dat") as segment_data: + with qi_archive.open(f"index/{curve_num}/segments/{direction}/channels/{segment_channel['name']}.dat") as segment_raw: dtype_str = '>i4' - raw_bytes = segment_data.read() + raw_bytes = segment_raw.read() raw_array = np.frombuffer(raw_bytes, dtype=dtype_str) - metres_array = (raw_array * segment_channel["multiplier"]) + segment_channel["offset"] - curve_data[segment_channel['name']] = metres_array + segment_array = (raw_array * segment_channel["multiplier"]) + segment_channel["offset"] + segment_dict[segment_channel['name']] = segment_array if save_as_h5: - curve_datasets[f"{direction}_{segment_channel['name']}"][curve_num] = metres_array + curve_datasets[f"{direction}_{segment_channel['name']}"][curve_num] = segment_array + if segment_channel['name'] not in curve_data: + curve_data[segment_channel['name']] = {} + curve_data[segment_channel['name']][f"Segment_{direction}"] = segment_array + except KeyError: break if channel == "contactPoint": if direction == 0: - image[y, x] = _find_contact_point(curve_data) + image[y, x] = _find_contact_point(segment_dict) elif channel == "manualTriggerPoint": if direction == 0: - image[y, x] = _find_trigger_point(curve_data) + image[y, x] = _find_trigger_point(segment_dict) + row.append(curve_data) + all_curve_data.append(row) if save_as_h5: # Move all the duplicated metadata to the top level metadata dict @@ -266,6 +276,8 @@ def load_jpk_qi( # Need to include flip image as _load_jpk flip image is set to false if flip_image: image = np.flipud(image) + if all_curve_data: + return (image, px2nm, all_curve_data) return image, px2nm From c25fb751bef867881c99e8c05c641d05b826c6dc Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 9 Mar 2026 10:56:58 +0000 Subject: [PATCH 12/54] Adding .bin files support --- AFMReader/bin.py | 57 +++++++++++++++++++++ AFMReader/general_loader.py | 27 +++++++--- AFMReader/h5_jpk.py | 11 +++-- AFMReader/jpk_qi.py | 99 +++++++++++++++++++++++++------------ 4 files changed, 152 insertions(+), 42 deletions(-) create mode 100644 AFMReader/bin.py diff --git a/AFMReader/bin.py b/AFMReader/bin.py new file mode 100644 index 0000000..3209215 --- /dev/null +++ b/AFMReader/bin.py @@ -0,0 +1,57 @@ +import numpy as np +import math +from .logging import logger +from pathlib import Path + +DTYPE_MAP = { + 'IEEE double': np.float64, + 'DBL': np.float64, + 'IEEE single': np.float32, + 'SGL': np.float32, + 'U32': np.uint32, + 'I32': np.int32, + 'U16': np.uint16, + 'I16': np.int16, + 'U8': np.uint8, + 'I8': np.int8, + 'float64': np.float64, + 'float32': np.float32, + 'int32': np.int32, +} + +def load_bin(filepath: str | Path, data_type: str, offset_bytes: int, size_x: float = None, size_y: float = None, shape_x: int = None, shape_y: int = None, z_scaling: float = 1.0): + filepath = Path(filepath) + dt_key = str(data_type).strip() + shape_x = None if shape_x == 0 else shape_x + shape_y = None if shape_y == 0 else shape_y + + if dt_key in DTYPE_MAP: + np_dtype = DTYPE_MAP[dt_key] + else: + logger.warning(f"Unknown data type '{dt_key}'. Defaulting to float64.") + np_dtype = np.float64 + with open(filepath, 'rb') as f: + f.seek(offset_bytes) + flat_data = np.fromfile(f, dtype=np_dtype) + if None in [shape_x, shape_y]: + dimension = int(math.sqrt(len(flat_data))) + shape_x, shape_y = dimension, dimension + if shape_x * shape_y != len(flat_data): + logger.error(f"Loading binary file {filepath.stem} did not receive a shape and is not square") + image = flat_data.reshape((shape_x, shape_y)) + image *= z_scaling + pixel_to_nm_scaling_factor_x = size_x / shape_x if shape_x > 0 else 1.0 + pixel_to_nm_scaling_factor_y = size_y / shape_y if shape_y > 0 else 1.0 + px2nm = (pixel_to_nm_scaling_factor_x + pixel_to_nm_scaling_factor_y) / 2 + return image, px2nm + +def get_bin_channels(): + kwarg_types = {"data_type" : (str, DTYPE_MAP.keys()), + "offset_bytes": int, + "size_x": float, + "size_y": float, + "shape_x": int, + "shape_y": int, + "z_scaling": float} + return [], kwarg_types + diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index f3d89eb..7bbd5c0 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -4,9 +4,10 @@ import numpy.typing as npt -from AFMReader import asd, gwy, h5_jpk, ibw, jpk, spm, stp, top, topostats, jpk_qi +from AFMReader import asd, gwy, h5_jpk, ibw, jpk, spm, stp, top, topostats, jpk_qi, bin from AFMReader.logging import logger + logger.enable(__package__) @@ -23,7 +24,7 @@ class LoadFile: Channel to extract from the AFM image. """ - def __init__(self, filepath: str | Path, channel: str): + def __init__(self, filepath: str | Path, channel: str, kwargs: dict = None): """ Initialise the general LoadFile class with a filepath and channel. @@ -37,8 +38,10 @@ def __init__(self, filepath: str | Path, channel: str): self.filepath = Path(filepath) self.channel = channel self.suffix = self.filepath.suffix + self.loaded_curves = False + self.kwargs = kwargs - def load(self) -> tuple[npt.NDArray | str, float | None]: # noqa: C901 + def load(self, channel: str | None = None, kwargs: dict = None) -> tuple[npt.NDArray | str, float | None]: # noqa: C901 """ Generally loads a file type that can be handled by AFMReader. @@ -54,6 +57,10 @@ def load(self) -> tuple[npt.NDArray | str, float | None]: # noqa: C901 Where the channel is not found, returned as a tuple of "error message" and "None" so that this can be propagated to Napari without outright failing. """ + if channel: + self.channel = channel + if kwargs: + self.kwargs = kwargs try: if self.suffix == ".asd": image, pixel_to_nanometre_scaling_factor, _ = asd.load_asd(self.filepath, self.channel) @@ -66,20 +73,22 @@ def load(self) -> tuple[npt.NDArray | str, float | None]: # noqa: C901 elif self.suffix == ".spm": image, pixel_to_nanometre_scaling_factor = spm.load_spm(self.filepath, self.channel) elif self.suffix == ".h5-jpk": - h5_returned = h5_jpk.load_h5jpk(self.filepath, self.channel) + h5_returned = h5_jpk.load_h5jpk(self.filepath, self.channel, load_curves=not self.loaded_curves) if len(h5_returned) == 3: image, pixel_to_nanometre_scaling_factor, _ = h5_returned elif len(h5_returned) == 4: image, pixel_to_nanometre_scaling_factor, curve_data, _ = h5_returned + self.loaded_curves = True return image, pixel_to_nanometre_scaling_factor, curve_data else: logger.error(f"Loading h5-jpk file returned unexpected number of items: {len(h5_returned)}") elif self.suffix == ".jpk-qi-data": - jpk_qi_returned = jpk_qi.load_jpk_qi(self.filepath, self.channel) + jpk_qi_returned = jpk_qi.load_jpk_qi(self.filepath, self.channel, **self.kwargs) if len(jpk_qi_returned) == 2: image, pixel_to_nanometre_scaling_factor = jpk_qi_returned elif len(jpk_qi_returned) == 3: image, pixel_to_nanometre_scaling_factor, curve_data = jpk_qi_returned + self.loaded_curves = True return image, pixel_to_nanometre_scaling_factor, curve_data else: logger.error(f"Loading h5-jpk file returned unexpected number of items: {len(jpk_qi_returned)}") @@ -99,6 +108,8 @@ def load(self) -> tuple[npt.NDArray | str, float | None]: # noqa: C901 f"'{self.channel}' not in available image keys: " f"{[im for im in image_keys if im in topostats_keys]}" ) from exc + elif self.suffix == ".bin": + image, pixel_to_nanometre_scaling_factor = bin.load_bin(self.filepath, **self.kwargs) else: raise ValueError(f"File type '{self.suffix}' is not currently handled by AFMReader.") @@ -124,10 +135,12 @@ def get_available_channels(self): available_channels = h5_jpk.get_h5jpk_channels(self.filepath) elif self.suffix == ".jpk-qi-data": available_channels = jpk_qi.get_jpk_qi_channels(self.filepath) - elif self.suffix in [".stp", ".top"]: - available_channels = stp.load_stp(self.filepath) elif self.suffix == ".topostats": available_channels = ["image", "image_original"] + elif self.suffix == ".bin": + available_channels = bin.get_bin_channels() + elif self.suffix in [".stp", ".top"]: + return [] else: raise ValueError(f"File type '{self.suffix}' is not currently handled by AFMReader.") return available_channels diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index 425ce29..5384758 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -273,7 +273,7 @@ def get_h5jpk_channels(file_path: Path | str): return available_channels def load_h5jpk( - file_path: Path | str, channel: str, flip_image: bool = True + file_path: Path | str, channel: str, flip_image: bool = True, load_curves: bool = True ) -> tuple[np.ndarray, float, dict[str, float]]: """ Load image from JPK Instruments .h5-jpk files. @@ -352,10 +352,15 @@ def load_h5jpk( logger.info(f"[{file_path.stem}] : Extracted {num_frames} frames from channel '{channel}'") - if "QI_Curve_Data" in f: + if load_curves and "QI_Curve_Data" in f: logger.info(f"[{file_path.stem}] : Found Force Curves QI data in file.") qi_data_group = f["QI_Curve_Data"] loaded_channels_data = {} + channels_units = {} + for key, value in qi_data_group["Global_Metadata"].attrs.items(): + if key.startswith("channel.unit."): + channels_units[key.split(".")[-1]] = value + for direction in ["Segment_0", "Segment_1"]: if direction in qi_data_group: loaded_channels_data[direction] = {} @@ -378,6 +383,6 @@ def load_h5jpk( row.append(curve_data) all_curve_data.append(row) - return (image_stack, _jpk_pixel_to_nm_scaling_h5(measurement_group), all_curve_data, timestamps) + return (image_stack, _jpk_pixel_to_nm_scaling_h5(measurement_group), (all_curve_data, channels_units), timestamps) return (image_stack, _jpk_pixel_to_nm_scaling_h5(measurement_group), timestamps) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 8fc6d51..a2b4de1 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -24,12 +24,14 @@ def _get_channel_scaling(props, channel_index): current_slot = props.get(f"{prefix}conversion-set.conversions.default") if not current_slot: - mult = float(props[f"{prefix}encoder.scaling.multiplier"]) - off = float(props[f"{prefix}encoder.scaling.offset"]) - return mult, off + mult = float(props.get(f"{prefix}encoder.scaling.multiplier", "1.0")) + off = float(props.get(f"{prefix}encoder.scaling.offset", "0.0")) + unit = props.get(f"{prefix}encoder.scaling.unit.unit", "Unknown") + return mult, off, unit cumulative_multiplier = 1.0 cumulative_offset = 0.0 + unit = props.get(f"{prefix}conversion-set.conversion.{current_slot}.scaling.unit.unit") while current_slot: slot_prefix = f"{prefix}conversion-set.conversion.{current_slot}." @@ -53,8 +55,24 @@ def _get_channel_scaling(props, channel_index): final_multiplier = cumulative_multiplier * enc_m final_offset = (cumulative_multiplier * enc_c) + cumulative_offset + if not unit: + unit = props.get(f"{prefix}encoder.scaling.unit.unit", "Unknown") - return final_multiplier, final_offset + return final_multiplier, final_offset, unit + +def _load_preprocessed_image(qi_archive, channel, config_path=None): + path_to_image = None + for file_name in qi_archive.namelist(): + if file_name.endswith(".jpk-qi-image"): + path_to_image = file_name + if path_to_image not in qi_archive.namelist(): + raise FileNotFoundError(f"{path_to_image} not found in JPK archive") + + tif_bytes = qi_archive.read(path_to_image) + + virtual_file = io.BytesIO(tif_bytes) + logger.info(f"Looking for channel {channel} in ") + return jpk._load_jpk(virtual_file, path_to_image, channel=channel, file_suffix=".jpk-qi-data", config_path=config_path, flip_image=False) def load_jpk_qi( @@ -69,19 +87,8 @@ def load_jpk_qi( file_path = Path(file_path) all_curve_data = None with zipfile.ZipFile(file_path, "r") as qi_archive: - if channel not in ADDITIONAL_CHANNELS: - path_to_image = None - for file_name in qi_archive.namelist(): - if file_name.endswith(".jpk-qi-image"): - path_to_image = file_name - if path_to_image not in qi_archive.namelist(): - raise FileNotFoundError(f"{path_to_image} not found in JPK archive") - - tif_bytes = qi_archive.read(path_to_image) - - virtual_file = io.BytesIO(tif_bytes) - logger.info(f"Looking for channel {channel} in ") - image, px2nm = jpk._load_jpk(virtual_file, path_to_image, channel=channel, file_suffix=".jpk-qi-data", config_path=config_path, flip_image=False) + if channel not in ADDITIONAL_CHANNELS and not save_as_h5: + image, px2nm = _load_preprocessed_image(qi_archive=qi_archive, channel=channel, config_path=config_path) else: if save_as_h5: @@ -127,9 +134,10 @@ def load_jpk_qi( while f"lcd-info.{channel_i}.channel.name" in shared_meta: channel_dict = {} channel_dict["name"] = shared_meta[f"lcd-info.{channel_i}.channel.name"] - multiplier, offset = _get_channel_scaling(shared_meta, channel_i) + multiplier, offset, unit = _get_channel_scaling(shared_meta, channel_i) channel_dict["offset"] = offset channel_dict["multiplier"] = multiplier + channel_dict["unit"] = unit segment_channels.append(channel_dict) channel_i += 1 @@ -178,14 +186,17 @@ def load_jpk_qi( for direction in range(2): if save_as_h5: - with qi_archive.open(f"index/{curve_num}/segments/{direction}/segment-header.properties") as segment_meta_file: - segment_meta_raw = javaproperties.load(segment_meta_file) - for key, value in segment_meta_raw.items(): - key = ".".join(key.split(".")[1:]) - segment_meta[curve_num * 2 + direction][key] = value - all_segment_keys.add(key) - if curve_num != 0 and (key not in segment_meta[0] or segment_meta[0][key] != value): - changing_segment_keys.add(key) + try: + with qi_archive.open(f"index/{curve_num}/segments/{direction}/segment-header.properties") as segment_meta_file: + segment_meta_raw = javaproperties.load(segment_meta_file) + for key, value in segment_meta_raw.items(): + key = ".".join(key.split(".")[1:]) + segment_meta[curve_num * 2 + direction][key] = value + all_segment_keys.add(key) + if curve_num != 0 and (key not in segment_meta[0] or segment_meta[0][key] != value): + changing_segment_keys.add(key) + except KeyError: + pass segment_dict = {} for segment_channel in segment_channels: try: @@ -211,6 +222,11 @@ def load_jpk_qi( image[y, x] = _find_trigger_point(segment_dict) row.append(curve_data) all_curve_data.append(row) + if channel not in ADDITIONAL_CHANNELS: + image, px2nm = _load_preprocessed_image(qi_archive=qi_archive, channel=channel, config_path=config_path) + channels_units = {} + for segment_channel in segment_channels: + channels_units[segment_channel['name']] = segment_channel['unit'] if save_as_h5: # Move all the duplicated metadata to the top level metadata dict @@ -221,7 +237,12 @@ def load_jpk_qi( for key in all_segment_keys - changing_segment_keys: top_level_meta[f"segment.{key}"] = segment_meta[0][key] for segment_metadata in segment_meta: - segment_metadata.pop(key) + try: + segment_metadata.pop(key) + except KeyError: + pass + for segment_channel in segment_channels: + global_meta_group.attrs[f"channel.unit.{segment_channel['name']}"] = segment_channel['unit'] for key, value in top_level_meta.items(): global_meta_group.attrs[key] = str(value).encode('utf-8') for i, curve_metadata in enumerate(curve_meta): @@ -248,6 +269,8 @@ def load_jpk_qi( meas_grp.attrs["position-pattern.grid.jlength"] = shape_y meas_grp.attrs["timing-settings.scanRate"] = 1.0 # Dummy value to satisfy reader + logger.info(f"Saving a hdf5 copy of the data {file_path.parent / f'{file_path.stem}.h5-jpk'}") + h5_channels = [channel] for file_name in qi_archive.namelist(): if file_name.endswith(".jpk-qi-image"): @@ -259,14 +282,25 @@ def load_jpk_qi( # For each available channel, save the required data to the h5 file # TODO make sure this metadata is accurate for the channels coming from the .jpk-qi-image file chan_grp = meas_grp.require_group(f"Channel_{_make_num_min_characters(i)}") - chan_grp.attrs["channel.name"] = h5_channel.encode("utf-8") - chan_grp.attrs["retrace"] = "false".encode("utf-8") + if "_" in h5_channel: + base_name, trace_dir = h5_channel.rsplit("_", 1) + is_retrace = "true" if trace_dir.lower() == "retrace" else "false" + else: + base_name = h5_channel + is_retrace = "false" + + chan_grp.attrs["channel.name"] = base_name.encode("utf-8") + chan_grp.attrs["retrace"] = is_retrace.encode("utf-8") chan_grp.attrs["net-encoder.scaling.multiplier"] = 1.0 chan_grp.attrs["net-encoder.scaling.offset"] = 0.0 # Format name and reshape image (flattened frame stack) dataset_name = h5_channel.split("_")[0].capitalize() - frame_stack = image.flatten().reshape(-1, 1) + if h5_channel == channel: + channel_image = image + else: + channel_image, _ = _load_preprocessed_image(qi_archive=qi_archive, channel=h5_channel, config_path=config_path) + frame_stack = channel_image.flatten().reshape(-1, 1) if dataset_name in chan_grp: del chan_grp[dataset_name] @@ -277,10 +311,11 @@ def load_jpk_qi( if flip_image: image = np.flipud(image) if all_curve_data: - return (image, px2nm, all_curve_data) + return (image, px2nm, (all_curve_data, channels_units)) return image, px2nm + def load_fdcurves_from_h5(file_path: Path | str): file_path = Path(file_path) @@ -320,7 +355,7 @@ def get_jpk_qi_channels(file_path: Path | str): with qi_archive.open(path_to_image, "r") as image_file: channels += jpk._get_jpk_channels(file=image_file, filename=file_path.stem, file_path=file_path / Path(path_to_image)) channels += ADDITIONAL_CHANNELS - return channels + return channels, {"save_as_h5": bool} def _find_contact_point(curve): # find contact point in vertical deflection by peak in first derivative From 085b2527b69b0fd9baacd8104e768ad30c4bc975 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 9 Mar 2026 14:05:16 +0000 Subject: [PATCH 13/54] Adding returning of metadata --- AFMReader/h5_jpk.py | 36 +++++++ AFMReader/jpk_qi.py | 241 +++++++++++++++++++++++++------------------- 2 files changed, 172 insertions(+), 105 deletions(-) diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index 5384758..003bd3a 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -357,9 +357,45 @@ def load_h5jpk( qi_data_group = f["QI_Curve_Data"] loaded_channels_data = {} channels_units = {} + top_level_meta = {} for key, value in qi_data_group["Global_Metadata"].attrs.items(): if key.startswith("channel.unit."): channels_units[key.split(".")[-1]] = value + top_level_meta[key] = value + if "Curve_Metadata" in qi_data_group: + curves_group = qi_data_group["Curve_Metadata"] + num_of_curves = len(curves_group.keys()) + + # Pre-allocate the lists + curve_meta = [{} for _ in range(num_of_curves)] + segment_meta = [{} for _ in range(num_of_curves * 2)] + + # Iterate through the curve groups + for i_str in curves_group.keys(): + i = int(i_str) + c_group = curves_group[i_str] + + # Extract curve-specific attributes + for key, val in c_group.attrs.items(): + if isinstance(val, bytes): + val = val.decode('utf-8') + curve_meta[i][key] = val + + # Extract segment-specific attributes (usually '0' for trace, '1' for retrace) + for d_str in ['0', '1']: + if d_str in c_group: + s_group = c_group[d_str] + idx = i * 2 + int(d_str) + for key, val in s_group.attrs.items(): + if isinstance(val, bytes): + val = val.decode('utf-8') + segment_meta[idx][key] = val + + full_metadata = { + "top_level": top_level_meta, + "curves": curve_meta, + "segments": segment_meta + } for direction in ["Segment_0", "Segment_1"]: if direction in qi_data_group: diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index a2b4de1..b58ad35 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -1,6 +1,7 @@ from pathlib import Path from contextlib import nullcontext import io +import re import zipfile import numpy as np @@ -65,7 +66,7 @@ def _load_preprocessed_image(qi_archive, channel, config_path=None): for file_name in qi_archive.namelist(): if file_name.endswith(".jpk-qi-image"): path_to_image = file_name - if path_to_image not in qi_archive.namelist(): + if path_to_image is None: raise FileNotFoundError(f"{path_to_image} not found in JPK archive") tif_bytes = qi_archive.read(path_to_image) @@ -86,33 +87,28 @@ def load_jpk_qi( # Load the file path passed to the function file_path = Path(file_path) all_curve_data = None + channels_units = {} + + # Initialize metadata containers + top_level_meta = {} + full_metadata = {} + with zipfile.ZipFile(file_path, "r") as qi_archive: if channel not in ADDITIONAL_CHANNELS and not save_as_h5: image, px2nm = _load_preprocessed_image(qi_archive=qi_archive, channel=channel, config_path=config_path) else: - if save_as_h5: - top_level_meta = {} - changing_curve_keys = set() - changing_segment_keys = set() - all_curve_keys = set() - all_segment_keys = set() - - with qi_archive.open("header.properties") as archive_meta_file: - props = javaproperties.load(archive_meta_file) - if save_as_h5: + if "header.properties" in qi_archive.namelist(): + with qi_archive.open("header.properties") as archive_meta_file: + props = javaproperties.load(archive_meta_file) for key, value in props.items(): top_level_meta[f"shared-data.{key}"] = value - size_x, size_y, shape_x, shape_y = None, None, None, None - for key, value in props.items(): - if key.endswith(".ulength"): - size_x = float(value) - if key.endswith(".vlength"): - size_y = float(value) - if key.endswith(".ilength"): - shape_x = int(value) - if key.endswith(".jlength"): - shape_y = int(value) + + # Parse dimensions + size_x = float(props.get("position-pattern.grid.ulength", 0)) if "position-pattern.grid.ulength" in props else None + size_y = float(props.get("position-pattern.grid.vlength", 0)) if "position-pattern.grid.vlength" in props else None + shape_x = int(props.get("position-pattern.grid.ilength", 0)) if "position-pattern.grid.ilength" in props else None + shape_y = int(props.get("position-pattern.grid.jlength", 0)) if "position-pattern.grid.jlength" in props else None if None in [size_x, size_y, shape_x, shape_y]: @@ -125,39 +121,57 @@ def load_jpk_qi( px2nm = (pixel_to_nm_scaling_factor_x + pixel_to_nm_scaling_factor_y) / 2 segment_channels = [] - with qi_archive.open("shared-data/header.properties") as shared_data_file: - shared_meta = javaproperties.load(shared_data_file) - channel_i = 0 - if save_as_h5: + if "shared-data/header.properties" in qi_archive.namelist(): + with qi_archive.open("shared-data/header.properties") as shared_data_file: + shared_meta = javaproperties.load(shared_data_file) + channel_i = 0 for key, value in shared_meta.items(): top_level_meta[f"shared-data.{key}"] = value - while f"lcd-info.{channel_i}.channel.name" in shared_meta: - channel_dict = {} - channel_dict["name"] = shared_meta[f"lcd-info.{channel_i}.channel.name"] - multiplier, offset, unit = _get_channel_scaling(shared_meta, channel_i) - channel_dict["offset"] = offset - channel_dict["multiplier"] = multiplier - channel_dict["unit"] = unit - segment_channels.append(channel_dict) - channel_i += 1 + + while f"lcd-info.{channel_i}.channel.name" in shared_meta: + channel_dict = {"name": shared_meta[f"lcd-info.{channel_i}.channel.name"]} + multiplier, offset, unit = _get_channel_scaling(shared_meta, channel_i) + channel_dict["offset"] = offset + channel_dict["multiplier"] = multiplier + channel_dict["unit"] = unit + segment_channels.append(channel_dict) + channel_i += 1 if len(segment_channels) == 0: logger.error("Could not find channels for segments") + channels_units = {seg_chan['name'] : seg_chan['unit'] for seg_chan in segment_channels} + h5_context = h5py.File(file_path.parent / f"{file_path.stem}.h5-jpk", "a") if save_as_h5 else nullcontext() with h5_context as h5file: - vlen_type = h5py.vlen_dtype(np.float32) num_of_curves = shape_x * shape_y - all_curve_data = [] + + # Pre-allocate data structures + curve_meta_dict = {} + segment_meta_dict = {} + flat_curve_data = [{} for _ in range(num_of_curves)] + all_curve_keys = set() + all_segment_keys = set() + + # Lookup map for binary scaling + chan_scaling = {chan["name"]: chan for chan in segment_channels} + + # Compile Regexes + dat_regex = re.compile(r"index/(\d+)/segments/(\d+)/channels/([^/]+)\.dat") + curve_meta_regex = re.compile(r"index/(\d+)/header\.properties") + segment_meta_regex = re.compile(r"index/(\d+)/segments/(\d+)/segment-header\.properties") + + # Setup H5 Data structures if needed + curve_datasets = {} + if save_as_h5: - curve_meta = [{} for _ in range(num_of_curves)] - segment_meta = [{} for _ in range(num_of_curves * 2)] + vlen_type = h5py.vlen_dtype(np.float32) qi_group = h5file.require_group("QI_Curve_Data") global_meta_group = qi_group.require_group("Global_Metadata") curves_meta_group = qi_group.require_group("Curve_Metadata") - curve_datasets = {} + # curve_datasets = {} for direction in range(2): dir_group = qi_group.require_group(f"Segment_{direction}") for seg_chan in segment_channels: @@ -169,85 +183,101 @@ def load_jpk_qi( else: curve_datasets[f"{direction}_{ds_name}"] = dir_group[ds_name] + for file_info in qi_archive.infolist(): + filename = file_info.filename + + # Check Binary Data + dat_match = dat_regex.match(filename) + if dat_match: + curve_num, direction, chan_name = int(dat_match.group(1)), int(dat_match.group(2)), dat_match.group(3) + if chan_name in chan_scaling: + scale = chan_scaling[chan_name] + with qi_archive.open(file_info) as f: + raw_array = np.frombuffer(f.read(), dtype='>i4') + segment_array = (raw_array * scale["multiplier"]) + scale["offset"] + + if chan_name not in flat_curve_data[curve_num]: + flat_curve_data[curve_num][chan_name] = {} + flat_curve_data[curve_num][chan_name][f"Segment_{direction}"] = segment_array + + if save_as_h5: + curve_datasets[f"{direction}_{chan_name}"][curve_num] = segment_array + continue + + # Check Curve Metadata + c_match = curve_meta_regex.match(filename) + if c_match: + curve_num = int(c_match.group(1)) + with qi_archive.open(file_info) as f: + cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} + curve_meta_dict[curve_num] = cleaned_meta + all_curve_keys.update(cleaned_meta.keys()) + continue + + # Check Segment Metadata + s_match = segment_meta_regex.match(filename) + if s_match: + curve_num, direction = int(s_match.group(1)), int(s_match.group(2)) + idx = curve_num * 2 + direction + with qi_archive.open(file_info) as f: + cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} + segment_meta_dict[idx] = cleaned_meta + all_segment_keys.update(cleaned_meta.keys()) + + curve_meta = [curve_meta_dict.get(i, {}) for i in range(num_of_curves)] + segment_meta = [segment_meta_dict.get(i, {}) for i in range(num_of_curves * 2)] + + # Find keys that change across curves/segments + changing_curve_keys = {k for k in all_curve_keys if any(curve_meta[i].get(k) != curve_meta[0].get(k) for i in range(1, num_of_curves))} + changing_segment_keys = {k for k in all_segment_keys if any(segment_meta[i].get(k) != segment_meta[0].get(k) for i in range(1, len(segment_meta)))} + + all_curve_data = [] for y in range(shape_y): row = [] for x in range(shape_x): - curve_num = shape_x * y + x - curve_data = {} - if save_as_h5: - with qi_archive.open(f"index/{curve_num}/header.properties") as curve_meta_file: - curve_meta_raw = javaproperties.load(curve_meta_file) - for key, value in curve_meta_raw.items(): - key = ".".join(key.split(".")[1:]) - curve_meta[curve_num][key] = value - all_curve_keys.add(key) - if curve_num != 0 and (key not in curve_meta[0] or curve_meta[0][key] != value): - changing_curve_keys.add(key) - - for direction in range(2): - if save_as_h5: - try: - with qi_archive.open(f"index/{curve_num}/segments/{direction}/segment-header.properties") as segment_meta_file: - segment_meta_raw = javaproperties.load(segment_meta_file) - for key, value in segment_meta_raw.items(): - key = ".".join(key.split(".")[1:]) - segment_meta[curve_num * 2 + direction][key] = value - all_segment_keys.add(key) - if curve_num != 0 and (key not in segment_meta[0] or segment_meta[0][key] != value): - changing_segment_keys.add(key) - except KeyError: - pass - segment_dict = {} - for segment_channel in segment_channels: - try: - with qi_archive.open(f"index/{curve_num}/segments/{direction}/channels/{segment_channel['name']}.dat") as segment_raw: - dtype_str = '>i4' - raw_bytes = segment_raw.read() - raw_array = np.frombuffer(raw_bytes, dtype=dtype_str) - segment_array = (raw_array * segment_channel["multiplier"]) + segment_channel["offset"] - segment_dict[segment_channel['name']] = segment_array - if save_as_h5: - curve_datasets[f"{direction}_{segment_channel['name']}"][curve_num] = segment_array - if segment_channel['name'] not in curve_data: - curve_data[segment_channel['name']] = {} - curve_data[segment_channel['name']][f"Segment_{direction}"] = segment_array - - except KeyError: - break + curve_num = y * shape_x + x + curve_data = flat_curve_data[curve_num] + row.append(curve_data) + + # Calculate on-the-fly image data if required + if channel in ADDITIONAL_CHANNELS: + seg_0_dict = {c: data["Segment_0"] for c, data in curve_data.items() if "Segment_0" in data} if channel == "contactPoint": - if direction == 0: - image[y, x] = _find_contact_point(segment_dict) + image[y, x] = _find_contact_point(seg_0_dict) elif channel == "manualTriggerPoint": - if direction == 0: - image[y, x] = _find_trigger_point(segment_dict) - row.append(curve_data) + image[y, x] = _find_trigger_point(seg_0_dict) all_curve_data.append(row) if channel not in ADDITIONAL_CHANNELS: image, px2nm = _load_preprocessed_image(qi_archive=qi_archive, channel=channel, config_path=config_path) - channels_units = {} - for segment_channel in segment_channels: - channels_units[segment_channel['name']] = segment_channel['unit'] - if save_as_h5: - # Move all the duplicated metadata to the top level metadata dict - for key in all_curve_keys - changing_curve_keys: + # Move duplicated meta to top level + for key in all_curve_keys - changing_curve_keys: + if curve_meta and key in curve_meta[0]: top_level_meta[f"curve.{key}"] = curve_meta[0][key] - for curve_metadata in curve_meta: - curve_metadata.pop(key) - for key in all_segment_keys - changing_segment_keys: + for key in all_segment_keys - changing_segment_keys: + if segment_meta and key in segment_meta[0]: top_level_meta[f"segment.{key}"] = segment_meta[0][key] - for segment_metadata in segment_meta: - try: - segment_metadata.pop(key) - except KeyError: - pass - for segment_channel in segment_channels: - global_meta_group.attrs[f"channel.unit.{segment_channel['name']}"] = segment_channel['unit'] + + # Strip duplicated keys from individual curve/segment dicts + for c_meta in curve_meta: + for k in all_curve_keys - changing_curve_keys: c_meta.pop(k, None) + for s_meta in segment_meta: + for k in all_segment_keys - changing_segment_keys: s_meta.pop(k, None) + + full_metadata = { + "top_level": top_level_meta, + "curves": curve_meta, + "segments": segment_meta + } + + if save_as_h5: + for seg_chan in segment_channels: + global_meta_group.attrs[f"channel.unit.{seg_chan['name']}"] = seg_chan['unit'] for key, value in top_level_meta.items(): global_meta_group.attrs[key] = str(value).encode('utf-8') - for i, curve_metadata in enumerate(curve_meta): + for i, c_meta in enumerate(curve_meta): curve_meta_group = curves_meta_group.require_group(f"{i}") - for key, value in curve_metadata.items(): + for key, value in c_meta.items(): curve_meta_group.attrs[key] = str(value).encode('utf-8') for d in range(2): @@ -275,6 +305,7 @@ def load_jpk_qi( for file_name in qi_archive.namelist(): if file_name.endswith(".jpk-qi-image"): path_to_image = file_name + break # Add the channels which exist in the jpk-qi-image file with qi_archive.open(path_to_image, "r") as image_file: h5_channels += jpk._get_jpk_channels(file=image_file, filename=file_path.stem, file_path=file_path / Path(path_to_image)) @@ -311,7 +342,7 @@ def load_jpk_qi( if flip_image: image = np.flipud(image) if all_curve_data: - return (image, px2nm, (all_curve_data, channels_units)) + return (image, px2nm, (all_curve_data, channels_units, full_metadata)) return image, px2nm From 06ec0dee106ba1115ee624f5f1d2468bb9b250fc Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Thu, 12 Mar 2026 15:54:18 +0000 Subject: [PATCH 14/54] Refactoring jpk-qi-data reader to use a loader class for greater modularity --- AFMReader/jpk_qi.py | 737 +++++++++++++++++++++++++++++--------------- 1 file changed, 490 insertions(+), 247 deletions(-) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index b58ad35..e6f0015 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -75,86 +75,75 @@ def _load_preprocessed_image(qi_archive, channel, config_path=None): logger.info(f"Looking for channel {channel} in ") return jpk._load_jpk(virtual_file, path_to_image, channel=channel, file_suffix=".jpk-qi-data", config_path=config_path, flip_image=False) - -def load_jpk_qi( - file_path: Path | str, - channel: str, - config_path: Path | str | None = None, - flip_image: bool | None = True, - save_as_h5: bool | None = False -) -> tuple[np.ndarray, float]: - - # Load the file path passed to the function - file_path = Path(file_path) - all_curve_data = None - channels_units = {} - - # Initialize metadata containers - top_level_meta = {} - full_metadata = {} - - with zipfile.ZipFile(file_path, "r") as qi_archive: - if channel not in ADDITIONAL_CHANNELS and not save_as_h5: - image, px2nm = _load_preprocessed_image(qi_archive=qi_archive, channel=channel, config_path=config_path) - - else: - if "header.properties" in qi_archive.namelist(): - with qi_archive.open("header.properties") as archive_meta_file: - props = javaproperties.load(archive_meta_file) - for key, value in props.items(): - top_level_meta[f"shared-data.{key}"] = value - - # Parse dimensions - size_x = float(props.get("position-pattern.grid.ulength", 0)) if "position-pattern.grid.ulength" in props else None - size_y = float(props.get("position-pattern.grid.vlength", 0)) if "position-pattern.grid.vlength" in props else None - shape_x = int(props.get("position-pattern.grid.ilength", 0)) if "position-pattern.grid.ilength" in props else None - shape_y = int(props.get("position-pattern.grid.jlength", 0)) if "position-pattern.grid.jlength" in props else None - - - if None in [size_x, size_y, shape_x, shape_y]: - logger.error(f"Incomplete dimension data in {file_path}") - - image = np.zeros((shape_y, shape_x), dtype=np.float32) - - pixel_to_nm_scaling_factor_x = size_x / shape_x * 1e9 if shape_x > 0 else 1.0 - pixel_to_nm_scaling_factor_y = size_y / shape_y * 1e9 if shape_y > 0 else 1.0 - px2nm = (pixel_to_nm_scaling_factor_x + pixel_to_nm_scaling_factor_y) / 2 - - segment_channels = [] - if "shared-data/header.properties" in qi_archive.namelist(): - with qi_archive.open("shared-data/header.properties") as shared_data_file: - shared_meta = javaproperties.load(shared_data_file) - channel_i = 0 - for key, value in shared_meta.items(): - top_level_meta[f"shared-data.{key}"] = value - - while f"lcd-info.{channel_i}.channel.name" in shared_meta: - channel_dict = {"name": shared_meta[f"lcd-info.{channel_i}.channel.name"]} - multiplier, offset, unit = _get_channel_scaling(shared_meta, channel_i) - channel_dict["offset"] = offset - channel_dict["multiplier"] = multiplier - channel_dict["unit"] = unit - segment_channels.append(channel_dict) - channel_i += 1 - - if len(segment_channels) == 0: - logger.error("Could not find channels for segments") - - channels_units = {seg_chan['name'] : seg_chan['unit'] for seg_chan in segment_channels} - - h5_context = h5py.File(file_path.parent / f"{file_path.stem}.h5-jpk", "a") if save_as_h5 else nullcontext() - with h5_context as h5file: - num_of_curves = shape_x * shape_y - - # Pre-allocate data structures - curve_meta_dict = {} - segment_meta_dict = {} - flat_curve_data = [{} for _ in range(num_of_curves)] - all_curve_keys = set() - all_segment_keys = set() - - # Lookup map for binary scaling - chan_scaling = {chan["name"]: chan for chan in segment_channels} +class jpk_qi_loader: + """Class for readability and improving modularity in the load jpk qi data function""" + def __init__(self, + filepath: Path | str, + channel: str, + config_path: Path | str | None = None, + flip_image: bool | None = True, + save_as: str | None = None): + """ + Initializes the loader with the provided parameters. + + Parameters + ---------- + filepath : Path | str + The path to the .jpk-qi file to be loaded. + channel : str + The specific channel to be extracted from the file (e.g., "measuredHeight")." + config_path : Path | str | None, optional + The path to the configuration file, if any. Default is None. + flip_image : bool | None, optional + Whether to flip the image vertically. Default is True. + save_as : str | None, optional + The format to save the loaded data. Default is None. + """ + + self.filepath = Path(filepath) + self.channel = channel + self.config_path = config_path + self.flip_image = flip_image + self.save_as = save_as + + # Initialize key attributes that will be returned / accessed frequently + + # Just the top level metadata extracted from the header files + self.top_level_meta = {} + # A dictionary containing all metadata, splitting the top level metadata and the metadata for each curve and segment + self.full_metadata = {} + + self.all_curve_data = None + # A lookup for channel name to unit to be returned + self.channels_units = {} + # The list of channels for the segments with their scaling information extracted from the shared header + self.segment_channels = [] + + # Dictionary of the metadata for each curve, indexed by curve number + self.curve_meta_dict = {} + # Dictionary of the metadata for each segment, indexed by segement number (calculated as curve_num * 2 + direction) + self.segment_meta_dict = {} + # The keys that exist in the curve metadata across all curves, used to determine which keys to move to the top level metadata + self.all_curve_keys = set() + # The keys that exist in the segment metadata across all segments, used to determine which keys to move to the top level metadata + self.all_segment_keys = set() + # Define the image shape and size attributes + self.size_x, self.size_y, self.shape_x, self.shape_y = None, None, None, None + + # Instantiate containers for data to be saved (so an exception is not caused if not saving) + self.curve_datasets = None + + def load(self): + with zipfile.ZipFile(self.filepath, "r") as qi_archive: + self.extract_global_metadata(qi_archive) + + self.parse_dimension_data() + + # Pre-allocate the image array based on the dimensions parsed from the metadata + self.image = np.zeros((self.shape_y, self.shape_x), dtype=np.float32) + + # Access the curve data and metadata, and save to given file format + with self.get_saving_context() as file: # Compile Regexes dat_regex = re.compile(r"index/(\d+)/segments/(\d+)/channels/([^/]+)\.dat") @@ -162,26 +151,8 @@ def load_jpk_qi( segment_meta_regex = re.compile(r"index/(\d+)/segments/(\d+)/segment-header\.properties") # Setup H5 Data structures if needed - curve_datasets = {} - - - if save_as_h5: - vlen_type = h5py.vlen_dtype(np.float32) - qi_group = h5file.require_group("QI_Curve_Data") - global_meta_group = qi_group.require_group("Global_Metadata") - curves_meta_group = qi_group.require_group("Curve_Metadata") - - # curve_datasets = {} - for direction in range(2): - dir_group = qi_group.require_group(f"Segment_{direction}") - for seg_chan in segment_channels: - ds_name = seg_chan["name"] - if ds_name not in dir_group: - curve_datasets[f"{direction}_{ds_name}"] = dir_group.create_dataset( - ds_name, shape=(num_of_curves,), dtype=vlen_type - ) - else: - curve_datasets[f"{direction}_{ds_name}"] = dir_group[ds_name] + if self.save_as == "h5": + self.curve_datasets, self.global_meta_group, self.curves_meta_group = self.setup_h5_structure(file) for file_info in qi_archive.infolist(): filename = file_info.filename @@ -189,162 +160,432 @@ def load_jpk_qi( # Check Binary Data dat_match = dat_regex.match(filename) if dat_match: + # If file is a .dat file, extract the curve number, segment direction and channel name from the filename curve_num, direction, chan_name = int(dat_match.group(1)), int(dat_match.group(2)), dat_match.group(3) - if chan_name in chan_scaling: - scale = chan_scaling[chan_name] - with qi_archive.open(file_info) as f: - raw_array = np.frombuffer(f.read(), dtype='>i4') - segment_array = (raw_array * scale["multiplier"]) + scale["offset"] - - if chan_name not in flat_curve_data[curve_num]: - flat_curve_data[curve_num][chan_name] = {} - flat_curve_data[curve_num][chan_name][f"Segment_{direction}"] = segment_array - - if save_as_h5: - curve_datasets[f"{direction}_{chan_name}"][curve_num] = segment_array + # Then load the data from the file + self.extract_dat_file(qi_archive, file_info, curve_num, direction, chan_name) continue # Check Curve Metadata - c_match = curve_meta_regex.match(filename) - if c_match: - curve_num = int(c_match.group(1)) - with qi_archive.open(file_info) as f: - cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} - curve_meta_dict[curve_num] = cleaned_meta - all_curve_keys.update(cleaned_meta.keys()) + curve_meta_match = curve_meta_regex.match(filename) + if curve_meta_match: + # If file is a curve metadata file, extract the curve number from the filename + curve_num = int(curve_meta_match.group(1)) + # Then load the metadata from the file + self.extract_curve_metadata(qi_archive, file_info, curve_num) continue # Check Segment Metadata - s_match = segment_meta_regex.match(filename) - if s_match: - curve_num, direction = int(s_match.group(1)), int(s_match.group(2)) - idx = curve_num * 2 + direction - with qi_archive.open(file_info) as f: - cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} - segment_meta_dict[idx] = cleaned_meta - all_segment_keys.update(cleaned_meta.keys()) - - curve_meta = [curve_meta_dict.get(i, {}) for i in range(num_of_curves)] - segment_meta = [segment_meta_dict.get(i, {}) for i in range(num_of_curves * 2)] - - # Find keys that change across curves/segments - changing_curve_keys = {k for k in all_curve_keys if any(curve_meta[i].get(k) != curve_meta[0].get(k) for i in range(1, num_of_curves))} - changing_segment_keys = {k for k in all_segment_keys if any(segment_meta[i].get(k) != segment_meta[0].get(k) for i in range(1, len(segment_meta)))} - - all_curve_data = [] - for y in range(shape_y): - row = [] - for x in range(shape_x): - curve_num = y * shape_x + x - curve_data = flat_curve_data[curve_num] - row.append(curve_data) - - # Calculate on-the-fly image data if required - if channel in ADDITIONAL_CHANNELS: - seg_0_dict = {c: data["Segment_0"] for c, data in curve_data.items() if "Segment_0" in data} - if channel == "contactPoint": - image[y, x] = _find_contact_point(seg_0_dict) - elif channel == "manualTriggerPoint": - image[y, x] = _find_trigger_point(seg_0_dict) - all_curve_data.append(row) - if channel not in ADDITIONAL_CHANNELS: - image, px2nm = _load_preprocessed_image(qi_archive=qi_archive, channel=channel, config_path=config_path) - - # Move duplicated meta to top level - for key in all_curve_keys - changing_curve_keys: - if curve_meta and key in curve_meta[0]: - top_level_meta[f"curve.{key}"] = curve_meta[0][key] - for key in all_segment_keys - changing_segment_keys: - if segment_meta and key in segment_meta[0]: - top_level_meta[f"segment.{key}"] = segment_meta[0][key] - - # Strip duplicated keys from individual curve/segment dicts - for c_meta in curve_meta: - for k in all_curve_keys - changing_curve_keys: c_meta.pop(k, None) - for s_meta in segment_meta: - for k in all_segment_keys - changing_segment_keys: s_meta.pop(k, None) - - full_metadata = { - "top_level": top_level_meta, - "curves": curve_meta, - "segments": segment_meta - } - - if save_as_h5: - for seg_chan in segment_channels: - global_meta_group.attrs[f"channel.unit.{seg_chan['name']}"] = seg_chan['unit'] - for key, value in top_level_meta.items(): - global_meta_group.attrs[key] = str(value).encode('utf-8') - for i, c_meta in enumerate(curve_meta): - curve_meta_group = curves_meta_group.require_group(f"{i}") - for key, value in c_meta.items(): - curve_meta_group.attrs[key] = str(value).encode('utf-8') - - for d in range(2): - segment_meta_group = curve_meta_group.require_group(f"{d}") - for key, value in segment_meta[i*2+d].items(): - segment_meta_group.attrs[key] = str(value).encode('utf-8') + segment_meta_match = segment_meta_regex.match(filename) + if segment_meta_match: + # If file is a segment metadata file, extract the curve number and segment direction from the filename + curve_num, direction = int(segment_meta_match.group(1)), int(segment_meta_match.group(2)) + # Then load the segment metadata from the file + self.extract_segment_metadata(qi_archive, file_info, curve_num, direction) + continue + + self.all_curve_data = self.process_flat_curve_data() + + # TODO can we remove curve_meta_dict and just use curve_meta or is the non duplicating necessary + self.curve_meta = [self.curve_meta_dict.get(i, {}) for i in range(self.num_of_curves)] + self.segment_meta = [self.segment_meta_dict.get(i, {}) for i in range(self.num_of_curves * 2)] + self.full_metadata = self.construct_full_metadata() + + # Save the full metadata to appropiate format if saving + self.save_metadata() + # Convert to nanometers if in meters - if channel in ADDITIONAL_CHANNELS_IN_M: - image = image * 1e9 - - if save_as_h5: - with h5py.File(file_path.parent / f"{file_path.stem}.h5-jpk", "a") as h5file: - # Save data required for reading the h5 file as a normal image file - meas_grp = h5file.require_group("Measurement_000") - meas_grp.attrs["position-pattern.grid.ulength"] = size_x - meas_grp.attrs["position-pattern.grid.ilength"] = shape_x - meas_grp.attrs["position-pattern.grid.vlength"] = size_y - meas_grp.attrs["position-pattern.grid.jlength"] = shape_y - meas_grp.attrs["timing-settings.scanRate"] = 1.0 # Dummy value to satisfy reader - - logger.info(f"Saving a hdf5 copy of the data {file_path.parent / f'{file_path.stem}.h5-jpk'}") - - h5_channels = [channel] - for file_name in qi_archive.namelist(): - if file_name.endswith(".jpk-qi-image"): - path_to_image = file_name - break - # Add the channels which exist in the jpk-qi-image file - with qi_archive.open(path_to_image, "r") as image_file: - h5_channels += jpk._get_jpk_channels(file=image_file, filename=file_path.stem, file_path=file_path / Path(path_to_image)) - for i, h5_channel in enumerate(h5_channels): - # For each available channel, save the required data to the h5 file - # TODO make sure this metadata is accurate for the channels coming from the .jpk-qi-image file - chan_grp = meas_grp.require_group(f"Channel_{_make_num_min_characters(i)}") - if "_" in h5_channel: - base_name, trace_dir = h5_channel.rsplit("_", 1) - is_retrace = "true" if trace_dir.lower() == "retrace" else "false" - else: - base_name = h5_channel - is_retrace = "false" - - chan_grp.attrs["channel.name"] = base_name.encode("utf-8") - chan_grp.attrs["retrace"] = is_retrace.encode("utf-8") - chan_grp.attrs["net-encoder.scaling.multiplier"] = 1.0 - chan_grp.attrs["net-encoder.scaling.offset"] = 0.0 - - # Format name and reshape image (flattened frame stack) - dataset_name = h5_channel.split("_")[0].capitalize() - if h5_channel == channel: - channel_image = image - else: - channel_image, _ = _load_preprocessed_image(qi_archive=qi_archive, channel=h5_channel, config_path=config_path) - frame_stack = channel_image.flatten().reshape(-1, 1) - - if dataset_name in chan_grp: - del chan_grp[dataset_name] - chan_grp.create_dataset(dataset_name, data=frame_stack) - - - # Need to include flip image as _load_jpk flip image is set to false - if flip_image: - image = np.flipud(image) - if all_curve_data: - return (image, px2nm, (all_curve_data, channels_units, full_metadata)) - - return image, px2nm + if self.channel in ADDITIONAL_CHANNELS_IN_M: + self.image = self.image * 1e9 + + # Save a lite form of the images (precalculated) if saving to a file + if self.save_as is not None: + self.save_lite_data(qi_archive) + + # Need to include flip image as _load_jpk flip image is set to false + if self.flip_image: + self.image = np.flipud(self.image) + if self.all_curve_data: + return (self.image, self.px2nm, (self.all_curve_data, self.channels_units, self.full_metadata)) + + return self.image, self.px2nm + + def process_flat_curve_data(self): + all_curve_data = [] + for y in range(self.shape_y): + row = [] + for x in range(self.shape_x): + curve_num = y * self.shape_x + x + curve_data = self.flat_curve_data[curve_num] + row.append(curve_data) + + # Calculate on-the-fly image data if required + if self.channel in ADDITIONAL_CHANNELS: + seg_0_dict = {c: data["Segment_0"] for c, data in curve_data.items() if "Segment_0" in data} + if self.channel == "contactPoint": + self.image[y, x] = _find_contact_point(seg_0_dict) + elif self.channel == "manualTriggerPoint": + self.image[y, x] = _find_trigger_point(seg_0_dict) + all_curve_data.append(row) + return all_curve_data + + def save_lite_data(self, qi_archive): + """ + Saves a lite form of the data (e.g., the calculated image data) to the appropriate format based on the save_as attribute. + + Parameters + ---------- + qi_archive : zipfile.ZipFile + The archive containing the .jpk-qi-image file. + """ + if self.save_as == "h5": + with h5py.File(self.filepath.parent / f"{self.filepath.stem}.h5-jpk", "a") as h5file: + # Save data required for reading the h5 file as a normal image file + meas_grp = h5file.require_group("Measurement_000") + # Save dimensions data + meas_grp.attrs["position-pattern.grid.ulength"] = self.size_x + meas_grp.attrs["position-pattern.grid.ilength"] = self.shape_x + meas_grp.attrs["position-pattern.grid.vlength"] = self.size_y + meas_grp.attrs["position-pattern.grid.jlength"] = self.shape_y + meas_grp.attrs["timing-settings.scanRate"] = 1.0 # Dummy value to satisfy reader + + logger.info(f"Saving a hdf5 copy of the data {self.filepath.parent / f'{self.filepath.stem}.h5-jpk'}") + + h5_channels = [self.channel] + # Look for the jpk-qi-image file in the archive + for file_name in qi_archive.namelist(): + if file_name.endswith(".jpk-qi-image"): + path_to_image = file_name + break + # Add the channels which exist in the jpk-qi-image file + with qi_archive.open(path_to_image, "r") as image_file: + h5_channels += jpk._get_jpk_channels(file=image_file, filename=self.filepath.stem, file_path=self.filepath / Path(path_to_image)) + for i, h5_channel in enumerate(h5_channels): + # For each available channel, save the required data to the h5 file + # TODO make sure this metadata is accurate for the channels coming from the .jpk-qi-image file + chan_grp = meas_grp.require_group(f"Channel_{_make_num_min_characters(i)}") + # Extract name and retrace information from the channel name + if "_" in h5_channel: + base_name, trace_dir = h5_channel.rsplit("_", 1) + is_retrace = "true" if trace_dir.lower() == "retrace" else "false" + else: + base_name = h5_channel + is_retrace = "false" + + # Add the necessary attributes to the channel group + chan_grp.attrs["channel.name"] = base_name.encode("utf-8") + chan_grp.attrs["retrace"] = is_retrace.encode("utf-8") + chan_grp.attrs["net-encoder.scaling.multiplier"] = 1.0 + chan_grp.attrs["net-encoder.scaling.offset"] = 0.0 + + # Format name and reshape image (flattened frame stack) + dataset_name = h5_channel.split("_")[0].capitalize() + # Include all the channels including the calculated channel + if h5_channel == self.channel: + channel_image = self.image + else: + channel_image, _ = _load_preprocessed_image(qi_archive=qi_archive, channel=h5_channel, config_path=self.config_path) + frame_stack = channel_image.flatten().reshape(-1, 1) + + # Update/ replace the channels dataset + if dataset_name in chan_grp: + del chan_grp[dataset_name] + chan_grp.create_dataset(dataset_name, data=frame_stack) + + + def save_metadata(self): + """Saves the metadata to the appropriate format based on the save_as attribute.""" + if self.save_as == "h5": + for seg_chan in self.segment_channels: + self.global_meta_group.attrs[f"channel.unit.{seg_chan['name']}"] = seg_chan['unit'] + for key, value in self.top_level_meta.items(): + self.global_meta_group.attrs[key] = str(value).encode('utf-8') + for i, c_meta in enumerate(self.curve_meta): + curve_meta_group = self.curves_meta_group.require_group(f"{i}") + for key, value in c_meta.items(): + curve_meta_group.attrs[key] = str(value).encode('utf-8') + + for d in range(2): + segment_meta_group = curve_meta_group.require_group(f"{d}") + for key, value in self.segment_meta[i*2+d].items(): + segment_meta_group.attrs[key] = str(value).encode('utf-8') + + + def construct_full_metadata(self): + """ + Constructs the full metadata dictionary by determining which keys in the curve and segment metadata change across curves/segments and which do not, moving the non-changing keys to the top level metadata, and then combining everything into a single dictionary. + + Returns + ------- + dict + The full metadata dictionary with the structure: + { + "top_level": { ... }, + "curves": [ { ... }, { ... }, ... ], + "segments": [ { ... }, { ... }, ... ] + } + """ + + # Find keys that change across curves/segments + changing_curve_keys = {k for k in self.all_curve_keys if any(self.curve_meta[i].get(k) != self.curve_meta[0].get(k) for i in range(1, self.num_of_curves))} + changing_segment_keys = {k for k in self.all_segment_keys if any(self.segment_meta[i].get(k) != self.segment_meta[0].get(k) for i in range(1, len(self.segment_meta)))} + + # Move duplicated meta to top level for both segments and curves + for key in self.all_curve_keys - changing_curve_keys: + if self.curve_meta and key in self.curve_meta[0]: + self.top_level_meta[f"curve.{key}"] = self.curve_meta[0][key] + for key in self.all_segment_keys - changing_segment_keys: + if self.segment_meta and key in self.segment_meta[0]: + self.top_level_meta[f"segment.{key}"] = self.segment_meta[0][key] + + # Strip duplicated keys from individual curve/segment dicts + for c_meta in self.curve_meta: + for k in self.all_curve_keys - changing_curve_keys: c_meta.pop(k, None) + for s_meta in self.segment_meta: + for k in self.all_segment_keys - changing_segment_keys: s_meta.pop(k, None) + + # Construct full metadata dict from subdicts + full_metadata = { + "top_level": self.top_level_meta, + "curves": self.curve_meta, + "segments": self.segment_meta + } + return full_metadata + + def extract_dat_file(self, qi_archive: zipfile.ZipFile, file_info: zipfile.ZipInfo, curve_num: int, direction: int, chan_name: str): + """ + Extracts the data from a .dat file in the JPK QI archive, applies the appropriate scaling, and saves it to the internal data structure and h5 dataset if required. + + Parameters + ---------- + qi_archive : zipfile.ZipFile + The JPK QI archive from which to extract the .dat file. + file_info : zipfile.ZipInfo + The ZipInfo object corresponding to the .dat file to be extracted. + curve_num : int + The curve number associated with the .dat file, parsed from the filename. + direction : int + The segment direction (0 or 1) associated with the .dat file, parsed from the filename. + chan_name : str + The channel name associated with the .dat file, parsed from the filename. + """ + if chan_name in self.channel_scaling: + scale = self.channel_scaling[chan_name] + with qi_archive.open(file_info) as f: + # Read the binary data as big-endian 32-bit integers + raw_array = np.frombuffer(f.read(), dtype='>i4') + # Apply the scaling to convert raw values into real world values + segment_array = (raw_array * scale["multiplier"]) + scale["offset"] + + # If the channel doesn't exist in the curve data for this curve, add it as a new entry + if chan_name not in self.flat_curve_data[curve_num]: + self.flat_curve_data[curve_num][chan_name] = {} + # Add the segment data to the curve data under the appropriate channel and segment direction + self.flat_curve_data[curve_num][chan_name][f"Segment_{direction}"] = segment_array + + # Update the dataset if saving as h5 + if self.curve_datasets is not None: + self.curve_datasets[f"{direction}_{chan_name}"][curve_num] = segment_array + else: + logger.warning(f"Channel {chan_name} not found in scaling information. Skipping data for curve {curve_num}, direction {direction}.") + + + def extract_curve_metadata(self, qi_archive: zipfile.ZipFile, file_info: zipfile.ZipInfo, curve_num: int): + """ + Extracts the curve metadata from its header.properties file in the JPK QI archive and saves it to the internal data structure. + + Parameters + ---------- + qi_archive : zipfile.ZipFile + The JPK QI archive from which to extract the curve metadata. + file_info : zipfile.ZipInfo + The ZipInfo object corresponding to the header.properties file to be extracted. + curve_num : int + The curve number associated with the metadata, parsed from the filename. + """ + with qi_archive.open(file_info) as f: + cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} + self.curve_meta_dict[curve_num] = cleaned_meta + self.all_curve_keys.update(cleaned_meta.keys()) + + + def extract_segment_metadata(self, qi_archive: zipfile.ZipFile, file_info: zipfile.ZipInfo, curve_num: int, direction: int): + """ + Extracts the segment metadata from its header.properties file in the JPK QI archive and saves it to the internal data structure. + + Parameters + ---------- + qi_archive : zipfile.ZipFile + The JPK QI archive from which to extract the segment metadata. + file_info : zipfile.ZipInfo + The ZipInfo object corresponding to the header.properties file to be extracted. + curve_num : int + The curve number associated with the metadata, parsed from the filename. + direction : int + The segment direction (0 or 1) associated with the metadata, parsed from the filename. + """ + idx = curve_num * 2 + direction + with qi_archive.open(file_info) as f: + cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} + self.segment_meta_dict[idx] = cleaned_meta + self.all_segment_keys.update(cleaned_meta.keys()) + + def setup_h5_structure(self, h5file): + """ + Sets up the structure in the h5 file for saving the curve data and metadata, and returns the datasets and metadata groups for later use. + + Parameters + ---------- + h5file : h5py.File + The h5 file in which to set up the structure for saving the curve data and metadata. + Returns + ------- + curve_datasets : dict + A dictionary containing the datasets for each curve and segment direction. + global_meta_group : h5py.Group + The h5 group for storing global metadata. + curves_meta_group : h5py.Group + The h5 group for storing curve metadata. + """ + vlen_type = h5py.vlen_dtype(np.float32) + + # Create the main group for the QI curve data that all the curve data will be in + qi_group = h5file.require_group("QI_Curve_Data") + curve_datasets = {} + + # Establish empty groups for global metadata and curve metadata + global_meta_group = qi_group.require_group("Global_Metadata") + curves_meta_group = qi_group.require_group("Curve_Metadata") + + for direction in range(2): + # For each segment direction, establish a group and datasets for each channel in the segment channels list + dir_group = qi_group.require_group(f"Segment_{direction}") + for seg_chan in self.segment_channels: + ds_name = seg_chan["name"] + if ds_name not in dir_group: + # Create the dataset for the given channel and segment direction if it doesn't already exist + curve_datasets[f"{direction}_{ds_name}"] = dir_group.create_dataset( + ds_name, shape=(self.num_of_curves,), dtype=vlen_type + ) + else: + # If the dataset already exists, just add it to the curve datasets dictionary for later use + curve_datasets[f"{direction}_{ds_name}"] = dir_group[ds_name] + return curve_datasets, global_meta_group, curves_meta_group + + def get_saving_context(self): + """ + Returns the appropriate context manager for saving the data based on the save_as attribute. + If save_as is "h5", it returns a context manager for an h5 file. Otherwise, it returns a null context. + + Returns + ------- + contextlib.AbstractContextManager + The context manager for saving the data. + """ + if self.save_as == "h5": + return h5py.File(self.filepath.parent / f"{self.filepath.stem}.h5-jpk", "a") + else: + return nullcontext() + + + def parse_dimension_data(self): + """ + Parses the dimension data from the provided properties dictionary and calculates the pixel to nanometer scaling factor. + """ + # Extract both real size and pixel dimensions from the metadata + for key, value in self.top_level_meta.items(): + if key.endswith(".ulength"): + self.size_x = float(value) + if key.endswith(".vlength"): + self.size_y = float(value) + if key.endswith(".ilength"): + self.shape_x = int(value) + if key.endswith(".jlength"): + self.shape_y = int(value) + + # Log an error if any of these do not exist + if None in [self.size_x, self.size_y, self.shape_x, self.shape_y]: + logger.error(f"Incomplete dimension data in {self.filepath}") + + # Calculate the pixel to nano metre scaling as an average of the scale for each direction + pixel_to_nm_scaling_factor_x = self.size_x / self.shape_x * 1e9 if self.shape_x > 0 else 1.0 + pixel_to_nm_scaling_factor_y = self.size_y / self.shape_y * 1e9 if self.shape_y > 0 else 1.0 + self.px2nm = (pixel_to_nm_scaling_factor_x + pixel_to_nm_scaling_factor_y) / 2 + + # Establish number of curves + self.num_of_curves = self.shape_x * self.shape_y + + # Setup the data structure to hold the curve data as it's read in + self.flat_curve_data = [{} for _ in range(self.num_of_curves)] + + def extract_global_metadata(self, qi_archive): + """ + Extracts the global metadata from the JPK QI archive and populates the top level metadata dictionary and segment channels list. + + Parameters + ---------- + qi_archive : zipfile.ZipFile + The JPK QI archive from which to extract metadata. + """ + + # Load the metadata from the global properties file + if "header.properties" in qi_archive.namelist(): + with qi_archive.open("header.properties") as archive_meta_file: + props = javaproperties.load(archive_meta_file) + + # Add all the data from the main header to the top level metadata with a prefix to avoid key clashes + for key, value in props.items(): + self.top_level_meta[f"main-header.{key}"] = value + else: + logger.error(f"File {self.filepath} does not contain essential metadata and cannot be loaded") + + # Load the metadata from the shared header and parse the channel information for the segments + if "shared-data/header.properties" in qi_archive.namelist(): + with qi_archive.open("shared-data/header.properties") as shared_data_file: + shared_meta = javaproperties.load(shared_data_file) + channel_idx = 0 + + # Add all the data from the shared header to the top level metadata with a prefix to avoid key clashes + for key, value in shared_meta.items(): + self.top_level_meta[f"shared-data.{key}"] = value + + # Collect channel data from the shared metadata + while f"lcd-info.{channel_idx}.channel.name" in shared_meta: + channel_dict = {} + # Calculate and store the offset and multiplier to convert raw values into real world values + multiplier, offset, unit = _get_channel_scaling(shared_meta, channel_idx) + channel_dict["name"] = shared_meta[f"lcd-info.{channel_idx}.channel.name"] + channel_dict["offset"] = offset + channel_dict["multiplier"] = multiplier + channel_dict["unit"] = unit + # Add the channel dict to the list + self.segment_channels.append(channel_dict) + # Increment the channel index to look for the next channel + channel_idx += 1 + else: + logger.error(f"File {self.filepath} does not contain essential channel metadata and cannot be loaded") + + if len(self.segment_channels) == 0: + logger.error("Could not find channels for segments") + + # Create a lookup for channel name to unit to be returned + self.channels_units = {seg_chan['name'] : seg_chan['unit'] for seg_chan in self.segment_channels} + # Lookup map for binary scaling + self.channel_scaling = {chan["name"]: chan for chan in self.segment_channels} + + + +def load_jpk_qi( + file_path: Path | str, + channel: str, + config_path: Path | str | None = None, + flip_image: bool | None = True, + save_as_h5: bool | None = False +) -> tuple[np.ndarray, float]: + + jpk_loader = jpk_qi_loader(filepath=file_path, channel=channel, config_path=config_path, flip_image=flip_image, save_as="h5" if save_as_h5 else None) + return jpk_loader.load() def load_fdcurves_from_h5(file_path: Path | str): @@ -377,6 +618,7 @@ def _make_num_min_characters(num : int, min_chars: int = 3): return string_num def get_jpk_qi_channels(file_path: Path | str): + logger.debug("Starting to get jpk qi data channels") file_path = Path(file_path) channels = [] with zipfile.ZipFile(file_path, "r") as qi_archive: @@ -386,6 +628,7 @@ def get_jpk_qi_channels(file_path: Path | str): with qi_archive.open(path_to_image, "r") as image_file: channels += jpk._get_jpk_channels(file=image_file, filename=file_path.stem, file_path=file_path / Path(path_to_image)) channels += ADDITIONAL_CHANNELS + logger.debug("Got jpk qi data channels") return channels, {"save_as_h5": bool} def _find_contact_point(curve): From 78c54a73bd14fab4a9915f941e1ae77b31e35775 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 16 Mar 2026 13:19:00 +0000 Subject: [PATCH 15/54] Starting to implement lazy loading --- AFMReader/general_loader.py | 2 + AFMReader/h5_jpk.py | 154 +++++++------- AFMReader/jpk_qi.py | 403 ++++++++++++++++++++++++++++-------- 3 files changed, 401 insertions(+), 158 deletions(-) diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index 7bbd5c0..1efc028 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -4,7 +4,9 @@ import numpy.typing as npt +print("Importing AFMReader modules...") from AFMReader import asd, gwy, h5_jpk, ibw, jpk, spm, stp, top, topostats, jpk_qi, bin +print("AFMReader modules imported.") from AFMReader.logging import logger diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index 003bd3a..bdfeb2f 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -272,6 +272,69 @@ def get_h5jpk_channels(file_path: Path | str): available_channels = list(_available_channels(f)) return available_channels +class LazyQIData: + def __init__(self, qi_data_group: h5py.Group, shape_x: int): + self.qi_data_group = qi_data_group + self.shape_x = shape_x + + def __getitem__(self, y: int): + + class RowProxy: + def __init__(self, parent, y): + self.parent = parent + self.y = y + def __getitem__(self, x: int): + return self.parent._fetch_curve(self.y, x) + return RowProxy(self, y) + + def _fetch_curve(self, y: int, x: int): + curve_dict = {} + curve_num = self.shape_x * y + x + for segment, segment_group in self.qi_data_group["Curves"].items(): + for channel in segment_group["Indicies"]: + start_idx = segment_group["Indicies"][channel][curve_num] + end_idx = segment_group["Indicies"][channel][curve_num + 1] + if channel not in curve_dict: + curve_dict[channel] = {} + curve_dict[channel][segment] = segment_group["Data"][channel][start_idx:end_idx] + return curve_dict + +class LazyCurveMetadata: + """A proxy class that fetches header.properties files on demand.""" + def __init__(self, qi_data_group: h5py.Group, top_level_meta: dict): + self.qi_data_group = qi_data_group + self.top_level_meta = top_level_meta + # Expose top_level so the frontend can still do `raw_metadata["top_level"]` + self.top_level = top_level_meta + + def __getitem__(self, key): + if key == "top_level": + return self.top_level + elif key == "curves": + return LazyMetaProxy(self.qi_data_group, "curve") + elif key == "segments": + return LazyMetaProxy(self.qi_data_group, "segment") + raise KeyError(key) + +class LazyMetaProxy: + def __init__(self, qi_data_group: h5py.Group, meta_type: str, idx: int = None): + self.qi_data_group = qi_data_group + self.meta_type = meta_type + self.idx = idx + + + def __getitem__(self, key): + if isinstance(key, int): + return LazyMetaProxy(self.qi_data_group, self.meta_type, key) + else: + value = self.qi_data_group["Curve_Metadata"][key] + if isinstance(value, h5py.Dataset): + return value[self.idx] + else: + return value.decode("utf-8") if isinstance(value, bytes) else value + + + def load_h5jpk( file_path: Path | str, channel: str, flip_image: bool = True, load_curves: bool = True ) -> tuple[np.ndarray, float, dict[str, float]]: @@ -352,73 +415,24 @@ def load_h5jpk( logger.info(f"[{file_path.stem}] : Extracted {num_frames} frames from channel '{channel}'") - if load_curves and "QI_Curve_Data" in f: - logger.info(f"[{file_path.stem}] : Found Force Curves QI data in file.") - qi_data_group = f["QI_Curve_Data"] - loaded_channels_data = {} - channels_units = {} - top_level_meta = {} - for key, value in qi_data_group["Global_Metadata"].attrs.items(): - if key.startswith("channel.unit."): - channels_units[key.split(".")[-1]] = value - top_level_meta[key] = value - if "Curve_Metadata" in qi_data_group: - curves_group = qi_data_group["Curve_Metadata"] - num_of_curves = len(curves_group.keys()) - - # Pre-allocate the lists - curve_meta = [{} for _ in range(num_of_curves)] - segment_meta = [{} for _ in range(num_of_curves * 2)] - - # Iterate through the curve groups - for i_str in curves_group.keys(): - i = int(i_str) - c_group = curves_group[i_str] - - # Extract curve-specific attributes - for key, val in c_group.attrs.items(): - if isinstance(val, bytes): - val = val.decode('utf-8') - curve_meta[i][key] = val - - # Extract segment-specific attributes (usually '0' for trace, '1' for retrace) - for d_str in ['0', '1']: - if d_str in c_group: - s_group = c_group[d_str] - idx = i * 2 + int(d_str) - for key, val in s_group.attrs.items(): - if isinstance(val, bytes): - val = val.decode('utf-8') - segment_meta[idx][key] = val - - full_metadata = { - "top_level": top_level_meta, - "curves": curve_meta, - "segments": segment_meta - } - - for direction in ["Segment_0", "Segment_1"]: - if direction in qi_data_group: - loaded_channels_data[direction] = {} - for channel, channel_group in qi_data_group[direction].items(): - if channel != "error": - loaded_channels_data[direction][channel] = channel_group[:] - - all_curve_data = [] - for y in range(shape_y): - row = [] - for x in range(shape_x): - curve_num = shape_x * y + x - curve_data = {} - - for direction, channels in loaded_channels_data.items(): - for channel, data_array in channels.items(): - if channel not in curve_data: - curve_data[channel] = {} - curve_data[channel][direction] = data_array[curve_num] - - row.append(curve_data) - all_curve_data.append(row) - return (image_stack, _jpk_pixel_to_nm_scaling_h5(measurement_group), (all_curve_data, channels_units), timestamps) - - return (image_stack, _jpk_pixel_to_nm_scaling_h5(measurement_group), timestamps) + if "QI_Curve_Data" not in f: + load_curves = False + + if load_curves: + logger.info(f"[{file_path.stem}] : Found Force Curves QI data in file.") + qi_data_group = f["QI_Curve_Data"] + loaded_channels_data = {} + channels_units = {} + top_level_meta = {} + for key, value in qi_data_group["Global_Metadata"].attrs.items(): + if key.startswith("channel.unit."): + channels_units[key.split(".")[-1]] = value + top_level_meta[key] = value + + full_metadata = LazyCurveMetadata(qi_data_group, top_level_meta) + + all_curve_data = LazyQIData(qi_data_group, shape_x) + + return (image_stack, _jpk_pixel_to_nm_scaling_h5(measurement_group), (all_curve_data, channels_units, full_metadata), timestamps) + + return (image_stack, _jpk_pixel_to_nm_scaling_h5(measurement_group), timestamps) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index e6f0015..0229588 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -3,6 +3,7 @@ import io import re import zipfile +import time import numpy as np import javaproperties @@ -15,6 +16,102 @@ ADDITIONAL_CHANNELS = ["contactPoint", "manualTriggerPoint"] ADDITIONAL_CHANNELS_IN_M = ["contactPoint", "manualTriggerPoint"] +class LazyCurveData: + """A proxy class that behaves like a 2D list but fetches .dat files on demand.""" + def __init__(self, filepath, shape_x, channel_scaling, archive): + self.filepath = filepath + self.shape_x = shape_x + self.channel_scaling = channel_scaling + + # 1. OPEN THE ARCHIVE ONCE AND KEEP IT OPEN + self.archive = archive + + def __getitem__(self, y: int): + # Return a row proxy to handle the second index [x] + class RowProxy: + def __init__(self, parent, y): + self.parent = parent + self.y = y + def __getitem__(self, x: int): + return self.parent._fetch_curve(self.y, x) + return RowProxy(self, y) + + + def _fetch_curve(self, y: int, x: int): + t_start_total = time.perf_counter() + + curve_num = y * self.shape_x + x + curve_data = {} + + # 3. REUSE THE ALREADY-OPEN ARCHIVE + for chan_name, scale in self.channel_scaling.items(): + curve_data[chan_name] = {} + for direction in (0, 1): + dat_path = f"index/{curve_num}/segments/{direction}/channels/{chan_name}.dat" + try: + # Access the file directly without re-parsing the ZIP directory + with self.archive.open(dat_path) as f: + raw_array = np.frombuffer(f.read(), dtype='>i4') + curve_data[chan_name][f"Segment_{direction}"] = ( + raw_array * scale["multiplier"] + ) + scale["offset"] + except KeyError: + pass # File doesn't exist for this segment + + t_end_total = time.perf_counter() + print(f"[Lazy Data] Total fetch for curve {curve_num} at (y={y}, x={x}) took {t_end_total - t_start_total:.6f} seconds") + + return curve_data + + # Good practice: add a method to close the archive when the user closes the image + def close(self): + self.archive.close() + + +class LazyCurveMetadata: + """A proxy class that fetches header.properties files on demand.""" + def __init__(self, filepath, top_level_meta, archive): + self.filepath = filepath + # Expose top_level so the frontend can still do `raw_metadata["top_level"]` + self.top_level = top_level_meta + self.archive = archive + + def __getitem__(self, key): + if key == "top_level": + return self.top_level + elif key == "curves": + return LazyMetaProxy(self.filepath, "curve", self.archive) + elif key == "segments": + return LazyMetaProxy(self.filepath, "segment", self.archive) + raise KeyError(key) + +class LazyMetaProxy: + def __init__(self, filepath, meta_type, archive): + self.filepath = filepath + self.meta_type = meta_type + self.archive = archive + + def __getitem__(self, idx: int): + t_start = time.perf_counter() + + if self.meta_type == "curve": + path = f"index/{idx}/header.properties" + else: + curve_num = idx // 2 + direction = idx % 2 + path = f"index/{curve_num}/segments/{direction}/segment-header.properties" + + try: + with self.archive.open(path) as f: + meta_dict = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} + except KeyError: + meta_dict = {} + + t_end = time.perf_counter() + print(f"[Lazy Meta] Fetched {self.meta_type} metadata for index {idx} in {t_end - t_start:.6f} seconds") + + return meta_dict + def _get_channel_scaling(props, channel_index): """ Parses the JPK properties dictionary to find the cumulative multiplier @@ -82,7 +179,8 @@ def __init__(self, channel: str, config_path: Path | str | None = None, flip_image: bool | None = True, - save_as: str | None = None): + save_as: str | None = None, + return_meta: bool = False): """ Initializes the loader with the provided parameters. @@ -98,6 +196,8 @@ def __init__(self, Whether to flip the image vertically. Default is True. save_as : str | None, optional The format to save the loaded data. Default is None. + return_meta : bool, optional + Whether to return the full metadata. Default is False. """ self.filepath = Path(filepath) @@ -105,6 +205,9 @@ def __init__(self, self.config_path = config_path self.flip_image = flip_image self.save_as = save_as + self.return_meta = return_meta + # Open the ZIP archive once and keep it open for the duration of the loading process to improve performance when accessing multiple files within the archive + self.qi_archive = zipfile.ZipFile(self.filepath, "r") # Initialize key attributes that will be returned / accessed frequently @@ -112,13 +215,12 @@ def __init__(self, self.top_level_meta = {} # A dictionary containing all metadata, splitting the top level metadata and the metadata for each curve and segment self.full_metadata = {} - + # A 2D list of curve data dictionaries, where each dictionary contains the data for all channels and segments self.all_curve_data = None # A lookup for channel name to unit to be returned self.channels_units = {} # The list of channels for the segments with their scaling information extracted from the shared header self.segment_channels = [] - # Dictionary of the metadata for each curve, indexed by curve number self.curve_meta_dict = {} # Dictionary of the metadata for each segment, indexed by segement number (calculated as curve_num * 2 + direction) @@ -133,28 +235,53 @@ def __init__(self, # Instantiate containers for data to be saved (so an exception is not caused if not saving) self.curve_datasets = None + def get_channels(self): + """ + Retrieves the available channels from the .jpk-qi-image file within the archive, and adds any additional calculated channels. + + Returns + ------- + channels : list + A list of available channels including the calculated channels. + metadata_options : dict + A dictionary of options for what metadata to return + """ + channels = [] + + # Look for the jpk-qi-image file in the archive + for file_name in self.qi_archive.namelist(): + if file_name.endswith(".jpk-qi-image"): + path_to_image = file_name + + # Add the channels which exist in the jpk-qi-image file + with self.qi_archive.open(path_to_image, "r") as image_file: + channels += jpk._get_jpk_channels(file=image_file, filename=self.filepath.stem, file_path=self.filepath / Path(path_to_image)) + channels += ADDITIONAL_CHANNELS + return channels, {"save_as_h5": bool} + def load(self): - with zipfile.ZipFile(self.filepath, "r") as qi_archive: - self.extract_global_metadata(qi_archive) + """Loads the .jpk-qi file""" + self.extract_global_metadata() - self.parse_dimension_data() + self.parse_dimension_data() - # Pre-allocate the image array based on the dimensions parsed from the metadata - self.image = np.zeros((self.shape_y, self.shape_x), dtype=np.float32) + # Pre-allocate the image array based on the dimensions parsed from the metadata + self.image = np.zeros((self.shape_y, self.shape_x), dtype=np.float32) - # Access the curve data and metadata, and save to given file format - with self.get_saving_context() as file: + # Access the curve data and metadata, and save to given file format + with self.get_saving_context() as file: - # Compile Regexes - dat_regex = re.compile(r"index/(\d+)/segments/(\d+)/channels/([^/]+)\.dat") - curve_meta_regex = re.compile(r"index/(\d+)/header\.properties") - segment_meta_regex = re.compile(r"index/(\d+)/segments/(\d+)/segment-header\.properties") + # Compile Regexes + dat_regex = re.compile(r"index/(\d+)/segments/(\d+)/channels/([^/]+)\.dat") + curve_meta_regex = re.compile(r"index/(\d+)/header\.properties") + segment_meta_regex = re.compile(r"index/(\d+)/segments/(\d+)/segment-header\.properties") - # Setup H5 Data structures if needed - if self.save_as == "h5": - self.curve_datasets, self.global_meta_group, self.curves_meta_group = self.setup_h5_structure(file) + # Setup H5 Data structures if needed + if self.save_as == "h5": + self.curve_datasets, self.global_meta_group, self.curves_meta_group = self.setup_h5_structure(file) - for file_info in qi_archive.infolist(): + if self.channel in ADDITIONAL_CHANNELS or self.save_as is not None or self.return_meta: + for file_info in self.qi_archive.infolist(): filename = file_info.filename # Check Binary Data @@ -163,55 +290,156 @@ def load(self): # If file is a .dat file, extract the curve number, segment direction and channel name from the filename curve_num, direction, chan_name = int(dat_match.group(1)), int(dat_match.group(2)), dat_match.group(3) # Then load the data from the file - self.extract_dat_file(qi_archive, file_info, curve_num, direction, chan_name) - continue - - # Check Curve Metadata - curve_meta_match = curve_meta_regex.match(filename) - if curve_meta_match: - # If file is a curve metadata file, extract the curve number from the filename - curve_num = int(curve_meta_match.group(1)) - # Then load the metadata from the file - self.extract_curve_metadata(qi_archive, file_info, curve_num) - continue - - # Check Segment Metadata - segment_meta_match = segment_meta_regex.match(filename) - if segment_meta_match: - # If file is a segment metadata file, extract the curve number and segment direction from the filename - curve_num, direction = int(segment_meta_match.group(1)), int(segment_meta_match.group(2)) - # Then load the segment metadata from the file - self.extract_segment_metadata(qi_archive, file_info, curve_num, direction) + self.extract_dat_file(file_info, curve_num, direction, chan_name) continue - self.all_curve_data = self.process_flat_curve_data() - + if self.return_meta or self.save_as is not None: + # Check Curve Metadata + curve_meta_match = curve_meta_regex.match(filename) + if curve_meta_match: + # If file is a curve metadata file, extract the curve number from the filename + curve_num = int(curve_meta_match.group(1)) + # Then load the metadata from the file + self.extract_curve_metadata(file_info, curve_num) + continue + + # Check Segment Metadata + segment_meta_match = segment_meta_regex.match(filename) + if segment_meta_match: + # If file is a segment metadata file, extract the curve number and segment direction from the filename + curve_num, direction = int(segment_meta_match.group(1)), int(segment_meta_match.group(2)) + # Then load the segment metadata from the file + self.extract_segment_metadata(file_info, curve_num, direction) + continue + + if self.return_meta or self.save_as is not None: # TODO can we remove curve_meta_dict and just use curve_meta or is the non duplicating necessary self.curve_meta = [self.curve_meta_dict.get(i, {}) for i in range(self.num_of_curves)] self.segment_meta = [self.segment_meta_dict.get(i, {}) for i in range(self.num_of_curves * 2)] self.full_metadata = self.construct_full_metadata() - # Save the full metadata to appropiate format if saving - self.save_metadata() - - - # Convert to nanometers if in meters - if self.channel in ADDITIONAL_CHANNELS_IN_M: - self.image = self.image * 1e9 - - # Save a lite form of the images (precalculated) if saving to a file if self.save_as is not None: - self.save_lite_data(qi_archive) + # If saving, need to collate the curve data into a format that can be easily saved to the h5 file (a dataset per channel per segment direction) + self.collated_curve_data, self.indicies = self.get_collated_curves() + self.collated_metadata = self.get_collated_metadata() - # Need to include flip image as _load_jpk flip image is set to false - if self.flip_image: - self.image = np.flipud(self.image) + # Save as h5 if required + if self.save_as == "h5": + # Save the curve data to the appropriate datasets in the h5 file + for chan_name, chan_data in self.collated_curve_data.items(): + for direction in range(2): + # Save the curve data and indicies to the appropriate dataset in the h5 file + self.curve_datasets[f"{direction}_{chan_name}_data"] = chan_data[f"Segment_{direction}"] + self.curve_datasets[f"{direction}_{chan_name}_indices"] = self.indicies[chan_name][f"Segment_{direction}"] + + # Save the global metadata to the h5 file + vlen_str_dt = h5py.string_dtype(encoding='utf-8') + for key, value in self.collated_metadata.items(): + if isinstance(value, list): + # If the key is a changing key, save as a dataset with one entry per curve/ segment + self.curves_meta_group.create_dataset( + name=key, + data=value, + dtype=vlen_str_dt + ) + else: + self.global_meta_group.attrs[key] = str(value).encode('utf-8') + + self.full_metadata = LazyCurveMetadata(self.filepath, self.top_level_meta, self.qi_archive) + self.all_curve_data = LazyCurveData(self.filepath, self.shape_x, self.channel_scaling, self.qi_archive) + + # Convert to nanometers if in meters + if self.channel in ADDITIONAL_CHANNELS_IN_M: + self.image = self.image * 1e9 + + # Save a lite form of the images (precalculated) if saving to a file + if self.save_as is not None: + self.save_lite_data() + + # Need to include flip image as _load_jpk flip image is set to false + if self.flip_image: + self.image = np.flipud(self.image) if self.all_curve_data: return (self.image, self.px2nm, (self.all_curve_data, self.channels_units, self.full_metadata)) return self.image, self.px2nm + def get_collated_curves(self): + """ + Collates the curve data from the flat structure it is extracted in into a structure grouped by channel and segment for easier saving to h5. + + Returns + ------- + collated_curve_data : dict + A dictionary containing the curve data collated by channel and segment, with the structure: + { + "channel_name": { + "Segment_0": [...], + "Segment_1": [...], + ... + }, + ... + } + indicies : dict + A dictionary containing the indexes of the curve data within each segment, with the structure: + { + "channel_name": { + "Segment_0": [...], + "Segment_1": [...], + ... + }, + ... + } + """ + collated_curve_data = {} + indicies = {} + for curve_data in self.flat_curve_data: + for chan_name, chan_data in curve_data.items(): + for seg_name, seg_data in chan_data.items(): + if chan_name not in collated_curve_data: + collated_curve_data[chan_name] = {} + indicies[chan_name] = {} + if seg_name not in collated_curve_data[chan_name]: + collated_curve_data[chan_name][seg_name] = [] + indicies[chan_name][seg_name] = [] + indicies[chan_name][seg_name].append(len(collated_curve_data[chan_name][seg_name])) + collated_curve_data[chan_name][seg_name].append(seg_data) + + return collated_curve_data, indicies + + + def get_collated_metadata(self): + """ + Collates the metadata from being split by curve, to being split by attribute so data can be saved more efficiently + + Returns + ------- + collated_meta : dict + A dictionary containing the collated metadata. + """ + collated_meta = {} + for key, value in self.top_level_meta.items(): + collated_meta[key] = value + for curve_dict in self.full_metadata["curves"]: + for key, value in curve_dict.items(): + if key not in collated_meta: + collated_meta[key] = [] + collated_meta[key].append(value) + for segment_dict in self.full_metadata["segments"]: + for key, value in segment_dict.items(): + if key not in collated_meta: + collated_meta[key] = [] + collated_meta[key].append(value) + return collated_meta + + def process_flat_curve_data(self): + """ + Processes the flat curve data dictionary into a 2D list structure matching the image dimensions. + ------- + all_curve_data : list + A 2D list of curve data dictionaries, where each dictionary contains the data for all channels and segments for that curve. + """ all_curve_data = [] for y in range(self.shape_y): row = [] @@ -230,7 +458,7 @@ def process_flat_curve_data(self): all_curve_data.append(row) return all_curve_data - def save_lite_data(self, qi_archive): + def save_lite_data(self): """ Saves a lite form of the data (e.g., the calculated image data) to the appropriate format based on the save_as attribute. @@ -254,12 +482,12 @@ def save_lite_data(self, qi_archive): h5_channels = [self.channel] # Look for the jpk-qi-image file in the archive - for file_name in qi_archive.namelist(): + for file_name in self.qi_archive.namelist(): if file_name.endswith(".jpk-qi-image"): path_to_image = file_name break # Add the channels which exist in the jpk-qi-image file - with qi_archive.open(path_to_image, "r") as image_file: + with self.qi_archive.open(path_to_image, "r") as image_file: h5_channels += jpk._get_jpk_channels(file=image_file, filename=self.filepath.stem, file_path=self.filepath / Path(path_to_image)) for i, h5_channel in enumerate(h5_channels): # For each available channel, save the required data to the h5 file @@ -285,7 +513,7 @@ def save_lite_data(self, qi_archive): if h5_channel == self.channel: channel_image = self.image else: - channel_image, _ = _load_preprocessed_image(qi_archive=qi_archive, channel=h5_channel, config_path=self.config_path) + channel_image, _ = _load_preprocessed_image(qi_archive=self.qi_archive, channel=h5_channel, config_path=self.config_path) frame_stack = channel_image.flatten().reshape(-1, 1) # Update/ replace the channels dataset @@ -353,14 +581,12 @@ def construct_full_metadata(self): } return full_metadata - def extract_dat_file(self, qi_archive: zipfile.ZipFile, file_info: zipfile.ZipInfo, curve_num: int, direction: int, chan_name: str): + def extract_dat_file(self, file_info: zipfile.ZipInfo, curve_num: int, direction: int, chan_name: str): """ Extracts the data from a .dat file in the JPK QI archive, applies the appropriate scaling, and saves it to the internal data structure and h5 dataset if required. Parameters ---------- - qi_archive : zipfile.ZipFile - The JPK QI archive from which to extract the .dat file. file_info : zipfile.ZipInfo The ZipInfo object corresponding to the .dat file to be extracted. curve_num : int @@ -372,7 +598,7 @@ def extract_dat_file(self, qi_archive: zipfile.ZipFile, file_info: zipfile.ZipIn """ if chan_name in self.channel_scaling: scale = self.channel_scaling[chan_name] - with qi_archive.open(file_info) as f: + with self.qi_archive.open(file_info) as f: # Read the binary data as big-endian 32-bit integers raw_array = np.frombuffer(f.read(), dtype='>i4') # Apply the scaling to convert raw values into real world values @@ -384,40 +610,33 @@ def extract_dat_file(self, qi_archive: zipfile.ZipFile, file_info: zipfile.ZipIn # Add the segment data to the curve data under the appropriate channel and segment direction self.flat_curve_data[curve_num][chan_name][f"Segment_{direction}"] = segment_array - # Update the dataset if saving as h5 - if self.curve_datasets is not None: - self.curve_datasets[f"{direction}_{chan_name}"][curve_num] = segment_array else: logger.warning(f"Channel {chan_name} not found in scaling information. Skipping data for curve {curve_num}, direction {direction}.") - def extract_curve_metadata(self, qi_archive: zipfile.ZipFile, file_info: zipfile.ZipInfo, curve_num: int): + def extract_curve_metadata(self, file_info: zipfile.ZipInfo, curve_num: int): """ Extracts the curve metadata from its header.properties file in the JPK QI archive and saves it to the internal data structure. Parameters ---------- - qi_archive : zipfile.ZipFile - The JPK QI archive from which to extract the curve metadata. file_info : zipfile.ZipInfo The ZipInfo object corresponding to the header.properties file to be extracted. curve_num : int The curve number associated with the metadata, parsed from the filename. """ - with qi_archive.open(file_info) as f: + with self.qi_archive.open(file_info) as f: cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} self.curve_meta_dict[curve_num] = cleaned_meta self.all_curve_keys.update(cleaned_meta.keys()) - def extract_segment_metadata(self, qi_archive: zipfile.ZipFile, file_info: zipfile.ZipInfo, curve_num: int, direction: int): + def extract_segment_metadata(self, file_info: zipfile.ZipInfo, curve_num: int, direction: int): """ Extracts the segment metadata from its header.properties file in the JPK QI archive and saves it to the internal data structure. Parameters ---------- - qi_archive : zipfile.ZipFile - The JPK QI archive from which to extract the segment metadata. file_info : zipfile.ZipInfo The ZipInfo object corresponding to the header.properties file to be extracted. curve_num : int @@ -426,7 +645,7 @@ def extract_segment_metadata(self, qi_archive: zipfile.ZipFile, file_info: zipfi The segment direction (0 or 1) associated with the metadata, parsed from the filename. """ idx = curve_num * 2 + direction - with qi_archive.open(file_info) as f: + with self.qi_archive.open(file_info) as f: cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} self.segment_meta_dict[idx] = cleaned_meta self.all_segment_keys.update(cleaned_meta.keys()) @@ -461,16 +680,26 @@ def setup_h5_structure(self, h5file): for direction in range(2): # For each segment direction, establish a group and datasets for each channel in the segment channels list dir_group = qi_group.require_group(f"Segment_{direction}") + data_group = dir_group.require_group("Data") + indices_group = dir_group.require_group("Indicies") for seg_chan in self.segment_channels: ds_name = seg_chan["name"] - if ds_name not in dir_group: + if f"{ds_name}_data" not in data_group: # Create the dataset for the given channel and segment direction if it doesn't already exist - curve_datasets[f"{direction}_{ds_name}"] = dir_group.create_dataset( - ds_name, shape=(self.num_of_curves,), dtype=vlen_type + curve_datasets[f"{direction}_{ds_name}_data"] = data_group.create_dataset( + f"{ds_name}", shape=(self.num_of_curves,), dtype=vlen_type ) else: # If the dataset already exists, just add it to the curve datasets dictionary for later use - curve_datasets[f"{direction}_{ds_name}"] = dir_group[ds_name] + curve_datasets[f"{direction}_{ds_name}_data"] = data_group[f"{ds_name}"] + if f"{ds_name}" not in indices_group: + # Create a corresponding dataset to hold the indicies for the curve data + curve_datasets[f"{direction}_{ds_name}_indices"] = indices_group.create_dataset( + f"{ds_name}", shape=(self.num_of_curves,), dtype=vlen_type + ) + else: + # If the indicies dataset already exists, just add it to the curve datasets dictionary for later use + curve_datasets[f"{direction}_{ds_name}_indices"] = indices_group[f"{ds_name}"] return curve_datasets, global_meta_group, curves_meta_group def get_saving_context(self): @@ -519,19 +748,12 @@ def parse_dimension_data(self): # Setup the data structure to hold the curve data as it's read in self.flat_curve_data = [{} for _ in range(self.num_of_curves)] - def extract_global_metadata(self, qi_archive): - """ - Extracts the global metadata from the JPK QI archive and populates the top level metadata dictionary and segment channels list. - - Parameters - ---------- - qi_archive : zipfile.ZipFile - The JPK QI archive from which to extract metadata. - """ + def extract_global_metadata(self): + """Extracts the global metadata from the JPK QI archive and populates the top level metadata dictionary and segment channels list.""" # Load the metadata from the global properties file - if "header.properties" in qi_archive.namelist(): - with qi_archive.open("header.properties") as archive_meta_file: + if "header.properties" in self.qi_archive.namelist(): + with self.qi_archive.open("header.properties") as archive_meta_file: props = javaproperties.load(archive_meta_file) # Add all the data from the main header to the top level metadata with a prefix to avoid key clashes @@ -541,8 +763,8 @@ def extract_global_metadata(self, qi_archive): logger.error(f"File {self.filepath} does not contain essential metadata and cannot be loaded") # Load the metadata from the shared header and parse the channel information for the segments - if "shared-data/header.properties" in qi_archive.namelist(): - with qi_archive.open("shared-data/header.properties") as shared_data_file: + if "shared-data/header.properties" in self.qi_archive.namelist(): + with self.qi_archive.open("shared-data/header.properties") as shared_data_file: shared_meta = javaproperties.load(shared_data_file) channel_idx = 0 @@ -574,6 +796,11 @@ def extract_global_metadata(self, qi_archive): # Lookup map for binary scaling self.channel_scaling = {chan["name"]: chan for chan in self.segment_channels} + def close(self): + """Closes the ZIP archive when done to free up system resources.""" + self.qi_archive.close() + + def load_jpk_qi( From 3c327e9a12803f3872ec2efd8a51cc1594dc58c5 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 16 Mar 2026 18:28:46 +0000 Subject: [PATCH 16/54] Made force curves lazy loaded for h5-jpk and jpk-qi-data --- AFMReader/h5_jpk.py | 38 ++++++------ AFMReader/jpk_qi.py | 145 ++++++++++++++++++++++++-------------------- 2 files changed, 100 insertions(+), 83 deletions(-) diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index bdfeb2f..dbdb7fc 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -292,8 +292,9 @@ def _fetch_curve(self, y: int, x: int): curve_num = self.shape_x * y + x for segment, segment_group in self.qi_data_group["Curves"].items(): for channel in segment_group["Indicies"]: - start_idx = segment_group["Indicies"][channel][curve_num] - end_idx = segment_group["Indicies"][channel][curve_num + 1] + start_idx = int(segment_group["Indicies"][channel][curve_num]) + end_idx = int(segment_group["Indicies"][channel][curve_num + 1]) + print(f"Fetching curve for pixel (y={y}, x={x}), segment '{segment}', channel '{channel}': start_idx={start_idx}, end_idx={end_idx}") if channel not in curve_dict: curve_dict[channel] = {} curve_dict[channel][segment] = segment_group["Data"][channel][start_idx:end_idx] @@ -317,22 +318,21 @@ def __getitem__(self, key): raise KeyError(key) class LazyMetaProxy: - def __init__(self, qi_data_group: h5py.Group, meta_type: str, idx: int = None): + def __init__(self, qi_data_group: h5py.Group, meta_type: str): self.qi_data_group = qi_data_group self.meta_type = meta_type - self.idx = idx - def __getitem__(self, key): - if isinstance(key, int): - return LazyMetaProxy(self.qi_data_group, self.meta_type, key) - else: - value = self.qi_data_group["Curve_Metadata"][key] - if isinstance(value, h5py.Dataset): - return value[self.idx] - else: - return value.decode("utf-8") if isinstance(value, bytes) else value - + def __getitem__(self, idx: int): + meta_dict = {} + for key in self.qi_data_group["Curve_Metadata"]: + if key.startswith(f"{self.meta_type}."): + new_key = key.split(".", 1)[1] + if isinstance(self.qi_data_group["Curve_Metadata"][key], h5py.Dataset): + meta_dict[new_key] = self.qi_data_group["Curve_Metadata"][key][idx] + else: + meta_dict[new_key] = self.qi_data_group["Curve_Metadata"][key] + return meta_dict def load_h5jpk( @@ -414,17 +414,21 @@ def load_h5jpk( timestamps = generate_timestamps(num_frames, line_rate, shape_y) logger.info(f"[{file_path.stem}] : Extracted {num_frames} frames from channel '{channel}'") + px2nm = _jpk_pixel_to_nm_scaling_h5(measurement_group) if "QI_Curve_Data" not in f: load_curves = False + if load_curves: + f = h5py.File(file_path, "r") + logger.debug(f"QI_Curve_Data group keys: {list(f.keys())}") logger.info(f"[{file_path.stem}] : Found Force Curves QI data in file.") qi_data_group = f["QI_Curve_Data"] - loaded_channels_data = {} channels_units = {} top_level_meta = {} for key, value in qi_data_group["Global_Metadata"].attrs.items(): + print(f"Global Metadata - {key}: {value}") if key.startswith("channel.unit."): channels_units[key.split(".")[-1]] = value top_level_meta[key] = value @@ -433,6 +437,6 @@ def load_h5jpk( all_curve_data = LazyQIData(qi_data_group, shape_x) - return (image_stack, _jpk_pixel_to_nm_scaling_h5(measurement_group), (all_curve_data, channels_units, full_metadata), timestamps) + return (image_stack, px2nm, (all_curve_data, channels_units, full_metadata), timestamps) - return (image_stack, _jpk_pixel_to_nm_scaling_h5(measurement_group), timestamps) + return (image_stack, px2nm, timestamps) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 0229588..807ae83 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -233,7 +233,7 @@ def __init__(self, self.size_x, self.size_y, self.shape_x, self.shape_y = None, None, None, None # Instantiate containers for data to be saved (so an exception is not caused if not saving) - self.curve_datasets = None + self.curve_groups = None def get_channels(self): """ @@ -265,9 +265,6 @@ def load(self): self.parse_dimension_data() - # Pre-allocate the image array based on the dimensions parsed from the metadata - self.image = np.zeros((self.shape_y, self.shape_x), dtype=np.float32) - # Access the curve data and metadata, and save to given file format with self.get_saving_context() as file: @@ -278,7 +275,7 @@ def load(self): # Setup H5 Data structures if needed if self.save_as == "h5": - self.curve_datasets, self.global_meta_group, self.curves_meta_group = self.setup_h5_structure(file) + self.curve_groups, self.global_meta_group, self.curves_meta_group = self.setup_h5_structure(file) if self.channel in ADDITIONAL_CHANNELS or self.save_as is not None or self.return_meta: for file_info in self.qi_archive.infolist(): @@ -329,8 +326,17 @@ def load(self): for chan_name, chan_data in self.collated_curve_data.items(): for direction in range(2): # Save the curve data and indicies to the appropriate dataset in the h5 file - self.curve_datasets[f"{direction}_{chan_name}_data"] = chan_data[f"Segment_{direction}"] - self.curve_datasets[f"{direction}_{chan_name}_indices"] = self.indicies[chan_name][f"Segment_{direction}"] + seg_name = f"Segment_{direction}" + self.curve_groups["Data"][seg_name].create_dataset( + name=chan_name, + data=chan_data[seg_name], + dtype=np.float32, + ) + self.curve_groups["Indicies"][seg_name].create_dataset( + name=chan_name, + data=self.indicies[chan_name][seg_name], + dtype=np.int32, + ) # Save the global metadata to the h5 file vlen_str_dt = h5py.string_dtype(encoding='utf-8') @@ -348,6 +354,9 @@ def load(self): self.full_metadata = LazyCurveMetadata(self.filepath, self.top_level_meta, self.qi_archive) self.all_curve_data = LazyCurveData(self.filepath, self.shape_x, self.channel_scaling, self.qi_archive) + # Load the image + self.image = self.get_image() + # Convert to nanometers if in meters if self.channel in ADDITIONAL_CHANNELS_IN_M: self.image = self.image * 1e9 @@ -393,18 +402,34 @@ def get_collated_curves(self): """ collated_curve_data = {} indicies = {} + for curve_data in self.flat_curve_data: for chan_name, chan_data in curve_data.items(): for seg_name, seg_data in chan_data.items(): if chan_name not in collated_curve_data: collated_curve_data[chan_name] = {} indicies[chan_name] = {} + if seg_name not in collated_curve_data[chan_name]: collated_curve_data[chan_name][seg_name] = [] - indicies[chan_name][seg_name] = [] - indicies[chan_name][seg_name].append(len(collated_curve_data[chan_name][seg_name])) + indicies[chan_name][seg_name] = [0] + + # Append the segment data as an array to the list (creates a 2D list) collated_curve_data[chan_name][seg_name].append(seg_data) + last_index = indicies[chan_name][seg_name][-1] + next_index = last_index + len(seg_data) + + indicies[chan_name][seg_name].append(next_index) + + for chan_name, segments in collated_curve_data.items(): + for seg_name in segments: + # Flattens the list of arrays into one massive 1D array for more efficiency + collated_curve_data[chan_name][seg_name] = np.concatenate(collated_curve_data[chan_name][seg_name]) + + # Converts the indices list into a standard fixed-length integer array + indicies[chan_name][seg_name] = np.array(indicies[chan_name][seg_name], dtype=np.int32) + return collated_curve_data, indicies @@ -418,55 +443,54 @@ def get_collated_metadata(self): A dictionary containing the collated metadata. """ collated_meta = {} + for seg_chan in self.segment_channels: + collated_meta[f"channel.unit.{seg_chan['name']}"] = seg_chan['unit'] for key, value in self.top_level_meta.items(): collated_meta[key] = value for curve_dict in self.full_metadata["curves"]: for key, value in curve_dict.items(): - if key not in collated_meta: - collated_meta[key] = [] - collated_meta[key].append(value) + if f"curve.{key}" not in collated_meta: + collated_meta[f"curve.{key}"] = [] + collated_meta[f"curve.{key}"].append(value) for segment_dict in self.full_metadata["segments"]: for key, value in segment_dict.items(): - if key not in collated_meta: - collated_meta[key] = [] - collated_meta[key].append(value) + if f"segment.{key}" not in collated_meta: + collated_meta[f"segment.{key}"] = [] + collated_meta[f"segment.{key}"].append(value) return collated_meta - def process_flat_curve_data(self): + def get_image(self): """ Processes the flat curve data dictionary into a 2D list structure matching the image dimensions. + + Returns ------- - all_curve_data : list - A 2D list of curve data dictionaries, where each dictionary contains the data for all channels and segments for that curve. + image : np.ndarray + A 2D array representing the image data. """ - all_curve_data = [] - for y in range(self.shape_y): - row = [] - for x in range(self.shape_x): - curve_num = y * self.shape_x + x - curve_data = self.flat_curve_data[curve_num] - row.append(curve_data) - - # Calculate on-the-fly image data if required - if self.channel in ADDITIONAL_CHANNELS: + + # If the image needs to be calculated, do so + if self.channel in ADDITIONAL_CHANNELS: + # Create an empty image array + image = np.zeros((self.shape_y, self.shape_x), dtype=np.float32) + for y in range(self.shape_y): + for x in range(self.shape_x): + curve_num = y * self.shape_x + x + curve_data = self.flat_curve_data[curve_num] + # Calculate on-the-fly image data if required seg_0_dict = {c: data["Segment_0"] for c, data in curve_data.items() if "Segment_0" in data} if self.channel == "contactPoint": - self.image[y, x] = _find_contact_point(seg_0_dict) + image[y, x] = _find_contact_point(seg_0_dict) elif self.channel == "manualTriggerPoint": - self.image[y, x] = _find_trigger_point(seg_0_dict) - all_curve_data.append(row) - return all_curve_data + image[y, x] = _find_trigger_point(seg_0_dict) + # Load the image directly if it already exists as a precalculated channel in the .jpk-qi-image file + else: + image, _ = _load_preprocessed_image(qi_archive=self.qi_archive, channel=self.channel, config_path=self.config_path) + return image def save_lite_data(self): - """ - Saves a lite form of the data (e.g., the calculated image data) to the appropriate format based on the save_as attribute. - - Parameters - ---------- - qi_archive : zipfile.ZipFile - The archive containing the .jpk-qi-image file. - """ + """Saves a lite form of the data (e.g., the calculated image data) to the appropriate format based on the save_as attribute.""" if self.save_as == "h5": with h5py.File(self.filepath.parent / f"{self.filepath.stem}.h5-jpk", "a") as h5file: # Save data required for reading the h5 file as a normal image file @@ -660,8 +684,8 @@ def setup_h5_structure(self, h5file): The h5 file in which to set up the structure for saving the curve data and metadata. Returns ------- - curve_datasets : dict - A dictionary containing the datasets for each curve and segment direction. + curve_groups : dict + A dictionary containing the group structure for each segment direction. global_meta_group : h5py.Group The h5 group for storing global metadata. curves_meta_group : h5py.Group @@ -671,36 +695,25 @@ def setup_h5_structure(self, h5file): # Create the main group for the QI curve data that all the curve data will be in qi_group = h5file.require_group("QI_Curve_Data") - curve_datasets = {} # Establish empty groups for global metadata and curve metadata global_meta_group = qi_group.require_group("Global_Metadata") curves_meta_group = qi_group.require_group("Curve_Metadata") + curves_group = qi_group.require_group("Curves") + + curve_groups = { + "Data": {}, + "Indicies": {} + } for direction in range(2): - # For each segment direction, establish a group and datasets for each channel in the segment channels list - dir_group = qi_group.require_group(f"Segment_{direction}") - data_group = dir_group.require_group("Data") - indices_group = dir_group.require_group("Indicies") - for seg_chan in self.segment_channels: - ds_name = seg_chan["name"] - if f"{ds_name}_data" not in data_group: - # Create the dataset for the given channel and segment direction if it doesn't already exist - curve_datasets[f"{direction}_{ds_name}_data"] = data_group.create_dataset( - f"{ds_name}", shape=(self.num_of_curves,), dtype=vlen_type - ) - else: - # If the dataset already exists, just add it to the curve datasets dictionary for later use - curve_datasets[f"{direction}_{ds_name}_data"] = data_group[f"{ds_name}"] - if f"{ds_name}" not in indices_group: - # Create a corresponding dataset to hold the indicies for the curve data - curve_datasets[f"{direction}_{ds_name}_indices"] = indices_group.create_dataset( - f"{ds_name}", shape=(self.num_of_curves,), dtype=vlen_type - ) - else: - # If the indicies dataset already exists, just add it to the curve datasets dictionary for later use - curve_datasets[f"{direction}_{ds_name}_indices"] = indices_group[f"{ds_name}"] - return curve_datasets, global_meta_group, curves_meta_group + # For each segment direction, establish the necessary group structure that will contain each channel dataset + seg_name = f"Segment_{direction}" + dir_group = curves_group.require_group(seg_name) + # Create the Data and Indicies subfolders and store their references + curve_groups["Data"][seg_name] = dir_group.require_group("Data") + curve_groups["Indicies"][seg_name] = dir_group.require_group("Indicies") + return curve_groups, global_meta_group, curves_meta_group def get_saving_context(self): """ From e776c665aadc84dca67df4d5607ab84610c2ee51 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Wed, 18 Mar 2026 15:43:15 +0000 Subject: [PATCH 17/54] Implementing caching of heavy data objects/ references to large open files and removing the ability to run curve analysis directly in the reader --- AFMReader/general_loader.py | 13 +- AFMReader/h5_jpk.py | 13 +- AFMReader/jpk_qi.py | 250 ++++++++++++++++-------------------- 3 files changed, 130 insertions(+), 146 deletions(-) diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index 1efc028..413f766 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -41,7 +41,10 @@ def __init__(self, filepath: str | Path, channel: str, kwargs: dict = None): self.channel = channel self.suffix = self.filepath.suffix self.loaded_curves = False - self.kwargs = kwargs + self.kwargs = kwargs if kwargs else {} + + # Store heavy loaded data in a dict to avoid having to reload it + self.cached_data = {} def load(self, channel: str | None = None, kwargs: dict = None) -> tuple[npt.NDArray | str, float | None]: # noqa: C901 """ @@ -85,7 +88,9 @@ def load(self, channel: str | None = None, kwargs: dict = None) -> tuple[npt.NDA else: logger.error(f"Loading h5-jpk file returned unexpected number of items: {len(h5_returned)}") elif self.suffix == ".jpk-qi-data": - jpk_qi_returned = jpk_qi.load_jpk_qi(self.filepath, self.channel, **self.kwargs) + if "jpk_qi_loader" not in self.cached_data: + self.cached_data["jpk_qi_loader"] = jpk_qi.jpk_qi_loader(filepath=self.filepath, channel=self.channel, **self.kwargs) + jpk_qi_returned = self.cached_data["jpk_qi_loader"].load(channel=self.channel, **self.kwargs) if len(jpk_qi_returned) == 2: image, pixel_to_nanometre_scaling_factor = jpk_qi_returned elif len(jpk_qi_returned) == 3: @@ -136,7 +141,9 @@ def get_available_channels(self): elif self.suffix == ".h5-jpk": available_channels = h5_jpk.get_h5jpk_channels(self.filepath) elif self.suffix == ".jpk-qi-data": - available_channels = jpk_qi.get_jpk_qi_channels(self.filepath) + if "jpk_qi_loader" not in self.cached_data: + self.cached_data["jpk_qi_loader"] = jpk_qi.jpk_qi_loader(filepath=self.filepath, **self.kwargs) + available_channels = self.cached_data["jpk_qi_loader"].get_available_channels() elif self.suffix == ".topostats": available_channels = ["image", "image_original"] elif self.suffix == ".bin": diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index dbdb7fc..388a5a8 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -273,9 +273,12 @@ def get_h5jpk_channels(file_path: Path | str): return available_channels class LazyQIData: - def __init__(self, qi_data_group: h5py.Group, shape_x: int): + def __init__(self, qi_data_group: h5py.Group, shape_x: int, shape_y: int): self.qi_data_group = qi_data_group self.shape_x = shape_x + self.shape_y = shape_y + self.dims = (shape_y, shape_x) + def __getitem__(self, y: int): @@ -287,6 +290,11 @@ def __getitem__(self, x: int): return self.parent._fetch_curve(self.y, x) return RowProxy(self, y) + def __iter__(self): + for y in range(self.shape_y): + for x in range(self.shape_x): + yield self._fetch_curve(y, x) + def _fetch_curve(self, y: int, x: int): curve_dict = {} curve_num = self.shape_x * y + x @@ -294,7 +302,6 @@ def _fetch_curve(self, y: int, x: int): for channel in segment_group["Indicies"]: start_idx = int(segment_group["Indicies"][channel][curve_num]) end_idx = int(segment_group["Indicies"][channel][curve_num + 1]) - print(f"Fetching curve for pixel (y={y}, x={x}), segment '{segment}', channel '{channel}': start_idx={start_idx}, end_idx={end_idx}") if channel not in curve_dict: curve_dict[channel] = {} curve_dict[channel][segment] = segment_group["Data"][channel][start_idx:end_idx] @@ -435,7 +442,7 @@ def load_h5jpk( full_metadata = LazyCurveMetadata(qi_data_group, top_level_meta) - all_curve_data = LazyQIData(qi_data_group, shape_x) + all_curve_data = LazyQIData(qi_data_group, shape_x, shape_y) return (image_stack, px2nm, (all_curve_data, channels_units, full_metadata), timestamps) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 807ae83..833e423 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -3,7 +3,6 @@ import io import re import zipfile -import time import numpy as np import javaproperties @@ -13,17 +12,12 @@ from AFMReader import jpk -ADDITIONAL_CHANNELS = ["contactPoint", "manualTriggerPoint"] -ADDITIONAL_CHANNELS_IN_M = ["contactPoint", "manualTriggerPoint"] - class LazyCurveData: """A proxy class that behaves like a 2D list but fetches .dat files on demand.""" def __init__(self, filepath, shape_x, channel_scaling, archive): self.filepath = filepath self.shape_x = shape_x self.channel_scaling = channel_scaling - - # 1. OPEN THE ARCHIVE ONCE AND KEEP IT OPEN self.archive = archive def __getitem__(self, y: int): @@ -38,12 +32,10 @@ def __getitem__(self, x: int): def _fetch_curve(self, y: int, x: int): - t_start_total = time.perf_counter() curve_num = y * self.shape_x + x curve_data = {} - # 3. REUSE THE ALREADY-OPEN ARCHIVE for chan_name, scale in self.channel_scaling.items(): curve_data[chan_name] = {} for direction in (0, 1): @@ -58,12 +50,8 @@ def _fetch_curve(self, y: int, x: int): except KeyError: pass # File doesn't exist for this segment - t_end_total = time.perf_counter() - print(f"[Lazy Data] Total fetch for curve {curve_num} at (y={y}, x={x}) took {t_end_total - t_start_total:.6f} seconds") - return curve_data - # Good practice: add a method to close the archive when the user closes the image def close(self): self.archive.close() @@ -92,7 +80,6 @@ def __init__(self, filepath, meta_type, archive): self.archive = archive def __getitem__(self, idx: int): - t_start = time.perf_counter() if self.meta_type == "curve": path = f"index/{idx}/header.properties" @@ -107,15 +94,28 @@ def __getitem__(self, idx: int): except KeyError: meta_dict = {} - t_end = time.perf_counter() - print(f"[Lazy Meta] Fetched {self.meta_type} metadata for index {idx} in {t_end - t_start:.6f} seconds") - return meta_dict def _get_channel_scaling(props, channel_index): """ Parses the JPK properties dictionary to find the cumulative multiplier and offset for a specific channel index (e.g., '1' for vDeflection). + + Parameters + ---------- + props : dict + The properties dictionary loaded from the JPK file. + channel_index : str + The index of the channel to find the scaling for (e.g., '1' for vDeflection). + + Returns + ------- + final_multiplier : float + The cumulative multiplier for the specified channel. + final_offset : float + The cumulative offset for the specified channel. + unit : str + The unit of the channel. """ prefix = f"lcd-info.{channel_index}." @@ -158,29 +158,15 @@ def _get_channel_scaling(props, channel_index): return final_multiplier, final_offset, unit -def _load_preprocessed_image(qi_archive, channel, config_path=None): - path_to_image = None - for file_name in qi_archive.namelist(): - if file_name.endswith(".jpk-qi-image"): - path_to_image = file_name - if path_to_image is None: - raise FileNotFoundError(f"{path_to_image} not found in JPK archive") - - tif_bytes = qi_archive.read(path_to_image) - - virtual_file = io.BytesIO(tif_bytes) - logger.info(f"Looking for channel {channel} in ") - return jpk._load_jpk(virtual_file, path_to_image, channel=channel, file_suffix=".jpk-qi-data", config_path=config_path, flip_image=False) class jpk_qi_loader: """Class for readability and improving modularity in the load jpk qi data function""" def __init__(self, filepath: Path | str, - channel: str, + channel: str | None = None, config_path: Path | str | None = None, flip_image: bool | None = True, - save_as: str | None = None, - return_meta: bool = False): + save_as_h5: bool = False): """ Initializes the loader with the provided parameters. @@ -188,26 +174,25 @@ def __init__(self, ---------- filepath : Path | str The path to the .jpk-qi file to be loaded. - channel : str - The specific channel to be extracted from the file (e.g., "measuredHeight")." + channel : str | None, optional + The specific channel to be extracted from the file (e.g., "measuredHeight"). Default is None. config_path : Path | str | None, optional The path to the configuration file, if any. Default is None. flip_image : bool | None, optional Whether to flip the image vertically. Default is True. - save_as : str | None, optional - The format to save the loaded data. Default is None. - return_meta : bool, optional - Whether to return the full metadata. Default is False. + save_as_h5 : bool, optional + Whether to save the loaded data as an H5 file. Default is False. """ self.filepath = Path(filepath) self.channel = channel self.config_path = config_path self.flip_image = flip_image - self.save_as = save_as - self.return_meta = return_meta + self.save_as_h5 = save_as_h5 # Open the ZIP archive once and keep it open for the duration of the loading process to improve performance when accessing multiple files within the archive self.qi_archive = zipfile.ZipFile(self.filepath, "r") + # Set path to the .jpk-qi-image file within the archive for later use + self.path_to_image = None # Initialize key attributes that will be returned / accessed frequently @@ -235,7 +220,7 @@ def __init__(self, # Instantiate containers for data to be saved (so an exception is not caused if not saving) self.curve_groups = None - def get_channels(self): + def get_available_channels(self): """ Retrieves the available channels from the .jpk-qi-image file within the archive, and adds any additional calculated channels. @@ -246,21 +231,49 @@ def get_channels(self): metadata_options : dict A dictionary of options for what metadata to return """ - channels = [] # Look for the jpk-qi-image file in the archive - for file_name in self.qi_archive.namelist(): - if file_name.endswith(".jpk-qi-image"): - path_to_image = file_name + if self.path_to_image is None: + for file_name in self.qi_archive.namelist(): + if file_name.endswith(".jpk-qi-image"): + self.path_to_image = file_name # Add the channels which exist in the jpk-qi-image file - with self.qi_archive.open(path_to_image, "r") as image_file: - channels += jpk._get_jpk_channels(file=image_file, filename=self.filepath.stem, file_path=self.filepath / Path(path_to_image)) - channels += ADDITIONAL_CHANNELS + with self.qi_archive.open(self.path_to_image, "r") as image_file: + channels = jpk._get_jpk_channels(file=image_file, filename=self.filepath.stem, file_path=self.filepath / Path(self.path_to_image)) return channels, {"save_as_h5": bool} - def load(self): - """Loads the .jpk-qi file""" + def load(self, + channel: str | None = None, + config_path: Path | str | None = None, + flip_image: bool | None = True, + save_as_h5: bool = False) -> tuple[np.ndarray, float, dict] | tuple[np.ndarray, float]: + """ + Loads the .jpk-qi file + + Parameters + ---------- + channel : str | None, optional + The specific channel to be extracted from the file (e.g., "measuredHeight"). If None, the default channel will be used. Default is None. + config_path : Path | str | None, optional + Path to the configuration file. If None, the default configuration will be used. Default is None. + flip_image : bool | None, optional + Whether to flip the image. If None, the default behavior will be used. Default is True. + save_as_h5 : bool, optional + Whether to save the data as an H5 file. Default is False. + + Returns + ------- + tuple + A tuple containing the image data (numpy.ndarray), the pixel to nanometre scaling factor (float), and optionally the curve data (dict) if available. + """ + + # Update instance attributes based on provided parameters, largely so loader can be called to get channels without setting a channel + self.channel = channel if channel else self.channel + self.config_path = config_path if config_path else self.config_path + self.flip_image = flip_image if flip_image is not None else self.flip_image + self.save_as_h5 = save_as_h5 if save_as_h5 is not None else self.save_as_h5 + self.extract_global_metadata() self.parse_dimension_data() @@ -274,10 +287,9 @@ def load(self): segment_meta_regex = re.compile(r"index/(\d+)/segments/(\d+)/segment-header\.properties") # Setup H5 Data structures if needed - if self.save_as == "h5": + if self.save_as_h5: self.curve_groups, self.global_meta_group, self.curves_meta_group = self.setup_h5_structure(file) - if self.channel in ADDITIONAL_CHANNELS or self.save_as is not None or self.return_meta: for file_info in self.qi_archive.infolist(): filename = file_info.filename @@ -290,38 +302,35 @@ def load(self): self.extract_dat_file(file_info, curve_num, direction, chan_name) continue - if self.return_meta or self.save_as is not None: - # Check Curve Metadata - curve_meta_match = curve_meta_regex.match(filename) - if curve_meta_match: - # If file is a curve metadata file, extract the curve number from the filename - curve_num = int(curve_meta_match.group(1)) - # Then load the metadata from the file - self.extract_curve_metadata(file_info, curve_num) - continue - - # Check Segment Metadata - segment_meta_match = segment_meta_regex.match(filename) - if segment_meta_match: - # If file is a segment metadata file, extract the curve number and segment direction from the filename - curve_num, direction = int(segment_meta_match.group(1)), int(segment_meta_match.group(2)) - # Then load the segment metadata from the file - self.extract_segment_metadata(file_info, curve_num, direction) - continue - - if self.return_meta or self.save_as is not None: + # Check Segment Metadata + segment_meta_match = segment_meta_regex.match(filename) + if segment_meta_match: + # If file is a segment metadata file, extract the curve number and segment direction from the filename + curve_num, direction = int(segment_meta_match.group(1)), int(segment_meta_match.group(2)) + # Then load the segment metadata from the file + self.extract_segment_metadata(file_info, curve_num, direction) + continue + + # Check Curve Metadata + curve_meta_match = curve_meta_regex.match(filename) + if curve_meta_match: + # If file is a curve metadata file, extract the curve number from the filename + curve_num = int(curve_meta_match.group(1)) + # Then load the metadata from the file + self.extract_curve_metadata(file_info, curve_num) + continue + # TODO can we remove curve_meta_dict and just use curve_meta or is the non duplicating necessary self.curve_meta = [self.curve_meta_dict.get(i, {}) for i in range(self.num_of_curves)] self.segment_meta = [self.segment_meta_dict.get(i, {}) for i in range(self.num_of_curves * 2)] self.full_metadata = self.construct_full_metadata() - if self.save_as is not None: # If saving, need to collate the curve data into a format that can be easily saved to the h5 file (a dataset per channel per segment direction) self.collated_curve_data, self.indicies = self.get_collated_curves() self.collated_metadata = self.get_collated_metadata() # Save as h5 if required - if self.save_as == "h5": + if self.save_as_h5: # Save the curve data to the appropriate datasets in the h5 file for chan_name, chan_data in self.collated_curve_data.items(): for direction in range(2): @@ -355,19 +364,12 @@ def load(self): self.all_curve_data = LazyCurveData(self.filepath, self.shape_x, self.channel_scaling, self.qi_archive) # Load the image - self.image = self.get_image() - - # Convert to nanometers if in meters - if self.channel in ADDITIONAL_CHANNELS_IN_M: - self.image = self.image * 1e9 + self.image, _ = self.get_image() # Save a lite form of the images (precalculated) if saving to a file - if self.save_as is not None: + if self.save_as_h5: self.save_lite_data() - # Need to include flip image as _load_jpk flip image is set to false - if self.flip_image: - self.image = np.flipud(self.image) if self.all_curve_data: return (self.image, self.px2nm, (self.all_curve_data, self.channels_units, self.full_metadata)) @@ -460,7 +462,7 @@ def get_collated_metadata(self): return collated_meta - def get_image(self): + def get_image(self, overide_channel: str | None = None): """ Processes the flat curve data dictionary into a 2D list structure matching the image dimensions. @@ -470,28 +472,27 @@ def get_image(self): A 2D array representing the image data. """ - # If the image needs to be calculated, do so - if self.channel in ADDITIONAL_CHANNELS: - # Create an empty image array - image = np.zeros((self.shape_y, self.shape_x), dtype=np.float32) - for y in range(self.shape_y): - for x in range(self.shape_x): - curve_num = y * self.shape_x + x - curve_data = self.flat_curve_data[curve_num] - # Calculate on-the-fly image data if required - seg_0_dict = {c: data["Segment_0"] for c, data in curve_data.items() if "Segment_0" in data} - if self.channel == "contactPoint": - image[y, x] = _find_contact_point(seg_0_dict) - elif self.channel == "manualTriggerPoint": - image[y, x] = _find_trigger_point(seg_0_dict) - # Load the image directly if it already exists as a precalculated channel in the .jpk-qi-image file + if overide_channel: + channel = overide_channel else: - image, _ = _load_preprocessed_image(qi_archive=self.qi_archive, channel=self.channel, config_path=self.config_path) - return image + channel = self.channel + + path_to_image = None + for file_name in self.qi_archive.namelist(): + if file_name.endswith(".jpk-qi-image"): + path_to_image = file_name + if path_to_image is None: + raise FileNotFoundError(f"{path_to_image} not found in JPK archive") + + tif_bytes = self.qi_archive.read(path_to_image) + + virtual_file = io.BytesIO(tif_bytes) + logger.info(f"Looking for channel {channel} in {path_to_image}") + return jpk._load_jpk(virtual_file, path_to_image, channel=channel, file_suffix=".jpk-qi-data", config_path=self.config_path) def save_lite_data(self): - """Saves a lite form of the data (e.g., the calculated image data) to the appropriate format based on the save_as attribute.""" - if self.save_as == "h5": + """Saves a lite form of the data (e.g., the calculated image data) to the appropriate format based on the save_as_h5 attribute.""" + if self.save_as_h5: with h5py.File(self.filepath.parent / f"{self.filepath.stem}.h5-jpk", "a") as h5file: # Save data required for reading the h5 file as a normal image file meas_grp = h5file.require_group("Measurement_000") @@ -537,7 +538,7 @@ def save_lite_data(self): if h5_channel == self.channel: channel_image = self.image else: - channel_image, _ = _load_preprocessed_image(qi_archive=self.qi_archive, channel=h5_channel, config_path=self.config_path) + channel_image, _ = self.get_image(overide_channel=h5_channel) frame_stack = channel_image.flatten().reshape(-1, 1) # Update/ replace the channels dataset @@ -547,8 +548,8 @@ def save_lite_data(self): def save_metadata(self): - """Saves the metadata to the appropriate format based on the save_as attribute.""" - if self.save_as == "h5": + """Saves the metadata to the appropriate format based on the save_as_h5 attribute.""" + if self.save_as_h5: for seg_chan in self.segment_channels: self.global_meta_group.attrs[f"channel.unit.{seg_chan['name']}"] = seg_chan['unit'] for key, value in self.top_level_meta.items(): @@ -717,15 +718,15 @@ def setup_h5_structure(self, h5file): def get_saving_context(self): """ - Returns the appropriate context manager for saving the data based on the save_as attribute. - If save_as is "h5", it returns a context manager for an h5 file. Otherwise, it returns a null context. + Returns the appropriate context manager for saving the data based on the save_as_h5 attribute. + If save_as_h5 is True, it returns a context manager for an h5 file. Otherwise, it returns a null context. Returns ------- contextlib.AbstractContextManager The context manager for saving the data. """ - if self.save_as == "h5": + if self.save_as_h5: return h5py.File(self.filepath.parent / f"{self.filepath.stem}.h5-jpk", "a") else: return nullcontext() @@ -857,37 +858,6 @@ def _make_num_min_characters(num : int, min_chars: int = 3): string_num = "0" * (min_chars - len(string_num)) + string_num return string_num -def get_jpk_qi_channels(file_path: Path | str): - logger.debug("Starting to get jpk qi data channels") - file_path = Path(file_path) - channels = [] - with zipfile.ZipFile(file_path, "r") as qi_archive: - for file_name in qi_archive.namelist(): - if file_name.endswith(".jpk-qi-image"): - path_to_image = file_name - with qi_archive.open(path_to_image, "r") as image_file: - channels += jpk._get_jpk_channels(file=image_file, filename=file_path.stem, file_path=file_path / Path(path_to_image)) - channels += ADDITIONAL_CHANNELS - logger.debug("Got jpk qi data channels") - return channels, {"save_as_h5": bool} - -def _find_contact_point(curve): - # find contact point in vertical deflection by peak in first derivative - vdef = curve["vDeflection"] - if len(vdef) < 2: - return np.nan - derivative_vert_deflection = np.diff(vdef) - # Doesn't look like this line is needed: peak_derivative_value = np.max(derivative_vert_deflection) - peak_derivative_index = np.argmax(derivative_vert_deflection) - - # find corresponding height value - corresponding_height_at_peak = curve["measuredHeight"][peak_derivative_index] - - return corresponding_height_at_peak - -def _find_trigger_point(curve): - trigger_point = curve["measuredHeight"][-1] - return trigger_point def _max_points_buffer(curves_data, samples=20, points_buffer=1.2): From 3e0b8b10e75984b51e5626855dbf50f6e526a31a Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 23 Mar 2026 12:00:31 +0000 Subject: [PATCH 18/54] Adjusting curve data access method to work more like a 2D array for more robust coordinate access --- AFMReader/h5_jpk.py | 53 +++++++++++++++++++++++++++++++------ AFMReader/jpk_qi.py | 64 ++++++++++++++++++++++++++++++++++++--------- 2 files changed, 96 insertions(+), 21 deletions(-) diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index 388a5a8..a2b5477 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -273,11 +273,12 @@ def get_h5jpk_channels(file_path: Path | str): return available_channels class LazyQIData: - def __init__(self, qi_data_group: h5py.Group, shape_x: int, shape_y: int): + def __init__(self, qi_data_group: h5py.Group, shape_x: int, shape_y: int, flip_image: bool = True): self.qi_data_group = qi_data_group self.shape_x = shape_x self.shape_y = shape_y self.dims = (shape_y, shape_x) + self.flip_image = flip_image def __getitem__(self, y: int): @@ -296,7 +297,11 @@ def __iter__(self): yield self._fetch_curve(y, x) def _fetch_curve(self, y: int, x: int): + if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x: + raise IndexError(f"Curve index out of bounds: ({x}, {y})") curve_dict = {} + if self.flip_image: + y = self.shape_y - 1 - y curve_num = self.shape_x * y + x for segment, segment_group in self.qi_data_group["Curves"].items(): for channel in segment_group["Indicies"]: @@ -309,9 +314,12 @@ def _fetch_curve(self, y: int, x: int): class LazyCurveMetadata: """A proxy class that fetches header.properties files on demand.""" - def __init__(self, qi_data_group: h5py.Group, top_level_meta: dict): + def __init__(self, qi_data_group: h5py.Group, top_level_meta: dict, shape_x: int, shape_y: int, flip_image: bool = True): self.qi_data_group = qi_data_group self.top_level_meta = top_level_meta + self.shape_x = shape_x + self.shape_y = shape_y + self.flip_image = flip_image # Expose top_level so the frontend can still do `raw_metadata["top_level"]` self.top_level = top_level_meta @@ -319,18 +327,46 @@ def __getitem__(self, key): if key == "top_level": return self.top_level elif key == "curves": - return LazyMetaProxy(self.qi_data_group, "curve") + return LazyMetaProxy(self.qi_data_group, "curve", self.shape_x, self.shape_y, self.flip_image) elif key == "segments": - return LazyMetaProxy(self.qi_data_group, "segment") + return LazyMetaProxy(self.qi_data_group, "segment", self.shape_x, self.shape_y, self.flip_image) raise KeyError(key) class LazyMetaProxy: - def __init__(self, qi_data_group: h5py.Group, meta_type: str): + def __init__(self, qi_data_group: h5py.Group, meta_type: str, shape_x: int, shape_y: int, flip_image: bool = True): self.qi_data_group = qi_data_group self.meta_type = meta_type + self.shape_x = shape_x + self.shape_y = shape_y + self.flip_image = flip_image + def __getitem__(self, y: int): + class RowProxy: + def __init__(self, parent, y): + self.parent = parent + self.y = y + def __getitem__(self, x): + if self.parent.meta_type == "curve": + return self.parent._fetch_meta(self.y, x) + elif self.parent.meta_type == "segment": + class SegmentMetaProxy: + def __init__(self, parent, y, x): + self.parent = parent + self.y = y + self.x = x + def __getitem__(self, direction): + return self.parent.parent._fetch_meta(self.y, self.x, direction) + return SegmentMetaProxy(self, self.y, x) + return RowProxy(self, y) - def __getitem__(self, idx: int): + def _fetch_meta(self, y: int, x: int, direction: int = None): + if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x: + raise IndexError(f"Curve index out of bounds: ({x}, {y})") + if self.flip_image: + y = self.shape_y - 1 - y + idx = (y * self.shape_x) + x + if direction is not None: + idx = (idx * 2) + direction meta_dict = {} for key in self.qi_data_group["Curve_Metadata"]: if key.startswith(f"{self.meta_type}."): @@ -342,6 +378,7 @@ def __getitem__(self, idx: int): return meta_dict + def load_h5jpk( file_path: Path | str, channel: str, flip_image: bool = True, load_curves: bool = True ) -> tuple[np.ndarray, float, dict[str, float]]: @@ -440,9 +477,9 @@ def load_h5jpk( channels_units[key.split(".")[-1]] = value top_level_meta[key] = value - full_metadata = LazyCurveMetadata(qi_data_group, top_level_meta) + full_metadata = LazyCurveMetadata(qi_data_group, top_level_meta, shape_x, shape_y, flip_image) - all_curve_data = LazyQIData(qi_data_group, shape_x, shape_y) + all_curve_data = LazyQIData(qi_data_group, shape_x, shape_y, flip_image) return (image_stack, px2nm, (all_curve_data, channels_units, full_metadata), timestamps) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 833e423..b78e1fc 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -14,11 +14,13 @@ class LazyCurveData: """A proxy class that behaves like a 2D list but fetches .dat files on demand.""" - def __init__(self, filepath, shape_x, channel_scaling, archive): + def __init__(self, filepath, shape_x, shape_y, channel_scaling, archive, flip_image: bool = True): self.filepath = filepath self.shape_x = shape_x + self.shape_y = shape_y self.channel_scaling = channel_scaling self.archive = archive + self.flip_image = flip_image def __getitem__(self, y: int): # Return a row proxy to handle the second index [x] @@ -31,8 +33,16 @@ def __getitem__(self, x: int): return RowProxy(self, y) - def _fetch_curve(self, y: int, x: int): + def __iter__(self): + for y in range(self.shape_y): + for x in range(self.shape_x): + yield self._fetch_curve(y, x) + def _fetch_curve(self, y: int, x: int): + if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x: + raise IndexError(f"Curve index out of bounds: ({x}, {y})") + if self.flip_image: + y = self.shape_y - 1 - y curve_num = y * self.shape_x + x curve_data = {} @@ -58,35 +68,64 @@ def close(self): class LazyCurveMetadata: """A proxy class that fetches header.properties files on demand.""" - def __init__(self, filepath, top_level_meta, archive): + def __init__(self, filepath, top_level_meta, archive, shape_x: int, shape_y: int, flip_image: bool = True): self.filepath = filepath # Expose top_level so the frontend can still do `raw_metadata["top_level"]` self.top_level = top_level_meta self.archive = archive + self.shape_x = shape_x + self.shape_y = shape_y + self.flip_image = flip_image def __getitem__(self, key): if key == "top_level": return self.top_level elif key == "curves": - return LazyMetaProxy(self.filepath, "curve", self.archive) + return LazyMetaProxy(self.filepath, "curve", self.archive, self.shape_x, self.shape_y, self.flip_image) elif key == "segments": - return LazyMetaProxy(self.filepath, "segment", self.archive) + return LazyMetaProxy(self.filepath, "segment", self.archive, self.shape_x, self.shape_y, self.flip_image) raise KeyError(key) class LazyMetaProxy: - def __init__(self, filepath, meta_type, archive): + def __init__(self, filepath, meta_type, archive, shape_x: int, shape_y: int, flip_image: bool = True): self.filepath = filepath self.meta_type = meta_type self.archive = archive + self.shape_x = shape_x + self.shape_y = shape_y + self.flip_image = flip_image - def __getitem__(self, idx: int): + def __getitem__(self, y: int): + class RowProxy: + def __init__(self, parent, y): + self.parent = parent + self.y = y + def __getitem__(self, x): + if self.parent.meta_type == "curve": + return self.parent._fetch_meta(self.y, x) + elif self.parent.meta_type == "segment": + class SegmentMetaProxy: + def __init__(self, parent, y, x): + self.parent = parent + self.y = y + self.x = x + def __getitem__(self, direction): + return self.parent.parent._fetch_meta(self.y, self.x, direction) + return SegmentMetaProxy(self, self.y, x) + return RowProxy(self, y) + def _fetch_meta(self, y: int, x: int, direction: int = None): + if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x: + raise IndexError(f"Curve index out of bounds: ({x}, {y})") + if self.flip_image: + y = self.shape_y - 1 - y + idx = (y * self.shape_x) + x if self.meta_type == "curve": path = f"index/{idx}/header.properties" else: - curve_num = idx // 2 - direction = idx % 2 - path = f"index/{curve_num}/segments/{direction}/segment-header.properties" + if direction is None: + raise ValueError("Direction must be provided for segment metadata") + path = f"index/{idx}/segments/{direction}/segment-header.properties" try: with self.archive.open(path) as f: @@ -289,7 +328,6 @@ def load(self, # Setup H5 Data structures if needed if self.save_as_h5: self.curve_groups, self.global_meta_group, self.curves_meta_group = self.setup_h5_structure(file) - for file_info in self.qi_archive.infolist(): filename = file_info.filename @@ -360,8 +398,8 @@ def load(self, else: self.global_meta_group.attrs[key] = str(value).encode('utf-8') - self.full_metadata = LazyCurveMetadata(self.filepath, self.top_level_meta, self.qi_archive) - self.all_curve_data = LazyCurveData(self.filepath, self.shape_x, self.channel_scaling, self.qi_archive) + self.full_metadata = LazyCurveMetadata(self.filepath, self.top_level_meta, self.qi_archive, self.shape_x, self.shape_y, flip_image=self.flip_image) + self.all_curve_data = LazyCurveData(self.filepath, self.shape_x, self.shape_y, self.channel_scaling, self.qi_archive, flip_image=self.flip_image) # Load the image self.image, _ = self.get_image() From b8ec43a308c6ff89dc6b740276d96e97bdb1a799 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Wed, 25 Mar 2026 16:52:42 +0000 Subject: [PATCH 19/54] Separating saving functionality for jpk-qi-data and adding function to eagerly load lazy loaded data --- AFMReader/general_loader.py | 7 +- AFMReader/h5_jpk.py | 20 ++- AFMReader/jpk_qi.py | 314 ++++++++++++++++++++---------------- 3 files changed, 197 insertions(+), 144 deletions(-) diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index 413f766..c1e699b 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -4,9 +4,7 @@ import numpy.typing as npt -print("Importing AFMReader modules...") from AFMReader import asd, gwy, h5_jpk, ibw, jpk, spm, stp, top, topostats, jpk_qi, bin -print("AFMReader modules imported.") from AFMReader.logging import logger @@ -84,6 +82,8 @@ def load(self, channel: str | None = None, kwargs: dict = None) -> tuple[npt.NDA elif len(h5_returned) == 4: image, pixel_to_nanometre_scaling_factor, curve_data, _ = h5_returned self.loaded_curves = True + print(f"Loaded image with shape {image.shape} and pixel to nanometre scaling factor {pixel_to_nanometre_scaling_factor}") + print(f"Image has max value {image.max()} and min value {image.min()}") return image, pixel_to_nanometre_scaling_factor, curve_data else: logger.error(f"Loading h5-jpk file returned unexpected number of items: {len(h5_returned)}") @@ -120,6 +120,9 @@ def load(self, channel: str | None = None, kwargs: dict = None) -> tuple[npt.NDA else: raise ValueError(f"File type '{self.suffix}' is not currently handled by AFMReader.") + print(f"Loaded image with shape {image.shape} and pixel to nanometre scaling factor {pixel_to_nanometre_scaling_factor}") + print(f"Image has max value {image.max()} and min value {image.min()}") + return image, pixel_to_nanometre_scaling_factor except ValueError as e: diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index a2b5477..97a97b5 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -312,6 +312,25 @@ def _fetch_curve(self, y: int, x: int): curve_dict[channel][segment] = segment_group["Data"][channel][start_idx:end_idx] return curve_dict + def load_all_curves(self): + all_curves = [[{} for _ in range(self.shape_x)] for _ in range(self.shape_y)] + for segment, segment_group in self.qi_data_group["Curves"].items(): + for channel in segment_group["Indicies"]: + indicies = segment_group["Indicies"][channel][:] + data = segment_group["Data"][channel][:] + for i in range(len(indicies) - 1): + start_idx = int(indicies[i]) + end_idx = int(indicies[i + 1]) + x = i % self.shape_x + y = i // self.shape_x + if self.flip_image: + y = self.shape_y - 1 - y + if channel not in all_curves[y][x]: + all_curves[y][x][channel] = {} + all_curves[y][x][channel][segment] = data[start_idx:end_idx] + + return all_curves + class LazyCurveMetadata: """A proxy class that fetches header.properties files on demand.""" def __init__(self, qi_data_group: h5py.Group, top_level_meta: dict, shape_x: int, shape_y: int, flip_image: bool = True): @@ -472,7 +491,6 @@ def load_h5jpk( channels_units = {} top_level_meta = {} for key, value in qi_data_group["Global_Metadata"].attrs.items(): - print(f"Global Metadata - {key}: {value}") if key.startswith("channel.unit."): channels_units[key.split(".")[-1]] = value top_level_meta[key] = value diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index b78e1fc..a0edf96 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -14,6 +14,7 @@ class LazyCurveData: """A proxy class that behaves like a 2D list but fetches .dat files on demand.""" + def __init__(self, filepath, shape_x, shape_y, channel_scaling, archive, flip_image: bool = True): self.filepath = filepath self.shape_x = shape_x @@ -28,10 +29,11 @@ class RowProxy: def __init__(self, parent, y): self.parent = parent self.y = y + def __getitem__(self, x: int): return self.parent._fetch_curve(self.y, x) - return RowProxy(self, y) + return RowProxy(self, y) def __iter__(self): for y in range(self.shape_y): @@ -53,21 +55,38 @@ def _fetch_curve(self, y: int, x: int): try: # Access the file directly without re-parsing the ZIP directory with self.archive.open(dat_path) as f: - raw_array = np.frombuffer(f.read(), dtype='>i4') - curve_data[chan_name][f"Segment_{direction}"] = ( - raw_array * scale["multiplier"] - ) + scale["offset"] + raw_array = np.frombuffer(f.read(), dtype=">i4") + curve_data[chan_name][f"Segment_{direction}"] = (raw_array * scale["multiplier"]) + scale[ + "offset" + ] except KeyError: - pass # File doesn't exist for this segment + pass # File doesn't exist for this segment return curve_data + def load_all_curves(self): + """ + Eagerly loads all curve data into a 2D list structure matching the image dimensions. + + This can be used if the user wants to have all the curve data available at once, but it is not recommended + for large datasets as it will consume a lot of memory. In this case, it is not notably faster as the zip + structure means each curve is effectively loaded individually anyway + """ + all_curve_data = [[None for _ in range(self.shape_x)] for _ in range(self.shape_y)] + for y in range(self.shape_y): + for x in range(self.shape_x): + all_curve_data[y][x] = self._fetch_curve(y, x) + # TODO may be good to just return self here as not faster and lots of memory + # return self + return all_curve_data + def close(self): self.archive.close() class LazyCurveMetadata: """A proxy class that fetches header.properties files on demand.""" + def __init__(self, filepath, top_level_meta, archive, shape_x: int, shape_y: int, flip_image: bool = True): self.filepath = filepath # Expose top_level so the frontend can still do `raw_metadata["top_level"]` @@ -86,6 +105,7 @@ def __getitem__(self, key): return LazyMetaProxy(self.filepath, "segment", self.archive, self.shape_x, self.shape_y, self.flip_image) raise KeyError(key) + class LazyMetaProxy: def __init__(self, filepath, meta_type, archive, shape_x: int, shape_y: int, flip_image: bool = True): self.filepath = filepath @@ -100,18 +120,23 @@ class RowProxy: def __init__(self, parent, y): self.parent = parent self.y = y + def __getitem__(self, x): if self.parent.meta_type == "curve": return self.parent._fetch_meta(self.y, x) elif self.parent.meta_type == "segment": + class SegmentMetaProxy: def __init__(self, parent, y, x): self.parent = parent self.y = y self.x = x + def __getitem__(self, direction): return self.parent.parent._fetch_meta(self.y, self.x, direction) + return SegmentMetaProxy(self, self.y, x) + return RowProxy(self, y) def _fetch_meta(self, y: int, x: int, direction: int = None): @@ -135,6 +160,7 @@ def _fetch_meta(self, y: int, x: int, direction: int = None): return meta_dict + def _get_channel_scaling(props, channel_index): """ Parses the JPK properties dictionary to find the cumulative multiplier @@ -200,12 +226,15 @@ def _get_channel_scaling(props, channel_index): class jpk_qi_loader: """Class for readability and improving modularity in the load jpk qi data function""" - def __init__(self, + + def __init__( + self, filepath: Path | str, channel: str | None = None, config_path: Path | str | None = None, flip_image: bool | None = True, - save_as_h5: bool = False): + save_as_h5: bool = False, + ): """ Initializes the loader with the provided parameters. @@ -279,14 +308,18 @@ def get_available_channels(self): # Add the channels which exist in the jpk-qi-image file with self.qi_archive.open(self.path_to_image, "r") as image_file: - channels = jpk._get_jpk_channels(file=image_file, filename=self.filepath.stem, file_path=self.filepath / Path(self.path_to_image)) + channels = jpk._get_jpk_channels( + file=image_file, filename=self.filepath.stem, file_path=self.filepath / Path(self.path_to_image) + ) return channels, {"save_as_h5": bool} - def load(self, + def load( + self, channel: str | None = None, config_path: Path | str | None = None, flip_image: bool | None = True, - save_as_h5: bool = False) -> tuple[np.ndarray, float, dict] | tuple[np.ndarray, float]: + save_as_h5: bool | None = None, + ) -> tuple[np.ndarray, float, dict] | tuple[np.ndarray, float]: """ Loads the .jpk-qi file @@ -317,89 +350,17 @@ def load(self, self.parse_dimension_data() - # Access the curve data and metadata, and save to given file format - with self.get_saving_context() as file: - - # Compile Regexes - dat_regex = re.compile(r"index/(\d+)/segments/(\d+)/channels/([^/]+)\.dat") - curve_meta_regex = re.compile(r"index/(\d+)/header\.properties") - segment_meta_regex = re.compile(r"index/(\d+)/segments/(\d+)/segment-header\.properties") + # Setup H5 Data structures if needed + if self.save_as_h5: + self.load_all_data() + self.save_to_h5(collated_curve_data=self.collated_curve_data, indicies=self.indicies, collated_metadata=self.collated_metadata) - # Setup H5 Data structures if needed - if self.save_as_h5: - self.curve_groups, self.global_meta_group, self.curves_meta_group = self.setup_h5_structure(file) - for file_info in self.qi_archive.infolist(): - filename = file_info.filename - - # Check Binary Data - dat_match = dat_regex.match(filename) - if dat_match: - # If file is a .dat file, extract the curve number, segment direction and channel name from the filename - curve_num, direction, chan_name = int(dat_match.group(1)), int(dat_match.group(2)), dat_match.group(3) - # Then load the data from the file - self.extract_dat_file(file_info, curve_num, direction, chan_name) - continue - - # Check Segment Metadata - segment_meta_match = segment_meta_regex.match(filename) - if segment_meta_match: - # If file is a segment metadata file, extract the curve number and segment direction from the filename - curve_num, direction = int(segment_meta_match.group(1)), int(segment_meta_match.group(2)) - # Then load the segment metadata from the file - self.extract_segment_metadata(file_info, curve_num, direction) - continue - - # Check Curve Metadata - curve_meta_match = curve_meta_regex.match(filename) - if curve_meta_match: - # If file is a curve metadata file, extract the curve number from the filename - curve_num = int(curve_meta_match.group(1)) - # Then load the metadata from the file - self.extract_curve_metadata(file_info, curve_num) - continue - - # TODO can we remove curve_meta_dict and just use curve_meta or is the non duplicating necessary - self.curve_meta = [self.curve_meta_dict.get(i, {}) for i in range(self.num_of_curves)] - self.segment_meta = [self.segment_meta_dict.get(i, {}) for i in range(self.num_of_curves * 2)] - self.full_metadata = self.construct_full_metadata() - - # If saving, need to collate the curve data into a format that can be easily saved to the h5 file (a dataset per channel per segment direction) - self.collated_curve_data, self.indicies = self.get_collated_curves() - self.collated_metadata = self.get_collated_metadata() - - # Save as h5 if required - if self.save_as_h5: - # Save the curve data to the appropriate datasets in the h5 file - for chan_name, chan_data in self.collated_curve_data.items(): - for direction in range(2): - # Save the curve data and indicies to the appropriate dataset in the h5 file - seg_name = f"Segment_{direction}" - self.curve_groups["Data"][seg_name].create_dataset( - name=chan_name, - data=chan_data[seg_name], - dtype=np.float32, - ) - self.curve_groups["Indicies"][seg_name].create_dataset( - name=chan_name, - data=self.indicies[chan_name][seg_name], - dtype=np.int32, - ) - - # Save the global metadata to the h5 file - vlen_str_dt = h5py.string_dtype(encoding='utf-8') - for key, value in self.collated_metadata.items(): - if isinstance(value, list): - # If the key is a changing key, save as a dataset with one entry per curve/ segment - self.curves_meta_group.create_dataset( - name=key, - data=value, - dtype=vlen_str_dt - ) - else: - self.global_meta_group.attrs[key] = str(value).encode('utf-8') - - self.full_metadata = LazyCurveMetadata(self.filepath, self.top_level_meta, self.qi_archive, self.shape_x, self.shape_y, flip_image=self.flip_image) - self.all_curve_data = LazyCurveData(self.filepath, self.shape_x, self.shape_y, self.channel_scaling, self.qi_archive, flip_image=self.flip_image) + self.full_metadata = LazyCurveMetadata( + self.filepath, self.top_level_meta, self.qi_archive, self.shape_x, self.shape_y, flip_image=self.flip_image + ) + self.all_curve_data = LazyCurveData( + self.filepath, self.shape_x, self.shape_y, self.channel_scaling, self.qi_archive, flip_image=self.flip_image + ) # Load the image self.image, _ = self.get_image() @@ -413,6 +374,91 @@ def load(self, return self.image, self.px2nm + def load_all_data(self, include_metadata: bool = True): + # Compile Regexes + dat_regex = re.compile(r"index/(\d+)/segments/(\d+)/channels/([^/]+)\.dat") + if include_metadata: + curve_meta_regex = re.compile(r"index/(\d+)/header\.properties") + segment_meta_regex = re.compile(r"index/(\d+)/segments/(\d+)/segment-header\.properties") + for file_info in self.qi_archive.infolist(): + filename = file_info.filename + + # Check Binary Data + dat_match = dat_regex.match(filename) + if dat_match: + # If file is a .dat file, extract the curve number, segment direction and channel name from the filename + curve_num, direction, chan_name = int(dat_match.group(1)), int(dat_match.group(2)), dat_match.group(3) + # Then load the data from the file + self.extract_dat_file(file_info, curve_num, direction, chan_name) + continue + + if include_metadata: + # Check Segment Metadata + segment_meta_match = segment_meta_regex.match(filename) + if segment_meta_match: + # If file is a segment metadata file, extract the curve number and segment direction from the filename + curve_num, direction = int(segment_meta_match.group(1)), int(segment_meta_match.group(2)) + # Then load the segment metadata from the file + self.extract_segment_metadata(file_info, curve_num, direction) + continue + + # Check Curve Metadata + curve_meta_match = curve_meta_regex.match(filename) + if curve_meta_match: + # If file is a curve metadata file, extract the curve number from the filename + curve_num = int(curve_meta_match.group(1)) + # Then load the metadata from the file + self.extract_curve_metadata(file_info, curve_num) + continue + + # If saving, need to collate the curve data into a format that can be easily saved to the h5 file (a dataset per channel per segment direction) + self.collated_curve_data, self.indicies = self.get_collated_curves() + + # TODO can we remove curve_meta_dict and just use curve_meta or is the non duplicating necessary + if include_metadata: + self.curve_meta = [self.curve_meta_dict.get(i, {}) for i in range(self.num_of_curves)] + self.segment_meta = [self.segment_meta_dict.get(i, {}) for i in range(self.num_of_curves * 2)] + self.full_metadata = self.construct_full_metadata() + self.collated_metadata = self.get_collated_metadata() + + def save_to_h5( + self, + include_metadata: bool = True, + collated_curve_data: dict | None = None, + indicies: dict | None = None, + collated_metadata: dict | None = None, + ): + """Saves the data as an H5 file. If include_metadata is False, only the curve data will be saved.""" + with self.get_saving_context() as file: + + curve_groups, global_meta_group, curves_meta_group = self.setup_h5_structure(file) + + # Save the curve data to the appropriate datasets in the h5 file + for chan_name, chan_data in collated_curve_data.items(): + for direction in range(2): + # Save the curve data and indicies to the appropriate dataset in the h5 file + seg_name = f"Segment_{direction}" + curve_groups["Data"][seg_name].create_dataset( + name=chan_name, + data=chan_data[seg_name], + dtype=np.float32, + ) + curve_groups["Indicies"][seg_name].create_dataset( + name=chan_name, + data=indicies[chan_name][seg_name], + dtype=np.int32, + ) + + if include_metadata: + # Save the global metadata to the h5 file + vlen_str_dt = h5py.string_dtype(encoding="utf-8") + for key, value in collated_metadata.items(): + if isinstance(value, list): + # If the key is a changing key, save as a dataset with one entry per curve/ segment + curves_meta_group.create_dataset(name=key, data=value, dtype=vlen_str_dt) + else: + global_meta_group.attrs[key] = str(value).encode("utf-8") + def get_collated_curves(self): """ Collates the curve data from the flat structure it is extracted in into a structure grouped by channel and segment for easier saving to h5. @@ -472,7 +518,6 @@ def get_collated_curves(self): return collated_curve_data, indicies - def get_collated_metadata(self): """ Collates the metadata from being split by curve, to being split by attribute so data can be saved more efficiently @@ -484,7 +529,7 @@ def get_collated_metadata(self): """ collated_meta = {} for seg_chan in self.segment_channels: - collated_meta[f"channel.unit.{seg_chan['name']}"] = seg_chan['unit'] + collated_meta[f"channel.unit.{seg_chan['name']}"] = seg_chan["unit"] for key, value in self.top_level_meta.items(): collated_meta[key] = value for curve_dict in self.full_metadata["curves"]: @@ -499,7 +544,6 @@ def get_collated_metadata(self): collated_meta[f"segment.{key}"].append(value) return collated_meta - def get_image(self, overide_channel: str | None = None): """ Processes the flat curve data dictionary into a 2D list structure matching the image dimensions. @@ -526,7 +570,9 @@ def get_image(self, overide_channel: str | None = None): virtual_file = io.BytesIO(tif_bytes) logger.info(f"Looking for channel {channel} in {path_to_image}") - return jpk._load_jpk(virtual_file, path_to_image, channel=channel, file_suffix=".jpk-qi-data", config_path=self.config_path) + return jpk._load_jpk( + virtual_file, path_to_image, channel=channel, file_suffix=".jpk-qi-data", config_path=self.config_path + ) def save_lite_data(self): """Saves a lite form of the data (e.g., the calculated image data) to the appropriate format based on the save_as_h5 attribute.""" @@ -551,7 +597,9 @@ def save_lite_data(self): break # Add the channels which exist in the jpk-qi-image file with self.qi_archive.open(path_to_image, "r") as image_file: - h5_channels += jpk._get_jpk_channels(file=image_file, filename=self.filepath.stem, file_path=self.filepath / Path(path_to_image)) + h5_channels += jpk._get_jpk_channels( + file=image_file, filename=self.filepath.stem, file_path=self.filepath / Path(path_to_image) + ) for i, h5_channel in enumerate(h5_channels): # For each available channel, save the required data to the h5 file # TODO make sure this metadata is accurate for the channels coming from the .jpk-qi-image file @@ -584,24 +632,22 @@ def save_lite_data(self): del chan_grp[dataset_name] chan_grp.create_dataset(dataset_name, data=frame_stack) - def save_metadata(self): """Saves the metadata to the appropriate format based on the save_as_h5 attribute.""" if self.save_as_h5: for seg_chan in self.segment_channels: - self.global_meta_group.attrs[f"channel.unit.{seg_chan['name']}"] = seg_chan['unit'] + self.global_meta_group.attrs[f"channel.unit.{seg_chan['name']}"] = seg_chan["unit"] for key, value in self.top_level_meta.items(): - self.global_meta_group.attrs[key] = str(value).encode('utf-8') + self.global_meta_group.attrs[key] = str(value).encode("utf-8") for i, c_meta in enumerate(self.curve_meta): curve_meta_group = self.curves_meta_group.require_group(f"{i}") for key, value in c_meta.items(): - curve_meta_group.attrs[key] = str(value).encode('utf-8') + curve_meta_group.attrs[key] = str(value).encode("utf-8") for d in range(2): segment_meta_group = curve_meta_group.require_group(f"{d}") - for key, value in self.segment_meta[i*2+d].items(): - segment_meta_group.attrs[key] = str(value).encode('utf-8') - + for key, value in self.segment_meta[i * 2 + d].items(): + segment_meta_group.attrs[key] = str(value).encode("utf-8") def construct_full_metadata(self): """ @@ -619,8 +665,16 @@ def construct_full_metadata(self): """ # Find keys that change across curves/segments - changing_curve_keys = {k for k in self.all_curve_keys if any(self.curve_meta[i].get(k) != self.curve_meta[0].get(k) for i in range(1, self.num_of_curves))} - changing_segment_keys = {k for k in self.all_segment_keys if any(self.segment_meta[i].get(k) != self.segment_meta[0].get(k) for i in range(1, len(self.segment_meta)))} + changing_curve_keys = { + k + for k in self.all_curve_keys + if any(self.curve_meta[i].get(k) != self.curve_meta[0].get(k) for i in range(1, self.num_of_curves)) + } + changing_segment_keys = { + k + for k in self.all_segment_keys + if any(self.segment_meta[i].get(k) != self.segment_meta[0].get(k) for i in range(1, len(self.segment_meta))) + } # Move duplicated meta to top level for both segments and curves for key in self.all_curve_keys - changing_curve_keys: @@ -632,16 +686,14 @@ def construct_full_metadata(self): # Strip duplicated keys from individual curve/segment dicts for c_meta in self.curve_meta: - for k in self.all_curve_keys - changing_curve_keys: c_meta.pop(k, None) + for k in self.all_curve_keys - changing_curve_keys: + c_meta.pop(k, None) for s_meta in self.segment_meta: - for k in self.all_segment_keys - changing_segment_keys: s_meta.pop(k, None) + for k in self.all_segment_keys - changing_segment_keys: + s_meta.pop(k, None) # Construct full metadata dict from subdicts - full_metadata = { - "top_level": self.top_level_meta, - "curves": self.curve_meta, - "segments": self.segment_meta - } + full_metadata = {"top_level": self.top_level_meta, "curves": self.curve_meta, "segments": self.segment_meta} return full_metadata def extract_dat_file(self, file_info: zipfile.ZipInfo, curve_num: int, direction: int, chan_name: str): @@ -663,7 +715,7 @@ def extract_dat_file(self, file_info: zipfile.ZipInfo, curve_num: int, direction scale = self.channel_scaling[chan_name] with self.qi_archive.open(file_info) as f: # Read the binary data as big-endian 32-bit integers - raw_array = np.frombuffer(f.read(), dtype='>i4') + raw_array = np.frombuffer(f.read(), dtype=">i4") # Apply the scaling to convert raw values into real world values segment_array = (raw_array * scale["multiplier"]) + scale["offset"] @@ -674,8 +726,9 @@ def extract_dat_file(self, file_info: zipfile.ZipInfo, curve_num: int, direction self.flat_curve_data[curve_num][chan_name][f"Segment_{direction}"] = segment_array else: - logger.warning(f"Channel {chan_name} not found in scaling information. Skipping data for curve {curve_num}, direction {direction}.") - + logger.warning( + f"Channel {chan_name} not found in scaling information. Skipping data for curve {curve_num}, direction {direction}." + ) def extract_curve_metadata(self, file_info: zipfile.ZipInfo, curve_num: int): """ @@ -693,7 +746,6 @@ def extract_curve_metadata(self, file_info: zipfile.ZipInfo, curve_num: int): self.curve_meta_dict[curve_num] = cleaned_meta self.all_curve_keys.update(cleaned_meta.keys()) - def extract_segment_metadata(self, file_info: zipfile.ZipInfo, curve_num: int, direction: int): """ Extracts the segment metadata from its header.properties file in the JPK QI archive and saves it to the internal data structure. @@ -740,10 +792,7 @@ def setup_h5_structure(self, h5file): curves_meta_group = qi_group.require_group("Curve_Metadata") curves_group = qi_group.require_group("Curves") - curve_groups = { - "Data": {}, - "Indicies": {} - } + curve_groups = {"Data": {}, "Indicies": {}} for direction in range(2): # For each segment direction, establish the necessary group structure that will contain each channel dataset @@ -769,7 +818,6 @@ def get_saving_context(self): else: return nullcontext() - def parse_dimension_data(self): """ Parses the dimension data from the provided properties dictionary and calculates the pixel to nanometer scaling factor. @@ -844,7 +892,7 @@ def extract_global_metadata(self): logger.error("Could not find channels for segments") # Create a lookup for channel name to unit to be returned - self.channels_units = {seg_chan['name'] : seg_chan['unit'] for seg_chan in self.segment_channels} + self.channels_units = {seg_chan["name"]: seg_chan["unit"] for seg_chan in self.segment_channels} # Lookup map for binary scaling self.channel_scaling = {chan["name"]: chan for chan in self.segment_channels} @@ -854,19 +902,6 @@ def close(self): - -def load_jpk_qi( - file_path: Path | str, - channel: str, - config_path: Path | str | None = None, - flip_image: bool | None = True, - save_as_h5: bool | None = False -) -> tuple[np.ndarray, float]: - - jpk_loader = jpk_qi_loader(filepath=file_path, channel=channel, config_path=config_path, flip_image=flip_image, save_as="h5" if save_as_h5 else None) - return jpk_loader.load() - - def load_fdcurves_from_h5(file_path: Path | str): file_path = Path(file_path) @@ -889,7 +924,7 @@ def load_fdcurves_from_h5(file_path: Path | str): return all_curve_data -def _make_num_min_characters(num : int, min_chars: int = 3): +def _make_num_min_characters(num: int, min_chars: int = 3): string_num = str(num) if len(string_num) >= min_chars: return string_num @@ -902,6 +937,3 @@ def _max_points_buffer(curves_data, samples=20, points_buffer=1.2): step = len(curves_data) // samples max_points = np.max(len(curves_data[i]["segment"]) for i in range(0, len(curves_data), step)) return max_points * points_buffer - - - From 10c1050b64b0472c55208100c6a40644f3ae8aac Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Thu, 26 Mar 2026 15:43:20 +0000 Subject: [PATCH 20/54] Fixing duplicated converting to nm bug --- AFMReader/jpk.py | 4 ++-- AFMReader/jpk_qi.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/AFMReader/jpk.py b/AFMReader/jpk.py index 62f2792..a06c3f8 100644 --- a/AFMReader/jpk.py +++ b/AFMReader/jpk.py @@ -245,7 +245,7 @@ def load_jpk( image, px2nm = _load_jpk(file=file_path, filename=filename, channel=channel, file_suffix=file_path.suffix, config_path=config_path, flip_image=flip_image) return (image, px2nm) -def _load_jpk(file: Path | BytesIO, filename: str, channel: str, file_suffix: str, config_path: Path | str | None = None, flip_image: bool = True): +def _load_jpk(file: Path | BytesIO, filename: str, channel: str, file_suffix: str, config_path: Path | str | None = None, flip_image: bool = True, convert_to_nm: bool = True) -> tuple[np.ndarray, float]: jpk_tags = _load_jpk_tags(config_path) try: tif = tifffile.TiffFile(file) @@ -275,7 +275,7 @@ def _load_jpk(file: Path | BytesIO, filename: str, channel: str, file_suffix: st if flip_image is True: image = np.flipud(image) - if channel_page.tags[jpk_tags["channel_name"]].value in ("height", "measuredHeight", "amplitude"): + if convert_to_nm and channel_page.tags[jpk_tags["channel_name"]].value in ("height", "measuredHeight", "amplitude"): image = image * 1e9 # Get page for common metadata between scans diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index a0edf96..bc8496c 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -346,6 +346,7 @@ def load( self.flip_image = flip_image if flip_image is not None else self.flip_image self.save_as_h5 = save_as_h5 if save_as_h5 is not None else self.save_as_h5 + logger.info(f"Loading JPK QI data from {self.filepath} with channel {self.channel}") self.extract_global_metadata() self.parse_dimension_data() @@ -459,6 +460,8 @@ def save_to_h5( else: global_meta_group.attrs[key] = str(value).encode("utf-8") + logger.info(f"QI data copied to h5 data {file.name}") + def get_collated_curves(self): """ Collates the curve data from the flat structure it is extracted in into a structure grouped by channel and segment for easier saving to h5. @@ -544,7 +547,7 @@ def get_collated_metadata(self): collated_meta[f"segment.{key}"].append(value) return collated_meta - def get_image(self, overide_channel: str | None = None): + def get_image(self, overide_channel: str | None = None, convert_to_nm: bool = True) -> tuple[np.ndarray, float]: """ Processes the flat curve data dictionary into a 2D list structure matching the image dimensions. @@ -571,7 +574,7 @@ def get_image(self, overide_channel: str | None = None): virtual_file = io.BytesIO(tif_bytes) logger.info(f"Looking for channel {channel} in {path_to_image}") return jpk._load_jpk( - virtual_file, path_to_image, channel=channel, file_suffix=".jpk-qi-data", config_path=self.config_path + virtual_file, path_to_image, channel=channel, file_suffix=".jpk-qi-data", config_path=self.config_path, convert_to_nm=convert_to_nm ) def save_lite_data(self): @@ -624,7 +627,7 @@ def save_lite_data(self): if h5_channel == self.channel: channel_image = self.image else: - channel_image, _ = self.get_image(overide_channel=h5_channel) + channel_image, _ = self.get_image(overide_channel=h5_channel, convert_to_nm=False) frame_stack = channel_image.flatten().reshape(-1, 1) # Update/ replace the channels dataset From b510141c91b3166dd060057dcd9365e10c70a168 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Thu, 26 Mar 2026 15:51:47 +0000 Subject: [PATCH 21/54] Minor changes to fix double scaling on current channel as well --- AFMReader/jpk_qi.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index bc8496c..729c1e7 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -321,7 +321,7 @@ def load( save_as_h5: bool | None = None, ) -> tuple[np.ndarray, float, dict] | tuple[np.ndarray, float]: """ - Loads the .jpk-qi file + Loads the .jpk-qi-data file Parameters ---------- @@ -547,7 +547,7 @@ def get_collated_metadata(self): collated_meta[f"segment.{key}"].append(value) return collated_meta - def get_image(self, overide_channel: str | None = None, convert_to_nm: bool = True) -> tuple[np.ndarray, float]: + def get_image(self, overide_channel: str | None = None, convert_to_nm: bool = True): """ Processes the flat curve data dictionary into a 2D list structure matching the image dimensions. @@ -624,9 +624,7 @@ def save_lite_data(self): # Format name and reshape image (flattened frame stack) dataset_name = h5_channel.split("_")[0].capitalize() # Include all the channels including the calculated channel - if h5_channel == self.channel: - channel_image = self.image - else: + # TODO make this slightly faster by remembering we have load a channel already but difficult cause of scaling channel_image, _ = self.get_image(overide_channel=h5_channel, convert_to_nm=False) frame_stack = channel_image.flatten().reshape(-1, 1) From afc65ccf241e88f51efa5d9c4b8155918fa4280d Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Fri, 27 Mar 2026 20:47:36 +0000 Subject: [PATCH 22/54] Fixing minor error --- AFMReader/jpk_qi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 729c1e7..eda81fb 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -625,7 +625,7 @@ def save_lite_data(self): dataset_name = h5_channel.split("_")[0].capitalize() # Include all the channels including the calculated channel # TODO make this slightly faster by remembering we have load a channel already but difficult cause of scaling - channel_image, _ = self.get_image(overide_channel=h5_channel, convert_to_nm=False) + channel_image, _ = self.get_image(overide_channel=h5_channel, convert_to_nm=False) frame_stack = channel_image.flatten().reshape(-1, 1) # Update/ replace the channels dataset From b1ec5275fe591ca1a631a1f3be6f9db14ab5a81c Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 30 Mar 2026 11:29:28 +0100 Subject: [PATCH 23/54] Added timing for testing and started converting to more memory and time efficient method --- .gitignore | 3 ++ AFMReader/jpk_qi.py | 76 ++++++++++++++++++++++++++++++++++++++------- 2 files changed, 67 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index f5b3d71..8b2bbe7 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,9 @@ __pycache__/ *.py[cod] *$py.class +AFMReader/data/* +AFMReader/notebooks/* + # C extensions *.so diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index eda81fb..1bf85d1 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -2,6 +2,7 @@ from contextlib import nullcontext import io import re +import time import zipfile import numpy as np @@ -262,6 +263,12 @@ def __init__( # Set path to the .jpk-qi-image file within the archive for later use self.path_to_image = None + self.time_loading_data = 0.0 + self.time_loading_metadata = 0.0 + self.time_saving_h5 = 0.0 + self.time_with_regex = 0.0 + self.time_collating_data = 0.0 + # Initialize key attributes that will be returned / accessed frequently # Just the top level metadata extracted from the header files @@ -282,6 +289,8 @@ def __init__( self.all_curve_keys = set() # The keys that exist in the segment metadata across all segments, used to determine which keys to move to the top level metadata self.all_segment_keys = set() + self.curve_meta = {} + self.segment_meta = {} # Define the image shape and size attributes self.size_x, self.size_y, self.shape_x, self.shape_y = None, None, None, None @@ -354,7 +363,9 @@ def load( # Setup H5 Data structures if needed if self.save_as_h5: self.load_all_data() + start_time = time.perf_counter() self.save_to_h5(collated_curve_data=self.collated_curve_data, indicies=self.indicies, collated_metadata=self.collated_metadata) + self.time_saving_h5 += time.perf_counter() - start_time self.full_metadata = LazyCurveMetadata( self.filepath, self.top_level_meta, self.qi_archive, self.shape_x, self.shape_y, flip_image=self.flip_image @@ -368,7 +379,16 @@ def load( # Save a lite form of the images (precalculated) if saving to a file if self.save_as_h5: + start_time = time.perf_counter() self.save_lite_data() + self.time_saving_h5 += time.perf_counter() - start_time + + logger.info(f"Finished loading JPK QI data from {self.filepath} in {self.time_loading_data + self.time_loading_metadata + self.time_collating_data + self.time_saving_h5:.2f} seconds \n" + f"Data loading: {self.time_loading_data:.2f}s\n" + f"Metadata loading: {self.time_loading_metadata:.2f}s\n" + f"Time spent in regex matching: {self.time_with_regex:.2f}s\n" + f"Collating data: {self.time_collating_data:.2f}s\n" + f"Saving H5: {self.time_saving_h5:.2f}s") if self.all_curve_data: return (self.image, self.px2nm, (self.all_curve_data, self.channels_units, self.full_metadata)) @@ -381,16 +401,26 @@ def load_all_data(self, include_metadata: bool = True): if include_metadata: curve_meta_regex = re.compile(r"index/(\d+)/header\.properties") segment_meta_regex = re.compile(r"index/(\d+)/segments/(\d+)/segment-header\.properties") - for file_info in self.qi_archive.infolist(): + archive_infolist = self.qi_archive.infolist() + logger.info(f"Loading all curve data from JPK QI archive with {len(archive_infolist)} files {'' if include_metadata else 'not '}including metadata") + progress_counter = 0 + for file_info in archive_infolist: filename = file_info.filename + if progress_counter % 10000 == 0: + logger.info(f"Progress: {progress_counter}/{len(archive_infolist)} files processed") + progress_counter += 1 # Check Binary Data + start_time = time.perf_counter() dat_match = dat_regex.match(filename) if dat_match: # If file is a .dat file, extract the curve number, segment direction and channel name from the filename curve_num, direction, chan_name = int(dat_match.group(1)), int(dat_match.group(2)), dat_match.group(3) # Then load the data from the file + self.time_with_regex += time.perf_counter() - start_time + start_time = time.perf_counter() self.extract_dat_file(file_info, curve_num, direction, chan_name) + self.time_loading_data += time.perf_counter() - start_time continue if include_metadata: @@ -399,8 +429,11 @@ def load_all_data(self, include_metadata: bool = True): if segment_meta_match: # If file is a segment metadata file, extract the curve number and segment direction from the filename curve_num, direction = int(segment_meta_match.group(1)), int(segment_meta_match.group(2)) + self.time_with_regex += time.perf_counter() - start_time # Then load the segment metadata from the file + start_time = time.perf_counter() self.extract_segment_metadata(file_info, curve_num, direction) + self.time_loading_metadata += time.perf_counter() - start_time continue # Check Curve Metadata @@ -408,10 +441,15 @@ def load_all_data(self, include_metadata: bool = True): if curve_meta_match: # If file is a curve metadata file, extract the curve number from the filename curve_num = int(curve_meta_match.group(1)) + self.time_with_regex += time.perf_counter() - start_time # Then load the metadata from the file + start_time = time.perf_counter() + self.extract_curve_metadata(file_info, curve_num) + self.time_loading_metadata += time.perf_counter() - start_time continue + start_time = time.perf_counter() # If saving, need to collate the curve data into a format that can be easily saved to the h5 file (a dataset per channel per segment direction) self.collated_curve_data, self.indicies = self.get_collated_curves() @@ -422,6 +460,8 @@ def load_all_data(self, include_metadata: bool = True): self.full_metadata = self.construct_full_metadata() self.collated_metadata = self.get_collated_metadata() + self.time_collating_data += time.perf_counter() - start_time + def save_to_h5( self, include_metadata: bool = True, @@ -460,7 +500,7 @@ def save_to_h5( else: global_meta_group.attrs[key] = str(value).encode("utf-8") - logger.info(f"QI data copied to h5 data {file.name}") + logger.info(f"QI data copied to h5 data {file.filename}") def get_collated_curves(self): """ @@ -742,9 +782,20 @@ def extract_curve_metadata(self, file_info: zipfile.ZipInfo, curve_num: int): curve_num : int The curve number associated with the metadata, parsed from the filename. """ + # with self.qi_archive.open(file_info) as f: + # cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} + # self.curve_meta_dict[curve_num] = cleaned_meta + # self.all_curve_keys.update(cleaned_meta.keys()) + with self.qi_archive.open(file_info) as f: cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} - self.curve_meta_dict[curve_num] = cleaned_meta + for key, value in cleaned_meta.items(): + if key not in self.curve_meta: + self.curve_meta[key] = [None for _ in range(len(self.curve_meta))] + self.curve_meta[key].append(value) + for key in self.curve_meta.keys(): + if key not in cleaned_meta: + self.curve_meta[key].append(None) self.all_curve_keys.update(cleaned_meta.keys()) def extract_segment_metadata(self, file_info: zipfile.ZipInfo, curve_num: int, direction: int): @@ -763,8 +814,16 @@ def extract_segment_metadata(self, file_info: zipfile.ZipInfo, curve_num: int, d idx = curve_num * 2 + direction with self.qi_archive.open(file_info) as f: cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} - self.segment_meta_dict[idx] = cleaned_meta - self.all_segment_keys.update(cleaned_meta.keys()) + for key, value in cleaned_meta.items(): + if key not in self.segment_meta: + self.segment_meta[key] = [None for _ in range(len(self.segment_meta))] + self.segment_meta[key].append(value) + for key in self.segment_meta.keys(): + if key not in cleaned_meta: + self.segment_meta[key].append(None) + self.all_se.update(cleaned_meta.keys()) + # self.segment_meta_dict[idx] = cleaned_meta + # self.all_segment_keys.update(cleaned_meta.keys()) def setup_h5_structure(self, h5file): """ @@ -931,10 +990,3 @@ def _make_num_min_characters(num: int, min_chars: int = 3): return string_num string_num = "0" * (min_chars - len(string_num)) + string_num return string_num - - -def _max_points_buffer(curves_data, samples=20, points_buffer=1.2): - - step = len(curves_data) // samples - max_points = np.max(len(curves_data[i]["segment"]) for i in range(0, len(curves_data), step)) - return max_points * points_buffer From 1fd1eedd282c1262d696f3fb0e495a0e5327c8e4 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 30 Mar 2026 16:17:49 +0100 Subject: [PATCH 24/54] Making jpk-qi-data loading stream data into h5 file rather than saving to memory then saving in one go when duplicating data to h5 --- AFMReader/jpk_qi.py | 347 ++++++++++++++++++++------------------------ 1 file changed, 155 insertions(+), 192 deletions(-) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 1bf85d1..12082cc 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -1,18 +1,18 @@ +import os from pathlib import Path from contextlib import nullcontext import io -import re -import time import zipfile import numpy as np import javaproperties import h5py +import psutil +from pympler import asizeof from AFMReader.logging import logger from AFMReader import jpk - class LazyCurveData: """A proxy class that behaves like a 2D list but fetches .dat files on demand.""" @@ -260,14 +260,13 @@ def __init__( self.save_as_h5 = save_as_h5 # Open the ZIP archive once and keep it open for the duration of the loading process to improve performance when accessing multiple files within the archive self.qi_archive = zipfile.ZipFile(self.filepath, "r") + self.namelist = self.qi_archive.namelist() + logger.debug(f"JPK QI archive of size {asizeof.asizeof(self.qi_archive) / 1024 / 1024:.2f} MB, namelist is of size: {asizeof.asizeof(self.namelist) / 1024 / 1024:.2f} MB") # Set path to the .jpk-qi-image file within the archive for later use self.path_to_image = None - self.time_loading_data = 0.0 - self.time_loading_metadata = 0.0 - self.time_saving_h5 = 0.0 - self.time_with_regex = 0.0 - self.time_collating_data = 0.0 + self.DATA_CHUNKSIZE = 2 * 1024 * 1024 # chunk size for H5 datasets + self.INDICIES_CHUNKSIZE = 64 * 1024 # chunk size for indicies datasets # Initialize key attributes that will be returned / accessed frequently @@ -281,18 +280,11 @@ def __init__( self.channels_units = {} # The list of channels for the segments with their scaling information extracted from the shared header self.segment_channels = [] - # Dictionary of the metadata for each curve, indexed by curve number - self.curve_meta_dict = {} - # Dictionary of the metadata for each segment, indexed by segement number (calculated as curve_num * 2 + direction) - self.segment_meta_dict = {} - # The keys that exist in the curve metadata across all curves, used to determine which keys to move to the top level metadata - self.all_curve_keys = set() - # The keys that exist in the segment metadata across all segments, used to determine which keys to move to the top level metadata - self.all_segment_keys = set() self.curve_meta = {} self.segment_meta = {} # Define the image shape and size attributes self.size_x, self.size_y, self.shape_x, self.shape_y = None, None, None, None + self.failed_curves = set() # Instantiate containers for data to be saved (so an exception is not caused if not saving) self.curve_groups = None @@ -311,7 +303,7 @@ def get_available_channels(self): # Look for the jpk-qi-image file in the archive if self.path_to_image is None: - for file_name in self.qi_archive.namelist(): + for file_name in self.namelist: if file_name.endswith(".jpk-qi-image"): self.path_to_image = file_name @@ -362,10 +354,7 @@ def load( # Setup H5 Data structures if needed if self.save_as_h5: - self.load_all_data() - start_time = time.perf_counter() - self.save_to_h5(collated_curve_data=self.collated_curve_data, indicies=self.indicies, collated_metadata=self.collated_metadata) - self.time_saving_h5 += time.perf_counter() - start_time + self.save_to_h5() self.full_metadata = LazyCurveMetadata( self.filepath, self.top_level_meta, self.qi_archive, self.shape_x, self.shape_y, flip_image=self.flip_image @@ -379,126 +368,82 @@ def load( # Save a lite form of the images (precalculated) if saving to a file if self.save_as_h5: - start_time = time.perf_counter() self.save_lite_data() - self.time_saving_h5 += time.perf_counter() - start_time - - logger.info(f"Finished loading JPK QI data from {self.filepath} in {self.time_loading_data + self.time_loading_metadata + self.time_collating_data + self.time_saving_h5:.2f} seconds \n" - f"Data loading: {self.time_loading_data:.2f}s\n" - f"Metadata loading: {self.time_loading_metadata:.2f}s\n" - f"Time spent in regex matching: {self.time_with_regex:.2f}s\n" - f"Collating data: {self.time_collating_data:.2f}s\n" - f"Saving H5: {self.time_saving_h5:.2f}s") if self.all_curve_data: return (self.image, self.px2nm, (self.all_curve_data, self.channels_units, self.full_metadata)) + return self.image, self.px2nm - def load_all_data(self, include_metadata: bool = True): - # Compile Regexes - dat_regex = re.compile(r"index/(\d+)/segments/(\d+)/channels/([^/]+)\.dat") - if include_metadata: - curve_meta_regex = re.compile(r"index/(\d+)/header\.properties") - segment_meta_regex = re.compile(r"index/(\d+)/segments/(\d+)/segment-header\.properties") - archive_infolist = self.qi_archive.infolist() - logger.info(f"Loading all curve data from JPK QI archive with {len(archive_infolist)} files {'' if include_metadata else 'not '}including metadata") - progress_counter = 0 - for file_info in archive_infolist: - filename = file_info.filename - if progress_counter % 10000 == 0: - logger.info(f"Progress: {progress_counter}/{len(archive_infolist)} files processed") - progress_counter += 1 + def output_summary(self): + if self.failed_curves: + logger.warning(f"Failed to load {len(self.failed_curves)} files.") + logger.warning("Summary of missing files (up to 10 shown):") + for i, (curve_num, direction, chan_name) in enumerate(self.failed_curves): + if i < 10: # Limit the number of warnings to avoid spamming the logs + if chan_name: + logger.warning(f"Failed to load data file for curve {curve_num}, direction {direction}, channel {chan_name}") + else: + if direction is not None: + logger.warning(f"Failed to load segment meta file for curve {curve_num}, direction {direction}") + else: + logger.warning(f"Failed to load curve meta file for curve {curve_num}") + - # Check Binary Data - start_time = time.perf_counter() - dat_match = dat_regex.match(filename) - if dat_match: - # If file is a .dat file, extract the curve number, segment direction and channel name from the filename - curve_num, direction, chan_name = int(dat_match.group(1)), int(dat_match.group(2)), dat_match.group(3) - # Then load the data from the file - self.time_with_regex += time.perf_counter() - start_time - start_time = time.perf_counter() - self.extract_dat_file(file_info, curve_num, direction, chan_name) - self.time_loading_data += time.perf_counter() - start_time - continue + else: + logger.info("Successfully loaded all curve data without any missing files.") + def load_all_data(self, h5_datasets, include_metadata: bool = True): + logger.info(f"Loading all curve data from JPK QI archive with {len(self.namelist)} files {'' if include_metadata else 'not '}including metadata") + progress_counter = 0 + process = psutil.Process(os.getpid()) + for curve_num in range(self.num_of_curves): + # Output progress every 1000 curves to give some indication of how long the loading is taking + if progress_counter % 1000 == 0: + mem = process.memory_info().rss / 1024 / 1024 + logger.info(f"Progress: {progress_counter}/{self.num_of_curves} curves processed, Memory usage: {mem:.2f} MB") + progress_counter += 1 + for direction in range(2): + for chan in self.segment_channels: + # Save the actual curve data to the h5 datasets + self.extract_dat_file(h5_datasets=h5_datasets, curve_num=curve_num, direction=direction, chan_name=chan['name']) + if include_metadata: + # Extract and store the segment metadata for later saving + self.extract_segment_metadata(curve_num=curve_num, direction=direction) if include_metadata: - # Check Segment Metadata - segment_meta_match = segment_meta_regex.match(filename) - if segment_meta_match: - # If file is a segment metadata file, extract the curve number and segment direction from the filename - curve_num, direction = int(segment_meta_match.group(1)), int(segment_meta_match.group(2)) - self.time_with_regex += time.perf_counter() - start_time - # Then load the segment metadata from the file - start_time = time.perf_counter() - self.extract_segment_metadata(file_info, curve_num, direction) - self.time_loading_metadata += time.perf_counter() - start_time - continue - - # Check Curve Metadata - curve_meta_match = curve_meta_regex.match(filename) - if curve_meta_match: - # If file is a curve metadata file, extract the curve number from the filename - curve_num = int(curve_meta_match.group(1)) - self.time_with_regex += time.perf_counter() - start_time - # Then load the metadata from the file - start_time = time.perf_counter() - - self.extract_curve_metadata(file_info, curve_num) - self.time_loading_metadata += time.perf_counter() - start_time - continue - - start_time = time.perf_counter() - # If saving, need to collate the curve data into a format that can be easily saved to the h5 file (a dataset per channel per segment direction) - self.collated_curve_data, self.indicies = self.get_collated_curves() - - # TODO can we remove curve_meta_dict and just use curve_meta or is the non duplicating necessary - if include_metadata: - self.curve_meta = [self.curve_meta_dict.get(i, {}) for i in range(self.num_of_curves)] - self.segment_meta = [self.segment_meta_dict.get(i, {}) for i in range(self.num_of_curves * 2)] - self.full_metadata = self.construct_full_metadata() - self.collated_metadata = self.get_collated_metadata() - - self.time_collating_data += time.perf_counter() - start_time + # Extract and store the curve metadata for later saving + self.extract_curve_metadata(curve_num=curve_num) + self.output_summary() def save_to_h5( self, include_metadata: bool = True, - collated_curve_data: dict | None = None, - indicies: dict | None = None, - collated_metadata: dict | None = None, ): """Saves the data as an H5 file. If include_metadata is False, only the curve data will be saved.""" with self.get_saving_context() as file: - curve_groups, global_meta_group, curves_meta_group = self.setup_h5_structure(file) - - # Save the curve data to the appropriate datasets in the h5 file - for chan_name, chan_data in collated_curve_data.items(): - for direction in range(2): - # Save the curve data and indicies to the appropriate dataset in the h5 file - seg_name = f"Segment_{direction}" - curve_groups["Data"][seg_name].create_dataset( - name=chan_name, - data=chan_data[seg_name], - dtype=np.float32, - ) - curve_groups["Indicies"][seg_name].create_dataset( - name=chan_name, - data=indicies[chan_name][seg_name], - dtype=np.int32, - ) + global_meta_group, curves_meta_group, h5_datasets = self.setup_h5_structure(file) + + self.load_all_data(h5_datasets, include_metadata=include_metadata) + logger.debug(f"Curve meta size in memory: {asizeof.asizeof(self.curve_meta) / 1024 / 1024:.2f} MB") + logger.debug(f"Segment meta size in memory: {asizeof.asizeof(self.segment_meta) / 1024 / 1024:.2f} MB") if include_metadata: # Save the global metadata to the h5 file vlen_str_dt = h5py.string_dtype(encoding="utf-8") - for key, value in collated_metadata.items(): - if isinstance(value, list): - # If the key is a changing key, save as a dataset with one entry per curve/ segment - curves_meta_group.create_dataset(name=key, data=value, dtype=vlen_str_dt) - else: - global_meta_group.attrs[key] = str(value).encode("utf-8") + for key, value in self.get_collated_metadata().items(): + global_meta_group.attrs[key] = str(value).encode("utf-8") + for prefix, meta_dict in [("curve", self.curve_meta), ("segment", self.segment_meta)]: + for key, values in meta_dict.items(): + cleaned_values = [str(v) if v is not None else "" for v in values] + unique_values = set(cleaned_values) + if len(unique_values) == 1: + # If all values are the same, save as a single attribute + global_meta_group.attrs[f"{prefix}.{key}"] = str(unique_values.pop()).encode("utf-8") + else: + # If the key is a changing key, save as a dataset with one entry per curve/ segment + curves_meta_group.create_dataset(name=f"{prefix}.{key}", data=cleaned_values, dtype=vlen_str_dt) logger.info(f"QI data copied to h5 data {file.filename}") @@ -575,19 +520,9 @@ def get_collated_metadata(self): collated_meta[f"channel.unit.{seg_chan['name']}"] = seg_chan["unit"] for key, value in self.top_level_meta.items(): collated_meta[key] = value - for curve_dict in self.full_metadata["curves"]: - for key, value in curve_dict.items(): - if f"curve.{key}" not in collated_meta: - collated_meta[f"curve.{key}"] = [] - collated_meta[f"curve.{key}"].append(value) - for segment_dict in self.full_metadata["segments"]: - for key, value in segment_dict.items(): - if f"segment.{key}" not in collated_meta: - collated_meta[f"segment.{key}"] = [] - collated_meta[f"segment.{key}"].append(value) return collated_meta - def get_image(self, overide_channel: str | None = None, convert_to_nm: bool = True): + def get_image(self, overide_channel: str | None = None, convert_to_nm: bool = True, flip_image: bool | None = None) -> tuple[np.ndarray, float]: """ Processes the flat curve data dictionary into a 2D list structure matching the image dimensions. @@ -602,8 +537,11 @@ def get_image(self, overide_channel: str | None = None, convert_to_nm: bool = Tr else: channel = self.channel + if flip_image is None: + flip_image = self.flip_image + path_to_image = None - for file_name in self.qi_archive.namelist(): + for file_name in self.namelist: if file_name.endswith(".jpk-qi-image"): path_to_image = file_name if path_to_image is None: @@ -614,7 +552,7 @@ def get_image(self, overide_channel: str | None = None, convert_to_nm: bool = Tr virtual_file = io.BytesIO(tif_bytes) logger.info(f"Looking for channel {channel} in {path_to_image}") return jpk._load_jpk( - virtual_file, path_to_image, channel=channel, file_suffix=".jpk-qi-data", config_path=self.config_path, convert_to_nm=convert_to_nm + virtual_file, path_to_image, channel=channel, file_suffix=".jpk-qi-data", config_path=self.config_path, convert_to_nm=convert_to_nm, flip_image=flip_image ) def save_lite_data(self): @@ -634,7 +572,7 @@ def save_lite_data(self): h5_channels = [self.channel] # Look for the jpk-qi-image file in the archive - for file_name in self.qi_archive.namelist(): + for file_name in self.namelist: if file_name.endswith(".jpk-qi-image"): path_to_image = file_name break @@ -665,7 +603,7 @@ def save_lite_data(self): dataset_name = h5_channel.split("_")[0].capitalize() # Include all the channels including the calculated channel # TODO make this slightly faster by remembering we have load a channel already but difficult cause of scaling - channel_image, _ = self.get_image(overide_channel=h5_channel, convert_to_nm=False) + channel_image, _ = self.get_image(overide_channel=h5_channel, convert_to_nm=False, flip_image=False) frame_stack = channel_image.flatten().reshape(-1, 1) # Update/ replace the channels dataset @@ -737,14 +675,14 @@ def construct_full_metadata(self): full_metadata = {"top_level": self.top_level_meta, "curves": self.curve_meta, "segments": self.segment_meta} return full_metadata - def extract_dat_file(self, file_info: zipfile.ZipInfo, curve_num: int, direction: int, chan_name: str): + def extract_dat_file(self, h5_datasets, curve_num: int, direction: int, chan_name: str): """ Extracts the data from a .dat file in the JPK QI archive, applies the appropriate scaling, and saves it to the internal data structure and h5 dataset if required. Parameters ---------- - file_info : zipfile.ZipInfo - The ZipInfo object corresponding to the .dat file to be extracted. + h5_datasets : dict + A dictionary containing the h5 datasets for each channel and segment direction, used for saving the data curve_num : int The curve number associated with the .dat file, parsed from the filename. direction : int @@ -754,24 +692,45 @@ def extract_dat_file(self, file_info: zipfile.ZipInfo, curve_num: int, direction """ if chan_name in self.channel_scaling: scale = self.channel_scaling[chan_name] - with self.qi_archive.open(file_info) as f: - # Read the binary data as big-endian 32-bit integers - raw_array = np.frombuffer(f.read(), dtype=">i4") - # Apply the scaling to convert raw values into real world values - segment_array = (raw_array * scale["multiplier"]) + scale["offset"] - - # If the channel doesn't exist in the curve data for this curve, add it as a new entry - if chan_name not in self.flat_curve_data[curve_num]: - self.flat_curve_data[curve_num][chan_name] = {} - # Add the segment data to the curve data under the appropriate channel and segment direction - self.flat_curve_data[curve_num][chan_name][f"Segment_{direction}"] = segment_array - + dat_path = f"index/{curve_num}/segments/{direction}/channels/{chan_name}.dat" + try: + with self.qi_archive.open(dat_path) as f: + # Read the binary data as big-endian 32-bit integers + raw_array = np.frombuffer(f.read(), dtype=">i4") + # Apply the scaling to convert raw values into real world values + segment_array = (raw_array * scale["multiplier"]) + scale["offset"] + + # Fetch and resize the existing dataset for this channel and segment to fit the new data + current_dataset = h5_datasets[f"Segment_{direction}"][chan_name]["Data"] + current_size = current_dataset.shape[0] + current_dataset.resize((current_size + segment_array.shape[0],)) + + # Append the new data to the end of the existing dataset + current_dataset[current_size:] = segment_array + + # Fetch and resize the indicies dataset for this channel and segment to fit the new index + indicies_set = h5_datasets[f"Segment_{direction}"][chan_name]["Indicies"] + indicies_set.resize((indicies_set.shape[0] + 1,)) + + # Append the new index to the end of the existing indicies dataset + indicies_set[-1] = current_size + + except KeyError: + self.failed_curves.add((curve_num, direction, chan_name)) + if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs + logger.warning( + f"Data file {dat_path} not found in archive. Skipping data for curve {curve_num}, direction {direction}, channel {chan_name}." + ) + elif len(self.failed_curves) == 10: + logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") else: - logger.warning( - f"Channel {chan_name} not found in scaling information. Skipping data for curve {curve_num}, direction {direction}." - ) + self.failed_curves.add((curve_num, direction, chan_name)) + if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs + logger.warning( + f"Channel {chan_name} not found in scaling information. Skipping data for curve {curve_num}, direction {direction}." + ) - def extract_curve_metadata(self, file_info: zipfile.ZipInfo, curve_num: int): + def extract_curve_metadata(self, curve_num: int): """ Extracts the curve metadata from its header.properties file in the JPK QI archive and saves it to the internal data structure. @@ -782,49 +741,49 @@ def extract_curve_metadata(self, file_info: zipfile.ZipInfo, curve_num: int): curve_num : int The curve number associated with the metadata, parsed from the filename. """ - # with self.qi_archive.open(file_info) as f: - # cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} - # self.curve_meta_dict[curve_num] = cleaned_meta - # self.all_curve_keys.update(cleaned_meta.keys()) - - with self.qi_archive.open(file_info) as f: - cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} - for key, value in cleaned_meta.items(): - if key not in self.curve_meta: - self.curve_meta[key] = [None for _ in range(len(self.curve_meta))] - self.curve_meta[key].append(value) - for key in self.curve_meta.keys(): - if key not in cleaned_meta: - self.curve_meta[key].append(None) - self.all_curve_keys.update(cleaned_meta.keys()) - - def extract_segment_metadata(self, file_info: zipfile.ZipInfo, curve_num: int, direction: int): + + meta_path = f"index/{curve_num}/header.properties" + try: + with self.qi_archive.open(meta_path) as f: + cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} + for key, value in cleaned_meta.items(): + if key not in self.curve_meta: + # Initialize with None for all curves when a key is discovered + self.curve_meta[key] = [None] * self.num_of_curves + self.curve_meta[key][curve_num] = value + except KeyError: + self.failed_curves.add((curve_num, None, None)) + if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs + logger.warning(f"Metadata file {meta_path} not found in archive. Skipping metadata for curve {curve_num}.") + elif len(self.failed_curves) == 10: + logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") + + def extract_segment_metadata(self, curve_num: int, direction: int): """ Extracts the segment metadata from its header.properties file in the JPK QI archive and saves it to the internal data structure. Parameters ---------- - file_info : zipfile.ZipInfo - The ZipInfo object corresponding to the header.properties file to be extracted. curve_num : int The curve number associated with the metadata, parsed from the filename. direction : int The segment direction (0 or 1) associated with the metadata, parsed from the filename. """ - idx = curve_num * 2 + direction - with self.qi_archive.open(file_info) as f: - cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} - for key, value in cleaned_meta.items(): - if key not in self.segment_meta: - self.segment_meta[key] = [None for _ in range(len(self.segment_meta))] - self.segment_meta[key].append(value) - for key in self.segment_meta.keys(): - if key not in cleaned_meta: - self.segment_meta[key].append(None) - self.all_se.update(cleaned_meta.keys()) - # self.segment_meta_dict[idx] = cleaned_meta - # self.all_segment_keys.update(cleaned_meta.keys()) - + meta_path = f"index/{curve_num}/segments/{direction}/segment-header.properties" + try: + with self.qi_archive.open(meta_path) as f: + cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} + for key, value in cleaned_meta.items(): + if key not in self.segment_meta: + # Initialize with None for all segments when a key is discovered + self.segment_meta[key] = [None] * self.num_of_curves * 2 + self.segment_meta[key][curve_num * 2 + direction] = value + except KeyError: + self.failed_curves.add((curve_num, direction, None)) + if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs + logger.warning(f"Metadata file {meta_path} not found in archive. Skipping metadata for curve {curve_num}, direction {direction}.") + elif len(self.failed_curves) == 10: + logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") def setup_h5_structure(self, h5file): """ Sets up the structure in the h5 file for saving the curve data and metadata, and returns the datasets and metadata groups for later use. @@ -842,7 +801,6 @@ def setup_h5_structure(self, h5file): curves_meta_group : h5py.Group The h5 group for storing curve metadata. """ - vlen_type = h5py.vlen_dtype(np.float32) # Create the main group for the QI curve data that all the curve data will be in qi_group = h5file.require_group("QI_Curve_Data") @@ -853,15 +811,22 @@ def setup_h5_structure(self, h5file): curves_group = qi_group.require_group("Curves") curve_groups = {"Data": {}, "Indicies": {}} + h5_datasets = {} for direction in range(2): # For each segment direction, establish the necessary group structure that will contain each channel dataset seg_name = f"Segment_{direction}" dir_group = curves_group.require_group(seg_name) + h5_datasets[seg_name] = {} # Create the Data and Indicies subfolders and store their references curve_groups["Data"][seg_name] = dir_group.require_group("Data") curve_groups["Indicies"][seg_name] = dir_group.require_group("Indicies") - return curve_groups, global_meta_group, curves_meta_group + for chan in self.segment_channels: + h5_datasets[seg_name][chan['name']] = {} + # For each channel, create an empty dataset for the curve data and indicies with the appropriate name and data type + h5_datasets[seg_name][chan['name']]['Data'] = curve_groups["Data"][seg_name].create_dataset(name=chan["name"], shape=(0,), maxshape=(None,), chunks=(self.DATA_CHUNKSIZE,), dtype=np.float32) + h5_datasets[seg_name][chan['name']]['Indicies'] = curve_groups["Indicies"][seg_name].create_dataset(name=chan["name"], shape=(0,), maxshape=(None,), chunks=(self.INDICIES_CHUNKSIZE,), dtype=np.int32) + return global_meta_group, curves_meta_group, h5_datasets def get_saving_context(self): """ @@ -905,14 +870,12 @@ def parse_dimension_data(self): # Establish number of curves self.num_of_curves = self.shape_x * self.shape_y - # Setup the data structure to hold the curve data as it's read in - self.flat_curve_data = [{} for _ in range(self.num_of_curves)] def extract_global_metadata(self): """Extracts the global metadata from the JPK QI archive and populates the top level metadata dictionary and segment channels list.""" # Load the metadata from the global properties file - if "header.properties" in self.qi_archive.namelist(): + if "header.properties" in self.namelist: with self.qi_archive.open("header.properties") as archive_meta_file: props = javaproperties.load(archive_meta_file) @@ -923,7 +886,7 @@ def extract_global_metadata(self): logger.error(f"File {self.filepath} does not contain essential metadata and cannot be loaded") # Load the metadata from the shared header and parse the channel information for the segments - if "shared-data/header.properties" in self.qi_archive.namelist(): + if "shared-data/header.properties" in self.namelist: with self.qi_archive.open("shared-data/header.properties") as shared_data_file: shared_meta = javaproperties.load(shared_data_file) channel_idx = 0 From 4e90188ee2f9e2696e8d51dbc8db36dcf479fc8b Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Wed, 1 Apr 2026 13:49:26 +0100 Subject: [PATCH 25/54] Changing metdata data saving so 'changing keys' are assumed based on sampled curves then are streamed directly into h5 --- AFMReader/jpk_qi.py | 173 ++++++++++++++++++++++++++++++++------------ 1 file changed, 127 insertions(+), 46 deletions(-) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 12082cc..b692d53 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -265,8 +265,14 @@ def __init__( # Set path to the .jpk-qi-image file within the archive for later use self.path_to_image = None - self.DATA_CHUNKSIZE = 2 * 1024 * 1024 # chunk size for H5 datasets - self.INDICIES_CHUNKSIZE = 64 * 1024 # chunk size for indicies datasets + # Chunk size for H5 datasets + self.DATA_CHUNKSIZE = 2 * 1024 * 1024 + # Chunk size for indicies datasets + self.INDICIES_CHUNKSIZE = 64 * 1024 + # Chunk size for metadata datasets (if needed) + self.META_CHUNKSIZE = 64 * 1024 + # Maximum number of curves to check for changing metadata keys (to avoid checking every curve in large datasets) + self.MAX_CURVE_CHECKS = 20 # Initialize key attributes that will be returned / accessed frequently @@ -389,12 +395,10 @@ def output_summary(self): logger.warning(f"Failed to load segment meta file for curve {curve_num}, direction {direction}") else: logger.warning(f"Failed to load curve meta file for curve {curve_num}") - - else: logger.info("Successfully loaded all curve data without any missing files.") - def load_all_data(self, h5_datasets, include_metadata: bool = True): + def load_all_data(self, h5_datasets, h5_meta_datasets, include_metadata: bool = True): logger.info(f"Loading all curve data from JPK QI archive with {len(self.namelist)} files {'' if include_metadata else 'not '}including metadata") progress_counter = 0 process = psutil.Process(os.getpid()) @@ -410,10 +414,19 @@ def load_all_data(self, h5_datasets, include_metadata: bool = True): self.extract_dat_file(h5_datasets=h5_datasets, curve_num=curve_num, direction=direction, chan_name=chan['name']) if include_metadata: # Extract and store the segment metadata for later saving - self.extract_segment_metadata(curve_num=curve_num, direction=direction) + self.extract_segment_metadata(h5_meta_datasets=h5_meta_datasets, curve_num=curve_num, direction=direction) if include_metadata: # Extract and store the curve metadata for later saving - self.extract_curve_metadata(curve_num=curve_num) + self.extract_curve_metadata(h5_meta_datasets=h5_meta_datasets, curve_num=curve_num) + # Add the last index to the indicies datasets to mark the end of the last curve + for direction in range(2): + seg_name = f"Segment_{direction}" + for chan in self.segment_channels: + chan_name = chan['name'] + current_dataset = h5_datasets[seg_name][chan_name]["Data"] + indicies_dataset = h5_datasets[seg_name][chan_name]["Indicies"] + indicies_dataset[-1] = current_dataset.shape[0] + self.output_summary() def save_to_h5( @@ -423,30 +436,80 @@ def save_to_h5( """Saves the data as an H5 file. If include_metadata is False, only the curve data will be saved.""" with self.get_saving_context() as file: - global_meta_group, curves_meta_group, h5_datasets = self.setup_h5_structure(file) + self.changing_curve_keys, self.changing_segment_keys = self.get_changing_keys() + global_meta_group, h5_datasets, h5_meta_datasets = self.setup_h5_structure(file) - self.load_all_data(h5_datasets, include_metadata=include_metadata) + self.load_all_data(h5_datasets, h5_meta_datasets, include_metadata=include_metadata) logger.debug(f"Curve meta size in memory: {asizeof.asizeof(self.curve_meta) / 1024 / 1024:.2f} MB") logger.debug(f"Segment meta size in memory: {asizeof.asizeof(self.segment_meta) / 1024 / 1024:.2f} MB") if include_metadata: # Save the global metadata to the h5 file - vlen_str_dt = h5py.string_dtype(encoding="utf-8") for key, value in self.get_collated_metadata().items(): global_meta_group.attrs[key] = str(value).encode("utf-8") - for prefix, meta_dict in [("curve", self.curve_meta), ("segment", self.segment_meta)]: - for key, values in meta_dict.items(): - cleaned_values = [str(v) if v is not None else "" for v in values] - unique_values = set(cleaned_values) - if len(unique_values) == 1: - # If all values are the same, save as a single attribute - global_meta_group.attrs[f"{prefix}.{key}"] = str(unique_values.pop()).encode("utf-8") - else: - # If the key is a changing key, save as a dataset with one entry per curve/ segment - curves_meta_group.create_dataset(name=f"{prefix}.{key}", data=cleaned_values, dtype=vlen_str_dt) logger.info(f"QI data copied to h5 data {file.filename}") + def get_changing_keys(self): + # Check evenly spaced curves in the dataset to sample metadata without having to load every curve + step = 1 if self.num_of_curves <= self.MAX_CURVE_CHECKS else self.num_of_curves // self.MAX_CURVE_CHECKS + # If the step is equal to a shape dimension, we might just go down the row or column + while step in [self.shape_x, self.shape_y] and step > 1: + # So make the step slightly smaller (more checks) to ensure we get a good sample + step -= 1 + curves_to_check = range(0, self.num_of_curves, step) + curve_meta_dict = {} + segment_meta_dict = {} + for curve_num in curves_to_check: + for direction in range(2): + meta_path = f"index/{curve_num}/segments/{direction}/segment-header.properties" + while True: + try: + with self.qi_archive.open(meta_path) as f: + meta_dict = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} + for k, v in meta_dict.items(): + if k not in segment_meta_dict: + segment_meta_dict[k] = [] + segment_meta_dict[k].append(v) + break + + except KeyError: + if curve_num + 1 >= self.num_of_curves: + break # If we've gone past the number of curves, stop checking + curve_num += 1 + continue + meta_path = f"index/{curve_num}/header.properties" + while True: + try: + with self.qi_archive.open(meta_path) as f: + meta_dict = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} + for k, v in meta_dict.items(): + if k not in curve_meta_dict: + curve_meta_dict[k] = [] + curve_meta_dict[k].append(v) + break + except KeyError: + if curve_num + 1 >= self.num_of_curves: + break # If we've gone past the number of curves, stop checking + curve_num += 1 + continue + + changing_curve_keys, changing_segment_keys = set(), set() + for key, values in curve_meta_dict.items(): + if len(set(v for v in values if v is not None)) > 1: + changing_curve_keys.add(key) + else: + # If the key does not change across curves, move it to the top level metadata + self.top_level_meta[f"curve.{key}"] = values[0] + for key, values in segment_meta_dict.items(): + if len(set(v for v in values if v is not None)) > 1: + changing_segment_keys.add(key) + else: + # If the key does not change across segments, move it to the top level metadata + self.top_level_meta[f"segment.{key}"] = values[0] + return changing_curve_keys, changing_segment_keys + + def get_collated_curves(self): """ Collates the curve data from the flat structure it is extracted in into a structure grouped by channel and segment for easier saving to h5. @@ -693,36 +756,36 @@ def extract_dat_file(self, h5_datasets, curve_num: int, direction: int, chan_nam if chan_name in self.channel_scaling: scale = self.channel_scaling[chan_name] dat_path = f"index/{curve_num}/segments/{direction}/channels/{chan_name}.dat" + data_set = h5_datasets[f"Segment_{direction}"][chan_name]["Data"] + data_size = data_set.shape[0] + indicies_set = h5_datasets[f"Segment_{direction}"][chan_name]["Indicies"] try: with self.qi_archive.open(dat_path) as f: # Read the binary data as big-endian 32-bit integers raw_array = np.frombuffer(f.read(), dtype=">i4") + # Apply the scaling to convert raw values into real world values segment_array = (raw_array * scale["multiplier"]) + scale["offset"] # Fetch and resize the existing dataset for this channel and segment to fit the new data - current_dataset = h5_datasets[f"Segment_{direction}"][chan_name]["Data"] - current_size = current_dataset.shape[0] - current_dataset.resize((current_size + segment_array.shape[0],)) + data_set.resize((data_size + segment_array.shape[0],)) # Append the new data to the end of the existing dataset - current_dataset[current_size:] = segment_array - - # Fetch and resize the indicies dataset for this channel and segment to fit the new index - indicies_set = h5_datasets[f"Segment_{direction}"][chan_name]["Indicies"] - indicies_set.resize((indicies_set.shape[0] + 1,)) - - # Append the new index to the end of the existing indicies dataset - indicies_set[-1] = current_size + data_set[data_size:] = segment_array except KeyError: self.failed_curves.add((curve_num, direction, chan_name)) - if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs + + # Limit the number of warnings to avoid spamming the logs + if len(self.failed_curves) < 10: logger.warning( f"Data file {dat_path} not found in archive. Skipping data for curve {curve_num}, direction {direction}, channel {chan_name}." ) elif len(self.failed_curves) == 10: logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") + + # Append the new index to the end of the existing indicies dataset + indicies_set[curve_num] = data_size else: self.failed_curves.add((curve_num, direction, chan_name)) if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs @@ -730,12 +793,14 @@ def extract_dat_file(self, h5_datasets, curve_num: int, direction: int, chan_nam f"Channel {chan_name} not found in scaling information. Skipping data for curve {curve_num}, direction {direction}." ) - def extract_curve_metadata(self, curve_num: int): + def extract_curve_metadata(self, h5_meta_datasets, curve_num: int): """ Extracts the curve metadata from its header.properties file in the JPK QI archive and saves it to the internal data structure. Parameters ---------- + h5_meta_datasets : dict + A dictionary containing the h5 datasets for each metadata key, used for saving the metadata. file_info : zipfile.ZipInfo The ZipInfo object corresponding to the header.properties file to be extracted. curve_num : int @@ -743,14 +808,14 @@ def extract_curve_metadata(self, curve_num: int): """ meta_path = f"index/{curve_num}/header.properties" + cleaned_meta = {} try: with self.qi_archive.open(meta_path) as f: cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} - for key, value in cleaned_meta.items(): - if key not in self.curve_meta: - # Initialize with None for all curves when a key is discovered - self.curve_meta[key] = [None] * self.num_of_curves - self.curve_meta[key][curve_num] = value + for key in self.changing_curve_keys: + meta_set = h5_meta_datasets.get(f"curve.{key}") + meta_set.resize((meta_set.shape[0] + 1 if meta_set else curve_num,)) + meta_set[curve_num] = cleaned_meta.get(key, "No data") except KeyError: self.failed_curves.add((curve_num, None, None)) if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs @@ -758,32 +823,43 @@ def extract_curve_metadata(self, curve_num: int): elif len(self.failed_curves) == 10: logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") - def extract_segment_metadata(self, curve_num: int, direction: int): + for key in self.changing_curve_keys: + meta_set = h5_meta_datasets.get(f"curve.{key}") + if meta_set: + meta_set[curve_num] = cleaned_meta.get(key, "No data") + else: + logger.error(f"Metadata dataset for key curve.{key} not found when trying to save metadata for curve {curve_num}") + + def extract_segment_metadata(self, h5_meta_datasets, curve_num: int, direction: int): """ Extracts the segment metadata from its header.properties file in the JPK QI archive and saves it to the internal data structure. Parameters ---------- + h5_meta_datasets : dict + A dictionary containing the h5 datasets for each metadata key, used for saving the metadata. curve_num : int The curve number associated with the metadata, parsed from the filename. direction : int The segment direction (0 or 1) associated with the metadata, parsed from the filename. """ meta_path = f"index/{curve_num}/segments/{direction}/segment-header.properties" + cleaned_meta = {} try: with self.qi_archive.open(meta_path) as f: cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} - for key, value in cleaned_meta.items(): - if key not in self.segment_meta: - # Initialize with None for all segments when a key is discovered - self.segment_meta[key] = [None] * self.num_of_curves * 2 - self.segment_meta[key][curve_num * 2 + direction] = value except KeyError: self.failed_curves.add((curve_num, direction, None)) if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs logger.warning(f"Metadata file {meta_path} not found in archive. Skipping metadata for curve {curve_num}, direction {direction}.") elif len(self.failed_curves) == 10: logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") + + for key in self.changing_segment_keys: + meta_set = h5_meta_datasets.get(f"segment.{key}") + if meta_set: + meta_set[curve_num * 2 + direction] = cleaned_meta.get(key, "No data") + def setup_h5_structure(self, h5file): """ Sets up the structure in the h5 file for saving the curve data and metadata, and returns the datasets and metadata groups for later use. @@ -812,6 +888,11 @@ def setup_h5_structure(self, h5file): curve_groups = {"Data": {}, "Indicies": {}} h5_datasets = {} + h5_meta_datasets = {} + for key in self.changing_curve_keys: + h5_meta_datasets[f"curve.{key}"] = curves_meta_group.create_dataset(name=f"{key}", shape=(self.num_of_curves,), maxshape=(None,), chunks=self.META_CHUNKSIZE, dtype=h5py.string_dtype(encoding="utf-8")) + for key in self.changing_segment_keys: + h5_meta_datasets[f"segment.{key}"] = curves_meta_group.create_dataset(name=f"{key}", shape=(self.num_of_curves * 2,), maxshape=(None,), chunks=self.META_CHUNKSIZE, dtype=h5py.string_dtype(encoding="utf-8")) for direction in range(2): # For each segment direction, establish the necessary group structure that will contain each channel dataset @@ -825,8 +906,8 @@ def setup_h5_structure(self, h5file): h5_datasets[seg_name][chan['name']] = {} # For each channel, create an empty dataset for the curve data and indicies with the appropriate name and data type h5_datasets[seg_name][chan['name']]['Data'] = curve_groups["Data"][seg_name].create_dataset(name=chan["name"], shape=(0,), maxshape=(None,), chunks=(self.DATA_CHUNKSIZE,), dtype=np.float32) - h5_datasets[seg_name][chan['name']]['Indicies'] = curve_groups["Indicies"][seg_name].create_dataset(name=chan["name"], shape=(0,), maxshape=(None,), chunks=(self.INDICIES_CHUNKSIZE,), dtype=np.int32) - return global_meta_group, curves_meta_group, h5_datasets + h5_datasets[seg_name][chan['name']]['Indicies'] = curve_groups["Indicies"][seg_name].create_dataset(name=chan["name"], shape=(self.num_of_curves, ), maxshape=(None,), chunks=(self.INDICIES_CHUNKSIZE,), dtype=np.int32) + return global_meta_group, h5_datasets, h5_meta_datasets def get_saving_context(self): """ From 66f388a96fae22551c648daf92c900e62a5fbe1e Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Wed, 1 Apr 2026 15:47:14 +0100 Subject: [PATCH 26/54] Pre-sizing the curve data to make loading faster (using a best guess) --- AFMReader/jpk_qi.py | 83 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 72 insertions(+), 11 deletions(-) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index b692d53..ffa31f0 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -3,6 +3,7 @@ from contextlib import nullcontext import io import zipfile +import time import numpy as np import javaproperties @@ -258,8 +259,10 @@ def __init__( self.config_path = config_path self.flip_image = flip_image self.save_as_h5 = save_as_h5 + logger.debug(f"Initialized JPK QI loader for file {self.filepath} with channel {self.channel}") # Open the ZIP archive once and keep it open for the duration of the loading process to improve performance when accessing multiple files within the archive self.qi_archive = zipfile.ZipFile(self.filepath, "r") + logger.debug(f"Opened JPK QI archive at {self.filepath}") self.namelist = self.qi_archive.namelist() logger.debug(f"JPK QI archive of size {asizeof.asizeof(self.qi_archive) / 1024 / 1024:.2f} MB, namelist is of size: {asizeof.asizeof(self.namelist) / 1024 / 1024:.2f} MB") # Set path to the .jpk-qi-image file within the archive for later use @@ -292,6 +295,13 @@ def __init__( self.size_x, self.size_y, self.shape_x, self.shape_y = None, None, None, None self.failed_curves = set() + # Timing counters for performance monitoring + self.t_load_data = 0.0 + self.t_proc_data = 0.0 + self.t_load_meta = 0.0 + self.t_proc_meta = 0.0 + self.t_changing_keys = 0.0 + # Instantiate containers for data to be saved (so an exception is not caused if not saving) self.curve_groups = None @@ -436,12 +446,30 @@ def save_to_h5( """Saves the data as an H5 file. If include_metadata is False, only the curve data will be saved.""" with self.get_saving_context() as file: + t0 = time.perf_counter() self.changing_curve_keys, self.changing_segment_keys = self.get_changing_keys() + self.points_for_channel_segment = self.predict_total_points() + self.t_changing_keys = time.perf_counter() - t0 + global_meta_group, h5_datasets, h5_meta_datasets = self.setup_h5_structure(file) + for direction in range(2): + for chan in self.segment_channels: + self.points_for_channel_segment[direction][chan['name']] = 0 self.load_all_data(h5_datasets, h5_meta_datasets, include_metadata=include_metadata) + for direction in range(2): + for chan in self.segment_channels: + h5_datasets[f"Segment_{direction}"][chan['name']]["Data"].resize((self.points_for_channel_segment[direction][chan['name']],)) logger.debug(f"Curve meta size in memory: {asizeof.asizeof(self.curve_meta) / 1024 / 1024:.2f} MB") logger.debug(f"Segment meta size in memory: {asizeof.asizeof(self.segment_meta) / 1024 / 1024:.2f} MB") + summary = ( + f"\n--- Performance Summary ---\n" + f"Changing Keys Detection: {self.t_changing_keys:.2f}s\n" + f"Raw Data - Loading: {self.t_load_data:.2f}s | Processing: {self.t_proc_data:.2f}s\n" + f"Metadata - Loading: {self.t_load_meta:.2f}s | Processing: {self.t_proc_meta:.2f}s\n" + f"---------------------------" + ) + logger.info(summary) if include_metadata: # Save the global metadata to the h5 file @@ -450,20 +478,47 @@ def save_to_h5( logger.info(f"QI data copied to h5 data {file.filename}") - def get_changing_keys(self): + def get_curves_sample(self): # Check evenly spaced curves in the dataset to sample metadata without having to load every curve step = 1 if self.num_of_curves <= self.MAX_CURVE_CHECKS else self.num_of_curves // self.MAX_CURVE_CHECKS # If the step is equal to a shape dimension, we might just go down the row or column while step in [self.shape_x, self.shape_y] and step > 1: # So make the step slightly smaller (more checks) to ensure we get a good sample step -= 1 - curves_to_check = range(0, self.num_of_curves, step) + return range(0, self.num_of_curves, step) + + def predict_total_points(self): + curves_to_check = self.get_curves_sample() + points_for_channel_segment = {} + for direction in range(2): + points_for_channel_segment[direction] = {} + for channel in self.segment_channels: + points_for_channel_segment[direction][channel['name']] = [] + for curve_num in curves_to_check: + while True: + dat_path = f"index/{curve_num}/segments/{direction}/channels/{channel['name']}.dat" + try: + with self.qi_archive.open(dat_path) as f: + raw_array = np.frombuffer(f.read(), dtype=">i4") + points_for_channel_segment[direction][channel['name']].append(len(raw_array)) + break + except KeyError: + if curve_num + 1 >= self.num_of_curves: + break # If we've gone past the number of curves, stop checking + curve_num += 1 + continue + points_for_channel_segment[direction][channel['name']] = int(np.max(points_for_channel_segment[direction][channel['name']]) * 1.1) * self.num_of_curves + return points_for_channel_segment + + def get_changing_keys(self): + curve_meta_dict = {} segment_meta_dict = {} + curves_to_check = self.get_curves_sample() for curve_num in curves_to_check: for direction in range(2): - meta_path = f"index/{curve_num}/segments/{direction}/segment-header.properties" while True: + meta_path = f"index/{curve_num}/segments/{direction}/segment-header.properties" try: with self.qi_archive.open(meta_path) as f: meta_dict = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} @@ -760,18 +815,26 @@ def extract_dat_file(self, h5_datasets, curve_num: int, direction: int, chan_nam data_size = data_set.shape[0] indicies_set = h5_datasets[f"Segment_{direction}"][chan_name]["Indicies"] try: + t0 = time.perf_counter() with self.qi_archive.open(dat_path) as f: # Read the binary data as big-endian 32-bit integers - raw_array = np.frombuffer(f.read(), dtype=">i4") + raw_bytes = f.read() + self.t_load_data += time.perf_counter() - t0 + + t1 = time.perf_counter() + raw_array = np.frombuffer(raw_bytes, dtype=">i4") # Apply the scaling to convert raw values into real world values segment_array = (raw_array * scale["multiplier"]) + scale["offset"] - # Fetch and resize the existing dataset for this channel and segment to fit the new data - data_set.resize((data_size + segment_array.shape[0],)) + self.points_for_channel_segment[direction][chan_name] += len(segment_array) + if self.points_for_channel_segment[direction][chan_name] > data_size: + # Fetch and resize the existing dataset for this channel and segment to fit the new data + data_set.resize((self.points_for_channel_segment[direction][chan_name],)) # Append the new data to the end of the existing dataset data_set[data_size:] = segment_array + self.t_proc_data += time.perf_counter() - t1 except KeyError: self.failed_curves.add((curve_num, direction, chan_name)) @@ -812,10 +875,6 @@ def extract_curve_metadata(self, h5_meta_datasets, curve_num: int): try: with self.qi_archive.open(meta_path) as f: cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} - for key in self.changing_curve_keys: - meta_set = h5_meta_datasets.get(f"curve.{key}") - meta_set.resize((meta_set.shape[0] + 1 if meta_set else curve_num,)) - meta_set[curve_num] = cleaned_meta.get(key, "No data") except KeyError: self.failed_curves.add((curve_num, None, None)) if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs @@ -859,6 +918,8 @@ def extract_segment_metadata(self, h5_meta_datasets, curve_num: int, direction: meta_set = h5_meta_datasets.get(f"segment.{key}") if meta_set: meta_set[curve_num * 2 + direction] = cleaned_meta.get(key, "No data") + else: + logger.error(f"Metadata dataset for key segment.{key} not found when trying to save metadata for curve {curve_num}, direction {direction}") def setup_h5_structure(self, h5file): """ @@ -905,7 +966,7 @@ def setup_h5_structure(self, h5file): for chan in self.segment_channels: h5_datasets[seg_name][chan['name']] = {} # For each channel, create an empty dataset for the curve data and indicies with the appropriate name and data type - h5_datasets[seg_name][chan['name']]['Data'] = curve_groups["Data"][seg_name].create_dataset(name=chan["name"], shape=(0,), maxshape=(None,), chunks=(self.DATA_CHUNKSIZE,), dtype=np.float32) + h5_datasets[seg_name][chan['name']]['Data'] = curve_groups["Data"][seg_name].create_dataset(name=chan["name"], shape=(self.points_for_channel_segment[direction][chan['name']],), maxshape=(None,), chunks=(self.DATA_CHUNKSIZE,), dtype=np.float32) h5_datasets[seg_name][chan['name']]['Indicies'] = curve_groups["Indicies"][seg_name].create_dataset(name=chan["name"], shape=(self.num_of_curves, ), maxshape=(None,), chunks=(self.INDICIES_CHUNKSIZE,), dtype=np.int32) return global_meta_group, h5_datasets, h5_meta_datasets From f286e5774d2494476b29b47ac896ed0714a6b709 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Wed, 1 Apr 2026 17:09:31 +0100 Subject: [PATCH 27/54] Improving performance by removing javaproperties reliance in loop --- AFMReader/jpk_qi.py | 77 +++++++++++++++++++++++++++++++-------------- 1 file changed, 54 insertions(+), 23 deletions(-) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index ffa31f0..e3526be 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -264,7 +264,6 @@ def __init__( self.qi_archive = zipfile.ZipFile(self.filepath, "r") logger.debug(f"Opened JPK QI archive at {self.filepath}") self.namelist = self.qi_archive.namelist() - logger.debug(f"JPK QI archive of size {asizeof.asizeof(self.qi_archive) / 1024 / 1024:.2f} MB, namelist is of size: {asizeof.asizeof(self.namelist) / 1024 / 1024:.2f} MB") # Set path to the .jpk-qi-image file within the archive for later use self.path_to_image = None @@ -298,8 +297,10 @@ def __init__( # Timing counters for performance monitoring self.t_load_data = 0.0 self.t_proc_data = 0.0 + self.t_save_data = 0.0 self.t_load_meta = 0.0 self.t_proc_meta = 0.0 + self.t_save_meta = 0.0 self.t_changing_keys = 0.0 # Instantiate containers for data to be saved (so an exception is not caused if not saving) @@ -412,6 +413,10 @@ def load_all_data(self, h5_datasets, h5_meta_datasets, include_metadata: bool = logger.info(f"Loading all curve data from JPK QI archive with {len(self.namelist)} files {'' if include_metadata else 'not '}including metadata") progress_counter = 0 process = psutil.Process(os.getpid()) + curve_work = [(f"{k}=".encode("utf-8"), h5_meta_datasets[f"curve.{k}"]) + for k in self.changing_curve_keys] + seg_work = [(f"{k}=".encode("utf-8"), h5_meta_datasets[f"segment.{k}"]) + for k in self.changing_segment_keys] for curve_num in range(self.num_of_curves): # Output progress every 1000 curves to give some indication of how long the loading is taking if progress_counter % 1000 == 0: @@ -424,10 +429,10 @@ def load_all_data(self, h5_datasets, h5_meta_datasets, include_metadata: bool = self.extract_dat_file(h5_datasets=h5_datasets, curve_num=curve_num, direction=direction, chan_name=chan['name']) if include_metadata: # Extract and store the segment metadata for later saving - self.extract_segment_metadata(h5_meta_datasets=h5_meta_datasets, curve_num=curve_num, direction=direction) + self.extract_segment_metadata(h5_meta_datasets=h5_meta_datasets, curve_num=curve_num, direction=direction, seg_work=seg_work) if include_metadata: # Extract and store the curve metadata for later saving - self.extract_curve_metadata(h5_meta_datasets=h5_meta_datasets, curve_num=curve_num) + self.extract_curve_metadata(h5_meta_datasets=h5_meta_datasets, curve_num=curve_num, curve_work=curve_work) # Add the last index to the indicies datasets to mark the end of the last curve for direction in range(2): seg_name = f"Segment_{direction}" @@ -465,8 +470,8 @@ def save_to_h5( summary = ( f"\n--- Performance Summary ---\n" f"Changing Keys Detection: {self.t_changing_keys:.2f}s\n" - f"Raw Data - Loading: {self.t_load_data:.2f}s | Processing: {self.t_proc_data:.2f}s\n" - f"Metadata - Loading: {self.t_load_meta:.2f}s | Processing: {self.t_proc_meta:.2f}s\n" + f"Raw Data - Loading: {self.t_load_data:.2f}s | Processing: {self.t_proc_data:.2f}s | Saving: {self.t_save_data:.2f}s\n" + f"Metadata - Loading: {self.t_load_meta:.2f}s | Processing: {self.t_proc_meta:.2f}s | Saving: {self.t_save_meta:.2f}s\n" f"---------------------------" ) logger.info(summary) @@ -814,6 +819,7 @@ def extract_dat_file(self, h5_datasets, curve_num: int, direction: int, chan_nam data_set = h5_datasets[f"Segment_{direction}"][chan_name]["Data"] data_size = data_set.shape[0] indicies_set = h5_datasets[f"Segment_{direction}"][chan_name]["Indicies"] + filled_size = self.points_for_channel_segment[direction][chan_name] try: t0 = time.perf_counter() with self.qi_archive.open(dat_path) as f: @@ -832,9 +838,11 @@ def extract_dat_file(self, h5_datasets, curve_num: int, direction: int, chan_nam # Fetch and resize the existing dataset for this channel and segment to fit the new data data_set.resize((self.points_for_channel_segment[direction][chan_name],)) - # Append the new data to the end of the existing dataset - data_set[data_size:] = segment_array self.t_proc_data += time.perf_counter() - t1 + # Append the new data to the end of the existing dataset + start_time = time.perf_counter() + data_set[filled_size:filled_size + len(segment_array)] = segment_array + self.t_save_data += time.perf_counter() - start_time except KeyError: self.failed_curves.add((curve_num, direction, chan_name)) @@ -848,7 +856,7 @@ def extract_dat_file(self, h5_datasets, curve_num: int, direction: int, chan_nam logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") # Append the new index to the end of the existing indicies dataset - indicies_set[curve_num] = data_size + indicies_set[curve_num] = filled_size else: self.failed_curves.add((curve_num, direction, chan_name)) if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs @@ -856,7 +864,7 @@ def extract_dat_file(self, h5_datasets, curve_num: int, direction: int, chan_nam f"Channel {chan_name} not found in scaling information. Skipping data for curve {curve_num}, direction {direction}." ) - def extract_curve_metadata(self, h5_meta_datasets, curve_num: int): + def extract_curve_metadata(self, curve_num: int, curve_work): """ Extracts the curve metadata from its header.properties file in the JPK QI archive and saves it to the internal data structure. @@ -871,10 +879,12 @@ def extract_curve_metadata(self, h5_meta_datasets, curve_num: int): """ meta_path = f"index/{curve_num}/header.properties" - cleaned_meta = {} + raw_bytes = b"" try: + start_time = time.perf_counter() with self.qi_archive.open(meta_path) as f: - cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} + raw_bytes = f.read() + self.t_load_meta += time.perf_counter() - start_time except KeyError: self.failed_curves.add((curve_num, None, None)) if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs @@ -882,14 +892,26 @@ def extract_curve_metadata(self, h5_meta_datasets, curve_num: int): elif len(self.failed_curves) == 10: logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") - for key in self.changing_curve_keys: - meta_set = h5_meta_datasets.get(f"curve.{key}") + for search_term, meta_set in curve_work: + start_time = time.perf_counter() + start = raw_bytes.find(search_term) + if start != -1: + start += len(search_term) + end = raw_bytes.find(b"\n", start) + value = raw_bytes[start:end].decode("utf-8").strip() + else: + value = "No data" + self.t_proc_meta += time.perf_counter() - start_time if meta_set: - meta_set[curve_num] = cleaned_meta.get(key, "No data") + start_time = time.perf_counter() + meta_set[curve_num] = value + self.t_save_meta += time.perf_counter() - start_time else: - logger.error(f"Metadata dataset for key curve.{key} not found when trying to save metadata for curve {curve_num}") + logger.error(f"Metadata dataset for key {search_term.decode('utf-8')} not found when trying to save metadata for curve {curve_num}") + - def extract_segment_metadata(self, h5_meta_datasets, curve_num: int, direction: int): + + def extract_segment_metadata(self, curve_num: int, direction: int, seg_work): """ Extracts the segment metadata from its header.properties file in the JPK QI archive and saves it to the internal data structure. @@ -903,23 +925,32 @@ def extract_segment_metadata(self, h5_meta_datasets, curve_num: int, direction: The segment direction (0 or 1) associated with the metadata, parsed from the filename. """ meta_path = f"index/{curve_num}/segments/{direction}/segment-header.properties" - cleaned_meta = {} + raw_content = b"" try: + start_time = time.perf_counter() with self.qi_archive.open(meta_path) as f: - cleaned_meta = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} + raw_content = f.read() + self.t_load_meta += time.perf_counter() - start_time except KeyError: self.failed_curves.add((curve_num, direction, None)) if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs logger.warning(f"Metadata file {meta_path} not found in archive. Skipping metadata for curve {curve_num}, direction {direction}.") elif len(self.failed_curves) == 10: logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") - - for key in self.changing_segment_keys: - meta_set = h5_meta_datasets.get(f"segment.{key}") + start_time = time.perf_counter() + for search_term, meta_set in seg_work: + start = raw_content.find(search_term) + if start != -1: + start += len(search_term) + end = raw_content.find(b"\n", start) + value = raw_content[start:end].decode("utf-8").strip() + else: + value = "No data" if meta_set: - meta_set[curve_num * 2 + direction] = cleaned_meta.get(key, "No data") + meta_set[curve_num * 2 + direction] = value else: - logger.error(f"Metadata dataset for key segment.{key} not found when trying to save metadata for curve {curve_num}, direction {direction}") + logger.error(f"Metadata dataset for key {search_term.decode('utf-8')} not found when trying to save metadata for curve {curve_num}, direction {direction}") + self.t_proc_meta += time.perf_counter() - start_time def setup_h5_structure(self, h5file): """ From 5391179e34e502f889e68220bfff04f449d2eea2 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Thu, 2 Apr 2026 11:49:34 +0100 Subject: [PATCH 28/54] Making the saving to h5 save in sections using a buffer --- AFMReader/jpk_qi.py | 116 ++++++++++++++++++++++++++++++++------------ 1 file changed, 84 insertions(+), 32 deletions(-) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index e3526be..cc4e70c 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -275,6 +275,8 @@ def __init__( self.META_CHUNKSIZE = 64 * 1024 # Maximum number of curves to check for changing metadata keys (to avoid checking every curve in large datasets) self.MAX_CURVE_CHECKS = 20 + # Number of curves to hold in buffer + self.BUFFER_SIZE = 500 # Initialize key attributes that will be returned / accessed frequently @@ -409,13 +411,13 @@ def output_summary(self): else: logger.info("Successfully loaded all curve data without any missing files.") - def load_all_data(self, h5_datasets, h5_meta_datasets, include_metadata: bool = True): + def load_all_data(self, h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer, include_metadata: bool = True): logger.info(f"Loading all curve data from JPK QI archive with {len(self.namelist)} files {'' if include_metadata else 'not '}including metadata") progress_counter = 0 process = psutil.Process(os.getpid()) - curve_work = [(f"{k}=".encode("utf-8"), h5_meta_datasets[f"curve.{k}"]) + curve_work = [(f"{k}=".encode("utf-8"), h5_meta_datasets[f"curve.{k}"], h5_meta_datasets_buffer[f"curve.{k}"]) for k in self.changing_curve_keys] - seg_work = [(f"{k}=".encode("utf-8"), h5_meta_datasets[f"segment.{k}"]) + seg_work = [(f"{k}=".encode("utf-8"), h5_meta_datasets[f"segment.{k}"], h5_meta_datasets_buffer[f"segment.{k}"]) for k in self.changing_segment_keys] for curve_num in range(self.num_of_curves): # Output progress every 1000 curves to give some indication of how long the loading is taking @@ -426,13 +428,13 @@ def load_all_data(self, h5_datasets, h5_meta_datasets, include_metadata: bool = for direction in range(2): for chan in self.segment_channels: # Save the actual curve data to the h5 datasets - self.extract_dat_file(h5_datasets=h5_datasets, curve_num=curve_num, direction=direction, chan_name=chan['name']) + self.extract_dat_file(h5_datasets=h5_datasets, h5_datasets_buffer=h5_datasets_buffer, curve_num=curve_num, direction=direction, chan_name=chan['name']) if include_metadata: # Extract and store the segment metadata for later saving - self.extract_segment_metadata(h5_meta_datasets=h5_meta_datasets, curve_num=curve_num, direction=direction, seg_work=seg_work) + self.extract_segment_metadata(curve_num=curve_num, direction=direction, seg_work=seg_work) if include_metadata: # Extract and store the curve metadata for later saving - self.extract_curve_metadata(h5_meta_datasets=h5_meta_datasets, curve_num=curve_num, curve_work=curve_work) + self.extract_curve_metadata(curve_num=curve_num, curve_work=curve_work) # Add the last index to the indicies datasets to mark the end of the last curve for direction in range(2): seg_name = f"Segment_{direction}" @@ -455,13 +457,16 @@ def save_to_h5( self.changing_curve_keys, self.changing_segment_keys = self.get_changing_keys() self.points_for_channel_segment = self.predict_total_points() self.t_changing_keys = time.perf_counter() - t0 + self.current_offsets = {} - global_meta_group, h5_datasets, h5_meta_datasets = self.setup_h5_structure(file) + global_meta_group, h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer = self.setup_h5_structure(file) for direction in range(2): + self.current_offsets[direction] = {} for chan in self.segment_channels: + self.current_offsets[direction][chan['name']] = 0 self.points_for_channel_segment[direction][chan['name']] = 0 - self.load_all_data(h5_datasets, h5_meta_datasets, include_metadata=include_metadata) + self.load_all_data(h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer, include_metadata=include_metadata) for direction in range(2): for chan in self.segment_channels: h5_datasets[f"Segment_{direction}"][chan['name']]["Data"].resize((self.points_for_channel_segment[direction][chan['name']],)) @@ -798,7 +803,7 @@ def construct_full_metadata(self): full_metadata = {"top_level": self.top_level_meta, "curves": self.curve_meta, "segments": self.segment_meta} return full_metadata - def extract_dat_file(self, h5_datasets, curve_num: int, direction: int, chan_name: str): + def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, direction: int, chan_name: str): """ Extracts the data from a .dat file in the JPK QI archive, applies the appropriate scaling, and saves it to the internal data structure and h5 dataset if required. @@ -806,6 +811,8 @@ def extract_dat_file(self, h5_datasets, curve_num: int, direction: int, chan_nam ---------- h5_datasets : dict A dictionary containing the h5 datasets for each channel and segment direction, used for saving the data + h5_datasets_buffer : dict + A dictionary containing the buffer for each h5 dataset, used for temporary storage before writing to the dataset curve_num : int The curve number associated with the .dat file, parsed from the filename. direction : int @@ -817,9 +824,11 @@ def extract_dat_file(self, h5_datasets, curve_num: int, direction: int, chan_nam scale = self.channel_scaling[chan_name] dat_path = f"index/{curve_num}/segments/{direction}/channels/{chan_name}.dat" data_set = h5_datasets[f"Segment_{direction}"][chan_name]["Data"] - data_size = data_set.shape[0] indicies_set = h5_datasets[f"Segment_{direction}"][chan_name]["Indicies"] + data_size = data_set.shape[0] + buf = h5_datasets_buffer[f"Segment_{direction}"][chan_name] filled_size = self.points_for_channel_segment[direction][chan_name] + start_offset = self.current_offsets[direction][chan_name] try: t0 = time.perf_counter() with self.qi_archive.open(dat_path) as f: @@ -833,16 +842,24 @@ def extract_dat_file(self, h5_datasets, curve_num: int, direction: int, chan_nam # Apply the scaling to convert raw values into real world values segment_array = (raw_array * scale["multiplier"]) + scale["offset"] - self.points_for_channel_segment[direction][chan_name] += len(segment_array) - if self.points_for_channel_segment[direction][chan_name] > data_size: - # Fetch and resize the existing dataset for this channel and segment to fit the new data - data_set.resize((self.points_for_channel_segment[direction][chan_name],)) + self.current_offsets[direction][chan_name] += len(segment_array) + + buf["Data"].append(segment_array) + if len(buf["Data"]) >= self.BUFFER_SIZE: + if self.points_for_channel_segment[direction][chan_name] > data_size: + # Fetch and resize the existing dataset for this channel and segment to fit the new data + data_set.resize((self.points_for_channel_segment[direction][chan_name],)) + + buffered_data = np.concatenate(buf["Data"]) + + self.t_proc_data += time.perf_counter() - t1 + # Append the new data to the end of the existing dataset + start_time = time.perf_counter() + data_set[filled_size:filled_size + len(buffered_data)] = buffered_data + self.points_for_channel_segment[direction][chan_name] += len(buffered_data) + buf["Data"].clear() + self.t_save_data += time.perf_counter() - start_time - self.t_proc_data += time.perf_counter() - t1 - # Append the new data to the end of the existing dataset - start_time = time.perf_counter() - data_set[filled_size:filled_size + len(segment_array)] = segment_array - self.t_save_data += time.perf_counter() - start_time except KeyError: self.failed_curves.add((curve_num, direction, chan_name)) @@ -855,8 +872,23 @@ def extract_dat_file(self, h5_datasets, curve_num: int, direction: int, chan_nam elif len(self.failed_curves) == 10: logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") - # Append the new index to the end of the existing indicies dataset - indicies_set[curve_num] = filled_size + start_time = time.perf_counter() + # Append the new index to the indicies buffer + buf["Indicies"].append(start_offset) + + # If the indicies buffer is full add it to the indicies dataset and clear the buffer + if len(buf["Indicies"]) > 0 and len(buf["Indicies"]) % self.BUFFER_SIZE == 0: + indicies_set[curve_num - self.BUFFER_SIZE + 1 : curve_num + 1] = buf["Indicies"] + buf["Indicies"].clear() + + # Or if this is the last curve and there are still indicies in the buffer, add them to the indicies dataset and clear the buffer + elif len(buf["Indicies"]) > 0 and curve_num == self.num_of_curves - 1: + items_in_buffer = len(buf["Indicies"]) + indicies_set[curve_num - items_in_buffer + 1 : curve_num + 1] = buf["Indicies"] + buf["Indicies"].clear() + + self.t_save_data += time.perf_counter() - start_time + else: self.failed_curves.add((curve_num, direction, chan_name)) if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs @@ -892,7 +924,7 @@ def extract_curve_metadata(self, curve_num: int, curve_work): elif len(self.failed_curves) == 10: logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") - for search_term, meta_set in curve_work: + for search_term, meta_set, meta_buffer in curve_work: start_time = time.perf_counter() start = raw_bytes.find(search_term) if start != -1: @@ -902,9 +934,12 @@ def extract_curve_metadata(self, curve_num: int, curve_work): else: value = "No data" self.t_proc_meta += time.perf_counter() - start_time - if meta_set: + if meta_buffer is not None: start_time = time.perf_counter() - meta_set[curve_num] = value + meta_buffer.append(value) + if len(meta_buffer) >= self.BUFFER_SIZE or curve_num == self.num_of_curves - 1: + meta_set[curve_num - len(meta_buffer) + 1 : curve_num + 1] = meta_buffer + meta_buffer.clear() self.t_save_meta += time.perf_counter() - start_time else: logger.error(f"Metadata dataset for key {search_term.decode('utf-8')} not found when trying to save metadata for curve {curve_num}") @@ -917,8 +952,6 @@ def extract_segment_metadata(self, curve_num: int, direction: int, seg_work): Parameters ---------- - h5_meta_datasets : dict - A dictionary containing the h5 datasets for each metadata key, used for saving the metadata. curve_num : int The curve number associated with the metadata, parsed from the filename. direction : int @@ -938,7 +971,7 @@ def extract_segment_metadata(self, curve_num: int, direction: int, seg_work): elif len(self.failed_curves) == 10: logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") start_time = time.perf_counter() - for search_term, meta_set in seg_work: + for search_term, meta_set, meta_buffer in seg_work: start = raw_content.find(search_term) if start != -1: start += len(search_term) @@ -946,8 +979,12 @@ def extract_segment_metadata(self, curve_num: int, direction: int, seg_work): value = raw_content[start:end].decode("utf-8").strip() else: value = "No data" - if meta_set: - meta_set[curve_num * 2 + direction] = value + if meta_buffer is not None: + meta_buffer.append(value) + if len(meta_buffer) >= self.BUFFER_SIZE or curve_num == self.num_of_curves - 1: + idx = curve_num * 2 + direction + meta_set[idx - len(meta_buffer) + 1 : idx + 1] = meta_buffer + meta_buffer.clear() else: logger.error(f"Metadata dataset for key {search_term.decode('utf-8')} not found when trying to save metadata for curve {curve_num}, direction {direction}") self.t_proc_meta += time.perf_counter() - start_time @@ -981,16 +1018,21 @@ def setup_h5_structure(self, h5file): curve_groups = {"Data": {}, "Indicies": {}} h5_datasets = {} h5_meta_datasets = {} + h5_datasets_buffer = {} + h5_meta_datasets_buffer = {} for key in self.changing_curve_keys: - h5_meta_datasets[f"curve.{key}"] = curves_meta_group.create_dataset(name=f"{key}", shape=(self.num_of_curves,), maxshape=(None,), chunks=self.META_CHUNKSIZE, dtype=h5py.string_dtype(encoding="utf-8")) + h5_meta_datasets[f"curve.{key}"] = curves_meta_group.create_dataset(name=f"curve.{key}", shape=(self.num_of_curves,), maxshape=(None,), chunks=self.META_CHUNKSIZE, dtype=h5py.string_dtype(encoding="utf-8")) + h5_meta_datasets_buffer[f"curve.{key}"] = [] for key in self.changing_segment_keys: - h5_meta_datasets[f"segment.{key}"] = curves_meta_group.create_dataset(name=f"{key}", shape=(self.num_of_curves * 2,), maxshape=(None,), chunks=self.META_CHUNKSIZE, dtype=h5py.string_dtype(encoding="utf-8")) + h5_meta_datasets[f"segment.{key}"] = curves_meta_group.create_dataset(name=f"segment.{key}", shape=(self.num_of_curves * 2,), maxshape=(None,), chunks=self.META_CHUNKSIZE, dtype=h5py.string_dtype(encoding="utf-8")) + h5_meta_datasets_buffer[f"segment.{key}"] = [] for direction in range(2): # For each segment direction, establish the necessary group structure that will contain each channel dataset seg_name = f"Segment_{direction}" dir_group = curves_group.require_group(seg_name) h5_datasets[seg_name] = {} + h5_datasets_buffer[seg_name] = {} # Create the Data and Indicies subfolders and store their references curve_groups["Data"][seg_name] = dir_group.require_group("Data") curve_groups["Indicies"][seg_name] = dir_group.require_group("Indicies") @@ -999,7 +1041,8 @@ def setup_h5_structure(self, h5file): # For each channel, create an empty dataset for the curve data and indicies with the appropriate name and data type h5_datasets[seg_name][chan['name']]['Data'] = curve_groups["Data"][seg_name].create_dataset(name=chan["name"], shape=(self.points_for_channel_segment[direction][chan['name']],), maxshape=(None,), chunks=(self.DATA_CHUNKSIZE,), dtype=np.float32) h5_datasets[seg_name][chan['name']]['Indicies'] = curve_groups["Indicies"][seg_name].create_dataset(name=chan["name"], shape=(self.num_of_curves, ), maxshape=(None,), chunks=(self.INDICIES_CHUNKSIZE,), dtype=np.int32) - return global_meta_group, h5_datasets, h5_meta_datasets + h5_datasets_buffer[seg_name][chan['name']] = {'Data': [], 'Indicies': []} + return global_meta_group, h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer def get_saving_context(self): """ @@ -1095,6 +1138,15 @@ def extract_global_metadata(self): def close(self): """Closes the ZIP archive when done to free up system resources.""" self.qi_archive.close() + self.image = None + self.all_curve_data = None + self.curve_meta = {} + self.segment_meta = {} + self.top_level_meta = {} + self.full_metadata = {} + self.failed_curves = set() + self.points_for_channel_segment = {} + self.namelist = [] From f0235ad67381e7e6b81e9616ce2bec7c95d4f1e5 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Thu, 2 Apr 2026 14:00:10 +0100 Subject: [PATCH 29/54] Removing redundant functions and fixing minor index bugs --- AFMReader/h5_jpk.py | 26 ++++++--- AFMReader/jpk_qi.py | 127 ++------------------------------------------ 2 files changed, 24 insertions(+), 129 deletions(-) diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index 97a97b5..62a166d 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -267,11 +267,13 @@ def generate_timestamps(num_frames: int, line_rate: float, image_size: int) -> d # Compose a dictionary of timestamsps return {f"frame {i}": timestamp for i, timestamp in enumerate(timestamps)} + def get_h5jpk_channels(file_path: Path | str): with h5py.File(file_path, "r") as f: available_channels = list(_available_channels(f)) return available_channels + class LazyQIData: def __init__(self, qi_data_group: h5py.Group, shape_x: int, shape_y: int, flip_image: bool = True): self.qi_data_group = qi_data_group @@ -280,15 +282,16 @@ def __init__(self, qi_data_group: h5py.Group, shape_x: int, shape_y: int, flip_i self.dims = (shape_y, shape_x) self.flip_image = flip_image - def __getitem__(self, y: int): class RowProxy: def __init__(self, parent, y): self.parent = parent self.y = y + def __getitem__(self, x: int): return self.parent._fetch_curve(self.y, x) + return RowProxy(self, y) def __iter__(self): @@ -331,9 +334,13 @@ def load_all_curves(self): return all_curves + class LazyCurveMetadata: """A proxy class that fetches header.properties files on demand.""" - def __init__(self, qi_data_group: h5py.Group, top_level_meta: dict, shape_x: int, shape_y: int, flip_image: bool = True): + + def __init__( + self, qi_data_group: h5py.Group, top_level_meta: dict, shape_x: int, shape_y: int, flip_image: bool = True + ): self.qi_data_group = qi_data_group self.top_level_meta = top_level_meta self.shape_x = shape_x @@ -351,6 +358,7 @@ def __getitem__(self, key): return LazyMetaProxy(self.qi_data_group, "segment", self.shape_x, self.shape_y, self.flip_image) raise KeyError(key) + class LazyMetaProxy: def __init__(self, qi_data_group: h5py.Group, meta_type: str, shape_x: int, shape_y: int, flip_image: bool = True): self.qi_data_group = qi_data_group @@ -364,18 +372,23 @@ class RowProxy: def __init__(self, parent, y): self.parent = parent self.y = y + def __getitem__(self, x): if self.parent.meta_type == "curve": return self.parent._fetch_meta(self.y, x) elif self.parent.meta_type == "segment": + class SegmentMetaProxy: def __init__(self, parent, y, x): self.parent = parent self.y = y self.x = x + def __getitem__(self, direction): return self.parent.parent._fetch_meta(self.y, self.x, direction) + return SegmentMetaProxy(self, self.y, x) + return RowProxy(self, y) def _fetch_meta(self, y: int, x: int, direction: int = None): @@ -391,13 +404,16 @@ def _fetch_meta(self, y: int, x: int, direction: int = None): if key.startswith(f"{self.meta_type}."): new_key = key.split(".", 1)[1] if isinstance(self.qi_data_group["Curve_Metadata"][key], h5py.Dataset): - meta_dict[new_key] = self.qi_data_group["Curve_Metadata"][key][idx] + meta_dict[new_key] = ( + self.qi_data_group["Curve_Metadata"][key][idx].decode("utf-8") + if isinstance(self.qi_data_group["Curve_Metadata"][key][idx], bytes) + else self.qi_data_group["Curve_Metadata"][key][idx] + ) else: meta_dict[new_key] = self.qi_data_group["Curve_Metadata"][key] return meta_dict - def load_h5jpk( file_path: Path | str, channel: str, flip_image: bool = True, load_curves: bool = True ) -> tuple[np.ndarray, float, dict[str, float]]: @@ -447,7 +463,6 @@ def load_h5jpk( channel_group, measurement_group, dataset_name = _get_channel_info(f, channel) - # Load images and scaling factors from channel dataset images = channel_group[dataset_name][:] scaling, offset = _get_z_scaling_h5(channel_group) @@ -482,7 +497,6 @@ def load_h5jpk( if "QI_Curve_Data" not in f: load_curves = False - if load_curves: f = h5py.File(file_path, "r") logger.debug(f"QI_Curve_Data group keys: {list(f.keys())}") diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index cc4e70c..0eeaaa4 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -268,7 +268,7 @@ def __init__( self.path_to_image = None # Chunk size for H5 datasets - self.DATA_CHUNKSIZE = 2 * 1024 * 1024 + self.DATA_CHUNKSIZE = 512 * 1024 # Chunk size for indicies datasets self.INDICIES_CHUNKSIZE = 64 * 1024 # Chunk size for metadata datasets (if needed) @@ -469,6 +469,7 @@ def save_to_h5( self.load_all_data(h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer, include_metadata=include_metadata) for direction in range(2): for chan in self.segment_channels: + logger.debug(f"Resizing dataset for channel {chan['name']} in segment {direction} from {h5_datasets[f'Segment_{direction}'][chan['name']]['Data'].shape[0]} to final size {self.points_for_channel_segment[direction][chan['name']]}") h5_datasets[f"Segment_{direction}"][chan['name']]["Data"].resize((self.points_for_channel_segment[direction][chan['name']],)) logger.debug(f"Curve meta size in memory: {asizeof.asizeof(self.curve_meta) / 1024 / 1024:.2f} MB") logger.debug(f"Segment meta size in memory: {asizeof.asizeof(self.segment_meta) / 1024 / 1024:.2f} MB") @@ -575,64 +576,7 @@ def get_changing_keys(self): return changing_curve_keys, changing_segment_keys - def get_collated_curves(self): - """ - Collates the curve data from the flat structure it is extracted in into a structure grouped by channel and segment for easier saving to h5. - - Returns - ------- - collated_curve_data : dict - A dictionary containing the curve data collated by channel and segment, with the structure: - { - "channel_name": { - "Segment_0": [...], - "Segment_1": [...], - ... - }, - ... - } - indicies : dict - A dictionary containing the indexes of the curve data within each segment, with the structure: - { - "channel_name": { - "Segment_0": [...], - "Segment_1": [...], - ... - }, - ... - } - """ - collated_curve_data = {} - indicies = {} - - for curve_data in self.flat_curve_data: - for chan_name, chan_data in curve_data.items(): - for seg_name, seg_data in chan_data.items(): - if chan_name not in collated_curve_data: - collated_curve_data[chan_name] = {} - indicies[chan_name] = {} - - if seg_name not in collated_curve_data[chan_name]: - collated_curve_data[chan_name][seg_name] = [] - indicies[chan_name][seg_name] = [0] - - # Append the segment data as an array to the list (creates a 2D list) - collated_curve_data[chan_name][seg_name].append(seg_data) - - last_index = indicies[chan_name][seg_name][-1] - next_index = last_index + len(seg_data) - - indicies[chan_name][seg_name].append(next_index) - for chan_name, segments in collated_curve_data.items(): - for seg_name in segments: - # Flattens the list of arrays into one massive 1D array for more efficiency - collated_curve_data[chan_name][seg_name] = np.concatenate(collated_curve_data[chan_name][seg_name]) - - # Converts the indices list into a standard fixed-length integer array - indicies[chan_name][seg_name] = np.array(indicies[chan_name][seg_name], dtype=np.int32) - - return collated_curve_data, indicies def get_collated_metadata(self): """ @@ -739,69 +683,6 @@ def save_lite_data(self): del chan_grp[dataset_name] chan_grp.create_dataset(dataset_name, data=frame_stack) - def save_metadata(self): - """Saves the metadata to the appropriate format based on the save_as_h5 attribute.""" - if self.save_as_h5: - for seg_chan in self.segment_channels: - self.global_meta_group.attrs[f"channel.unit.{seg_chan['name']}"] = seg_chan["unit"] - for key, value in self.top_level_meta.items(): - self.global_meta_group.attrs[key] = str(value).encode("utf-8") - for i, c_meta in enumerate(self.curve_meta): - curve_meta_group = self.curves_meta_group.require_group(f"{i}") - for key, value in c_meta.items(): - curve_meta_group.attrs[key] = str(value).encode("utf-8") - - for d in range(2): - segment_meta_group = curve_meta_group.require_group(f"{d}") - for key, value in self.segment_meta[i * 2 + d].items(): - segment_meta_group.attrs[key] = str(value).encode("utf-8") - - def construct_full_metadata(self): - """ - Constructs the full metadata dictionary by determining which keys in the curve and segment metadata change across curves/segments and which do not, moving the non-changing keys to the top level metadata, and then combining everything into a single dictionary. - - Returns - ------- - dict - The full metadata dictionary with the structure: - { - "top_level": { ... }, - "curves": [ { ... }, { ... }, ... ], - "segments": [ { ... }, { ... }, ... ] - } - """ - - # Find keys that change across curves/segments - changing_curve_keys = { - k - for k in self.all_curve_keys - if any(self.curve_meta[i].get(k) != self.curve_meta[0].get(k) for i in range(1, self.num_of_curves)) - } - changing_segment_keys = { - k - for k in self.all_segment_keys - if any(self.segment_meta[i].get(k) != self.segment_meta[0].get(k) for i in range(1, len(self.segment_meta))) - } - - # Move duplicated meta to top level for both segments and curves - for key in self.all_curve_keys - changing_curve_keys: - if self.curve_meta and key in self.curve_meta[0]: - self.top_level_meta[f"curve.{key}"] = self.curve_meta[0][key] - for key in self.all_segment_keys - changing_segment_keys: - if self.segment_meta and key in self.segment_meta[0]: - self.top_level_meta[f"segment.{key}"] = self.segment_meta[0][key] - - # Strip duplicated keys from individual curve/segment dicts - for c_meta in self.curve_meta: - for k in self.all_curve_keys - changing_curve_keys: - c_meta.pop(k, None) - for s_meta in self.segment_meta: - for k in self.all_segment_keys - changing_segment_keys: - s_meta.pop(k, None) - - # Construct full metadata dict from subdicts - full_metadata = {"top_level": self.top_level_meta, "curves": self.curve_meta, "segments": self.segment_meta} - return full_metadata def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, direction: int, chan_name: str): """ @@ -845,7 +726,7 @@ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, dire self.current_offsets[direction][chan_name] += len(segment_array) buf["Data"].append(segment_array) - if len(buf["Data"]) >= self.BUFFER_SIZE: + if len(buf["Data"]) >= self.BUFFER_SIZE or curve_num == self.num_of_curves - 1: if self.points_for_channel_segment[direction][chan_name] > data_size: # Fetch and resize the existing dataset for this channel and segment to fit the new data data_set.resize((self.points_for_channel_segment[direction][chan_name],)) @@ -1040,7 +921,7 @@ def setup_h5_structure(self, h5file): h5_datasets[seg_name][chan['name']] = {} # For each channel, create an empty dataset for the curve data and indicies with the appropriate name and data type h5_datasets[seg_name][chan['name']]['Data'] = curve_groups["Data"][seg_name].create_dataset(name=chan["name"], shape=(self.points_for_channel_segment[direction][chan['name']],), maxshape=(None,), chunks=(self.DATA_CHUNKSIZE,), dtype=np.float32) - h5_datasets[seg_name][chan['name']]['Indicies'] = curve_groups["Indicies"][seg_name].create_dataset(name=chan["name"], shape=(self.num_of_curves, ), maxshape=(None,), chunks=(self.INDICIES_CHUNKSIZE,), dtype=np.int32) + h5_datasets[seg_name][chan['name']]['Indicies'] = curve_groups["Indicies"][seg_name].create_dataset(name=chan["name"], shape=(self.num_of_curves + 1, ), maxshape=(None,), chunks=(self.INDICIES_CHUNKSIZE,), dtype=np.int32) h5_datasets_buffer[seg_name][chan['name']] = {'Data': [], 'Indicies': []} return global_meta_group, h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer From 62481da2470fcfb7b88d8a7382401a9e56c85307 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Fri, 3 Apr 2026 19:23:29 +0100 Subject: [PATCH 30/54] removing possibility of size 1 image stack --- AFMReader/h5_jpk.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index 62a166d..8d113d5 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -474,14 +474,20 @@ def load_h5jpk( # Reshape each column vector (height, width) to get (num_frames, height, width) num_frames = images.shape[1] - image_stack = np.empty((num_frames, shape_y, shape_x), dtype=images.dtype) + if num_frames == 1: + image_stack = np.empty((shape_y, shape_x), dtype=images.dtype) + else: + image_stack = np.empty((num_frames, shape_y, shape_x), dtype=images.dtype) for i in range(num_frames): frame = images[:, i].reshape((shape_y, shape_x)) # Flip images if flip_image: frame = np.flipud(frame) - image_stack[i] = frame + if num_frames == 1: + image_stack = frame + else: + image_stack[i] = frame # Convert to nm if dataset_name.lower() in ("height", "error", "measuredheight", "amplitude"): From 588c7252d72ab7f5cac0b2d31f7387b5bdab3ebb Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Fri, 3 Apr 2026 19:24:04 +0100 Subject: [PATCH 31/54] adding comments and formatting --- AFMReader/jpk_qi.py | 411 +++++++++++++++++++++++++++++++------------- 1 file changed, 291 insertions(+), 120 deletions(-) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 0eeaaa4..c34827f 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -14,6 +14,7 @@ from AFMReader.logging import logger from AFMReader import jpk + class LazyCurveData: """A proxy class that behaves like a 2D list but fetches .dat files on demand.""" @@ -375,6 +376,8 @@ def load( if self.save_as_h5: self.save_to_h5() + # Establish the lazy loading structures for curve data and metadata. Note how lazy structure is used even if + # all the data has been accessed and saved to H5 to prevent excessive memory usage self.full_metadata = LazyCurveMetadata( self.filepath, self.top_level_meta, self.qi_archive, self.shape_x, self.shape_y, flip_image=self.flip_image ) @@ -392,95 +395,174 @@ def load( if self.all_curve_data: return (self.image, self.px2nm, (self.all_curve_data, self.channels_units, self.full_metadata)) - return self.image, self.px2nm def output_summary(self): + """ + Outputs a summary of the loading process, including any failed curve loads and their details. + """ if self.failed_curves: logger.warning(f"Failed to load {len(self.failed_curves)} files.") logger.warning("Summary of missing files (up to 10 shown):") + + # Output the first 10 failed loads with details for i, (curve_num, direction, chan_name) in enumerate(self.failed_curves): - if i < 10: # Limit the number of warnings to avoid spamming the logs + if i < 10: if chan_name: - logger.warning(f"Failed to load data file for curve {curve_num}, direction {direction}, channel {chan_name}") + logger.warning( + f"Failed to load data file for curve {curve_num}, direction {direction}, channel {chan_name}" + ) else: if direction is not None: - logger.warning(f"Failed to load segment meta file for curve {curve_num}, direction {direction}") + logger.warning( + f"Failed to load segment meta file for curve {curve_num}, direction {direction}" + ) else: logger.warning(f"Failed to load curve meta file for curve {curve_num}") + else: + break else: + # If there are no failed loads, log that all data was loaded successfully logger.info("Successfully loaded all curve data without any missing files.") - def load_all_data(self, h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer, include_metadata: bool = True): - logger.info(f"Loading all curve data from JPK QI archive with {len(self.namelist)} files {'' if include_metadata else 'not '}including metadata") + # Output the performance summary for the loading process + summary = ( + f"\n--- Performance Summary ---\n" + f"Changing Keys Detection: {self.t_changing_keys:.2f}s\n" + f"Raw Data - Loading: {self.t_load_data:.2f}s | Processing: {self.t_proc_data:.2f}s | Saving: {self.t_save_data:.2f}s\n" + f"Metadata - Loading: {self.t_load_meta:.2f}s | Processing: {self.t_proc_meta:.2f}s | Saving: {self.t_save_meta:.2f}s\n" + f"---------------------------" + ) + logger.info(summary) + + def extract_data_to_h5( + self, h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer, include_metadata: bool = True + ): + """ + Load all curve data and optionally metadata from the JPK QI archive into HDF5 datasets. + + Parameters + ---------- + h5_datasets : dict + Dictionary of HDF5 datasets for storing curve data. + h5_meta_datasets : dict + Dictionary of HDF5 datasets for storing metadata. + h5_datasets_buffer : dict + Dictionary of buffers for HDF5 curve data. + h5_meta_datasets_buffer : dict + Dictionary of buffers for HDF5 metadata. + include_metadata : bool, optional + Whether to include metadata in the loading process, by default True. + """ + + logger.info( + f"Loading all curve data from JPK QI archive with {len(self.namelist)} files {'' if include_metadata else 'not '}including metadata" + ) progress_counter = 0 process = psutil.Process(os.getpid()) - curve_work = [(f"{k}=".encode("utf-8"), h5_meta_datasets[f"curve.{k}"], h5_meta_datasets_buffer[f"curve.{k}"]) - for k in self.changing_curve_keys] - seg_work = [(f"{k}=".encode("utf-8"), h5_meta_datasets[f"segment.{k}"], h5_meta_datasets_buffer[f"segment.{k}"]) - for k in self.changing_segment_keys] + if include_metadata: + # Prepare keys for metadata to speed up processing + curve_work = [ + (f"{k}=".encode("utf-8"), h5_meta_datasets[f"curve.{k}"], h5_meta_datasets_buffer[f"curve.{k}"]) + for k in self.changing_curve_keys + ] + seg_work = [ + (f"{k}=".encode("utf-8"), h5_meta_datasets[f"segment.{k}"], h5_meta_datasets_buffer[f"segment.{k}"]) + for k in self.changing_segment_keys + ] for curve_num in range(self.num_of_curves): # Output progress every 1000 curves to give some indication of how long the loading is taking if progress_counter % 1000 == 0: mem = process.memory_info().rss / 1024 / 1024 - logger.info(f"Progress: {progress_counter}/{self.num_of_curves} curves processed, Memory usage: {mem:.2f} MB") + logger.info( + f"Progress: {progress_counter}/{self.num_of_curves} curves processed, Memory usage: {mem:.2f} MB" + ) progress_counter += 1 + for direction in range(2): for chan in self.segment_channels: # Save the actual curve data to the h5 datasets - self.extract_dat_file(h5_datasets=h5_datasets, h5_datasets_buffer=h5_datasets_buffer, curve_num=curve_num, direction=direction, chan_name=chan['name']) + self.extract_dat_file( + h5_datasets=h5_datasets, + h5_datasets_buffer=h5_datasets_buffer, + curve_num=curve_num, + direction=direction, + chan_name=chan["name"], + ) + if include_metadata: # Extract and store the segment metadata for later saving self.extract_segment_metadata(curve_num=curve_num, direction=direction, seg_work=seg_work) + if include_metadata: # Extract and store the curve metadata for later saving self.extract_curve_metadata(curve_num=curve_num, curve_work=curve_work) + # Add the last index to the indicies datasets to mark the end of the last curve for direction in range(2): seg_name = f"Segment_{direction}" for chan in self.segment_channels: - chan_name = chan['name'] + chan_name = chan["name"] current_dataset = h5_datasets[seg_name][chan_name]["Data"] indicies_dataset = h5_datasets[seg_name][chan_name]["Indicies"] indicies_dataset[-1] = current_dataset.shape[0] - self.output_summary() def save_to_h5( self, include_metadata: bool = True, ): - """Saves the data as an H5 file. If include_metadata is False, only the curve data will be saved.""" + """ + Saves the data as an H5 file. If include_metadata is False, only the curve data will be saved. + + Parameters + ---------- + include_metadata : bool, optional + If True, metadata will be included in the saved H5 file. Default is True. + """ with self.get_saving_context() as file: t0 = time.perf_counter() + + # Sample curves in dataset to make a best guess for the meta keys that need to be extracted from each curve self.changing_curve_keys, self.changing_segment_keys = self.get_changing_keys() self.points_for_channel_segment = self.predict_total_points() self.t_changing_keys = time.perf_counter() - t0 - self.current_offsets = {} - global_meta_group, h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer = self.setup_h5_structure(file) + # Setup H5 structure for saving the data, creating datasets for curve data and metadata as needed + global_meta_group, h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer = ( + self.setup_h5_structure(file) + ) + # Set up current_offsets to keep track of how many points have been read + self.current_offsets = {} for direction in range(2): self.current_offsets[direction] = {} for chan in self.segment_channels: - self.current_offsets[direction][chan['name']] = 0 - self.points_for_channel_segment[direction][chan['name']] = 0 - self.load_all_data(h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer, include_metadata=include_metadata) + self.current_offsets[direction][chan["name"]] = 0 + + # Reset the points for channel segment to 0 so it can be used to store the actual number of points held in each dataset + self.points_for_channel_segment[direction][chan["name"]] = 0 + + # Extract data from the JPK QI archive and save to H5 datasets, optionally including metadata + self.extract_data_to_h5( + h5_datasets, + h5_meta_datasets, + h5_datasets_buffer, + h5_meta_datasets_buffer, + include_metadata=include_metadata, + ) + # Resize the datasets to the actual number of points read for direction in range(2): for chan in self.segment_channels: - logger.debug(f"Resizing dataset for channel {chan['name']} in segment {direction} from {h5_datasets[f'Segment_{direction}'][chan['name']]['Data'].shape[0]} to final size {self.points_for_channel_segment[direction][chan['name']]}") - h5_datasets[f"Segment_{direction}"][chan['name']]["Data"].resize((self.points_for_channel_segment[direction][chan['name']],)) - logger.debug(f"Curve meta size in memory: {asizeof.asizeof(self.curve_meta) / 1024 / 1024:.2f} MB") - logger.debug(f"Segment meta size in memory: {asizeof.asizeof(self.segment_meta) / 1024 / 1024:.2f} MB") - summary = ( - f"\n--- Performance Summary ---\n" - f"Changing Keys Detection: {self.t_changing_keys:.2f}s\n" - f"Raw Data - Loading: {self.t_load_data:.2f}s | Processing: {self.t_proc_data:.2f}s | Saving: {self.t_save_data:.2f}s\n" - f"Metadata - Loading: {self.t_load_meta:.2f}s | Processing: {self.t_proc_meta:.2f}s | Saving: {self.t_save_meta:.2f}s\n" - f"---------------------------" - ) - logger.info(summary) + logger.debug( + f"Resizing dataset for channel {chan['name']} in segment {direction} from {h5_datasets[f'Segment_{direction}'][chan['name']]['Data'].shape[0]} to final size {self.points_for_channel_segment[direction][chan['name']]}" + ) + h5_datasets[f"Segment_{direction}"][chan["name"]]["Data"].resize( + (self.points_for_channel_segment[direction][chan["name"]],) + ) + + self.output_summary() if include_metadata: # Save the global metadata to the h5 file @@ -490,6 +572,15 @@ def save_to_h5( logger.info(f"QI data copied to h5 data {file.filename}") def get_curves_sample(self): + """ + Get a sample of curve numbers distrubuted evenly across the dataset + + Returns + ------- + range: + A range object representing the sampled curve numbers. + """ + # Check evenly spaced curves in the dataset to sample metadata without having to load every curve step = 1 if self.num_of_curves <= self.MAX_CURVE_CHECKS else self.num_of_curves // self.MAX_CURVE_CHECKS # If the step is equal to a shape dimension, we might just go down the row or column @@ -499,29 +590,62 @@ def get_curves_sample(self): return range(0, self.num_of_curves, step) def predict_total_points(self): + """ + Predicts the total number of points for each channel and segment by sampling a subset of curves + and extrapolating based on the maximum number of points found in the sample. + + Returns + ------- + dict: + A dictionary containing the predicted total points for each channel and segment. + """ + + # Get a sample of curve (indicies) curves_to_check = self.get_curves_sample() points_for_channel_segment = {} + + # Iterate through the segments, channels and our curve indicies for direction in range(2): points_for_channel_segment[direction] = {} for channel in self.segment_channels: - points_for_channel_segment[direction][channel['name']] = [] + points_for_channel_segment[direction][channel["name"]] = [] for curve_num in curves_to_check: + # Loop until we successfully retrieve some data while True: dat_path = f"index/{curve_num}/segments/{direction}/channels/{channel['name']}.dat" try: + # Count points in extracted data with self.qi_archive.open(dat_path) as f: raw_array = np.frombuffer(f.read(), dtype=">i4") - points_for_channel_segment[direction][channel['name']].append(len(raw_array)) + points_for_channel_segment[direction][channel["name"]].append(len(raw_array)) break + except KeyError: + # If the file doesn't exist for this curve, check the next curve so we don't just get a smaller sample if curve_num + 1 >= self.num_of_curves: - break # If we've gone past the number of curves, stop checking + # If we've gone past the number of curves, stop checking + break curve_num += 1 continue - points_for_channel_segment[direction][channel['name']] = int(np.max(points_for_channel_segment[direction][channel['name']]) * 1.1) * self.num_of_curves + # Calculate a prediction for total number of points based on maximum number of points then assuming + # maximum points throughout data is no more than 10% higher + points_for_channel_segment[direction][channel["name"]] = ( + int(np.max(points_for_channel_segment[direction][channel["name"]]) * 1.1) * self.num_of_curves + ) return points_for_channel_segment def get_changing_keys(self): + """ + Checks a sample of curves to see which metadata keys change across curves and segments, + so we can extract only the changing keys for each curve and segment. + + None changing keys are moved to the top level metadata and not extracted for each curve/segment. + + Returns + ------- + tuple: + A tuple containing two sets: changing_curve_keys and changing_segment_keys. + """ curve_meta_dict = {} segment_meta_dict = {} @@ -575,9 +699,6 @@ def get_changing_keys(self): self.top_level_meta[f"segment.{key}"] = values[0] return changing_curve_keys, changing_segment_keys - - - def get_collated_metadata(self): """ Collates the metadata from being split by curve, to being split by attribute so data can be saved more efficiently @@ -594,7 +715,9 @@ def get_collated_metadata(self): collated_meta[key] = value return collated_meta - def get_image(self, overide_channel: str | None = None, convert_to_nm: bool = True, flip_image: bool | None = None) -> tuple[np.ndarray, float]: + def get_image( + self, overide_channel: str | None = None, convert_to_nm: bool = True, flip_image: bool | None = None + ) -> tuple[np.ndarray, float]: """ Processes the flat curve data dictionary into a 2D list structure matching the image dimensions. @@ -604,6 +727,7 @@ def get_image(self, overide_channel: str | None = None, convert_to_nm: bool = Tr A 2D array representing the image data. """ + # Get channel and flip_image parameters, defaulting to the instance attributes if not provided if overide_channel: channel = overide_channel else: @@ -612,6 +736,7 @@ def get_image(self, overide_channel: str | None = None, convert_to_nm: bool = Tr if flip_image is None: flip_image = self.flip_image + # Search through the namelist to find the .jpk-qi-image file path_to_image = None for file_name in self.namelist: if file_name.endswith(".jpk-qi-image"): @@ -619,70 +744,77 @@ def get_image(self, overide_channel: str | None = None, convert_to_nm: bool = Tr if path_to_image is None: raise FileNotFoundError(f"{path_to_image} not found in JPK archive") + # Read the .jpk-qi-image file as bytes and load it using the existing jpk loading function tif_bytes = self.qi_archive.read(path_to_image) virtual_file = io.BytesIO(tif_bytes) logger.info(f"Looking for channel {channel} in {path_to_image}") return jpk._load_jpk( - virtual_file, path_to_image, channel=channel, file_suffix=".jpk-qi-data", config_path=self.config_path, convert_to_nm=convert_to_nm, flip_image=flip_image + virtual_file, + path_to_image, + channel=channel, + file_suffix=".jpk-qi-data", + config_path=self.config_path, + convert_to_nm=convert_to_nm, + flip_image=flip_image, ) def save_lite_data(self): - """Saves a lite form of the data (e.g., the calculated image data) to the appropriate format based on the save_as_h5 attribute.""" - if self.save_as_h5: - with h5py.File(self.filepath.parent / f"{self.filepath.stem}.h5-jpk", "a") as h5file: - # Save data required for reading the h5 file as a normal image file - meas_grp = h5file.require_group("Measurement_000") - # Save dimensions data - meas_grp.attrs["position-pattern.grid.ulength"] = self.size_x - meas_grp.attrs["position-pattern.grid.ilength"] = self.shape_x - meas_grp.attrs["position-pattern.grid.vlength"] = self.size_y - meas_grp.attrs["position-pattern.grid.jlength"] = self.shape_y - meas_grp.attrs["timing-settings.scanRate"] = 1.0 # Dummy value to satisfy reader - - logger.info(f"Saving a hdf5 copy of the data {self.filepath.parent / f'{self.filepath.stem}.h5-jpk'}") - - h5_channels = [self.channel] - # Look for the jpk-qi-image file in the archive - for file_name in self.namelist: - if file_name.endswith(".jpk-qi-image"): - path_to_image = file_name - break - # Add the channels which exist in the jpk-qi-image file - with self.qi_archive.open(path_to_image, "r") as image_file: - h5_channels += jpk._get_jpk_channels( - file=image_file, filename=self.filepath.stem, file_path=self.filepath / Path(path_to_image) - ) - for i, h5_channel in enumerate(h5_channels): - # For each available channel, save the required data to the h5 file - # TODO make sure this metadata is accurate for the channels coming from the .jpk-qi-image file - chan_grp = meas_grp.require_group(f"Channel_{_make_num_min_characters(i)}") - # Extract name and retrace information from the channel name - if "_" in h5_channel: - base_name, trace_dir = h5_channel.rsplit("_", 1) - is_retrace = "true" if trace_dir.lower() == "retrace" else "false" - else: - base_name = h5_channel - is_retrace = "false" - - # Add the necessary attributes to the channel group - chan_grp.attrs["channel.name"] = base_name.encode("utf-8") - chan_grp.attrs["retrace"] = is_retrace.encode("utf-8") - chan_grp.attrs["net-encoder.scaling.multiplier"] = 1.0 - chan_grp.attrs["net-encoder.scaling.offset"] = 0.0 - - # Format name and reshape image (flattened frame stack) - dataset_name = h5_channel.split("_")[0].capitalize() - # Include all the channels including the calculated channel - # TODO make this slightly faster by remembering we have load a channel already but difficult cause of scaling - channel_image, _ = self.get_image(overide_channel=h5_channel, convert_to_nm=False, flip_image=False) - frame_stack = channel_image.flatten().reshape(-1, 1) - - # Update/ replace the channels dataset - if dataset_name in chan_grp: - del chan_grp[dataset_name] - chan_grp.create_dataset(dataset_name, data=frame_stack) - + """ + Saves a lite form of the data (e.g., the calculated image data) to the appropriate format based on the save_as_h5 attribute. + """ + with h5py.File(self.filepath.parent / f"{self.filepath.stem}.h5-jpk", "a") as h5file: + # Save data required for reading the h5 file as a normal image file + meas_grp = h5file.require_group("Measurement_000") + # Save dimensions data + meas_grp.attrs["position-pattern.grid.ulength"] = self.size_x + meas_grp.attrs["position-pattern.grid.ilength"] = self.shape_x + meas_grp.attrs["position-pattern.grid.vlength"] = self.size_y + meas_grp.attrs["position-pattern.grid.jlength"] = self.shape_y + meas_grp.attrs["timing-settings.scanRate"] = 1.0 # Dummy value to satisfy reader + + logger.info(f"Saving a hdf5 copy of the data {self.filepath.parent / f'{self.filepath.stem}.h5-jpk'}") + + h5_channels = [self.channel] + # Look for the jpk-qi-image file in the archive + for file_name in self.namelist: + if file_name.endswith(".jpk-qi-image"): + path_to_image = file_name + break + # Add the channels which exist in the jpk-qi-image file + with self.qi_archive.open(path_to_image, "r") as image_file: + h5_channels += jpk._get_jpk_channels( + file=image_file, filename=self.filepath.stem, file_path=self.filepath / Path(path_to_image) + ) + for i, h5_channel in enumerate(h5_channels): + # For each available channel, save the required data to the h5 file + # TODO make sure this metadata is accurate for the channels coming from the .jpk-qi-image file + chan_grp = meas_grp.require_group(f"Channel_{_make_num_min_characters(i)}") + # Extract name and retrace information from the channel name + if "_" in h5_channel: + base_name, trace_dir = h5_channel.rsplit("_", 1) + is_retrace = "true" if trace_dir.lower() == "retrace" else "false" + else: + base_name = h5_channel + is_retrace = "false" + + # Add the necessary attributes to the channel group + chan_grp.attrs["channel.name"] = base_name.encode("utf-8") + chan_grp.attrs["retrace"] = is_retrace.encode("utf-8") + chan_grp.attrs["net-encoder.scaling.multiplier"] = 1.0 + chan_grp.attrs["net-encoder.scaling.offset"] = 0.0 + + # Format name and reshape image (flattened frame stack) + dataset_name = h5_channel.split("_")[0].capitalize() + # Include all the channels including the calculated channel + # TODO make this slightly faster by remembering we have load a channel already but difficult cause of scaling + channel_image, _ = self.get_image(overide_channel=h5_channel, convert_to_nm=False, flip_image=False) + frame_stack = channel_image.flatten().reshape(-1, 1) + + # Update/ replace the channels dataset + if dataset_name in chan_grp: + del chan_grp[dataset_name] + chan_grp.create_dataset(dataset_name, data=frame_stack) def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, direction: int, chan_name: str): """ @@ -702,6 +834,7 @@ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, dire The channel name associated with the .dat file, parsed from the filename. """ if chan_name in self.channel_scaling: + # Get data structures for this channel and segment scale = self.channel_scaling[chan_name] dat_path = f"index/{curve_num}/segments/{direction}/channels/{chan_name}.dat" data_set = h5_datasets[f"Segment_{direction}"][chan_name]["Data"] @@ -710,6 +843,7 @@ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, dire buf = h5_datasets_buffer[f"Segment_{direction}"][chan_name] filled_size = self.points_for_channel_segment[direction][chan_name] start_offset = self.current_offsets[direction][chan_name] + try: t0 = time.perf_counter() with self.qi_archive.open(dat_path) as f: @@ -723,6 +857,7 @@ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, dire # Apply the scaling to convert raw values into real world values segment_array = (raw_array * scale["multiplier"]) + scale["offset"] + # Update the current offset so it include the length of the data we have just read self.current_offsets[direction][chan_name] += len(segment_array) buf["Data"].append(segment_array) @@ -734,14 +869,16 @@ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, dire buffered_data = np.concatenate(buf["Data"]) self.t_proc_data += time.perf_counter() - t1 - # Append the new data to the end of the existing dataset start_time = time.perf_counter() - data_set[filled_size:filled_size + len(buffered_data)] = buffered_data + + # Add the buffer to the dataset + data_set[filled_size : filled_size + len(buffered_data)] = buffered_data + # Update the filled size for this channel and segment self.points_for_channel_segment[direction][chan_name] += len(buffered_data) + # Clear the buffer buf["Data"].clear() self.t_save_data += time.perf_counter() - start_time - except KeyError: self.failed_curves.add((curve_num, direction, chan_name)) @@ -751,7 +888,9 @@ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, dire f"Data file {dat_path} not found in archive. Skipping data for curve {curve_num}, direction {direction}, channel {chan_name}." ) elif len(self.failed_curves) == 10: - logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") + logger.warning( + "Lots of missing files, further warnings will be suppressed. View summary at the end." + ) start_time = time.perf_counter() # Append the new index to the indicies buffer @@ -771,6 +910,7 @@ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, dire self.t_save_data += time.perf_counter() - start_time else: + # Log if curve failed self.failed_curves.add((curve_num, direction, chan_name)) if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs logger.warning( @@ -779,39 +919,44 @@ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, dire def extract_curve_metadata(self, curve_num: int, curve_work): """ - Extracts the curve metadata from its header.properties file in the JPK QI archive and saves it to the internal data structure. + Extracts the curve metadata from its header.properties file in the JPK QI archive and save to h5 Parameters ---------- - h5_meta_datasets : dict - A dictionary containing the h5 datasets for each metadata key, used for saving the metadata. - file_info : zipfile.ZipInfo - The ZipInfo object corresponding to the header.properties file to be extracted. curve_num : int The curve number associated with the metadata, parsed from the filename. + curve_work : list + A list of tuples containing the search term for the metadata, the h5 dataset to save to, and the buffer for that dataset. """ meta_path = f"index/{curve_num}/header.properties" raw_bytes = b"" try: start_time = time.perf_counter() + # Read metadata file as raw bytes with self.qi_archive.open(meta_path) as f: raw_bytes = f.read() self.t_load_meta += time.perf_counter() - start_time except KeyError: self.failed_curves.add((curve_num, None, None)) - if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs - logger.warning(f"Metadata file {meta_path} not found in archive. Skipping metadata for curve {curve_num}.") + # Limit the number of warnings to avoid spamming the logs + if len(self.failed_curves) < 10: + logger.warning( + f"Metadata file {meta_path} not found in archive. Skipping metadata for curve {curve_num}." + ) elif len(self.failed_curves) == 10: logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") for search_term, meta_set, meta_buffer in curve_work: start_time = time.perf_counter() + # Find the location of the metadata value in the raw bytes start = raw_bytes.find(search_term) + # If found, extract the actual value if start != -1: start += len(search_term) end = raw_bytes.find(b"\n", start) value = raw_bytes[start:end].decode("utf-8").strip() + # Save a no data value if the search term is not found in the metadata file else: value = "No data" self.t_proc_meta += time.perf_counter() - start_time @@ -823,9 +968,9 @@ def extract_curve_metadata(self, curve_num: int, curve_work): meta_buffer.clear() self.t_save_meta += time.perf_counter() - start_time else: - logger.error(f"Metadata dataset for key {search_term.decode('utf-8')} not found when trying to save metadata for curve {curve_num}") - - + logger.error( + f"Metadata dataset for key {search_term.decode('utf-8')} not found when trying to save metadata for curve {curve_num}" + ) def extract_segment_metadata(self, curve_num: int, direction: int, seg_work): """ @@ -848,7 +993,9 @@ def extract_segment_metadata(self, curve_num: int, direction: int, seg_work): except KeyError: self.failed_curves.add((curve_num, direction, None)) if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs - logger.warning(f"Metadata file {meta_path} not found in archive. Skipping metadata for curve {curve_num}, direction {direction}.") + logger.warning( + f"Metadata file {meta_path} not found in archive. Skipping metadata for curve {curve_num}, direction {direction}." + ) elif len(self.failed_curves) == 10: logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") start_time = time.perf_counter() @@ -867,7 +1014,9 @@ def extract_segment_metadata(self, curve_num: int, direction: int, seg_work): meta_set[idx - len(meta_buffer) + 1 : idx + 1] = meta_buffer meta_buffer.clear() else: - logger.error(f"Metadata dataset for key {search_term.decode('utf-8')} not found when trying to save metadata for curve {curve_num}, direction {direction}") + logger.error( + f"Metadata dataset for key {search_term.decode('utf-8')} not found when trying to save metadata for curve {curve_num}, direction {direction}" + ) self.t_proc_meta += time.perf_counter() - start_time def setup_h5_structure(self, h5file): @@ -902,10 +1051,22 @@ def setup_h5_structure(self, h5file): h5_datasets_buffer = {} h5_meta_datasets_buffer = {} for key in self.changing_curve_keys: - h5_meta_datasets[f"curve.{key}"] = curves_meta_group.create_dataset(name=f"curve.{key}", shape=(self.num_of_curves,), maxshape=(None,), chunks=self.META_CHUNKSIZE, dtype=h5py.string_dtype(encoding="utf-8")) + h5_meta_datasets[f"curve.{key}"] = curves_meta_group.create_dataset( + name=f"curve.{key}", + shape=(self.num_of_curves,), + maxshape=(None,), + chunks=self.META_CHUNKSIZE, + dtype=h5py.string_dtype(encoding="utf-8"), + ) h5_meta_datasets_buffer[f"curve.{key}"] = [] for key in self.changing_segment_keys: - h5_meta_datasets[f"segment.{key}"] = curves_meta_group.create_dataset(name=f"segment.{key}", shape=(self.num_of_curves * 2,), maxshape=(None,), chunks=self.META_CHUNKSIZE, dtype=h5py.string_dtype(encoding="utf-8")) + h5_meta_datasets[f"segment.{key}"] = curves_meta_group.create_dataset( + name=f"segment.{key}", + shape=(self.num_of_curves * 2,), + maxshape=(None,), + chunks=self.META_CHUNKSIZE, + dtype=h5py.string_dtype(encoding="utf-8"), + ) h5_meta_datasets_buffer[f"segment.{key}"] = [] for direction in range(2): @@ -918,11 +1079,23 @@ def setup_h5_structure(self, h5file): curve_groups["Data"][seg_name] = dir_group.require_group("Data") curve_groups["Indicies"][seg_name] = dir_group.require_group("Indicies") for chan in self.segment_channels: - h5_datasets[seg_name][chan['name']] = {} + h5_datasets[seg_name][chan["name"]] = {} # For each channel, create an empty dataset for the curve data and indicies with the appropriate name and data type - h5_datasets[seg_name][chan['name']]['Data'] = curve_groups["Data"][seg_name].create_dataset(name=chan["name"], shape=(self.points_for_channel_segment[direction][chan['name']],), maxshape=(None,), chunks=(self.DATA_CHUNKSIZE,), dtype=np.float32) - h5_datasets[seg_name][chan['name']]['Indicies'] = curve_groups["Indicies"][seg_name].create_dataset(name=chan["name"], shape=(self.num_of_curves + 1, ), maxshape=(None,), chunks=(self.INDICIES_CHUNKSIZE,), dtype=np.int32) - h5_datasets_buffer[seg_name][chan['name']] = {'Data': [], 'Indicies': []} + h5_datasets[seg_name][chan["name"]]["Data"] = curve_groups["Data"][seg_name].create_dataset( + name=chan["name"], + shape=(self.points_for_channel_segment[direction][chan["name"]],), + maxshape=(None,), + chunks=(self.DATA_CHUNKSIZE,), + dtype=np.float32, + ) + h5_datasets[seg_name][chan["name"]]["Indicies"] = curve_groups["Indicies"][seg_name].create_dataset( + name=chan["name"], + shape=(self.num_of_curves + 1,), + maxshape=(None,), + chunks=(self.INDICIES_CHUNKSIZE,), + dtype=np.int32, + ) + h5_datasets_buffer[seg_name][chan["name"]] = {"Data": [], "Indicies": []} return global_meta_group, h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer def get_saving_context(self): @@ -967,7 +1140,6 @@ def parse_dimension_data(self): # Establish number of curves self.num_of_curves = self.shape_x * self.shape_y - def extract_global_metadata(self): """Extracts the global metadata from the JPK QI archive and populates the top level metadata dictionary and segment channels list.""" @@ -1030,7 +1202,6 @@ def close(self): self.namelist = [] - def load_fdcurves_from_h5(file_path: Path | str): file_path = Path(file_path) From 4c7ea88a52f187b2a02ff1f75df8a7db7ae7f105 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Fri, 24 Apr 2026 11:41:03 +0100 Subject: [PATCH 32/54] Ensuring that the h5-jpk copy of the jpk-qi-data doesn't get overwritten --- AFMReader/jpk_qi.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index c34827f..7b0b6f2 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -9,7 +9,6 @@ import javaproperties import h5py import psutil -from pympler import asizeof from AFMReader.logging import logger from AFMReader import jpk @@ -260,10 +259,10 @@ def __init__( self.config_path = config_path self.flip_image = flip_image self.save_as_h5 = save_as_h5 - logger.debug(f"Initialized JPK QI loader for file {self.filepath} with channel {self.channel}") + # Open the ZIP archive once and keep it open for the duration of the loading process to improve performance when accessing multiple files within the archive self.qi_archive = zipfile.ZipFile(self.filepath, "r") - logger.debug(f"Opened JPK QI archive at {self.filepath}") + logger.info(f"Opened JPK QI archive at {self.filepath}") self.namelist = self.qi_archive.namelist() # Set path to the .jpk-qi-image file within the archive for later use self.path_to_image = None @@ -286,7 +285,7 @@ def __init__( # A dictionary containing all metadata, splitting the top level metadata and the metadata for each curve and segment self.full_metadata = {} # A 2D list of curve data dictionaries, where each dictionary contains the data for all channels and segments - self.all_curve_data = None + self.curve_data = None # A lookup for channel name to unit to be returned self.channels_units = {} # The list of channels for the segments with their scaling information extracted from the shared header @@ -367,6 +366,13 @@ def load( self.flip_image = flip_image if flip_image is not None else self.flip_image self.save_as_h5 = save_as_h5 if save_as_h5 is not None else self.save_as_h5 + if self.save_as_h5: + self.h5_path = self.filepath.parent / f"{self.filepath.stem}.h5-jpk" + i = 0 + while self.h5_path.exists(): + self.h5_path = self.filepath.parent / f"{self.filepath.stem}_{i}.h5-jpk" + i += 1 + logger.info(f"Loading JPK QI data from {self.filepath} with channel {self.channel}") self.extract_global_metadata() @@ -381,7 +387,7 @@ def load( self.full_metadata = LazyCurveMetadata( self.filepath, self.top_level_meta, self.qi_archive, self.shape_x, self.shape_y, flip_image=self.flip_image ) - self.all_curve_data = LazyCurveData( + self.curve_data = LazyCurveData( self.filepath, self.shape_x, self.shape_y, self.channel_scaling, self.qi_archive, flip_image=self.flip_image ) @@ -392,10 +398,7 @@ def load( if self.save_as_h5: self.save_lite_data() - if self.all_curve_data: - return (self.image, self.px2nm, (self.all_curve_data, self.channels_units, self.full_metadata)) - - return self.image, self.px2nm + return (self.image, self.px2nm, (self.curve_data, self.channels_units, self.full_metadata)) def output_summary(self): """ @@ -555,9 +558,6 @@ def save_to_h5( # Resize the datasets to the actual number of points read for direction in range(2): for chan in self.segment_channels: - logger.debug( - f"Resizing dataset for channel {chan['name']} in segment {direction} from {h5_datasets[f'Segment_{direction}'][chan['name']]['Data'].shape[0]} to final size {self.points_for_channel_segment[direction][chan['name']]}" - ) h5_datasets[f"Segment_{direction}"][chan["name"]]["Data"].resize( (self.points_for_channel_segment[direction][chan["name"]],) ) @@ -763,7 +763,7 @@ def save_lite_data(self): """ Saves a lite form of the data (e.g., the calculated image data) to the appropriate format based on the save_as_h5 attribute. """ - with h5py.File(self.filepath.parent / f"{self.filepath.stem}.h5-jpk", "a") as h5file: + with h5py.File(self.h5_path, "a") as h5file: # Save data required for reading the h5 file as a normal image file meas_grp = h5file.require_group("Measurement_000") # Save dimensions data @@ -773,7 +773,7 @@ def save_lite_data(self): meas_grp.attrs["position-pattern.grid.jlength"] = self.shape_y meas_grp.attrs["timing-settings.scanRate"] = 1.0 # Dummy value to satisfy reader - logger.info(f"Saving a hdf5 copy of the data {self.filepath.parent / f'{self.filepath.stem}.h5-jpk'}") + logger.info(f"Saving a hdf5 copy of the data {self.h5_path}") h5_channels = [self.channel] # Look for the jpk-qi-image file in the archive @@ -1109,7 +1109,7 @@ def get_saving_context(self): The context manager for saving the data. """ if self.save_as_h5: - return h5py.File(self.filepath.parent / f"{self.filepath.stem}.h5-jpk", "a") + return h5py.File(self.h5_path, "a") else: return nullcontext() @@ -1192,7 +1192,7 @@ def close(self): """Closes the ZIP archive when done to free up system resources.""" self.qi_archive.close() self.image = None - self.all_curve_data = None + self.curve_data = None self.curve_meta = {} self.segment_meta = {} self.top_level_meta = {} From 62637c231e7a60c5de05b18e4c0d980217fc2faa Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Fri, 24 Apr 2026 11:44:39 +0100 Subject: [PATCH 33/54] Improving __iter__ function so loading of all curves for analysis is much more memory efficient --- AFMReader/h5_jpk.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index 8d113d5..12f4a97 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -281,6 +281,7 @@ def __init__(self, qi_data_group: h5py.Group, shape_x: int, shape_y: int, flip_i self.shape_y = shape_y self.dims = (shape_y, shape_x) self.flip_image = flip_image + self.CHUNK_SIZE = 500 # Number of curves to load at once when iterating def __getitem__(self, y: int): @@ -295,9 +296,34 @@ def __getitem__(self, x: int): return RowProxy(self, y) def __iter__(self): - for y in range(self.shape_y): + indicies_map = {} + for segment, segment_group in self.qi_data_group["Curves"].items(): + for channel in segment_group["Indicies"]: + if channel not in indicies_map: + indicies_map[channel] = {} + indicies_map[channel][segment] = segment_group["Indicies"][channel][:] + for y_idx in range(self.shape_y): + data = {} + y = self.shape_y - 1 - y_idx if self.flip_image else y_idx + for segment, segment_group in self.qi_data_group["Curves"].items(): + for channel in segment_group["Indicies"]: + if channel not in data: + data[channel] = {} + indicies = indicies_map[channel][segment] + start_idx = int(indicies[self.shape_x * y]) + end_idx = int(indicies[self.shape_x * (y + 1)]) + + data[channel][segment] = segment_group["Data"][channel][start_idx:end_idx] for x in range(self.shape_x): - yield self._fetch_curve(y, x) + curve_data = {} + for channel in data: + curve_data[channel] = {} + for segment in data[channel]: + indicies = indicies_map[channel][segment] + start_idx = int(indicies[self.shape_x * y + x]) - int(indicies[self.shape_x * y]) + end_idx = int(indicies[self.shape_x * y + x + 1]) - int(indicies[self.shape_x * y]) + curve_data[channel][segment] = data[channel][segment][start_idx:end_idx] + yield curve_data def _fetch_curve(self, y: int, x: int): if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x: From b9104cc71c373ecca2519653a8be88bc9a9a67cc Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Fri, 24 Apr 2026 11:45:31 +0100 Subject: [PATCH 34/54] Fixing errors with tests caused by logging problems --- AFMReader/general_loader.py | 4 ---- AFMReader/logging.py | 2 +- tests/test_general_loader.py | 21 ++++++++++----------- tests/test_spm.py | 5 +++-- 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index c1e699b..fdb3874 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -120,15 +120,11 @@ def load(self, channel: str | None = None, kwargs: dict = None) -> tuple[npt.NDA else: raise ValueError(f"File type '{self.suffix}' is not currently handled by AFMReader.") - print(f"Loaded image with shape {image.shape} and pixel to nanometre scaling factor {pixel_to_nanometre_scaling_factor}") - print(f"Image has max value {image.max()} and min value {image.min()}") - return image, pixel_to_nanometre_scaling_factor except ValueError as e: logger.error(f"{e}") raise e - return (e, None) # cheeky return of an image, px2nm-like tuple object to propagate error message to Napari def get_available_channels(self): if self.suffix == ".asd": diff --git a/AFMReader/logging.py b/AFMReader/logging.py index 5d3a70f..f722dff 100644 --- a/AFMReader/logging.py +++ b/AFMReader/logging.py @@ -7,7 +7,7 @@ logger.remove() # Set the format to have blue time, green file, module, function and line, and white message logger.add( - sys.stderr, + lambda msg: sys.stderr.write(msg), colorize=True, format="{time:HH:mm:ss} | {level} |" "{file}:{module}:" diff --git a/tests/test_general_loader.py b/tests/test_general_loader.py index b04547d..d1c6861 100644 --- a/tests/test_general_loader.py +++ b/tests/test_general_loader.py @@ -126,23 +126,22 @@ ), ], ) -def test_load(caplog: pytest.LogCaptureFixture, filepath: Path, channel: str, error: bool, message: str) -> None: +def test_load(capsys: pytest.CaptureFixture, filepath: Path, channel: str, error: bool, message: str) -> None: """Test loading of all (asd, gwy, ibw, jpk, spm, stp, top, topostats) filetypes.""" loader = general_loader.LoadFile(filepath, channel) - - image, px2nm = loader.load() - - if not error: + try: + image, px2nm = loader.load() # check array and px2nm returned assert isinstance(image, np.ndarray) assert isinstance(px2nm, float) - else: - # check when channel wrong - assert isinstance(image, ValueError) - assert px2nm is None - + except ValueError as e: + if error: + assert message in str(e) # check output logs - assert message in caplog.text + captured = capsys.readouterr() + assert message in captured.err + + @pytest.mark.parametrize( diff --git a/tests/test_spm.py b/tests/test_spm.py index f15632e..466be2c 100644 --- a/tests/test_spm.py +++ b/tests/test_spm.py @@ -132,7 +132,7 @@ def test_load_spm_file_not_found() -> None: ], ) def test_load_spm_channel_not_found( - caplog: pytest.LogCaptureFixture, + capsys: pytest.CaptureFixture, channel: str, message: str, error: bool, @@ -143,4 +143,5 @@ def test_load_spm_channel_not_found( spm.load_spm(RESOURCES / "sample_0.spm", channel) else: spm.load_spm(RESOURCES / "sample_0.spm", channel) - assert message in caplog.text + captured = capsys.readouterr() + assert message in captured.err From 0a1f459aeb40ddd67d27fc8f440f47835e48c02c Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Fri, 24 Apr 2026 12:59:03 +0100 Subject: [PATCH 35/54] Updating tests for jpk-qi-data and h5-jpk with curve data --- AFMReader/general_loader.py | 2 +- AFMReader/h5_jpk.py | 2 +- tests/test_h5jpk.py | 45 +++++++++++ tests/test_jpk_qi.py | 154 ++++++++++++++++++++++++++++++++++++ 4 files changed, 201 insertions(+), 2 deletions(-) create mode 100644 tests/test_jpk_qi.py diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index fdb3874..600727e 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -80,7 +80,7 @@ def load(self, channel: str | None = None, kwargs: dict = None) -> tuple[npt.NDA if len(h5_returned) == 3: image, pixel_to_nanometre_scaling_factor, _ = h5_returned elif len(h5_returned) == 4: - image, pixel_to_nanometre_scaling_factor, curve_data, _ = h5_returned + image, pixel_to_nanometre_scaling_factor, _, curve_data = h5_returned self.loaded_curves = True print(f"Loaded image with shape {image.shape} and pixel to nanometre scaling factor {pixel_to_nanometre_scaling_factor}") print(f"Image has max value {image.max()} and min value {image.min()}") diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index 12f4a97..18ea4bc 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -545,6 +545,6 @@ def load_h5jpk( all_curve_data = LazyQIData(qi_data_group, shape_x, shape_y, flip_image) - return (image_stack, px2nm, (all_curve_data, channels_units, full_metadata), timestamps) + return (image_stack, px2nm, timestamps, (all_curve_data, channels_units, full_metadata)) return (image_stack, px2nm, timestamps) diff --git a/tests/test_h5jpk.py b/tests/test_h5jpk.py index b0336b2..ff9ca8b 100644 --- a/tests/test_h5jpk.py +++ b/tests/test_h5jpk.py @@ -143,6 +143,51 @@ def test_load_h5jpk( for i in range(len(results_timestamps) - 1) ) +@pytest.mark.parametrize( + ( + "file_name", + "channel", + "flip_image", + "curve_coords", + "curve_direction", + "curve_targets", + ), + [ + pytest.param( + "sample_0_curves.h5-jpk", + "height_trace", + True, + (0, 0), + "Segment_0", + { + "height": (31, 0.00019106604), + "measuredHeight": (31, 0.00027384484), + "smoothedMeasuredHeight": (31, -38066577408.0), + "vDeflection": (31, 2.7409627e-07), + }, + id="test curves 0", + ), + ], +) +def test_load_h5jpk_curves( + file_name: str, + channel: str, + flip_image: bool, + curve_coords: tuple[int, int], + curve_direction: str, + curve_targets: dict[str, tuple[int, float]], + +) -> None: + + _, _, _, result_curve_data = h5_jpk.load_h5jpk( + RESOURCES / file_name, channel, flip_image + ) + all_curves, _, _ = result_curve_data + curve_at_coords = all_curves[curve_coords[0]][curve_coords[1]] + for curve_channel, (expected_size, expected_sum) in curve_targets.items(): + curve = curve_at_coords[curve_channel][curve_direction] + assert curve.shape == (expected_size,) + assert curve.sum() == pytest.approx(expected_sum) def test_load_h5jpk_file_not_found() -> None: """Ensure FileNotFound error is raised.""" diff --git a/tests/test_jpk_qi.py b/tests/test_jpk_qi.py new file mode 100644 index 0000000..a1e6654 --- /dev/null +++ b/tests/test_jpk_qi.py @@ -0,0 +1,154 @@ +"""Test the loading of jpk-qi-data files.""" + +from pathlib import Path + +import numpy as np +import pytest + +from AFMReader import jpk_qi + +BASE_DIR = Path.cwd() +RESOURCES = BASE_DIR / "tests" / "resources" + + +@pytest.mark.parametrize( + ( + "file_name", + "channel", + "pixel_to_nm_scaling", + "image_shape", + "image_dtype", + "image_sum", + "curve_coords", + "curve_direction", + "curve_targets", + ), + [ + pytest.param( + "sample_0.jpk-qi-data", + "height_trace", + 390.62499999999994, + (256, 256), + float, + 412271271.9961158, + (0, 0), + "Segment_0", + { + "height": (31, 0.00019106601492896875), + "vDeflection": (31, 2.740962611337846e-07), + "measuredHeight": (31, 0.00027384485398464497), + "smoothedMeasuredHeight": (31, -38066578894.999535), + }, + id="qi-data 0; height_trace", + ), + pytest.param( + "sample_0.jpk-qi-data", + "slope_trace", + 390.62499999999994, + (256, 256), + float, + 267675.3050073493, + (0, 0), + "Segment_0", + { + "height": (31, 0.00019106601492896875), + "vDeflection": (31, 2.740962611337846e-07), + "measuredHeight": (31, 0.00027384485398464497), + "smoothedMeasuredHeight": (31, -38066578894.999535), + }, + id="qi-data 0; slope_trace", + ), + pytest.param( + "sample_0.jpk-qi-data", + "adhesion_trace", + 390.62499999999994, + (256, 256), + float, + 0.0008930453784792601, + (0, 0), + "Segment_0", + { + "height": (31, 0.00019106601492896875), + "vDeflection": (31, 2.740962611337846e-07), + "measuredHeight": (31, 0.00027384485398464497), + "smoothedMeasuredHeight": (31, -38066578894.999535), + }, + id="qi-data 0; adhesion_trace", + ), + pytest.param( + "sample_0.jpk-qi-data", + "measuredHeight_trace", + 390.62499999999994, + (256, 256), + float, + 590908347.7454677, + (0, 0), + "Segment_0", + { + "height": (31, 0.00019106601492896875), + "vDeflection": (31, 2.740962611337846e-07), + "measuredHeight": (31, 0.00027384485398464497), + "smoothedMeasuredHeight": (31, -38066578894.999535), + }, + id="qi-data 0; measuredHeight_trace", + ), + pytest.param( + "sample_0.jpk-qi-data", + "vDeflection_trace", + 390.62499999999994, + (256, 256), + float, + 0.0004062236060247368, + (0, 0), + "Segment_0", + { + "height": (31, 0.00019106601492896875), + "vDeflection": (31, 2.740962611337846e-07), + "measuredHeight": (31, 0.00027384485398464497), + "smoothedMeasuredHeight": (31, -38066578894.999535), + }, + id="qi-data 0; vDeflection_trace", + ), + ], +) +def test_load_jpk_qi_data( + file_name: str, + channel: str, + pixel_to_nm_scaling: float, + image_shape: tuple[int, int], + image_dtype: type, + image_sum: float, + curve_coords: tuple[int, int], + curve_direction: str, + curve_targets: dict[str, tuple[int, float]], +) -> None: + """Test the normal operation of loading a .jpk-qi-data file.""" + result_image = np.ndarray + result_pixel_to_nm_scaling = float + file_path = RESOURCES / file_name + jpk_qi_loader = jpk_qi.jpk_qi_loader(file_path, channel) + result_image, result_pixel_to_nm_scaling, result_curve_data = jpk_qi_loader.load() # type: ignore + + assert result_pixel_to_nm_scaling == pytest.approx(pixel_to_nm_scaling) + assert isinstance(result_image, np.ndarray) + assert result_image.shape == image_shape + assert result_image.dtype == image_dtype + assert result_image.sum() == pytest.approx(image_sum) + + # Test curve data for all targets + all_curves, _, _ = result_curve_data + curve_at_coords = all_curves[curve_coords[0]][curve_coords[1]] + for curve_channel, (expected_size, expected_sum) in curve_targets.items(): + curve = curve_at_coords[curve_channel][curve_direction] + assert curve.shape == (expected_size,) + assert curve.sum() == pytest.approx(expected_sum) + + jpk_qi_loader.close() # type: ignore + + +def test_load_jpk_data_file_not_found() -> None: + """Ensure FileNotFound error is raised.""" + with pytest.raises(FileNotFoundError): + jpk_qi_loader = jpk_qi.jpk_qi_loader("noexistant_file.jpk-qi-data", "TP") + jpk_qi_loader.load() # type: ignore + jpk_qi_loader.close() # type: ignore From 58a3a0fdf35ef998308700ff57c0b15796d437c5 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Fri, 24 Apr 2026 13:11:06 +0100 Subject: [PATCH 36/54] Updating documentation --- README.md | 53 ++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c574127..c902a84 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,12 @@ Supported file formats | `.ibw` | [WaveMetrics](https://www.wavemetrics.com/) | | `.jpk-qi-image` | [Bruker](https://www.bruker.com/) | | `.jpk` | [Bruker](https://www.bruker.com/) | +| `.jpk-qi-data` | [Bruker](https://www.bruker.com/) | | `.spm` | [Bruker's Format](https://www.bruker.com/) | | `.stp` | [WSXM AFM software files](http://www.wsxm.eu) | | `.top` | `.stp` variant | | `.topostats` | [TopoStats](https://github.com/AFM-SPM/TopoStats) | +| `.bin` | Unspecificied binary file format | Support for the following additional formats is planned. Some of these are already supported in TopoStats and are awaiting refactoring to move their functionality into AFMReader these are denoted in bold below. @@ -116,9 +118,9 @@ from AFMReader.ibw import load_ibw image, pixel_to_nanometre_scaling_factor = load_ibw(file_path="./my_ibw_file.ibw", channel="HeightTracee") ``` -### .jpk +### .jpk and .jpk-qi-image -You can open `.jpk` files using the `load_jpk` function. Just pass in the path +You can open `.jpk` and `.jpk-qi-image` files using the `load_jpk` function. Just pass in the path to the file and the channel name you want to use. (If in doubt, use `height_trace` or `measuredHeight_trace`). ```python @@ -127,6 +129,19 @@ from AFMReader.jpk import load_jpk image, pixel_to_nanometre_scaling_factor = load_jpk(file_path="./my_jpk_file.jpk", channel="height_trace") ``` +### .jpk-qi-data + +You can open `.jpk-qi-data` files using the `jpk_qi_loader` class. Just pass in the path to the file +and the channel name you want to use. Then call the `my_jpk_qi_loader.load()` method. If in doubt, +use `height_trace` or `measuredHeight_trace`. + +```python +from AFMReader.jpk_qi import jpk_qi_loader + +my_jpk_qi_loader = jpk_qi_loader(file_path="./my_jpk_qi_data_file.jpk-qi-data", channel="height_trace") +image, pixel_to_nanometre_scaling_factor, force_curves = my_jpk_qi_loader.load() +``` + ### .h5-jpk You can open `.h5-jpk` files using the `load_h5jpk` function. Just pass in the path @@ -138,9 +153,19 @@ Note: Since `.h5-jpk` stores timeseries AFM data a dictionary of timestamps for ```python from AFMReader.h5_jpk import load_h5jpk -frames, pixel_to_nanometre_scaling_factor, timestamp_dict = load_h5jpk(file_path="./my_jpk_file.jpk", channel="height_trace") +frames, pixel_to_nanometre_scaling_factor, timestamp_dict = load_h5jpk(file_path="./my_jpk_file.h5-jpk", channel="height_trace") +``` + +If your `.h5-jpk` file was created from a `.jpk-qi-data` file, then the curve data can be read like so. Note that reading force +curves like this will keep the file open as the force curves are lazy loaded from your hard drive. + +```python +from AFMReader.h5_jpk import load_h5jpk + +frames, pixel_to_nanometre_scaling_factor, timestamp_dict, force_curves = load_h5jpk(file_path="./my_jpk_file.h5-jpk", channel="height_trace") ``` + ### .stp You can open `.stp` files using the `load_stp` function. Just pass in the path @@ -163,6 +188,28 @@ from AFMReader.top import load_top image, pixel_to_nanometre_scaling_factor = load_top(file_path="./my_top_file.top") ``` +### .bin + +You can open unspecified binary files using the `load_bin` function. You must supply the path +to the file, the data type, the byte offset where the image data begins, and the physical dimensions +of the scan. Supported `data_type` values include `"IEEE double"`, `"IEEE single"`, `"float64"`, +`"float32"`, `"I32"`, `"U32"`, `"I16"`, `"U16"`, `"I8"`, and `"U8"`. + +```python +from AFMReader.bin import load_bin + +image, pixel_to_nanometre_scaling_factor = load_bin( + filepath="./my_binary_file.bin", + data_type="IEEE double", + offset_bytes=0, + size_x=1000.0, # physical width in nm + size_y=1000.0, # physical height in nm + shape_x=512, # pixels along x + shape_y=512, # pixels along y + z_scaling=1.0, # optional z-axis scaling factor +) +``` + ## Contributing Bug reports and feature requests are welcome. Please search for existing issues, if none relating to your bug/feature From c0a2924cbc1f448ca333bfd64f8c1480474d8c6e Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Sat, 2 May 2026 11:15:23 +0100 Subject: [PATCH 37/54] Reformatting to match pre-commit conditions and make more robust --- AFMReader/asd.py | 17 +- AFMReader/bin.py | 57 ----- AFMReader/general_loader.py | 50 +++- AFMReader/gwy.py | 15 ++ AFMReader/h5_jpk.py | 224 ++++++++++++----- AFMReader/ibw.py | 15 ++ AFMReader/jpk.py | 45 +++- AFMReader/jpk_qi.py | 486 +++++++++++++++++++++--------------- AFMReader/jpk_utils.py | 186 ++++++++++++++ AFMReader/raw_bin.py | 116 +++++++++ AFMReader/spm.py | 17 +- 11 files changed, 880 insertions(+), 348 deletions(-) delete mode 100644 AFMReader/bin.py create mode 100644 AFMReader/jpk_utils.py create mode 100644 AFMReader/raw_bin.py diff --git a/AFMReader/asd.py b/AFMReader/asd.py index fd42733..4187e12 100644 --- a/AFMReader/asd.py +++ b/AFMReader/asd.py @@ -287,7 +287,21 @@ def load_asd(file_path: str | Path, channel: str): logger.info(f"[{filename}] : Extracted image.") return frames, pixel_to_nanometre_scaling_factor, header_dict + def get_asd_channels(file_path: Path): + """ + Get the channels available in given .asd file. + + Parameters + ---------- + file_path : Path + Path to the .asd file. + + Returns + ------- + list + List of channels available in the .asd file. + """ with Path.open(file_path, "rb", encoding=None) as open_file: # pylint: disable=unspecified-encoding file_version = read_file_version(open_file) @@ -303,8 +317,7 @@ def get_asd_channels(file_path: Path): raise ValueError( f"File version {file_version} unknown. Please add support if you know how to decode this file version." ) - channel_list = [header_dict["channel1"], header_dict["channel2"]] - return channel_list + return [header_dict["channel1"], header_dict["channel2"]] def read_file_version(open_file: BinaryIO) -> int: diff --git a/AFMReader/bin.py b/AFMReader/bin.py deleted file mode 100644 index 3209215..0000000 --- a/AFMReader/bin.py +++ /dev/null @@ -1,57 +0,0 @@ -import numpy as np -import math -from .logging import logger -from pathlib import Path - -DTYPE_MAP = { - 'IEEE double': np.float64, - 'DBL': np.float64, - 'IEEE single': np.float32, - 'SGL': np.float32, - 'U32': np.uint32, - 'I32': np.int32, - 'U16': np.uint16, - 'I16': np.int16, - 'U8': np.uint8, - 'I8': np.int8, - 'float64': np.float64, - 'float32': np.float32, - 'int32': np.int32, -} - -def load_bin(filepath: str | Path, data_type: str, offset_bytes: int, size_x: float = None, size_y: float = None, shape_x: int = None, shape_y: int = None, z_scaling: float = 1.0): - filepath = Path(filepath) - dt_key = str(data_type).strip() - shape_x = None if shape_x == 0 else shape_x - shape_y = None if shape_y == 0 else shape_y - - if dt_key in DTYPE_MAP: - np_dtype = DTYPE_MAP[dt_key] - else: - logger.warning(f"Unknown data type '{dt_key}'. Defaulting to float64.") - np_dtype = np.float64 - with open(filepath, 'rb') as f: - f.seek(offset_bytes) - flat_data = np.fromfile(f, dtype=np_dtype) - if None in [shape_x, shape_y]: - dimension = int(math.sqrt(len(flat_data))) - shape_x, shape_y = dimension, dimension - if shape_x * shape_y != len(flat_data): - logger.error(f"Loading binary file {filepath.stem} did not receive a shape and is not square") - image = flat_data.reshape((shape_x, shape_y)) - image *= z_scaling - pixel_to_nm_scaling_factor_x = size_x / shape_x if shape_x > 0 else 1.0 - pixel_to_nm_scaling_factor_y = size_y / shape_y if shape_y > 0 else 1.0 - px2nm = (pixel_to_nm_scaling_factor_x + pixel_to_nm_scaling_factor_y) / 2 - return image, px2nm - -def get_bin_channels(): - kwarg_types = {"data_type" : (str, DTYPE_MAP.keys()), - "offset_bytes": int, - "size_x": float, - "size_y": float, - "shape_x": int, - "shape_y": int, - "z_scaling": float} - return [], kwarg_types - diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index 600727e..e5efded 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -1,17 +1,17 @@ """Switchboard for input files.""" from pathlib import Path +from typing import Any import numpy.typing as npt -from AFMReader import asd, gwy, h5_jpk, ibw, jpk, spm, stp, top, topostats, jpk_qi, bin +from AFMReader import asd, gwy, h5_jpk, ibw, jpk, raw_bin, spm, stp, top, topostats, jpk_qi from AFMReader.logging import logger - logger.enable(__package__) -# pylint: disable=too-few-public-methods +# pylint: disable=too-few-public-methods,too-many-branches,too-many-statements,fixme class LoadFile: """ Class to handle the general loading of an AFM file. @@ -22,9 +22,11 @@ class LoadFile: Path to the AFM image. channel : str Channel to extract from the AFM image. + kwargs : dict, optional + Additional keyword arguments to pass to the specific loaders """ - def __init__(self, filepath: str | Path, channel: str, kwargs: dict = None): + def __init__(self, filepath: str | Path, channel: str, kwargs: dict | None = None): """ Initialise the general LoadFile class with a filepath and channel. @@ -34,6 +36,8 @@ def __init__(self, filepath: str | Path, channel: str, kwargs: dict = None): Path to the AFM image. channel : str Channel to extract from the AFM image. + kwargs : dict, optional + Additional keyword arguments to pass to the specific loaders """ self.filepath = Path(filepath) self.channel = channel @@ -42,17 +46,26 @@ def __init__(self, filepath: str | Path, channel: str, kwargs: dict = None): self.kwargs = kwargs if kwargs else {} # Store heavy loaded data in a dict to avoid having to reload it - self.cached_data = {} + self.cached_data: dict[str, Any] = {} - def load(self, channel: str | None = None, kwargs: dict = None) -> tuple[npt.NDArray | str, float | None]: # noqa: C901 + def load( # noqa: C901 + self, channel: str | None = None, kwargs: dict | None = None + ) -> tuple[npt.NDArray | str, float | None] | tuple[npt.NDArray | str, float | None, Any]: """ Generally loads a file type that can be handled by AFMReader. + Parameters + ---------- + channel : str, optional + Overriding channel to extract from the AFM image. + kwargs : dict, optional + Additional keyword arguments to pass to the specific loaders + Returns ------- tuple The image data (stack if ''.asd'' or ''.h5-jpk'') and the pixel to nanometre scaling ratio. - If curve data is found, also return the curve data (a large dict of all the curves) + If curve data is found, also return the curve data (a large dict of all the curves). Raises ------ @@ -82,14 +95,19 @@ def load(self, channel: str | None = None, kwargs: dict = None) -> tuple[npt.NDA elif len(h5_returned) == 4: image, pixel_to_nanometre_scaling_factor, _, curve_data = h5_returned self.loaded_curves = True - print(f"Loaded image with shape {image.shape} and pixel to nanometre scaling factor {pixel_to_nanometre_scaling_factor}") + print( + f"Loaded image with shape {image.shape} and pixel to nanometre " + f"scaling factor {pixel_to_nanometre_scaling_factor}" + ) print(f"Image has max value {image.max()} and min value {image.min()}") return image, pixel_to_nanometre_scaling_factor, curve_data else: logger.error(f"Loading h5-jpk file returned unexpected number of items: {len(h5_returned)}") elif self.suffix == ".jpk-qi-data": if "jpk_qi_loader" not in self.cached_data: - self.cached_data["jpk_qi_loader"] = jpk_qi.jpk_qi_loader(filepath=self.filepath, channel=self.channel, **self.kwargs) + self.cached_data["jpk_qi_loader"] = jpk_qi.jpk_qi_loader( + filepath=self.filepath, channel=self.channel, **self.kwargs + ) jpk_qi_returned = self.cached_data["jpk_qi_loader"].load(channel=self.channel, **self.kwargs) if len(jpk_qi_returned) == 2: image, pixel_to_nanometre_scaling_factor = jpk_qi_returned @@ -116,7 +134,7 @@ def load(self, channel: str | None = None, kwargs: dict = None) -> tuple[npt.NDA f"{[im for im in image_keys if im in topostats_keys]}" ) from exc elif self.suffix == ".bin": - image, pixel_to_nanometre_scaling_factor = bin.load_bin(self.filepath, **self.kwargs) + image, pixel_to_nanometre_scaling_factor = raw_bin.load_bin(self.filepath, **self.kwargs) else: raise ValueError(f"File type '{self.suffix}' is not currently handled by AFMReader.") @@ -126,7 +144,15 @@ def load(self, channel: str | None = None, kwargs: dict = None) -> tuple[npt.NDA logger.error(f"{e}") raise e - def get_available_channels(self): + def get_available_channels(self): # noqa: C901 + """ + Get the available channels for the file type. + + Returns + ------- + list + List of available channels. + """ if self.suffix == ".asd": available_channels = asd.get_asd_channels(self.filepath) elif self.suffix == ".gwy": @@ -146,7 +172,7 @@ def get_available_channels(self): elif self.suffix == ".topostats": available_channels = ["image", "image_original"] elif self.suffix == ".bin": - available_channels = bin.get_bin_channels() + available_channels = raw_bin.get_bin_channels() elif self.suffix in [".stp", ".top"]: return [] else: diff --git a/AFMReader/gwy.py b/AFMReader/gwy.py index c52cd3f..e36f50b 100644 --- a/AFMReader/gwy.py +++ b/AFMReader/gwy.py @@ -9,7 +9,21 @@ from AFMReader.io import read_char, read_double, read_null_terminated_string, read_uint32 + def get_gwy_channels(file_path): + """ + Extract a list of available channels and their corresponding dictionary key ids from the `.gwy` file. + + Parameters + ---------- + file_path : Path or str + Path to the .gwy file. + + Returns + ------- + list + List of available channels. + """ image_data_dict: dict[Any, Any] = {} with Path.open(file_path, "rb") as open_file: # pylint: disable=unspecified-encoding # Read header @@ -21,6 +35,7 @@ def get_gwy_channels(file_path): return list(channel_ids) + def load_gwy(file_path: Path | str, channel: str) -> tuple[np.ndarray[Any, np.float64], float]: """ Extract image and pixel to nm scaling from the .gwy file. diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index 18ea4bc..cf37cd7 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -6,14 +6,22 @@ """ from pathlib import Path +from typing import Any import h5py import numpy as np from AFMReader.logging import logger +from AFMReader.jpk_utils import ( + LazyMetaProxy, + LazyMetadata, + LazyQiData, +) logger.enable(__package__) +# pylint: disable=too-few-public-methods,too-many-locals,fixme + def _parse_channel_name(channel: str) -> tuple[str, str]: """ @@ -269,33 +277,58 @@ def generate_timestamps(num_frames: int, line_rate: float, image_size: int) -> d def get_h5jpk_channels(file_path: Path | str): - with h5py.File(file_path, "r") as f: - available_channels = list(_available_channels(f)) - return available_channels + """ + Get available channels from a .h5-jpk file. + + Parameters + ---------- + file_path : Path | str + Path to the .h5-jpk file. + Returns + ------- + list + List of available channels. + """ + with h5py.File(file_path, "r") as f: + return list(_available_channels(f)) -class LazyQIData: - def __init__(self, qi_data_group: h5py.Group, shape_x: int, shape_y: int, flip_image: bool = True): - self.qi_data_group = qi_data_group - self.shape_x = shape_x - self.shape_y = shape_y - self.dims = (shape_y, shape_x) - self.flip_image = flip_image - self.CHUNK_SIZE = 500 # Number of curves to load at once when iterating - def __getitem__(self, y: int): +class LazyH5QiData(LazyQiData): + """ + A proxy class that fetches QI curve data from the HDF5 file on demand. - class RowProxy: - def __init__(self, parent, y): - self.parent = parent - self.y = y + It behaves like a 2D array of shape (shape_y, shape_x) where each element + is a dictionary containing the QI curve data for that pixel. + """ - def __getitem__(self, x: int): - return self.parent._fetch_curve(self.y, x) + def __init__(self, qi_data_group: h5py.Group, shape_x: int, shape_y: int, flip_image: bool = True): + """ + Initialize the LazyH5QiData proxy. + + Parameters + ---------- + qi_data_group : h5py.Group + The HDF5 group containing the QI curve data. + shape_x : int + The number of columns in the image. + shape_y : int + The number of rows in the image. + flip_image : bool, optional + Whether to flip the image vertically. Default is ``True``. + """ + super().__init__(shape_x, shape_y, flip_image) + self.qi_data_group = qi_data_group - return RowProxy(self, y) + def __iter__(self): # noqa: C901 + """ + Efficiently iterate over the QI curve data, loading one row at a time. - def __iter__(self): + Yields + ------ + dict + A dictionary containing the QI curve data for each channel and segment. + """ indicies_map = {} for segment, segment_group in self.qi_data_group["Curves"].items(): for channel in segment_group["Indicies"]: @@ -326,6 +359,21 @@ def __iter__(self): yield curve_data def _fetch_curve(self, y: int, x: int): + """ + Fetch the QI curve data for a specific pixel (x, y) on demand. + + Parameters + ---------- + y : int + The row index. + x : int + The column index. + + Returns + ------- + dict + A dictionary containing the QI curve data for the specified pixel. + """ if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x: raise IndexError(f"Curve index out of bounds: ({x}, {y})") curve_dict = {} @@ -342,6 +390,14 @@ def _fetch_curve(self, y: int, x: int): return curve_dict def load_all_curves(self): + """ + Load all QI curve data into memory. + + Returns + ------- + list + A 2D list containing dictionaries with QI curve data for each pixel. + """ all_curves = [[{} for _ in range(self.shape_x)] for _ in range(self.shape_y)] for segment, segment_group in self.qi_data_group["Curves"].items(): for channel in segment_group["Indicies"]: @@ -361,63 +417,100 @@ def load_all_curves(self): return all_curves -class LazyCurveMetadata: +class LazyH5Metadata(LazyMetadata): """A proxy class that fetches header.properties files on demand.""" def __init__( self, qi_data_group: h5py.Group, top_level_meta: dict, shape_x: int, shape_y: int, flip_image: bool = True ): + """ + Initialize the LazyH5Metadata proxy. + + Parameters + ---------- + qi_data_group : h5py.Group + The HDF5 group containing the QI curve data. + top_level_meta : dict + The top-level metadata dictionary. + shape_x : int + The number of columns in the image. + shape_y : int + The number of rows in the image. + flip_image : bool, optional + Whether to flip the image vertically. Default is ``True``. + """ self.qi_data_group = qi_data_group - self.top_level_meta = top_level_meta - self.shape_x = shape_x - self.shape_y = shape_y - self.flip_image = flip_image - # Expose top_level so the frontend can still do `raw_metadata["top_level"]` - self.top_level = top_level_meta + super().__init__(top_level_meta, shape_x, shape_y, flip_image) def __getitem__(self, key): + """ + Fetch metadata based on the key. + + Parameters + ---------- + key : str + The key to fetch metadata for. + + Returns + ------- + object + The fetched metadata, a lazy object that can be further queried. + """ if key == "top_level": return self.top_level - elif key == "curves": - return LazyMetaProxy(self.qi_data_group, "curve", self.shape_x, self.shape_y, self.flip_image) - elif key == "segments": - return LazyMetaProxy(self.qi_data_group, "segment", self.shape_x, self.shape_y, self.flip_image) + if key == "curves": + return LazyH5MetaProxy(self.qi_data_group, "curve", self.shape_x, self.shape_y, self.flip_image) + if key == "segments": + return LazyH5MetaProxy(self.qi_data_group, "segment", self.shape_x, self.shape_y, self.flip_image) raise KeyError(key) -class LazyMetaProxy: - def __init__(self, qi_data_group: h5py.Group, meta_type: str, shape_x: int, shape_y: int, flip_image: bool = True): - self.qi_data_group = qi_data_group - self.meta_type = meta_type - self.shape_x = shape_x - self.shape_y = shape_y - self.flip_image = flip_image - - def __getitem__(self, y: int): - class RowProxy: - def __init__(self, parent, y): - self.parent = parent - self.y = y - - def __getitem__(self, x): - if self.parent.meta_type == "curve": - return self.parent._fetch_meta(self.y, x) - elif self.parent.meta_type == "segment": - - class SegmentMetaProxy: - def __init__(self, parent, y, x): - self.parent = parent - self.y = y - self.x = x - - def __getitem__(self, direction): - return self.parent.parent._fetch_meta(self.y, self.x, direction) +class LazyH5MetaProxy(LazyMetaProxy): + """ + A proxy class that fetches curve or segment metadata from the HDF5 file on demand. - return SegmentMetaProxy(self, self.y, x) + It behaves like a 2D array of shape (shape_y, shape_x) where each element + is a dictionary containing the requested metadata for that pixel. + """ - return RowProxy(self, y) + def __init__(self, qi_data_group: h5py.Group, meta_type: str, shape_x: int, shape_y: int, flip_image: bool = True): + """ + Initialize the LazyH5MetaProxy. + + Parameters + ---------- + qi_data_group : h5py.Group + The HDF5 group containing the QI curve data. + meta_type : str + The type of metadata to fetch ("curve" or "segment"). + shape_x : int + The number of columns in the image. + shape_y : int + The number of rows in the image. + flip_image : bool, optional + Whether to flip the image vertically. Default is ``True``. + """ + super().__init__(meta_type, shape_x, shape_y, flip_image) + self.qi_data_group = qi_data_group - def _fetch_meta(self, y: int, x: int, direction: int = None): + def _fetch_meta(self, y: int, x: int, direction: int | None = None): + """ + Fetch metadata for a specific pixel (x, y) on demand. + + Parameters + ---------- + y : int + The row index. + x : int + The column index. + direction : int, optional + The direction index for segment metadata (0 or 1), required if meta_type is "segment". + + Returns + ------- + dict + A dictionary containing the fetched metadata. + """ if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x: raise IndexError(f"Curve index out of bounds: ({x}, {y})") if self.flip_image: @@ -442,7 +535,7 @@ def _fetch_meta(self, y: int, x: int, direction: int = None): def load_h5jpk( file_path: Path | str, channel: str, flip_image: bool = True, load_curves: bool = True -) -> tuple[np.ndarray, float, dict[str, float]]: +) -> tuple[np.ndarray, float, dict[str, float]] | tuple[np.ndarray, float, dict[str, float], Any]: """ Load image from JPK Instruments .h5-jpk files. @@ -463,6 +556,9 @@ def load_h5jpk( Scaling factor converting pixels to nanometers. timestamps : dict[str, float] Dictionary mapping frame labels (e.g., "frame 0") to timestamp values in seconds. + curves_data : tuple(LazyH5QiData, dict, LazyH5Metadata), optional + Tuple containing lazy-loaded QI curve data, channel units, and metadata. + Returned only if load_curves is True and QI curve data is present in the file. Raises ------ @@ -541,9 +637,9 @@ def load_h5jpk( channels_units[key.split(".")[-1]] = value top_level_meta[key] = value - full_metadata = LazyCurveMetadata(qi_data_group, top_level_meta, shape_x, shape_y, flip_image) + full_metadata = LazyH5Metadata(qi_data_group, top_level_meta, shape_x, shape_y, flip_image) - all_curve_data = LazyQIData(qi_data_group, shape_x, shape_y, flip_image) + all_curve_data = LazyH5QiData(qi_data_group, shape_x, shape_y, flip_image) return (image_stack, px2nm, timestamps, (all_curve_data, channels_units, full_metadata)) diff --git a/AFMReader/ibw.py b/AFMReader/ibw.py index 51cd08a..624d2f6 100644 --- a/AFMReader/ibw.py +++ b/AFMReader/ibw.py @@ -38,7 +38,22 @@ def _ibw_pixel_to_nm_scaling(scan: dict) -> float: float(notes["FastScanSize"]) / scan["wave"]["wData"].shape[1] * 1e9, # as in m )[0] + def get_ibw_channels(file_path: Path | str): + """ + Extract a list of available channels and their corresponding dictionary key ids from the `.ibw` file. + + Parameters + ---------- + file_path : Path or str + Path to the .ibw file. + + Returns + ------- + list + List of available channels. + """ + file_path = Path(file_path) filename = file_path.stem scan = binarywave.load(file_path) logger.info(f"[{filename}] : Loaded image from : {file_path}") diff --git a/AFMReader/jpk.py b/AFMReader/jpk.py index a06c3f8..f69cc1d 100644 --- a/AFMReader/jpk.py +++ b/AFMReader/jpk.py @@ -12,7 +12,7 @@ logger.enable(__package__) -# pylint: disable=too-many-locals +# pylint: disable=too-many-locals,too-many-positional-arguments,fixme def _jpk_pixel_to_nm_scaling(tiff_page: tifffile.tifffile.TiffPage, jpk_tags: dict[str, int]) -> float: @@ -171,8 +171,9 @@ def _get_z_scaling(tif: tifffile.tifffile, channel_idx: int, jpk_tags: dict[str, raise ValueError(f"Scaling type {scaling_type} is not 'NullScaling' or 'LinearScaling'") return scaling, offset + def _get_jpk_channels( - file: Path | BytesIO, filename: str, file_path: Path | str, config_path: Path | str | None = None + file: Path | BytesIO, filename: str, file_path: Path | str, config_path: Path | str | None = None ): jpk_tags = _load_jpk_tags(config_path) try: @@ -191,10 +192,24 @@ def _get_jpk_channels( channel_list[f"{available_channel}_{tr_rt}"] = i + 1 return channel_list -def get_jpk_channels( - file_path: Path | str, config_path: Path | str | None = None -) -> list[str]: +def get_jpk_channels(file_path: Path | str, config_path: Path | str | None = None) -> list[str]: + """ + Get the list of channels available in the .jpk file. + + Parameters + ---------- + file_path : Path | str + Path to the .jpk file. + config_path : Path | str | None + Path to a configuration file. If ''None'' (default) then the packages + default configuration is loaded from ''default_config.yaml''. + + Returns + ------- + list[str] + List of available channels. + """ file_path = Path(file_path) filename = file_path.stem return _get_jpk_channels(file_path, filename, file_path, config_path) @@ -242,10 +257,26 @@ def load_jpk( logger.info(f"Loading image from : {file_path}") file_path = Path(file_path) filename = file_path.stem - image, px2nm = _load_jpk(file=file_path, filename=filename, channel=channel, file_suffix=file_path.suffix, config_path=config_path, flip_image=flip_image) + image, px2nm = _load_jpk( + file=file_path, + filename=filename, + channel=channel, + file_suffix=file_path.suffix, + config_path=config_path, + flip_image=flip_image, + ) return (image, px2nm) -def _load_jpk(file: Path | BytesIO, filename: str, channel: str, file_suffix: str, config_path: Path | str | None = None, flip_image: bool = True, convert_to_nm: bool = True) -> tuple[np.ndarray, float]: + +def _load_jpk( + file: Path | BytesIO, + filename: str, + channel: str, + file_suffix: str, + config_path: Path | str | None = None, + flip_image: bool = True, + convert_to_nm: bool = True, +) -> tuple[np.ndarray, float]: jpk_tags = _load_jpk_tags(config_path) try: tif = tifffile.TiffFile(file) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 7b0b6f2..95b3884 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -1,54 +1,86 @@ +""" +Module to decode and load JPK QI (Quantitative Imaging) data files. + +It provides lazy loading for curve data and metadata to minimize memory usage, +and supports exporting to HDF5 format. +""" + +# pylint: disable=too-many-lines,too-many-positional-arguments,too-few-public-methods,too-many-instance-attributes +# pylint: disable=too-many-locals,too-many-branches,protected-access,attribute-defined-outside-init,fixme + import os -from pathlib import Path -from contextlib import nullcontext import io import zipfile import time +from pathlib import Path +from contextlib import nullcontext +from typing import Any import numpy as np import javaproperties import h5py import psutil +from AFMReader.jpk_utils import LazyMetadata, LazyMetaProxy, LazyQiData from AFMReader.logging import logger from AFMReader import jpk -class LazyCurveData: - """A proxy class that behaves like a 2D list but fetches .dat files on demand.""" +class LazyJpkQiData(LazyQiData): + """A proxy class that behaves like a 2D list of shape (shape_y, shape_x) but fetches .dat file data on demand.""" - def __init__(self, filepath, shape_x, shape_y, channel_scaling, archive, flip_image: bool = True): + def __init__(self, filepath, shape_x: int, shape_y: int, channel_scaling, archive, flip_image: bool = True): + """ + Initialize the LazyJpkQiData instance. + + Parameters + ---------- + filepath : str + Path to the .jpk file. + shape_x : int + Number of columns in the image. + shape_y : int + Number of rows in the image. + channel_scaling : dict + Dictionary containing scaling information for each channel. + archive : zipfile.ZipFile + The opened ZIP archive containing the .dat files. + flip_image : bool, optional + Whether to flip the image vertically. Default is ``True``. + """ + super().__init__(shape_x, shape_y, flip_image) self.filepath = filepath - self.shape_x = shape_x - self.shape_y = shape_y self.channel_scaling = channel_scaling self.archive = archive - self.flip_image = flip_image - - def __getitem__(self, y: int): - # Return a row proxy to handle the second index [x] - class RowProxy: - def __init__(self, parent, y): - self.parent = parent - self.y = y - - def __getitem__(self, x: int): - return self.parent._fetch_curve(self.y, x) - - return RowProxy(self, y) def __iter__(self): + """Yield the curve data for each pixel in the image, iterating in row-major order (y first, then x).""" for y in range(self.shape_y): for x in range(self.shape_x): yield self._fetch_curve(y, x) def _fetch_curve(self, y: int, x: int): + """ + Fetch the curve data for a specific pixel. + + Parameters + ---------- + y : int + Row index of the pixel. + x : int + Column index of the pixel. + + Returns + ------- + curve_data : dict + Dictionary containing the curve data for the specified pixel. + """ if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x: raise IndexError(f"Curve index out of bounds: ({x}, {y})") if self.flip_image: y = self.shape_y - 1 - y curve_num = y * self.shape_x + x - curve_data = {} + curve_data: dict[str, Any] = {} for chan_name, scale in self.channel_scaling.items(): curve_data[chan_name] = {} @@ -73,6 +105,11 @@ def load_all_curves(self): This can be used if the user wants to have all the curve data available at once, but it is not recommended for large datasets as it will consume a lot of memory. In this case, it is not notably faster as the zip structure means each curve is effectively loaded individually anyway + + Returns + ------- + all_curve_data : list + A 2D list containing dictionaries with curve data for each pixel. """ all_curve_data = [[None for _ in range(self.shape_x)] for _ in range(self.shape_y)] for y in range(self.shape_y): @@ -83,65 +120,116 @@ def load_all_curves(self): return all_curve_data def close(self): + """Close the ZIP archive when done to free up resources.""" self.archive.close() -class LazyCurveMetadata: - """A proxy class that fetches header.properties files on demand.""" +class LazyQiMetadata(LazyMetadata): + """ + A proxy class that fetches header.properties files on demand. + + It behaves like a 2D array of shape (shape_y, shape_x) where each element + is a dictionary containing the requested metadata for that pixel. + """ def __init__(self, filepath, top_level_meta, archive, shape_x: int, shape_y: int, flip_image: bool = True): + """ + Initialize the LazyQiMetadata instance. + + Parameters + ---------- + filepath : str + Path to the .jpk file. + top_level_meta : dict + Dictionary containing the top-level metadata extracted from the header files. + archive : zipfile.ZipFile + The opened ZIP archive containing the JPK file contents. + shape_x : int + The number of columns in the image. + shape_y : int + The number of rows in the image. + flip_image : bool, optional + Whether to flip the image vertically. Default is True. + """ self.filepath = filepath # Expose top_level so the frontend can still do `raw_metadata["top_level"]` - self.top_level = top_level_meta self.archive = archive - self.shape_x = shape_x - self.shape_y = shape_y - self.flip_image = flip_image + super().__init__(top_level_meta, shape_x, shape_y, flip_image) def __getitem__(self, key): + """ + Fetch requested metadata based on key. + + If the key is 'top_level', it returns the top-level metadata. + If the key is 'curves' or 'segments', it returns a LazyQiMetaProxy + that can be used to fetch curve or segment metadata on demand. + + Parameters + ---------- + key : str + The key to fetch metadata for. + + Returns + ------- + dict or LazyQiMetaProxy + The requested metadata. + """ if key == "top_level": return self.top_level - elif key == "curves": - return LazyMetaProxy(self.filepath, "curve", self.archive, self.shape_x, self.shape_y, self.flip_image) - elif key == "segments": - return LazyMetaProxy(self.filepath, "segment", self.archive, self.shape_x, self.shape_y, self.flip_image) + if key == "curves": + return LazyQiMetaProxy(self.filepath, "curve", self.archive, self.shape_x, self.shape_y, self.flip_image) + if key == "segments": + return LazyQiMetaProxy(self.filepath, "segment", self.archive, self.shape_x, self.shape_y, self.flip_image) raise KeyError(key) -class LazyMetaProxy: +class LazyQiMetaProxy(LazyMetaProxy): + """ + A proxy class that behaves like a 2D array of shape (shape_y, shape_x) + but fetches header.properties files on demand for curves or segments. + """ + def __init__(self, filepath, meta_type, archive, shape_x: int, shape_y: int, flip_image: bool = True): + """ + Initialize the LazyQiMetaProxy instance. + + Parameters + ---------- + filepath : str + Path to the .jpk file. + meta_type : str + The type of metadata to fetch ('curve' or 'segment'). + archive : zipfile.ZipFile + The opened ZIP archive containing the JPK file contents. + shape_x : int + The number of columns in the image. + shape_y : int + The number of rows in the image. + flip_image : bool, optional + Whether to flip the image vertically. Default is True. + """ self.filepath = filepath - self.meta_type = meta_type self.archive = archive - self.shape_x = shape_x - self.shape_y = shape_y - self.flip_image = flip_image - - def __getitem__(self, y: int): - class RowProxy: - def __init__(self, parent, y): - self.parent = parent - self.y = y - - def __getitem__(self, x): - if self.parent.meta_type == "curve": - return self.parent._fetch_meta(self.y, x) - elif self.parent.meta_type == "segment": - - class SegmentMetaProxy: - def __init__(self, parent, y, x): - self.parent = parent - self.y = y - self.x = x + super().__init__(meta_type, shape_x, shape_y, flip_image) - def __getitem__(self, direction): - return self.parent.parent._fetch_meta(self.y, self.x, direction) - - return SegmentMetaProxy(self, self.y, x) + def _fetch_meta(self, y: int, x: int, direction: int | None = None): + """ + Fetch the metadata for a specific curve or segment. - return RowProxy(self, y) + Parameters + ---------- + y : int + Row index of the curve or segment. + x : int + Column index of the curve or segment. + direction : int, optional + The direction index for segment metadata. Required if meta_type is 'segment'. - def _fetch_meta(self, y: int, x: int, direction: int = None): + Returns + ------- + dict + The metadata dictionary for the specified curve or segment. + """ if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x: raise IndexError(f"Curve index out of bounds: ({x}, {y})") if self.flip_image: @@ -165,8 +253,7 @@ def _fetch_meta(self, y: int, x: int, direction: int = None): def _get_channel_scaling(props, channel_index): """ - Parses the JPK properties dictionary to find the cumulative multiplier - and offset for a specific channel index (e.g., '1' for vDeflection). + Parse the JPK properties dictionary to find cumulative multiplier and offset for a specific channel index. Parameters ---------- @@ -227,7 +314,7 @@ def _get_channel_scaling(props, channel_index): class jpk_qi_loader: - """Class for readability and improving modularity in the load jpk qi data function""" + """Class for readability and improving modularity in the load jpk qi data function.""" def __init__( self, @@ -238,14 +325,14 @@ def __init__( save_as_h5: bool = False, ): """ - Initializes the loader with the provided parameters. + Initialize the loader with the provided parameters. Parameters ---------- filepath : Path | str The path to the .jpk-qi file to be loaded. channel : str | None, optional - The specific channel to be extracted from the file (e.g., "measuredHeight"). Default is None. + The specific channel to be extracted (e.g., "measuredHeight"). Default is None. config_path : Path | str | None, optional The path to the configuration file, if any. Default is None. flip_image : bool | None, optional @@ -253,14 +340,13 @@ def __init__( save_as_h5 : bool, optional Whether to save the loaded data as an H5 file. Default is False. """ - self.filepath = Path(filepath) self.channel = channel self.config_path = config_path self.flip_image = flip_image self.save_as_h5 = save_as_h5 - # Open the ZIP archive once and keep it open for the duration of the loading process to improve performance when accessing multiple files within the archive + # Open the ZIP archive once and keep it open for the duration of the loading process self.qi_archive = zipfile.ZipFile(self.filepath, "r") logger.info(f"Opened JPK QI archive at {self.filepath}") self.namelist = self.qi_archive.namelist() @@ -273,7 +359,7 @@ def __init__( self.INDICIES_CHUNKSIZE = 64 * 1024 # Chunk size for metadata datasets (if needed) self.META_CHUNKSIZE = 64 * 1024 - # Maximum number of curves to check for changing metadata keys (to avoid checking every curve in large datasets) + # Maximum number of curves to check for changing metadata keys (to avoid checking every curve) self.MAX_CURVE_CHECKS = 20 # Number of curves to hold in buffer self.BUFFER_SIZE = 500 @@ -281,20 +367,23 @@ def __init__( # Initialize key attributes that will be returned / accessed frequently # Just the top level metadata extracted from the header files - self.top_level_meta = {} - # A dictionary containing all metadata, splitting the top level metadata and the metadata for each curve and segment - self.full_metadata = {} - # A 2D list of curve data dictionaries, where each dictionary contains the data for all channels and segments - self.curve_data = None + self.top_level_meta: dict[str, Any] = {} + # A dictionary containing all metadata + self.full_metadata: dict[str, Any] = {} + # A 2D list of curve data dictionaries + self.curve_data: Any = None # A lookup for channel name to unit to be returned - self.channels_units = {} + self.channels_units: dict[str, str] = {} # The list of channels for the segments with their scaling information extracted from the shared header - self.segment_channels = [] - self.curve_meta = {} - self.segment_meta = {} + self.segment_channels: list[dict[str, Any]] = [] + self.curve_meta: dict[str, Any] = {} + self.segment_meta: dict[str, Any] = {} # Define the image shape and size attributes - self.size_x, self.size_y, self.shape_x, self.shape_y = None, None, None, None - self.failed_curves = set() + self.size_x: float | None = None + self.size_y: float | None = None + self.shape_x: int | None = None + self.shape_y: int | None = None + self.failed_curves: set[tuple[int, int | None, str | None]] = set() # Timing counters for performance monitoring self.t_load_data = 0.0 @@ -310,16 +399,15 @@ def __init__( def get_available_channels(self): """ - Retrieves the available channels from the .jpk-qi-image file within the archive, and adds any additional calculated channels. + Retrieve available channels from the .jpk-qi-image file within the archive. Returns ------- channels : list A list of available channels including the calculated channels. metadata_options : dict - A dictionary of options for what metadata to return + A dictionary of options for what metadata to return. """ - # Look for the jpk-qi-image file in the archive if self.path_to_image is None: for file_name in self.namelist: @@ -339,28 +427,27 @@ def load( config_path: Path | str | None = None, flip_image: bool | None = True, save_as_h5: bool | None = None, - ) -> tuple[np.ndarray, float, dict] | tuple[np.ndarray, float]: + ) -> tuple[np.ndarray, float, Any] | tuple[np.ndarray, float]: """ - Loads the .jpk-qi-data file + Load the .jpk-qi-data file. Parameters ---------- channel : str | None, optional - The specific channel to be extracted from the file (e.g., "measuredHeight"). If None, the default channel will be used. Default is None. + The specific channel to be extracted. Default is None. config_path : Path | str | None, optional - Path to the configuration file. If None, the default configuration will be used. Default is None. + Path to the configuration file. Default is None. flip_image : bool | None, optional - Whether to flip the image. If None, the default behavior will be used. Default is True. + Whether to flip the image. Default is True. save_as_h5 : bool, optional Whether to save the data as an H5 file. Default is False. Returns ------- tuple - A tuple containing the image data (numpy.ndarray), the pixel to nanometre scaling factor (float), and optionally the curve data (dict) if available. + A tuple containing image data, scaling factor, and optionally curve data. """ - - # Update instance attributes based on provided parameters, largely so loader can be called to get channels without setting a channel + # Update instance attributes based on provided parameters self.channel = channel if channel else self.channel self.config_path = config_path if config_path else self.config_path self.flip_image = flip_image if flip_image is not None else self.flip_image @@ -384,11 +471,21 @@ def load( # Establish the lazy loading structures for curve data and metadata. Note how lazy structure is used even if # all the data has been accessed and saved to H5 to prevent excessive memory usage - self.full_metadata = LazyCurveMetadata( - self.filepath, self.top_level_meta, self.qi_archive, self.shape_x, self.shape_y, flip_image=self.flip_image + self.full_metadata = LazyQiMetadata( + self.filepath, + self.top_level_meta, + self.qi_archive, + self.shape_x or 0, + self.shape_y or 0, + flip_image=bool(self.flip_image), ) - self.curve_data = LazyCurveData( - self.filepath, self.shape_x, self.shape_y, self.channel_scaling, self.qi_archive, flip_image=self.flip_image + self.curve_data = LazyJpkQiData( + self.filepath, + self.shape_x or 0, + self.shape_y or 0, + self.channel_scaling, + self.qi_archive, + flip_image=bool(self.flip_image), ) # Load the image @@ -401,9 +498,7 @@ def load( return (self.image, self.px2nm, (self.curve_data, self.channels_units, self.full_metadata)) def output_summary(self): - """ - Outputs a summary of the loading process, including any failed curve loads and their details. - """ + """Output a summary of the loading process, including any failed curve loads and their details.""" if self.failed_curves: logger.warning(f"Failed to load {len(self.failed_curves)} files.") logger.warning("Summary of missing files (up to 10 shown):") @@ -413,7 +508,7 @@ def output_summary(self): if i < 10: if chan_name: logger.warning( - f"Failed to load data file for curve {curve_num}, direction {direction}, channel {chan_name}" + f"Failed to load data for curve {curve_num}, direction {direction}, channel {chan_name}" ) else: if direction is not None: @@ -432,8 +527,10 @@ def output_summary(self): summary = ( f"\n--- Performance Summary ---\n" f"Changing Keys Detection: {self.t_changing_keys:.2f}s\n" - f"Raw Data - Loading: {self.t_load_data:.2f}s | Processing: {self.t_proc_data:.2f}s | Saving: {self.t_save_data:.2f}s\n" - f"Metadata - Loading: {self.t_load_meta:.2f}s | Processing: {self.t_proc_meta:.2f}s | Saving: {self.t_save_meta:.2f}s\n" + f"Raw Data - Loading: {self.t_load_data:.2f}s | Processing: {self.t_proc_data:.2f}s | " + f"Saving: {self.t_save_data:.2f}s\n" + f"Metadata - Loading: {self.t_load_meta:.2f}s | Processing: {self.t_proc_meta:.2f}s | " + f"Saving: {self.t_save_meta:.2f}s\n" f"---------------------------" ) logger.info(summary) @@ -457,20 +554,20 @@ def extract_data_to_h5( include_metadata : bool, optional Whether to include metadata in the loading process, by default True. """ - logger.info( - f"Loading all curve data from JPK QI archive with {len(self.namelist)} files {'' if include_metadata else 'not '}including metadata" + f"Loading all curve data from JPK QI archive with {len(self.namelist)} files " + f"{'' if include_metadata else 'not '}including metadata" ) progress_counter = 0 process = psutil.Process(os.getpid()) if include_metadata: # Prepare keys for metadata to speed up processing curve_work = [ - (f"{k}=".encode("utf-8"), h5_meta_datasets[f"curve.{k}"], h5_meta_datasets_buffer[f"curve.{k}"]) + (f"{k}=".encode(), h5_meta_datasets[f"curve.{k}"], h5_meta_datasets_buffer[f"curve.{k}"]) for k in self.changing_curve_keys ] seg_work = [ - (f"{k}=".encode("utf-8"), h5_meta_datasets[f"segment.{k}"], h5_meta_datasets_buffer[f"segment.{k}"]) + (f"{k}=".encode(), h5_meta_datasets[f"segment.{k}"], h5_meta_datasets_buffer[f"segment.{k}"]) for k in self.changing_segment_keys ] for curve_num in range(self.num_of_curves): @@ -510,13 +607,12 @@ def extract_data_to_h5( indicies_dataset = h5_datasets[seg_name][chan_name]["Indicies"] indicies_dataset[-1] = current_dataset.shape[0] - def save_to_h5( self, include_metadata: bool = True, ): """ - Saves the data as an H5 file. If include_metadata is False, only the curve data will be saved. + Save data as an H5 file. If include_metadata is False, only curve data is saved. Parameters ---------- @@ -527,27 +623,27 @@ def save_to_h5( t0 = time.perf_counter() - # Sample curves in dataset to make a best guess for the meta keys that need to be extracted from each curve + # Sample curves in dataset to make a best guess for the meta keys self.changing_curve_keys, self.changing_segment_keys = self.get_changing_keys() self.points_for_channel_segment = self.predict_total_points() self.t_changing_keys = time.perf_counter() - t0 - # Setup H5 structure for saving the data, creating datasets for curve data and metadata as needed + # Setup H5 structure for saving the data global_meta_group, h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer = ( self.setup_h5_structure(file) ) # Set up current_offsets to keep track of how many points have been read - self.current_offsets = {} + self.current_offsets: dict[int, dict[str, int]] = {} for direction in range(2): self.current_offsets[direction] = {} for chan in self.segment_channels: self.current_offsets[direction][chan["name"]] = 0 - # Reset the points for channel segment to 0 so it can be used to store the actual number of points held in each dataset + # Reset points_for_channel_segment to 0 to store actual number of points self.points_for_channel_segment[direction][chan["name"]] = 0 - # Extract data from the JPK QI archive and save to H5 datasets, optionally including metadata + # Extract data from the JPK QI archive and save to H5 datasets self.extract_data_to_h5( h5_datasets, h5_meta_datasets, @@ -573,14 +669,13 @@ def save_to_h5( def get_curves_sample(self): """ - Get a sample of curve numbers distrubuted evenly across the dataset + Get a sample of curve numbers distrubuted evenly across the dataset. Returns ------- range: A range object representing the sampled curve numbers. """ - # Check evenly spaced curves in the dataset to sample metadata without having to load every curve step = 1 if self.num_of_curves <= self.MAX_CURVE_CHECKS else self.num_of_curves // self.MAX_CURVE_CHECKS # If the step is equal to a shape dimension, we might just go down the row or column @@ -591,20 +686,19 @@ def get_curves_sample(self): def predict_total_points(self): """ - Predicts the total number of points for each channel and segment by sampling a subset of curves - and extrapolating based on the maximum number of points found in the sample. + Predict the total number of points for each channel and segment by sampling a subset of curves and + extrapolating based on the maximum number of points found in the sample. Returns ------- dict: A dictionary containing the predicted total points for each channel and segment. """ - - # Get a sample of curve (indicies) + # Get a sample of curve (indices) curves_to_check = self.get_curves_sample() points_for_channel_segment = {} - # Iterate through the segments, channels and our curve indicies + # Iterate through the segments, channels and our curve indices for direction in range(2): points_for_channel_segment[direction] = {} for channel in self.segment_channels: @@ -621,7 +715,8 @@ def predict_total_points(self): break except KeyError: - # If the file doesn't exist for this curve, check the next curve so we don't just get a smaller sample + # If the file doesn't exist for this curve, check the next curve so we don't just get + # a smaller sample if curve_num + 1 >= self.num_of_curves: # If we've gone past the number of curves, stop checking break @@ -634,10 +729,10 @@ def predict_total_points(self): ) return points_for_channel_segment - def get_changing_keys(self): + def get_changing_keys(self): # noqa: C901 """ - Checks a sample of curves to see which metadata keys change across curves and segments, - so we can extract only the changing keys for each curve and segment. + Check a sample of curves to see which metadata keys change across curves and segments, so we can + extract only the changing keys for each curve and segment. None changing keys are moved to the top level metadata and not extracted for each curve/segment. @@ -646,9 +741,8 @@ def get_changing_keys(self): tuple: A tuple containing two sets: changing_curve_keys and changing_segment_keys. """ - - curve_meta_dict = {} - segment_meta_dict = {} + curve_meta_dict: dict[str, list[Any]] = {} + segment_meta_dict: dict[str, list[Any]] = {} curves_to_check = self.get_curves_sample() for curve_num in curves_to_check: for direction in range(2): @@ -686,13 +780,13 @@ def get_changing_keys(self): changing_curve_keys, changing_segment_keys = set(), set() for key, values in curve_meta_dict.items(): - if len(set(v for v in values if v is not None)) > 1: + if len({v for v in values if v is not None}) > 1: changing_curve_keys.add(key) else: # If the key does not change across curves, move it to the top level metadata self.top_level_meta[f"curve.{key}"] = values[0] for key, values in segment_meta_dict.items(): - if len(set(v for v in values if v is not None)) > 1: + if len({v for v in values if v is not None}) > 1: changing_segment_keys.add(key) else: # If the key does not change across segments, move it to the top level metadata @@ -701,7 +795,7 @@ def get_changing_keys(self): def get_collated_metadata(self): """ - Collates the metadata from being split by curve, to being split by attribute so data can be saved more efficiently + Collate metadata from being split by curve to being split by attribute. Returns ------- @@ -719,22 +813,18 @@ def get_image( self, overide_channel: str | None = None, convert_to_nm: bool = True, flip_image: bool | None = None ) -> tuple[np.ndarray, float]: """ - Processes the flat curve data dictionary into a 2D list structure matching the image dimensions. + Process the flat curve data dictionary into a 2D list structure matching the image dimensions. Returns ------- image : np.ndarray A 2D array representing the image data. """ - - # Get channel and flip_image parameters, defaulting to the instance attributes if not provided - if overide_channel: - channel = overide_channel - else: - channel = self.channel + # Get channel and flip_image parameters + channel = str(overide_channel) if overide_channel else str(self.channel) if flip_image is None: - flip_image = self.flip_image + flip_image = bool(self.flip_image) # Search through the namelist to find the .jpk-qi-image file path_to_image = None @@ -756,13 +846,11 @@ def get_image( file_suffix=".jpk-qi-data", config_path=self.config_path, convert_to_nm=convert_to_nm, - flip_image=flip_image, + flip_image=bool(flip_image), ) def save_lite_data(self): - """ - Saves a lite form of the data (e.g., the calculated image data) to the appropriate format based on the save_as_h5 attribute. - """ + """Save a lite form of the data (e.g., the calculated image data) to H5.""" with h5py.File(self.h5_path, "a") as h5file: # Save data required for reading the h5 file as a normal image file meas_grp = h5file.require_group("Measurement_000") @@ -777,22 +865,24 @@ def save_lite_data(self): h5_channels = [self.channel] # Look for the jpk-qi-image file in the archive + path_to_image = None for file_name in self.namelist: if file_name.endswith(".jpk-qi-image"): path_to_image = file_name break # Add the channels which exist in the jpk-qi-image file - with self.qi_archive.open(path_to_image, "r") as image_file: - h5_channels += jpk._get_jpk_channels( - file=image_file, filename=self.filepath.stem, file_path=self.filepath / Path(path_to_image) - ) + if path_to_image: + with self.qi_archive.open(path_to_image, "r") as image_file: + h5_channels += jpk._get_jpk_channels( + file=image_file, filename=self.filepath.stem, file_path=self.filepath / Path(path_to_image) + ) for i, h5_channel in enumerate(h5_channels): # For each available channel, save the required data to the h5 file # TODO make sure this metadata is accurate for the channels coming from the .jpk-qi-image file chan_grp = meas_grp.require_group(f"Channel_{_make_num_min_characters(i)}") # Extract name and retrace information from the channel name - if "_" in h5_channel: - base_name, trace_dir = h5_channel.rsplit("_", 1) + if h5_channel and "_" in str(h5_channel): + base_name, trace_dir = str(h5_channel).rsplit("_", 1) is_retrace = "true" if trace_dir.lower() == "retrace" else "false" else: base_name = h5_channel @@ -818,7 +908,7 @@ def save_lite_data(self): def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, direction: int, chan_name: str): """ - Extracts the data from a .dat file in the JPK QI archive, applies the appropriate scaling, and saves it to the internal data structure and h5 dataset if required. + Extract the data from a .dat file in the JPK QI archive, applies the appropriate scaling, and saves it to the internal data structure and h5 dataset if required. Parameters ---------- @@ -847,14 +937,14 @@ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, dire try: t0 = time.perf_counter() with self.qi_archive.open(dat_path) as f: - # Read the binary data as big-endian 32-bit integers + # Read binary data as big-endian 32-bit integers raw_bytes = f.read() self.t_load_data += time.perf_counter() - t0 t1 = time.perf_counter() raw_array = np.frombuffer(raw_bytes, dtype=">i4") - # Apply the scaling to convert raw values into real world values + # Apply scaling to convert raw values into real world values segment_array = (raw_array * scale["multiplier"]) + scale["offset"] # Update the current offset so it include the length of the data we have just read @@ -885,7 +975,8 @@ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, dire # Limit the number of warnings to avoid spamming the logs if len(self.failed_curves) < 10: logger.warning( - f"Data file {dat_path} not found in archive. Skipping data for curve {curve_num}, direction {direction}, channel {chan_name}." + f"Data file {dat_path} not found in archive. Skipping data for curve {curve_num}, " + f"direction {direction}, channel {chan_name}." ) elif len(self.failed_curves) == 10: logger.warning( @@ -901,8 +992,9 @@ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, dire indicies_set[curve_num - self.BUFFER_SIZE + 1 : curve_num + 1] = buf["Indicies"] buf["Indicies"].clear() - # Or if this is the last curve and there are still indicies in the buffer, add them to the indicies dataset and clear the buffer + # Or if this is the last curve and there are still indicies in the buffer elif len(buf["Indicies"]) > 0 and curve_num == self.num_of_curves - 1: + # Add the remaining indicies to the indicies dataset and clear the buffer items_in_buffer = len(buf["Indicies"]) indicies_set[curve_num - items_in_buffer + 1 : curve_num + 1] = buf["Indicies"] buf["Indicies"].clear() @@ -914,21 +1006,22 @@ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, dire self.failed_curves.add((curve_num, direction, chan_name)) if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs logger.warning( - f"Channel {chan_name} not found in scaling information. Skipping data for curve {curve_num}, direction {direction}." + f"Channel {chan_name} not found in scaling information. Skipping data for curve {curve_num}, " + f"direction {direction}." ) def extract_curve_metadata(self, curve_num: int, curve_work): """ - Extracts the curve metadata from its header.properties file in the JPK QI archive and save to h5 + Extract the curve metadata from its header.properties file in the JPK QI archive and save to h5. Parameters ---------- curve_num : int - The curve number associated with the metadata, parsed from the filename. + The curve number associated with the metadata. curve_work : list - A list of tuples containing the search term for the metadata, the h5 dataset to save to, and the buffer for that dataset. + A list of tuples containing the search term for the metadata, the h5 dataset to save to, + and the buffer for that dataset. """ - meta_path = f"index/{curve_num}/header.properties" raw_bytes = b"" try: @@ -969,19 +1062,22 @@ def extract_curve_metadata(self, curve_num: int, curve_work): self.t_save_meta += time.perf_counter() - start_time else: logger.error( - f"Metadata dataset for key {search_term.decode('utf-8')} not found when trying to save metadata for curve {curve_num}" + f"Metadata dataset for key {search_term.decode('utf-8')} not found when trying to save " + f"metadata for curve {curve_num}" ) def extract_segment_metadata(self, curve_num: int, direction: int, seg_work): """ - Extracts the segment metadata from its header.properties file in the JPK QI archive and saves it to the internal data structure. + Extract segment metadata from its header.properties file. Parameters ---------- curve_num : int - The curve number associated with the metadata, parsed from the filename. + The curve number associated with the metadata. direction : int - The segment direction (0 or 1) associated with the metadata, parsed from the filename. + The segment direction (0 or 1) associated with the metadata. + seg_work : list + A list of tuples containing metadata extraction information. """ meta_path = f"index/{curve_num}/segments/{direction}/segment-header.properties" raw_content = b"" @@ -994,7 +1090,8 @@ def extract_segment_metadata(self, curve_num: int, direction: int, seg_work): self.failed_curves.add((curve_num, direction, None)) if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs logger.warning( - f"Metadata file {meta_path} not found in archive. Skipping metadata for curve {curve_num}, direction {direction}." + f"Metadata file {meta_path} not found in archive. Skipping metadata for curve {curve_num}, " + f"direction {direction}." ) elif len(self.failed_curves) == 10: logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") @@ -1015,28 +1112,33 @@ def extract_segment_metadata(self, curve_num: int, direction: int, seg_work): meta_buffer.clear() else: logger.error( - f"Metadata dataset for key {search_term.decode('utf-8')} not found when trying to save metadata for curve {curve_num}, direction {direction}" + f"Metadata dataset for key {search_term.decode('utf-8')} not found when trying to save " + f"metadata for curve {curve_num}, direction {direction}" ) self.t_proc_meta += time.perf_counter() - start_time def setup_h5_structure(self, h5file): """ - Sets up the structure in the h5 file for saving the curve data and metadata, and returns the datasets and metadata groups for later use. + Set up structure in the h5 file for saving curve data and metadata. Parameters ---------- h5file : h5py.File - The h5 file in which to set up the structure for saving the curve data and metadata. + The h5 file in which to set up the structure. + Returns ------- - curve_groups : dict - A dictionary containing the group structure for each segment direction. global_meta_group : h5py.Group The h5 group for storing global metadata. - curves_meta_group : h5py.Group - The h5 group for storing curve metadata. + h5_datasets : dict + A dictionary containing the h5 datasets for storing curve data. + h5_meta_datasets : dict + A dictionary containing the h5 datasets for storing metadata. + h5_datasets_buffer : dict + A dictionary containing buffers for the curve data datasets for temporary pre-writing storage. + h5_meta_datasets_buffer : dict + A dictionary containing buffers for the metadata datasets for temporary pre-writing storage. """ - # Create the main group for the QI curve data that all the curve data will be in qi_group = h5file.require_group("QI_Curve_Data") @@ -1070,7 +1172,7 @@ def setup_h5_structure(self, h5file): h5_meta_datasets_buffer[f"segment.{key}"] = [] for direction in range(2): - # For each segment direction, establish the necessary group structure that will contain each channel dataset + # For each segment direction, establish necessary group structure that will contain each channel dataset seg_name = f"Segment_{direction}" dir_group = curves_group.require_group(seg_name) h5_datasets[seg_name] = {} @@ -1080,7 +1182,7 @@ def setup_h5_structure(self, h5file): curve_groups["Indicies"][seg_name] = dir_group.require_group("Indicies") for chan in self.segment_channels: h5_datasets[seg_name][chan["name"]] = {} - # For each channel, create an empty dataset for the curve data and indicies with the appropriate name and data type + # For each channel, create an empty dataset h5_datasets[seg_name][chan["name"]]["Data"] = curve_groups["Data"][seg_name].create_dataset( name=chan["name"], shape=(self.points_for_channel_segment[direction][chan["name"]],), @@ -1100,7 +1202,8 @@ def setup_h5_structure(self, h5file): def get_saving_context(self): """ - Returns the appropriate context manager for saving the data based on the save_as_h5 attribute. + Return the appropriate context manager for saving the data based on the save_as_h5 attribute. + If save_as_h5 is True, it returns a context manager for an h5 file. Otherwise, it returns a null context. Returns @@ -1110,13 +1213,10 @@ def get_saving_context(self): """ if self.save_as_h5: return h5py.File(self.h5_path, "a") - else: - return nullcontext() + return nullcontext() def parse_dimension_data(self): - """ - Parses the dimension data from the provided properties dictionary and calculates the pixel to nanometer scaling factor. - """ + """Parse dimension data and calculate the pixel to nanometer scaling factor.""" # Extract both real size and pixel dimensions from the metadata for key, value in self.top_level_meta.items(): if key.endswith(".ulength"): @@ -1141,20 +1241,19 @@ def parse_dimension_data(self): self.num_of_curves = self.shape_x * self.shape_y def extract_global_metadata(self): - """Extracts the global metadata from the JPK QI archive and populates the top level metadata dictionary and segment channels list.""" - + """Extract global metadata and populate top level metadata dictionary and segment channels list.""" # Load the metadata from the global properties file if "header.properties" in self.namelist: with self.qi_archive.open("header.properties") as archive_meta_file: props = javaproperties.load(archive_meta_file) - # Add all the data from the main header to the top level metadata with a prefix to avoid key clashes + # Add data from the main header to the top level metadata with a prefix to avoid key clashes for key, value in props.items(): self.top_level_meta[f"main-header.{key}"] = value else: logger.error(f"File {self.filepath} does not contain essential metadata and cannot be loaded") - # Load the metadata from the shared header and parse the channel information for the segments + # Load the metadata from the shared header if "shared-data/header.properties" in self.namelist: with self.qi_archive.open("shared-data/header.properties") as shared_data_file: shared_meta = javaproperties.load(shared_data_file) @@ -1189,7 +1288,7 @@ def extract_global_metadata(self): self.channel_scaling = {chan["name"]: chan for chan in self.segment_channels} def close(self): - """Closes the ZIP archive when done to free up system resources.""" + """Close the ZIP archive when done to free up system resources.""" self.qi_archive.close() self.image = None self.curve_data = None @@ -1202,31 +1301,8 @@ def close(self): self.namelist = [] -def load_fdcurves_from_h5(file_path: Path | str): - file_path = Path(file_path) - - with h5py.File(file_path, "r") as h5file: - meas_grp = h5file["Measurement_000"] - shape_x = meas_grp.attrs["position-pattern.grid.ilength"] - shape_y = meas_grp.attrs["position-pattern.grid.jlength"] - num_of_curves = shape_x * shape_y - qi_data_group = h5file["QI_Curve_Data"] - all_curve_data = [] - for i in range(num_of_curves): - curve_data = {} - for direction, direction_group in qi_data_group.items(): - for channel, channel_group in direction_group.items(): - if channel not in curve_data: - curve_data[channel] = {} - curve_data[channel][direction] = channel_group[str(i)] - all_curve_data.append(curve_data) - - return all_curve_data - - def _make_num_min_characters(num: int, min_chars: int = 3): string_num = str(num) if len(string_num) >= min_chars: return string_num - string_num = "0" * (min_chars - len(string_num)) + string_num - return string_num + return "0" * (min_chars - len(string_num)) + string_num diff --git a/AFMReader/jpk_utils.py b/AFMReader/jpk_utils.py new file mode 100644 index 0000000..5cf494b --- /dev/null +++ b/AFMReader/jpk_utils.py @@ -0,0 +1,186 @@ +"""Utility classes for lazy loading of QI curve data and metadata from JPK files.""" + +# pylint: disable=too-few-public-methods,fixme + + +class LazyQiData: + """ + A proxy class that fetches QI curve data on demand. + + It behaves like a 2D array of shape (shape_y, shape_x) where each element is a dictionary + containing the QI curve data for that pixel. + """ + + def __init__(self, shape_x: int, shape_y: int, flip_image: bool = True): + """ + Initialize the LazyQiData instance. + + Parameters + ---------- + shape_x : int + The number of columns in the image. + shape_y : int + The number of rows in the image. + flip_image : bool, optional + Whether to flip the image vertically. Default is True. + """ + self.shape_x = shape_x + self.shape_y = shape_y + self.dims = (shape_y, shape_x) + self.flip_image = flip_image + + def __getitem__(self, y: int): + """ + Return a proxy object for the specified row that can be indexed to fetch curve data for each pixel in that row. + + This allows for lazy loading. + + Parameters + ---------- + y : int + The row index. + + Returns + ------- + RowProxy + A proxy object for the specified row. + """ + + class RowProxy: + def __init__(self, parent, y): + self.parent = parent + self.y = y + + def __getitem__(self, x: int): + return self.parent._fetch_curve(self.y, x) + + return RowProxy(self, y) + + def _fetch_curve(self, y: int, x: int): + """ + Fetch the QI curve data for a specific pixel. + + Should be implemented by subclasses to define how the curve data is retrieved + from the underlying data source. + + Parameters + ---------- + y : int + Row index of the pixel. + x : int + Column index of the pixel. + + Returns + ------- + dict + The QI curve data for the specified pixel. + """ + raise NotImplementedError("This method should be implemented by subclasses to fetch curve data on demand.") + + +class LazyMetadata: + """A proxy class that fetches metadata on demand. Superclass for metadata proxy classes.""" + + def __init__(self, top_level_meta: dict, shape_x: int, shape_y: int, flip_image: bool = True): + """ + Initialize the LazyMetadata instance. + + Parameters + ---------- + top_level_meta : dict + The top-level metadata dictionary. + shape_x : int + The number of columns in the image. + shape_y : int + The number of rows in the image. + flip_image : bool, optional + Whether to flip the image vertically. Default is True. + """ + self.top_level = top_level_meta + self.shape_x = shape_x + self.shape_y = shape_y + self.flip_image = flip_image + + +class LazyMetaProxy: + """ + A proxy class that fetches metadata on demand. + + It behaves like a 2D array of shape (shape_y, shape_x) where each element is a dictionary + containing the metadata for that pixel. + """ + + def __init__(self, meta_type: str, shape_x: int, shape_y: int, flip_image: bool = True): + """ + Initialize the LazyMetaProxy instance. + + Parameters + ---------- + meta_type : str + The type of metadata to fetch ("curve" or "segment"). + shape_x : int + The number of columns in the image. + shape_y : int + The number of rows in the image. + flip_image : bool, optional + Whether to flip the image vertically. Default is True. + """ + self.meta_type = meta_type + self.shape_x = shape_x + self.shape_y = shape_y + self.flip_image = flip_image + + def __getitem__(self, y: int): + """ + Return a proxy object for the specified row that can be indexed to fetch metadata for each pixel in that row. + + This allows for lazy loading. + Nested proxy objects are used to allow for fetching segment metadata which requires both x and y indices + as well as the direction of the segment (approach or retract). + """ + + class RowProxy: + def __init__(self, parent, y): + self.parent = parent + self.y = y + + def __getitem__(self, x): + if self.parent.meta_type == "curve": + return self.parent._fetch_meta(self.y, x) + if self.parent.meta_type == "segment": + + class SegmentMetaProxy: + def __init__(self, parent, y, x): + self.parent = parent + self.y = y + self.x = x + + def __getitem__(self, direction): + return self.parent.parent._fetch_meta(self.y, self.x, direction) + + return SegmentMetaProxy(self, self.y, x) + raise IndexError(f"Unknown metadata type '{self.parent.meta_type}'. Expected 'curve' or 'segment'.") + + return RowProxy(self, y) + + def _fetch_meta(self, y: int, x: int, direction: int | None = None): + """ + Fetch the metadata for a specific pixel. + Should be implemented by subclasses to define how the metadata is retrieved + from the underlying data source. + + Parameters + ---------- + y : int + The row index of the pixel. + x : int + The column index of the pixel. + direction : int, optional + The direction of the segment ("approach" or "retract"). Only used for segment metadata. + + Returns + ------- + dict + The metadata for the specified pixel. + """ + raise NotImplementedError("This method should be implemented by subclasses to fetch metadata on demand.") diff --git a/AFMReader/raw_bin.py b/AFMReader/raw_bin.py new file mode 100644 index 0000000..b963afe --- /dev/null +++ b/AFMReader/raw_bin.py @@ -0,0 +1,116 @@ +"""Module to decode and load .bin AFM files into Python Numpy arrays.""" + +import numpy as np +import math +from .logging import logger +from pathlib import Path + +# pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,fixme + +DTYPE_MAP = { + "IEEE double": np.float64, + "DBL": np.float64, + "IEEE single": np.float32, + "SGL": np.float32, + "U32": np.uint32, + "I32": np.int32, + "U16": np.uint16, + "I16": np.int16, + "U8": np.uint8, + "I8": np.int8, + "float64": np.float64, + "float32": np.float32, + "int32": np.int32, +} + + +def load_bin( + filepath: str | Path, + data_type: str, + offset_bytes: int, + size_x: float | None = None, + size_y: float | None = None, + shape_x: int | None = None, + shape_y: int | None = None, + z_scaling: float = 1.0, +): + """ + Load image from binary file. Parameters to interpret the binary file must be provided. + + Parameters + ---------- + filepath : str | Path + Path to the binary file. + data_type : str + Data type of the binary file. + offset_bytes : int + Number of bytes to skip at the beginning of the file. + size_x : float, optional + Size of the image in the x direction (default is None). + size_y : float, optional + Size of the image in the y direction (default is None). + shape_x : int, optional + Number of pixels in the x direction (default is None). + shape_y : int, optional + Number of pixels in the y direction (default is None). + z_scaling : float, optional + Scaling factor for the z values (default is 1.0). + + Returns + ------- + image : np.ndarray + 2D array of shape (height, width) with image data. + px2nm : float + Scaling factor converting pixels to nanometers. + """ + filepath = Path(filepath) + dt_key = str(data_type).strip() + shape_x = None if shape_x == 0 else shape_x + shape_y = None if shape_y == 0 else shape_y + + if dt_key in DTYPE_MAP: + np_dtype = DTYPE_MAP[dt_key] + else: + logger.warning(f"Unknown data type '{dt_key}'. Defaulting to float64.") + np_dtype = np.float64 + with filepath.open("rb") as f: + f.seek(offset_bytes) + flat_data = np.fromfile(f, dtype=np_dtype) + if None in [shape_x, shape_y]: + dimension = int(math.sqrt(len(flat_data))) + shape_x, shape_y = dimension, dimension + if shape_x * shape_y != len(flat_data): + logger.error(f"Loading binary file {filepath.stem} did not receive a shape and is not square") + image = flat_data.reshape((shape_x, shape_y)) + image *= z_scaling + pixel_to_nm_scaling_factor_x = size_x / shape_x if shape_x > 0 else 1.0 + pixel_to_nm_scaling_factor_y = size_y / shape_y if shape_y > 0 else 1.0 + px2nm = (pixel_to_nm_scaling_factor_x + pixel_to_nm_scaling_factor_y) / 2 + return image, px2nm + + +def get_bin_channels(): + """ + Get the list of channels available in the binary file. + + Since binary files do not have a standard structure, + this function returns an empty list (as no standard channels are available) and the expected keyword + arguments for loading a binary file. + + Returns + ------- + list + Empty list. + dict + Dictionary of expected keyword arguments for loading a binary file. + """ + kwarg_types = { + "data_type": (str, DTYPE_MAP.keys()), + "offset_bytes": int, + "size_x": float, + "size_y": float, + "shape_x": int, + "shape_y": int, + "z_scaling": float, + } + return [], kwarg_types diff --git a/AFMReader/spm.py b/AFMReader/spm.py index 8fd7ff3..da29f9a 100644 --- a/AFMReader/spm.py +++ b/AFMReader/spm.py @@ -111,8 +111,23 @@ def load_spm(file_path: Path | str, channel: str) -> tuple: return (image, spm_pixel_to_nm_scaling(filename, channel_data)) -def get_spm_channels(file_path): + +def get_spm_channels(file_path: Path | str) -> list: + """ + Get the list of channels available in the .spm file. + + Parameters + ---------- + file_path : Path or str + Path to the .spm file. + + Returns + ------- + list + List of available channels. + """ labels = [] + file_path = Path(file_path) filename = file_path.stem try: scan = pySPM.Bruker(file_path) From e64620582b44b17faca680f3e613f336622c6406 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Sat, 2 May 2026 13:32:44 +0100 Subject: [PATCH 38/54] Minor formatting changes on tests --- tests/test_general_loader.py | 2 -- tests/test_h5jpk.py | 7 +++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/test_general_loader.py b/tests/test_general_loader.py index d1c6861..276de15 100644 --- a/tests/test_general_loader.py +++ b/tests/test_general_loader.py @@ -142,8 +142,6 @@ def test_load(capsys: pytest.CaptureFixture, filepath: Path, channel: str, error assert message in captured.err - - @pytest.mark.parametrize( ("filepath"), [ diff --git a/tests/test_h5jpk.py b/tests/test_h5jpk.py index ff9ca8b..f6995d7 100644 --- a/tests/test_h5jpk.py +++ b/tests/test_h5jpk.py @@ -143,6 +143,7 @@ def test_load_h5jpk( for i in range(len(results_timestamps) - 1) ) + @pytest.mark.parametrize( ( "file_name", @@ -176,12 +177,9 @@ def test_load_h5jpk_curves( curve_coords: tuple[int, int], curve_direction: str, curve_targets: dict[str, tuple[int, float]], - ) -> None: - _, _, _, result_curve_data = h5_jpk.load_h5jpk( - RESOURCES / file_name, channel, flip_image - ) + _, _, _, result_curve_data = h5_jpk.load_h5jpk(RESOURCES / file_name, channel, flip_image) all_curves, _, _ = result_curve_data curve_at_coords = all_curves[curve_coords[0]][curve_coords[1]] for curve_channel, (expected_size, expected_sum) in curve_targets.items(): @@ -189,6 +187,7 @@ def test_load_h5jpk_curves( assert curve.shape == (expected_size,) assert curve.sum() == pytest.approx(expected_sum) + def test_load_h5jpk_file_not_found() -> None: """Ensure FileNotFound error is raised.""" with pytest.raises(FileNotFoundError): From a6388fd811cfb9e50c212fd7c8450127988ad68a Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Sat, 2 May 2026 14:09:25 +0100 Subject: [PATCH 39/54] Indices spelling correction --- AFMReader/h5_jpk.py | 40 ++++++++++----------- AFMReader/jpk_qi.py | 88 +++++++++++++-------------------------------- 2 files changed, 44 insertions(+), 84 deletions(-) diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index cf37cd7..1de0e3b 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -329,22 +329,22 @@ def __iter__(self): # noqa: C901 dict A dictionary containing the QI curve data for each channel and segment. """ - indicies_map = {} + indices_map = {} for segment, segment_group in self.qi_data_group["Curves"].items(): - for channel in segment_group["Indicies"]: - if channel not in indicies_map: - indicies_map[channel] = {} - indicies_map[channel][segment] = segment_group["Indicies"][channel][:] + for channel in segment_group["Indices"]: + if channel not in indices_map: + indices_map[channel] = {} + indices_map[channel][segment] = segment_group["Indices"][channel][:] for y_idx in range(self.shape_y): data = {} y = self.shape_y - 1 - y_idx if self.flip_image else y_idx for segment, segment_group in self.qi_data_group["Curves"].items(): - for channel in segment_group["Indicies"]: + for channel in segment_group["Indices"]: if channel not in data: data[channel] = {} - indicies = indicies_map[channel][segment] - start_idx = int(indicies[self.shape_x * y]) - end_idx = int(indicies[self.shape_x * (y + 1)]) + indices = indices_map[channel][segment] + start_idx = int(indices[self.shape_x * y]) + end_idx = int(indices[self.shape_x * (y + 1)]) data[channel][segment] = segment_group["Data"][channel][start_idx:end_idx] for x in range(self.shape_x): @@ -352,9 +352,9 @@ def __iter__(self): # noqa: C901 for channel in data: curve_data[channel] = {} for segment in data[channel]: - indicies = indicies_map[channel][segment] - start_idx = int(indicies[self.shape_x * y + x]) - int(indicies[self.shape_x * y]) - end_idx = int(indicies[self.shape_x * y + x + 1]) - int(indicies[self.shape_x * y]) + indices = indices_map[channel][segment] + start_idx = int(indices[self.shape_x * y + x]) - int(indices[self.shape_x * y]) + end_idx = int(indices[self.shape_x * y + x + 1]) - int(indices[self.shape_x * y]) curve_data[channel][segment] = data[channel][segment][start_idx:end_idx] yield curve_data @@ -381,9 +381,9 @@ def _fetch_curve(self, y: int, x: int): y = self.shape_y - 1 - y curve_num = self.shape_x * y + x for segment, segment_group in self.qi_data_group["Curves"].items(): - for channel in segment_group["Indicies"]: - start_idx = int(segment_group["Indicies"][channel][curve_num]) - end_idx = int(segment_group["Indicies"][channel][curve_num + 1]) + for channel in segment_group["Indices"]: + start_idx = int(segment_group["Indices"][channel][curve_num]) + end_idx = int(segment_group["Indices"][channel][curve_num + 1]) if channel not in curve_dict: curve_dict[channel] = {} curve_dict[channel][segment] = segment_group["Data"][channel][start_idx:end_idx] @@ -400,12 +400,12 @@ def load_all_curves(self): """ all_curves = [[{} for _ in range(self.shape_x)] for _ in range(self.shape_y)] for segment, segment_group in self.qi_data_group["Curves"].items(): - for channel in segment_group["Indicies"]: - indicies = segment_group["Indicies"][channel][:] + for channel in segment_group["Indices"]: + indices = segment_group["Indices"][channel][:] data = segment_group["Data"][channel][:] - for i in range(len(indicies) - 1): - start_idx = int(indicies[i]) - end_idx = int(indicies[i + 1]) + for i in range(len(indices) - 1): + start_idx = int(indices[i]) + end_idx = int(indices[i + 1]) x = i % self.shape_x y = i // self.shape_x if self.flip_image: diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 95b3884..d3d53fc 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -355,8 +355,8 @@ def __init__( # Chunk size for H5 datasets self.DATA_CHUNKSIZE = 512 * 1024 - # Chunk size for indicies datasets - self.INDICIES_CHUNKSIZE = 64 * 1024 + # Chunk size for indices datasets + self.INDICES_CHUNKSIZE = 64 * 1024 # Chunk size for metadata datasets (if needed) self.META_CHUNKSIZE = 64 * 1024 # Maximum number of curves to check for changing metadata keys (to avoid checking every curve) @@ -385,15 +385,6 @@ def __init__( self.shape_y: int | None = None self.failed_curves: set[tuple[int, int | None, str | None]] = set() - # Timing counters for performance monitoring - self.t_load_data = 0.0 - self.t_proc_data = 0.0 - self.t_save_data = 0.0 - self.t_load_meta = 0.0 - self.t_proc_meta = 0.0 - self.t_save_meta = 0.0 - self.t_changing_keys = 0.0 - # Instantiate containers for data to be saved (so an exception is not caused if not saving) self.curve_groups = None @@ -523,18 +514,6 @@ def output_summary(self): # If there are no failed loads, log that all data was loaded successfully logger.info("Successfully loaded all curve data without any missing files.") - # Output the performance summary for the loading process - summary = ( - f"\n--- Performance Summary ---\n" - f"Changing Keys Detection: {self.t_changing_keys:.2f}s\n" - f"Raw Data - Loading: {self.t_load_data:.2f}s | Processing: {self.t_proc_data:.2f}s | " - f"Saving: {self.t_save_data:.2f}s\n" - f"Metadata - Loading: {self.t_load_meta:.2f}s | Processing: {self.t_proc_meta:.2f}s | " - f"Saving: {self.t_save_meta:.2f}s\n" - f"---------------------------" - ) - logger.info(summary) - def extract_data_to_h5( self, h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer, include_metadata: bool = True ): @@ -598,14 +577,14 @@ def extract_data_to_h5( # Extract and store the curve metadata for later saving self.extract_curve_metadata(curve_num=curve_num, curve_work=curve_work) - # Add the last index to the indicies datasets to mark the end of the last curve + # Add the last index to the indices datasets to mark the end of the last curve for direction in range(2): seg_name = f"Segment_{direction}" for chan in self.segment_channels: chan_name = chan["name"] current_dataset = h5_datasets[seg_name][chan_name]["Data"] - indicies_dataset = h5_datasets[seg_name][chan_name]["Indicies"] - indicies_dataset[-1] = current_dataset.shape[0] + indices_dataset = h5_datasets[seg_name][chan_name]["Indices"] + indices_dataset[-1] = current_dataset.shape[0] def save_to_h5( self, @@ -928,20 +907,17 @@ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, dire scale = self.channel_scaling[chan_name] dat_path = f"index/{curve_num}/segments/{direction}/channels/{chan_name}.dat" data_set = h5_datasets[f"Segment_{direction}"][chan_name]["Data"] - indicies_set = h5_datasets[f"Segment_{direction}"][chan_name]["Indicies"] + indices_set = h5_datasets[f"Segment_{direction}"][chan_name]["Indices"] data_size = data_set.shape[0] buf = h5_datasets_buffer[f"Segment_{direction}"][chan_name] filled_size = self.points_for_channel_segment[direction][chan_name] start_offset = self.current_offsets[direction][chan_name] try: - t0 = time.perf_counter() with self.qi_archive.open(dat_path) as f: # Read binary data as big-endian 32-bit integers raw_bytes = f.read() - self.t_load_data += time.perf_counter() - t0 - t1 = time.perf_counter() raw_array = np.frombuffer(raw_bytes, dtype=">i4") # Apply scaling to convert raw values into real world values @@ -958,16 +934,12 @@ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, dire buffered_data = np.concatenate(buf["Data"]) - self.t_proc_data += time.perf_counter() - t1 - start_time = time.perf_counter() - # Add the buffer to the dataset data_set[filled_size : filled_size + len(buffered_data)] = buffered_data # Update the filled size for this channel and segment self.points_for_channel_segment[direction][chan_name] += len(buffered_data) # Clear the buffer buf["Data"].clear() - self.t_save_data += time.perf_counter() - start_time except KeyError: self.failed_curves.add((curve_num, direction, chan_name)) @@ -983,23 +955,20 @@ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, dire "Lots of missing files, further warnings will be suppressed. View summary at the end." ) - start_time = time.perf_counter() - # Append the new index to the indicies buffer - buf["Indicies"].append(start_offset) - - # If the indicies buffer is full add it to the indicies dataset and clear the buffer - if len(buf["Indicies"]) > 0 and len(buf["Indicies"]) % self.BUFFER_SIZE == 0: - indicies_set[curve_num - self.BUFFER_SIZE + 1 : curve_num + 1] = buf["Indicies"] - buf["Indicies"].clear() + # Append the new index to the indices buffer + buf["Indices"].append(start_offset) - # Or if this is the last curve and there are still indicies in the buffer - elif len(buf["Indicies"]) > 0 and curve_num == self.num_of_curves - 1: - # Add the remaining indicies to the indicies dataset and clear the buffer - items_in_buffer = len(buf["Indicies"]) - indicies_set[curve_num - items_in_buffer + 1 : curve_num + 1] = buf["Indicies"] - buf["Indicies"].clear() + # If the indices buffer is full add it to the indices dataset and clear the buffer + if len(buf["Indices"]) > 0 and len(buf["Indices"]) % self.BUFFER_SIZE == 0: + indices_set[curve_num - self.BUFFER_SIZE + 1 : curve_num + 1] = buf["Indices"] + buf["Indices"].clear() - self.t_save_data += time.perf_counter() - start_time + # Or if this is the last curve and there are still indices in the buffer + elif len(buf["Indices"]) > 0 and curve_num == self.num_of_curves - 1: + # Add the remaining indices to the indices dataset and clear the buffer + items_in_buffer = len(buf["Indices"]) + indices_set[curve_num - items_in_buffer + 1 : curve_num + 1] = buf["Indices"] + buf["Indices"].clear() else: # Log if curve failed @@ -1025,11 +994,9 @@ def extract_curve_metadata(self, curve_num: int, curve_work): meta_path = f"index/{curve_num}/header.properties" raw_bytes = b"" try: - start_time = time.perf_counter() # Read metadata file as raw bytes with self.qi_archive.open(meta_path) as f: raw_bytes = f.read() - self.t_load_meta += time.perf_counter() - start_time except KeyError: self.failed_curves.add((curve_num, None, None)) # Limit the number of warnings to avoid spamming the logs @@ -1052,14 +1019,11 @@ def extract_curve_metadata(self, curve_num: int, curve_work): # Save a no data value if the search term is not found in the metadata file else: value = "No data" - self.t_proc_meta += time.perf_counter() - start_time if meta_buffer is not None: - start_time = time.perf_counter() meta_buffer.append(value) if len(meta_buffer) >= self.BUFFER_SIZE or curve_num == self.num_of_curves - 1: meta_set[curve_num - len(meta_buffer) + 1 : curve_num + 1] = meta_buffer meta_buffer.clear() - self.t_save_meta += time.perf_counter() - start_time else: logger.error( f"Metadata dataset for key {search_term.decode('utf-8')} not found when trying to save " @@ -1082,10 +1046,8 @@ def extract_segment_metadata(self, curve_num: int, direction: int, seg_work): meta_path = f"index/{curve_num}/segments/{direction}/segment-header.properties" raw_content = b"" try: - start_time = time.perf_counter() with self.qi_archive.open(meta_path) as f: raw_content = f.read() - self.t_load_meta += time.perf_counter() - start_time except KeyError: self.failed_curves.add((curve_num, direction, None)) if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs @@ -1095,7 +1057,6 @@ def extract_segment_metadata(self, curve_num: int, direction: int, seg_work): ) elif len(self.failed_curves) == 10: logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") - start_time = time.perf_counter() for search_term, meta_set, meta_buffer in seg_work: start = raw_content.find(search_term) if start != -1: @@ -1115,7 +1076,6 @@ def extract_segment_metadata(self, curve_num: int, direction: int, seg_work): f"Metadata dataset for key {search_term.decode('utf-8')} not found when trying to save " f"metadata for curve {curve_num}, direction {direction}" ) - self.t_proc_meta += time.perf_counter() - start_time def setup_h5_structure(self, h5file): """ @@ -1147,7 +1107,7 @@ def setup_h5_structure(self, h5file): curves_meta_group = qi_group.require_group("Curve_Metadata") curves_group = qi_group.require_group("Curves") - curve_groups = {"Data": {}, "Indicies": {}} + curve_groups = {"Data": {}, "Indices": {}} h5_datasets = {} h5_meta_datasets = {} h5_datasets_buffer = {} @@ -1177,9 +1137,9 @@ def setup_h5_structure(self, h5file): dir_group = curves_group.require_group(seg_name) h5_datasets[seg_name] = {} h5_datasets_buffer[seg_name] = {} - # Create the Data and Indicies subfolders and store their references + # Create the Data and Indices subfolders and store their references curve_groups["Data"][seg_name] = dir_group.require_group("Data") - curve_groups["Indicies"][seg_name] = dir_group.require_group("Indicies") + curve_groups["Indices"][seg_name] = dir_group.require_group("Indices") for chan in self.segment_channels: h5_datasets[seg_name][chan["name"]] = {} # For each channel, create an empty dataset @@ -1190,14 +1150,14 @@ def setup_h5_structure(self, h5file): chunks=(self.DATA_CHUNKSIZE,), dtype=np.float32, ) - h5_datasets[seg_name][chan["name"]]["Indicies"] = curve_groups["Indicies"][seg_name].create_dataset( + h5_datasets[seg_name][chan["name"]]["Indices"] = curve_groups["Indices"][seg_name].create_dataset( name=chan["name"], shape=(self.num_of_curves + 1,), maxshape=(None,), - chunks=(self.INDICIES_CHUNKSIZE,), + chunks=(self.INDICES_CHUNKSIZE,), dtype=np.int32, ) - h5_datasets_buffer[seg_name][chan["name"]] = {"Data": [], "Indicies": []} + h5_datasets_buffer[seg_name][chan["name"]] = {"Data": [], "Indices": []} return global_meta_group, h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer def get_saving_context(self): From b88ee260995354be686a99b27c0bad8f25125e04 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Sun, 3 May 2026 17:13:44 +0100 Subject: [PATCH 40/54] Skipping tests requiring large test files which cannot be added to repo --- tests/test_h5jpk.py | 18 ++++++++++++++++++ tests/test_jpk_qi.py | 5 +++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/test_h5jpk.py b/tests/test_h5jpk.py index f6995d7..ae2acda 100644 --- a/tests/test_h5jpk.py +++ b/tests/test_h5jpk.py @@ -144,6 +144,7 @@ def test_load_h5jpk( ) +@pytest.mark.skip(reason="Test files are too large to store in the repo; a remote storage solution is needed.") @pytest.mark.parametrize( ( "file_name", @@ -178,7 +179,24 @@ def test_load_h5jpk_curves( curve_direction: str, curve_targets: dict[str, tuple[int, float]], ) -> None: + """ + Test loading of curve data from a .h5-jpk file. + Parameters + ---------- + file_name : str + The name of the .h5-jpk file to load (should be located in the test resources directory). + channel : str + The channel to load curve data for. + flip_image : bool + Whether to flip the image vertically. + curve_coords : tuple[int, int] + The coordinates of the curve to load. + curve_direction : str + The direction of the curve to load. + curve_targets : dict[str, tuple[int, float]] + A dictionary mapping curve channels to their expected size and sum, used for validating the loaded curve data. + """ _, _, _, result_curve_data = h5_jpk.load_h5jpk(RESOURCES / file_name, channel, flip_image) all_curves, _, _ = result_curve_data curve_at_coords = all_curves[curve_coords[0]][curve_coords[1]] diff --git a/tests/test_jpk_qi.py b/tests/test_jpk_qi.py index a1e6654..a786d57 100644 --- a/tests/test_jpk_qi.py +++ b/tests/test_jpk_qi.py @@ -11,6 +11,7 @@ RESOURCES = BASE_DIR / "tests" / "resources" +@pytest.mark.skip(reason="Test files are too large to store in the repo; a remote storage solution is needed.") @pytest.mark.parametrize( ( "file_name", @@ -150,5 +151,5 @@ def test_load_jpk_data_file_not_found() -> None: """Ensure FileNotFound error is raised.""" with pytest.raises(FileNotFoundError): jpk_qi_loader = jpk_qi.jpk_qi_loader("noexistant_file.jpk-qi-data", "TP") - jpk_qi_loader.load() # type: ignore - jpk_qi_loader.close() # type: ignore + jpk_qi_loader.load() # type: ignore + jpk_qi_loader.close() # type: ignore From be2840ed57553de1a75894653802adc1e62567ed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 3 May 2026 16:59:59 +0000 Subject: [PATCH 41/54] [pre-commit.ci] Fixing issues with pre-commit --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index c902a84..8ea0fe9 100644 --- a/README.md +++ b/README.md @@ -166,8 +166,6 @@ frames, pixel_to_nanometre_scaling_factor, timestamp_dict, force_curves = load_h ``` -### .stp - You can open `.stp` files using the `load_stp` function. Just pass in the path to the file you want to use. From a1764fe6f6236ebcec1d75b51c11e53140be0fd8 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 4 May 2026 13:53:11 +0100 Subject: [PATCH 42/54] Fixing pre-commit problems --- AFMReader/general_loader.py | 10 +- AFMReader/h5_jpk.py | 54 ++++++- AFMReader/jpk.py | 50 +++++- AFMReader/jpk_qi.py | 142 +++++++++++++++--- .../{jpk_utils.py => lazy_data_classes.py} | 142 +++++++++++++++++- AFMReader/logging.py | 2 +- AFMReader/raw_bin.py | 8 +- README.md | 5 +- tests/test_general_loader.py | 14 +- tests/test_h5jpk.py | 6 +- tests/test_jpk_qi.py | 6 +- 11 files changed, 380 insertions(+), 59 deletions(-) rename AFMReader/{jpk_utils.py => lazy_data_classes.py} (57%) diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index e5efded..1991419 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -23,7 +23,7 @@ class LoadFile: channel : str Channel to extract from the AFM image. kwargs : dict, optional - Additional keyword arguments to pass to the specific loaders + Additional keyword arguments to pass to the specific loaders. """ def __init__(self, filepath: str | Path, channel: str, kwargs: dict | None = None): @@ -37,7 +37,7 @@ def __init__(self, filepath: str | Path, channel: str, kwargs: dict | None = Non channel : str Channel to extract from the AFM image. kwargs : dict, optional - Additional keyword arguments to pass to the specific loaders + Additional keyword arguments to pass to the specific loaders. """ self.filepath = Path(filepath) self.channel = channel @@ -59,7 +59,7 @@ def load( # noqa: C901 channel : str, optional Overriding channel to extract from the AFM image. kwargs : dict, optional - Additional keyword arguments to pass to the specific loaders + Additional keyword arguments to pass to the specific loaders. Returns ------- @@ -91,9 +91,9 @@ def load( # noqa: C901 elif self.suffix == ".h5-jpk": h5_returned = h5_jpk.load_h5jpk(self.filepath, self.channel, load_curves=not self.loaded_curves) if len(h5_returned) == 3: - image, pixel_to_nanometre_scaling_factor, _ = h5_returned + image, pixel_to_nanometre_scaling_factor, _ = h5_returned # type: ignore[misc] elif len(h5_returned) == 4: - image, pixel_to_nanometre_scaling_factor, _, curve_data = h5_returned + image, pixel_to_nanometre_scaling_factor, _, curve_data = h5_returned # type: ignore[misc] self.loaded_curves = True print( f"Loaded image with shape {image.shape} and pixel to nanometre " diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index 1de0e3b..21b8b8e 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -12,7 +12,7 @@ import numpy as np from AFMReader.logging import logger -from AFMReader.jpk_utils import ( +from AFMReader.lazy_data_classes import ( LazyMetaProxy, LazyMetadata, LazyQiData, @@ -300,6 +300,17 @@ class LazyH5QiData(LazyQiData): It behaves like a 2D array of shape (shape_y, shape_x) where each element is a dictionary containing the QI curve data for that pixel. + + Parameters + ---------- + qi_data_group : h5py.Group + The HDF5 group containing the QI curve data. + shape_x : int + The number of columns in the image. + shape_y : int + The number of rows in the image. + flip_image : bool, optional + Whether to flip the image vertically. Default is ``True``. """ def __init__(self, qi_data_group: h5py.Group, shape_x: int, shape_y: int, flip_image: bool = True): @@ -349,13 +360,13 @@ def __iter__(self): # noqa: C901 data[channel][segment] = segment_group["Data"][channel][start_idx:end_idx] for x in range(self.shape_x): curve_data = {} - for channel in data: + for channel, channel_data in data.items(): curve_data[channel] = {} - for segment in data[channel]: + for segment, segment_data in channel_data.items(): indices = indices_map[channel][segment] start_idx = int(indices[self.shape_x * y + x]) - int(indices[self.shape_x * y]) end_idx = int(indices[self.shape_x * y + x + 1]) - int(indices[self.shape_x * y]) - curve_data[channel][segment] = data[channel][segment][start_idx:end_idx] + curve_data[channel][segment] = segment_data[start_idx:end_idx] yield curve_data def _fetch_curve(self, y: int, x: int): @@ -376,7 +387,7 @@ def _fetch_curve(self, y: int, x: int): """ if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x: raise IndexError(f"Curve index out of bounds: ({x}, {y})") - curve_dict = {} + curve_dict: dict[str, dict[str, Any]] = {} if self.flip_image: y = self.shape_y - 1 - y curve_num = self.shape_x * y + x @@ -418,7 +429,22 @@ def load_all_curves(self): class LazyH5Metadata(LazyMetadata): - """A proxy class that fetches header.properties files on demand.""" + """ + A proxy class that fetches header.properties files on demand. + + Parameters + ---------- + qi_data_group : h5py.Group + The HDF5 group containing the QI curve data. + top_level_meta : dict + The top-level metadata dictionary. + shape_x : int + The number of columns in the image. + shape_y : int + The number of rows in the image. + flip_image : bool, optional + Whether to flip the image vertically. Default is ``True``. + """ def __init__( self, qi_data_group: h5py.Group, top_level_meta: dict, shape_x: int, shape_y: int, flip_image: bool = True @@ -471,6 +497,19 @@ class LazyH5MetaProxy(LazyMetaProxy): It behaves like a 2D array of shape (shape_y, shape_x) where each element is a dictionary containing the requested metadata for that pixel. + + Parameters + ---------- + qi_data_group : h5py.Group + The HDF5 group containing the QI curve data. + meta_type : str + The type of metadata to fetch ("curve" or "segment"). + shape_x : int + The number of columns in the image. + shape_y : int + The number of rows in the image. + flip_image : bool, optional + Whether to flip the image vertically. Default is ``True``. """ def __init__(self, qi_data_group: h5py.Group, meta_type: str, shape_x: int, shape_y: int, flip_image: bool = True): @@ -547,6 +586,8 @@ def load_h5jpk( The channel to extract from the .h5-jpk file. flip_image : bool, optional Whether to flip the images vertically. Default is ``True``. + load_curves : bool, optional + Whether to load QI curve data if present. Default is ``True``. Returns ------- @@ -627,7 +668,6 @@ def load_h5jpk( if load_curves: f = h5py.File(file_path, "r") - logger.debug(f"QI_Curve_Data group keys: {list(f.keys())}") logger.info(f"[{file_path.stem}] : Found Force Curves QI data in file.") qi_data_group = f["QI_Curve_Data"] channels_units = {} diff --git a/AFMReader/jpk.py b/AFMReader/jpk.py index f69cc1d..1bf18b8 100644 --- a/AFMReader/jpk.py +++ b/AFMReader/jpk.py @@ -175,6 +175,26 @@ def _get_z_scaling(tif: tifffile.tifffile, channel_idx: int, jpk_tags: dict[str, def _get_jpk_channels( file: Path | BytesIO, filename: str, file_path: Path | str, config_path: Path | str | None = None ): + """ + Retrieve the list of available channels from a JPK TIFF file. + + Parameters + ---------- + file : Path | BytesIO + Path to the JPK TIFF file. + filename : str + Name of the JPK TIFF file. + file_path : Path | str + Path to the JPK TIFF file. + config_path : Path | str | None, optional + Path to a configuration file. If ''None'' (default) then the packages + default configuration is loaded from ''default_config.yaml''. + + Returns + ------- + dict + Dictionary of available channels with their corresponding page indices. + """ jpk_tags = _load_jpk_tags(config_path) try: tif = tifffile.TiffFile(file) @@ -216,7 +236,7 @@ def get_jpk_channels(file_path: Path | str, config_path: Path | str | None = Non def load_jpk( - file_path: Path | str, channel: str, config_path: Path | str | None = None, flip_image: bool | None = True + file_path: Path | str, channel: str, config_path: Path | str | None = None, flip_image: bool = True ) -> tuple[np.ndarray, float]: """ Load image from JPK Instruments .jpk files. @@ -230,7 +250,7 @@ def load_jpk( config_path : Path | str | None Path to a configuration file. If ''None'' (default) then the packages default configuration is loaded from ''default_config.yaml''. - flip_image : bool, optional + flip_image : bool Whether to flip the image vertically. Default is ``True``. Returns @@ -277,6 +297,32 @@ def _load_jpk( flip_image: bool = True, convert_to_nm: bool = True, ) -> tuple[np.ndarray, float]: + """ + Load image data and pixel scaling from a JPK TIFF file for a given channel. + + Parameters + ---------- + file : Path | BytesIO + Path to the JPK TIFF file. + filename : str + Name of the JPK TIFF file. + channel : str + The channel to extract from the JPK TIFF file. + file_suffix : str + The file suffix of the JPK TIFF file. + config_path : Path | str | None, optional + Path to a configuration file. If ''None'' (default) then the packages default configuration is + loaded from ''default_config.yaml''. + flip_image : bool, optional + Whether to flip the image vertically. Default is True. + convert_to_nm : bool, optional + Whether to convert the image to nanometres. Default is True. + + Returns + ------- + tuple[np.ndarray, float] + A tuple containing the image and its pixel to nanometre scaling value. + """ jpk_tags = _load_jpk_tags(config_path) try: tif = tifffile.TiffFile(file) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index d3d53fc..56d5123 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -21,13 +21,30 @@ import h5py import psutil -from AFMReader.jpk_utils import LazyMetadata, LazyMetaProxy, LazyQiData +from AFMReader.lazy_data_classes import LazyMetadata, LazyMetaProxy, LazyQiData from AFMReader.logging import logger from AFMReader import jpk class LazyJpkQiData(LazyQiData): - """A proxy class that behaves like a 2D list of shape (shape_y, shape_x) but fetches .dat file data on demand.""" + """ + A proxy class that behaves like a 2D list of shape (shape_y, shape_x) but fetches .dat file data on demand. + + Parameters + ---------- + filepath : str + Path to the .jpk file. + shape_x : int + Number of columns in the image. + shape_y : int + Number of rows in the image. + channel_scaling : dict + Dictionary containing scaling information for each channel. + archive : zipfile.ZipFile + The opened ZIP archive containing the .dat files. + flip_image : bool, optional + Whether to flip the image vertically. Default is ``True``. + """ def __init__(self, filepath, shape_x: int, shape_y: int, channel_scaling, archive, flip_image: bool = True): """ @@ -72,7 +89,7 @@ def _fetch_curve(self, y: int, x: int): Returns ------- - curve_data : dict + dict Dictionary containing the curve data for the specified pixel. """ if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x: @@ -108,7 +125,7 @@ def load_all_curves(self): Returns ------- - all_curve_data : list + list A 2D list containing dictionaries with curve data for each pixel. """ all_curve_data = [[None for _ in range(self.shape_x)] for _ in range(self.shape_y)] @@ -130,6 +147,21 @@ class LazyQiMetadata(LazyMetadata): It behaves like a 2D array of shape (shape_y, shape_x) where each element is a dictionary containing the requested metadata for that pixel. + + Parameters + ---------- + filepath : str + Path to the .jpk file. + top_level_meta : dict + Dictionary containing the top-level metadata extracted from the header files. + archive : zipfile.ZipFile + The opened ZIP archive containing the JPK file contents. + shape_x : int + The number of columns in the image. + shape_y : int + The number of rows in the image. + flip_image : bool, optional + Whether to flip the image vertically. Default is True. """ def __init__(self, filepath, top_level_meta, archive, shape_x: int, shape_y: int, flip_image: bool = True): @@ -185,8 +217,25 @@ def __getitem__(self, key): class LazyQiMetaProxy(LazyMetaProxy): """ - A proxy class that behaves like a 2D array of shape (shape_y, shape_x) - but fetches header.properties files on demand for curves or segments. + A proxy class to represent curve and segment metadata. + + It behaves like a 2D list of shape (shape_y, shape_x) but fetches header.properties files on demand for + curves or segments. + + Parameters + ---------- + filepath : str + Path to the .jpk file. + meta_type : str + The type of metadata to fetch ('curve' or 'segment'). + archive : zipfile.ZipFile + The opened ZIP archive containing the JPK file contents. + shape_x : int + The number of columns in the image. + shape_y : int + The number of rows in the image. + flip_image : bool, optional + Whether to flip the image vertically. Default is True. """ def __init__(self, filepath, meta_type, archive, shape_x: int, shape_y: int, flip_image: bool = True): @@ -314,7 +363,22 @@ def _get_channel_scaling(props, channel_index): class jpk_qi_loader: - """Class for readability and improving modularity in the load jpk qi data function.""" + """ + Class for readability and improving modularity in the load jpk qi data function. + + Parameters + ---------- + filepath : Path | str + The path to the .jpk-qi file to be loaded. + channel : str | None, optional + The specific channel to be extracted (e.g., "measuredHeight"). Default is None. + config_path : Path | str | None, optional + The path to the configuration file, if any. Default is None. + flip_image : bool | None, optional + Whether to flip the image vertically. Default is True. + save_as_h5 : bool, optional + Whether to save the loaded data as an H5 file. Default is False. + """ def __init__( self, @@ -347,7 +411,7 @@ def __init__( self.save_as_h5 = save_as_h5 # Open the ZIP archive once and keep it open for the duration of the loading process - self.qi_archive = zipfile.ZipFile(self.filepath, "r") + self.qi_archive = zipfile.ZipFile(self.filepath, "r") # pylint: disable=consider-using-with logger.info(f"Opened JPK QI archive at {self.filepath}") self.namelist = self.qi_archive.namelist() # Set path to the .jpk-qi-image file within the archive for later use @@ -368,8 +432,8 @@ def __init__( # Just the top level metadata extracted from the header files self.top_level_meta: dict[str, Any] = {} - # A dictionary containing all metadata - self.full_metadata: dict[str, Any] = {} + # A lazy reference containing all metadata + self.full_metadata: LazyQiMetadata | None = None # A 2D list of curve data dictionaries self.curve_data: Any = None # A lookup for channel name to unit to be returned @@ -378,6 +442,7 @@ def __init__( self.segment_channels: list[dict[str, Any]] = [] self.curve_meta: dict[str, Any] = {} self.segment_meta: dict[str, Any] = {} + # Define the image shape and size attributes self.size_x: float | None = None self.size_y: float | None = None @@ -665,8 +730,10 @@ def get_curves_sample(self): def predict_total_points(self): """ - Predict the total number of points for each channel and segment by sampling a subset of curves and - extrapolating based on the maximum number of points found in the sample. + Predict the total number of points for each channel and segment. + + This is done by sampling a subset of curves and extrapolating based on the maximum number + of points found in the sample. Returns ------- @@ -710,10 +777,10 @@ def predict_total_points(self): def get_changing_keys(self): # noqa: C901 """ - Check a sample of curves to see which metadata keys change across curves and segments, so we can - extract only the changing keys for each curve and segment. + Check a sample of curves to see which metadata keys change across curves and segments. - None changing keys are moved to the top level metadata and not extracted for each curve/segment. + This allows us to extract only the changing keys for each curve and segment. + Non-changing keys are moved to the top-level metadata and not extracted for each curve/segment. Returns ------- @@ -778,7 +845,7 @@ def get_collated_metadata(self): Returns ------- - collated_meta : dict + dict A dictionary containing the collated metadata. """ collated_meta = {} @@ -794,10 +861,19 @@ def get_image( """ Process the flat curve data dictionary into a 2D list structure matching the image dimensions. + Parameters + ---------- + overide_channel : str | None, optional + Channel name to use instead of the instance default. Default is None. + convert_to_nm : bool, optional + Whether to convert the image data to nanometres. Default is True. + flip_image : bool | None, optional + Whether to flip the image vertically. Defaults to the instance setting if None. + Returns ------- - image : np.ndarray - A 2D array representing the image data. + tuple[np.ndarray, float] + A 2D array representing the image data and the pixel-to-nm scaling factor. """ # Get channel and flip_image parameters channel = str(overide_channel) if overide_channel else str(self.channel) @@ -876,7 +952,8 @@ def save_lite_data(self): # Format name and reshape image (flattened frame stack) dataset_name = h5_channel.split("_")[0].capitalize() # Include all the channels including the calculated channel - # TODO make this slightly faster by remembering we have load a channel already but difficult cause of scaling + # TODO make this slightly faster by remembering we have load a channel already but + # difficult cause of scaling channel_image, _ = self.get_image(overide_channel=h5_channel, convert_to_nm=False, flip_image=False) frame_stack = channel_image.flatten().reshape(-1, 1) @@ -887,14 +964,18 @@ def save_lite_data(self): def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, direction: int, chan_name: str): """ - Extract the data from a .dat file in the JPK QI archive, applies the appropriate scaling, and saves it to the internal data structure and h5 dataset if required. + Extract the data from a .dat file in the JPK QI archive. + + Applies the appropriate scaling and saves it to the internal data structure and h5 dataset if required. Parameters ---------- h5_datasets : dict - A dictionary containing the h5 datasets for each channel and segment direction, used for saving the data + A dictionary containing the h5 datasets for each channel and segment direction, used for saving the + data. h5_datasets_buffer : dict - A dictionary containing the buffer for each h5 dataset, used for temporary storage before writing to the dataset + A dictionary containing the buffer for each h5 dataset, used for temporary storage before writing to + the dataset. curve_num : int The curve number associated with the .dat file, parsed from the filename. direction : int @@ -1008,7 +1089,6 @@ def extract_curve_metadata(self, curve_num: int, curve_work): logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.") for search_term, meta_set, meta_buffer in curve_work: - start_time = time.perf_counter() # Find the location of the metadata value in the raw bytes start = raw_bytes.find(search_term) # If found, extract the actual value @@ -1255,13 +1335,27 @@ def close(self): self.curve_meta = {} self.segment_meta = {} self.top_level_meta = {} - self.full_metadata = {} self.failed_curves = set() self.points_for_channel_segment = {} self.namelist = [] def _make_num_min_characters(num: int, min_chars: int = 3): + """ + Zero-pad an integer to a minimum number of characters. + + Parameters + ---------- + num : int + The integer to pad. + min_chars : int + The minimum number of characters the resulting string should have. Default is 3. + + Returns + ------- + str + The zero-padded string. + """ string_num = str(num) if len(string_num) >= min_chars: return string_num diff --git a/AFMReader/jpk_utils.py b/AFMReader/lazy_data_classes.py similarity index 57% rename from AFMReader/jpk_utils.py rename to AFMReader/lazy_data_classes.py index 5cf494b..e474e34 100644 --- a/AFMReader/jpk_utils.py +++ b/AFMReader/lazy_data_classes.py @@ -9,6 +9,15 @@ class LazyQiData: It behaves like a 2D array of shape (shape_y, shape_x) where each element is a dictionary containing the QI curve data for that pixel. + + Parameters + ---------- + shape_x : int + The number of columns in the image. + shape_y : int + The number of rows in the image. + flip_image : bool, optional + Whether to flip the image vertically. Default is True. """ def __init__(self, shape_x: int, shape_y: int, flip_image: bool = True): @@ -47,11 +56,45 @@ def __getitem__(self, y: int): """ class RowProxy: + """ + A proxy class for a single row of the QI data that fetches curve data on demand. + + Parameters + ---------- + parent : LazyQiData + The parent LazyQiData instance. + y : int + The row index. + """ + def __init__(self, parent, y): + """ + Initialize RowProxy with parent LazyQiData and row index. + + Parameters + ---------- + parent : LazyQiData + The parent LazyQiData instance. + y : int + The row index. + """ self.parent = parent self.y = y def __getitem__(self, x: int): + """ + Fetch curve data for column x in this row. + + Parameters + ---------- + x : int + The column index. + + Returns + ------- + dict + The QI curve data for the specified pixel. + """ return self.parent._fetch_curve(self.y, x) return RowProxy(self, y) @@ -60,8 +103,7 @@ def _fetch_curve(self, y: int, x: int): """ Fetch the QI curve data for a specific pixel. - Should be implemented by subclasses to define how the curve data is retrieved - from the underlying data source. + Should be implemented by subclasses to define how the curve data is retrieved from the underlying data source. Parameters ---------- @@ -79,7 +121,20 @@ def _fetch_curve(self, y: int, x: int): class LazyMetadata: - """A proxy class that fetches metadata on demand. Superclass for metadata proxy classes.""" + """ + A proxy class that fetches metadata on demand. Superclass for metadata proxy classes. + + Parameters + ---------- + top_level_meta : dict + The top-level metadata dictionary. + shape_x : int + The number of columns in the image. + shape_y : int + The number of rows in the image. + flip_image : bool, optional + Whether to flip the image vertically. Default is True. + """ def __init__(self, top_level_meta: dict, shape_x: int, shape_y: int, flip_image: bool = True): """ @@ -108,6 +163,17 @@ class LazyMetaProxy: It behaves like a 2D array of shape (shape_y, shape_x) where each element is a dictionary containing the metadata for that pixel. + + Parameters + ---------- + meta_type : str + The type of metadata to fetch ("curve" or "segment"). + shape_x : int + The number of columns in the image. + shape_y : int + The number of rows in the image. + flip_image : bool, optional + Whether to flip the image vertically. Default is True. """ def __init__(self, meta_type: str, shape_x: int, shape_y: int, flip_image: bool = True): @@ -137,25 +203,94 @@ def __getitem__(self, y: int): This allows for lazy loading. Nested proxy objects are used to allow for fetching segment metadata which requires both x and y indices as well as the direction of the segment (approach or retract). + + Parameters + ---------- + y : int + The row index. + + Returns + ------- + RowProxy + A proxy object for the specified row. """ class RowProxy: + """ + A proxy class for a single row of the metadata that fetches metadata on demand. + + Parameters + ---------- + parent : LazyMetaProxy + The parent LazyMetaProxy instance. + y : int + The row index. + """ + def __init__(self, parent, y): + """ + Initialize RowProxy with parent LazyMetaProxy and row index. + + Parameters + ---------- + parent : LazyMetaProxy + The parent LazyMetaProxy instance. + y : int + The row index. + """ self.parent = parent self.y = y def __getitem__(self, x): + """ + Fetch metadata for column x in this row. + + Parameters + ---------- + x : int + The column index. + + Returns + ------- + dict or SegmentMetaProxy + The metadata for the specified column, or a proxy for segment metadata. + """ if self.parent.meta_type == "curve": return self.parent._fetch_meta(self.y, x) if self.parent.meta_type == "segment": class SegmentMetaProxy: + """A proxy class for a single pixel's segment metadata that fetches metadata on demand. + + Parameters + ---------- + parent : RowProxy + The parent RowProxy instance. + y : int + The row index. + x : int + The column index. + """ + def __init__(self, parent, y, x): self.parent = parent self.y = y self.x = x def __getitem__(self, direction): + """ + Fetch metadata for the specified segment direction. + + Parameters + ---------- + direction : int + The direction of the segment ("approach" or "retract"). + + Returns + ------- + dict + The metadata for the specified segment direction. + """ return self.parent.parent._fetch_meta(self.y, self.x, direction) return SegmentMetaProxy(self, self.y, x) @@ -166,6 +301,7 @@ def __getitem__(self, direction): def _fetch_meta(self, y: int, x: int, direction: int | None = None): """ Fetch the metadata for a specific pixel. + Should be implemented by subclasses to define how the metadata is retrieved from the underlying data source. diff --git a/AFMReader/logging.py b/AFMReader/logging.py index f722dff..1655df7 100644 --- a/AFMReader/logging.py +++ b/AFMReader/logging.py @@ -7,7 +7,7 @@ logger.remove() # Set the format to have blue time, green file, module, function and line, and white message logger.add( - lambda msg: sys.stderr.write(msg), + lambda msg: sys.stderr.write(msg), # pylint: disable=unnecessary-lambda colorize=True, format="{time:HH:mm:ss} | {level} |" "{file}:{module}:" diff --git a/AFMReader/raw_bin.py b/AFMReader/raw_bin.py index b963afe..269b62c 100644 --- a/AFMReader/raw_bin.py +++ b/AFMReader/raw_bin.py @@ -1,10 +1,12 @@ """Module to decode and load .bin AFM files into Python Numpy arrays.""" -import numpy as np import math -from .logging import logger from pathlib import Path +import numpy as np + +from .logging import logger + # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,fixme DTYPE_MAP = { @@ -79,6 +81,8 @@ def load_bin( if None in [shape_x, shape_y]: dimension = int(math.sqrt(len(flat_data))) shape_x, shape_y = dimension, dimension + assert shape_x is not None and shape_y is not None # noqa: PT018 + assert size_x is not None and size_y is not None # noqa: PT018 if shape_x * shape_y != len(flat_data): logger.error(f"Loading binary file {filepath.stem} did not receive a shape and is not square") image = flat_data.reshape((shape_x, shape_y)) diff --git a/README.md b/README.md index 8ea0fe9..9d48ca7 100644 --- a/README.md +++ b/README.md @@ -156,8 +156,8 @@ from AFMReader.h5_jpk import load_h5jpk frames, pixel_to_nanometre_scaling_factor, timestamp_dict = load_h5jpk(file_path="./my_jpk_file.h5-jpk", channel="height_trace") ``` -If your `.h5-jpk` file was created from a `.jpk-qi-data` file, then the curve data can be read like so. Note that reading force -curves like this will keep the file open as the force curves are lazy loaded from your hard drive. +If your `.h5-jpk` file was created from a `.jpk-qi-data` file, then the curve data can be read like so. Note that reading +force curves like this will keep the file open as the force curves are lazy loaded from your hard drive. ```python from AFMReader.h5_jpk import load_h5jpk @@ -165,6 +165,7 @@ from AFMReader.h5_jpk import load_h5jpk frames, pixel_to_nanometre_scaling_factor, timestamp_dict, force_curves = load_h5jpk(file_path="./my_jpk_file.h5-jpk", channel="height_trace") ``` +### .stp You can open `.stp` files using the `load_stp` function. Just pass in the path to the file you want to use. diff --git a/tests/test_general_loader.py b/tests/test_general_loader.py index 276de15..69c3891 100644 --- a/tests/test_general_loader.py +++ b/tests/test_general_loader.py @@ -1,5 +1,6 @@ """Test the general loader module.""" +import re from pathlib import Path import numpy as np @@ -129,14 +130,13 @@ def test_load(capsys: pytest.CaptureFixture, filepath: Path, channel: str, error: bool, message: str) -> None: """Test loading of all (asd, gwy, ibw, jpk, spm, stp, top, topostats) filetypes.""" loader = general_loader.LoadFile(filepath, channel) - try: - image, px2nm = loader.load() - # check array and px2nm returned + if error: + with pytest.raises(ValueError, match=re.escape(message)): + loader.load() + else: + image, px2nm = loader.load() # type: ignore[misc] assert isinstance(image, np.ndarray) assert isinstance(px2nm, float) - except ValueError as e: - if error: - assert message in str(e) # check output logs captured = capsys.readouterr() assert message in captured.err @@ -156,5 +156,5 @@ def test_load_filenotfounderror(filepath: Path) -> None: loader = general_loader.LoadFile(filepath, "channel") with pytest.raises(FileNotFoundError) as execinfo: # noqa: PT012 - _, _ = loader.load() + _, _ = loader.load() # type: ignore[misc] assert "[not_a_real_file] FileNotFoundError" in execinfo.value diff --git a/tests/test_h5jpk.py b/tests/test_h5jpk.py index ae2acda..a9bc2df 100644 --- a/tests/test_h5jpk.py +++ b/tests/test_h5jpk.py @@ -1,5 +1,7 @@ """Test the loading of .5h-jpk files.""" +# mypy: disable-error-code="arg-type,index" + from pathlib import Path import numpy as np @@ -127,7 +129,7 @@ def test_load_h5jpk( image_sum: float, ) -> None: """Test the normal operation of loading a .h5-jpk file.""" - result_image, result_pixel_to_nm_scaling, results_timestamps = h5_jpk.load_h5jpk( + result_image, result_pixel_to_nm_scaling, results_timestamps = h5_jpk.load_h5jpk( # type: ignore[misc] RESOURCES / file_name, channel, flip_image ) @@ -197,7 +199,7 @@ def test_load_h5jpk_curves( curve_targets : dict[str, tuple[int, float]] A dictionary mapping curve channels to their expected size and sum, used for validating the loaded curve data. """ - _, _, _, result_curve_data = h5_jpk.load_h5jpk(RESOURCES / file_name, channel, flip_image) + _, _, _, result_curve_data = h5_jpk.load_h5jpk(RESOURCES / file_name, channel, flip_image) # type: ignore[misc] all_curves, _, _ = result_curve_data curve_at_coords = all_curves[curve_coords[0]][curve_coords[1]] for curve_channel, (expected_size, expected_sum) in curve_targets.items(): diff --git a/tests/test_jpk_qi.py b/tests/test_jpk_qi.py index a786d57..5dc3f63 100644 --- a/tests/test_jpk_qi.py +++ b/tests/test_jpk_qi.py @@ -112,7 +112,7 @@ ), ], ) -def test_load_jpk_qi_data( +def test_load_jpk_qi_data( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals file_name: str, channel: str, pixel_to_nm_scaling: float, @@ -150,6 +150,4 @@ def test_load_jpk_qi_data( def test_load_jpk_data_file_not_found() -> None: """Ensure FileNotFound error is raised.""" with pytest.raises(FileNotFoundError): - jpk_qi_loader = jpk_qi.jpk_qi_loader("noexistant_file.jpk-qi-data", "TP") - jpk_qi_loader.load() # type: ignore - jpk_qi_loader.close() # type: ignore + jpk_qi.jpk_qi_loader("noexistant_file.jpk-qi-data", "TP") From 603d7bf4e595ae35a1d76d34bcef7720c0710a61 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Fri, 22 May 2026 18:12:29 +0100 Subject: [PATCH 43/54] chore: removing print statements --- AFMReader/general_loader.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index 1991419..1877648 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -95,11 +95,6 @@ def load( # noqa: C901 elif len(h5_returned) == 4: image, pixel_to_nanometre_scaling_factor, _, curve_data = h5_returned # type: ignore[misc] self.loaded_curves = True - print( - f"Loaded image with shape {image.shape} and pixel to nanometre " - f"scaling factor {pixel_to_nanometre_scaling_factor}" - ) - print(f"Image has max value {image.max()} and min value {image.min()}") return image, pixel_to_nanometre_scaling_factor, curve_data else: logger.error(f"Loading h5-jpk file returned unexpected number of items: {len(h5_returned)}") From 43640ac9b98e75c3da728608d9ba512fdfc9b180 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Sat, 23 May 2026 11:33:14 +0100 Subject: [PATCH 44/54] tests: adding tests for get channel functions --- tests/test_asd.py | 15 +++++ tests/test_general_loader.py | 118 +++++++++++++++++++++++++++++++++++ tests/test_gwy.py | 27 ++++++++ tests/test_h5jpk.py | 28 +++++++++ tests/test_ibw.py | 27 ++++++++ tests/test_jpk.py | 39 ++++++++++++ tests/test_spm.py | 26 ++++++++ 7 files changed, 280 insertions(+) diff --git a/tests/test_asd.py b/tests/test_asd.py index 3f411df..8af1b6e 100644 --- a/tests/test_asd.py +++ b/tests/test_asd.py @@ -35,3 +35,18 @@ def test_load_asd_file_not_found() -> None: """Ensure FileNotFound error is raised.""" with pytest.raises(FileNotFoundError): asd.load_asd("nonexistant_file.asd", channel="TP") + + +@pytest.mark.parametrize( + ("file_name", "expected_channels"), + [ + pytest.param("sample_0.asd", ["TP", "PH"], id="sample_0.asd"), + pytest.param("sample_1.asd", ["TP", "PH"], id="sample_1.asd"), + pytest.param("extra_sample.asd", ["TP", "PH"], id="extra_sample.asd"), + ], +) +def test_get_asd_channels(file_name: str, expected_channels: list[str]) -> None: + """Test get_asd_channels.""" + file_path = RESOURCES / file_name + channels = asd.get_asd_channels(file_path) + assert sorted(channels) == sorted(expected_channels) diff --git a/tests/test_general_loader.py b/tests/test_general_loader.py index 69c3891..66bb7a6 100644 --- a/tests/test_general_loader.py +++ b/tests/test_general_loader.py @@ -2,6 +2,7 @@ import re from pathlib import Path +from typing import Any import numpy as np import pytest @@ -158,3 +159,120 @@ def test_load_filenotfounderror(filepath: Path) -> None: with pytest.raises(FileNotFoundError) as execinfo: # noqa: PT012 _, _ = loader.load() # type: ignore[misc] assert "[not_a_real_file] FileNotFoundError" in execinfo.value + + +@pytest.mark.parametrize( + ("file_name", "expected"), + [ + pytest.param("sample_0.asd", ["TP", "PH"], id="asd"), + pytest.param( + "sample_0.gwy", + [ + "ZSensor", + "Peak Force Error", + "Stiffness", + "LogStiffness", + "Adhesion", + "Deformation", + "Dissipation", + "Height", + ], + id="gwy", + ), + pytest.param( + "sample_0.ibw", + [ + "HeightTracee", + "HeightRetrace", + "ZSensorTrace", + "ZSensorRetrace", + "UserIn0Trace", + "UserIn0Retrace", + "UserIn1Trace", + "UserIn1Retrace", + ], + id="ibw", + ), + pytest.param( + "sample_0.jpk", + { + "height_retrace": 1, + "measuredHeight_retrace": 2, + "amplitude_retrace": 3, + "phase_retrace": 4, + "error_retrace": 5, + "height_trace": 6, + "measuredHeight_trace": 7, + "amplitude_trace": 8, + "phase_trace": 9, + "error_trace": 10, + }, + id="jpk", + ), + pytest.param( + "sample_0.jpk-qi-image", + { + "measuredHeight_trace": 3, + "vDeflection_trace": 2, + "adhesion_trace": 4, + "height_trace": 5, + "slope_trace": 6, + }, + id="jpk-qi-image", + ), + pytest.param( + "sample_0.spm", + [ + "Height Sensor", + "Peak Force Error", + "DMTModulus", + "LogDMTModulus", + "Adhesion", + "Deformation", + "Dissipation", + "Height", + ], + id="spm", + ), + pytest.param( + "sample_0.h5-jpk", + [ + "error_trace", + "height_trace", + "phase_retrace", + "height_retrace", + "measuredheight_trace", + "error_retrace", + "amplitude_trace", + "amplitude_retrace", + "phase_trace", + ], + id="h5-jpk sample_0", + ), + pytest.param( + "sample_0_1.topostats", + ["image", "image_original"], + id="topostats 0.1", + ), + pytest.param( + "sample_0_2.topostats", + ["image", "image_original"], + id="topostats 0.2", + ), + ], +) +def test_get_available_channels_all_formats(file_name: str, expected: Any) -> None: + """Test get_available_channels for all formats.""" + file_path = RESOURCES / file_name + loader = general_loader.LoadFile(file_path, channel="") + channels = loader.get_available_channels() + + if isinstance(expected, list): + assert sorted(channels) == sorted(expected) + elif isinstance(expected, tuple) and len(expected) == 2: + assert isinstance(channels, tuple) + assert len(channels) == 2 + assert channels[0] == expected[0] + assert channels[1] == expected[1] + else: + assert channels == expected diff --git a/tests/test_gwy.py b/tests/test_gwy.py index 9b81a33..5d3aa1a 100644 --- a/tests/test_gwy.py +++ b/tests/test_gwy.py @@ -117,3 +117,30 @@ def test_load_gwy_file_not_found() -> None: """Ensure FileNotFound error is raised.""" with pytest.raises(FileNotFoundError): gwy.load_gwy("nonexistant_file.gwy", channel="TP") + + +@pytest.mark.parametrize( + ("file_name", "expected_channels"), + [ + pytest.param( + "sample_0.gwy", + [ + "ZSensor", + "Peak Force Error", + "Stiffness", + "LogStiffness", + "Adhesion", + "Deformation", + "Dissipation", + "Height", + ], + id="sample_0.gwy", + ), + ], +) +def test_get_gwy_channels(file_name: str, expected_channels: list[str]) -> None: + """Test get_gwy_channels.""" + file_path = RESOURCES / file_name + channels = gwy.get_gwy_channels(file_path) + # The order might not be guaranteed, so sort before comparing + assert sorted(channels) == sorted(expected_channels) diff --git a/tests/test_h5jpk.py b/tests/test_h5jpk.py index a9bc2df..955f876 100644 --- a/tests/test_h5jpk.py +++ b/tests/test_h5jpk.py @@ -212,3 +212,31 @@ def test_load_h5jpk_file_not_found() -> None: """Ensure FileNotFound error is raised.""" with pytest.raises(FileNotFoundError): h5_jpk.load_h5jpk("nonexistant_file.h5-jpk", channel="TP") + + +@pytest.mark.parametrize( + ("file_name", "expected_channels"), + [ + pytest.param( + "sample_0.h5-jpk", + [ + "error_trace", + "height_trace", + "phase_retrace", + "height_retrace", + "measuredheight_trace", + "error_retrace", + "amplitude_trace", + "amplitude_retrace", + "phase_trace", + ], + id="sample_0.h5-jpk", + ), + ], +) +def test_get_h5jpk_channels(file_name: str, expected_channels: list[str]) -> None: + """Test get_h5jpk_channels.""" + file_path = RESOURCES / file_name + channels = h5_jpk.get_h5jpk_channels(file_path) + # The order might not be guaranteed, so sort before comparing + assert sorted(channels) == sorted(expected_channels) diff --git a/tests/test_ibw.py b/tests/test_ibw.py index 330bdb4..34dfdc2 100644 --- a/tests/test_ibw.py +++ b/tests/test_ibw.py @@ -41,3 +41,30 @@ def test_load_ibw_file_not_found() -> None: """Ensure FileNotFound error is raised.""" with pytest.raises(FileNotFoundError): ibw.load_ibw("nonexistant_file.ibw", channel="TP") + + +@pytest.mark.parametrize( + ("file_name", "expected_channels"), + [ + pytest.param( + "sample_0.ibw", + [ + "HeightTracee", + "HeightRetrace", + "ZSensorTrace", + "ZSensorRetrace", + "UserIn0Trace", + "UserIn0Retrace", + "UserIn1Trace", + "UserIn1Retrace", + ], + id="sample_0.ibw", + ), + ], +) +def test_get_ibw_channels(file_name: str, expected_channels: list[str]) -> None: + """Test get_ibw_channels.""" + file_path = RESOURCES / file_name + channels = ibw.get_ibw_channels(file_path) + # The order might not be guaranteed, so sort before comparing + assert sorted(channels) == sorted(expected_channels) diff --git a/tests/test_jpk.py b/tests/test_jpk.py index 2e574f8..b84b05e 100644 --- a/tests/test_jpk.py +++ b/tests/test_jpk.py @@ -89,3 +89,42 @@ def test_load_jpk_file_not_found() -> None: """Ensure FileNotFound error is raised.""" with pytest.raises(FileNotFoundError): jpk.load_jpk("nonexistant_file.jpk", channel="TP") + + +@pytest.mark.parametrize( + ("file_name", "expected"), + [ + pytest.param( + "sample_0.jpk", + { + "height_retrace": 1, + "measuredHeight_retrace": 2, + "amplitude_retrace": 3, + "phase_retrace": 4, + "error_retrace": 5, + "height_trace": 6, + "measuredHeight_trace": 7, + "amplitude_trace": 8, + "phase_trace": 9, + "error_trace": 10, + }, + id="sample_0.jpk", + ), + pytest.param( + "sample_0.jpk-qi-image", + { + "measuredHeight_trace": 3, + "vDeflection_trace": 2, + "adhesion_trace": 4, + "height_trace": 5, + "slope_trace": 6, + }, + id="sample_0.jpk-qi-image", + ), + ], +) +def test_get_jpk_channels(file_name: str, expected: dict[str, int]) -> None: + """Test get_jpk_channels.""" + file_path = RESOURCES / file_name + channels = jpk.get_jpk_channels(file_path) + assert channels == expected diff --git a/tests/test_spm.py b/tests/test_spm.py index 466be2c..f24b276 100644 --- a/tests/test_spm.py +++ b/tests/test_spm.py @@ -145,3 +145,29 @@ def test_load_spm_channel_not_found( spm.load_spm(RESOURCES / "sample_0.spm", channel) captured = capsys.readouterr() assert message in captured.err + + +@pytest.mark.parametrize( + ("file_name", "expected_channels"), + [ + pytest.param( + "sample_0.spm", + [ + "Height Sensor", + "Peak Force Error", + "DMTModulus", + "LogDMTModulus", + "Adhesion", + "Deformation", + "Dissipation", + "Height", + ], + id="sample_0.spm", + ), + ], +) +def test_get_spm_channels(file_name: str, expected_channels: list[str]) -> None: + """Test get_spm_channels.""" + file_path = RESOURCES / file_name + channels = spm.get_spm_channels(file_path) + assert channels == expected_channels From 014654206097520c718be6079907259ebdbfc827 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Sat, 23 May 2026 13:00:10 +0100 Subject: [PATCH 45/54] fix: stopped unnecessary redefining of nested row proxy classes --- AFMReader/lazy_data_classes.py | 220 ++++++++++++++++----------------- 1 file changed, 109 insertions(+), 111 deletions(-) diff --git a/AFMReader/lazy_data_classes.py b/AFMReader/lazy_data_classes.py index e474e34..8301556 100644 --- a/AFMReader/lazy_data_classes.py +++ b/AFMReader/lazy_data_classes.py @@ -38,26 +38,21 @@ def __init__(self, shape_x: int, shape_y: int, flip_image: bool = True): self.dims = (shape_y, shape_x) self.flip_image = flip_image - def __getitem__(self, y: int): + class _RowProxy: """ - Return a proxy object for the specified row that can be indexed to fetch curve data for each pixel in that row. - - This allows for lazy loading. + A proxy class for a single row of the QI data that fetches curve data on demand. Parameters ---------- + parent : LazyQiData + The parent LazyQiData instance. y : int The row index. - - Returns - ------- - RowProxy - A proxy object for the specified row. """ - class RowProxy: + def __init__(self, parent, y): """ - A proxy class for a single row of the QI data that fetches curve data on demand. + Initialize _RowProxy with parent LazyQiData and row index. Parameters ---------- @@ -66,38 +61,42 @@ class RowProxy: y : int The row index. """ + self.parent = parent + self.y = y - def __init__(self, parent, y): - """ - Initialize RowProxy with parent LazyQiData and row index. + def __getitem__(self, x: int): + """ + Fetch curve data for column x in this row. - Parameters - ---------- - parent : LazyQiData - The parent LazyQiData instance. - y : int - The row index. - """ - self.parent = parent - self.y = y + Parameters + ---------- + x : int + The column index. - def __getitem__(self, x: int): - """ - Fetch curve data for column x in this row. + Returns + ------- + dict + The QI curve data for the specified pixel. + """ + return self.parent._fetch_curve(self.y, x) - Parameters - ---------- - x : int - The column index. + def __getitem__(self, y: int): + """ + Return a proxy object for the specified row that can be indexed to fetch curve data for each pixel in that row. - Returns - ------- - dict - The QI curve data for the specified pixel. - """ - return self.parent._fetch_curve(self.y, x) + This allows for lazy loading. - return RowProxy(self, y) + Parameters + ---------- + y : int + The row index. + + Returns + ------- + _RowProxy + A proxy object for the specified row. + """ + return self._RowProxy(self, y) def _fetch_curve(self, y: int, x: int): """ @@ -196,107 +195,106 @@ def __init__(self, meta_type: str, shape_x: int, shape_y: int, flip_image: bool self.shape_y = shape_y self.flip_image = flip_image - def __getitem__(self, y: int): + class _RowProxy: """ - Return a proxy object for the specified row that can be indexed to fetch metadata for each pixel in that row. - - This allows for lazy loading. - Nested proxy objects are used to allow for fetching segment metadata which requires both x and y indices - as well as the direction of the segment (approach or retract). + A proxy class for a single row of the metadata that fetches metadata on demand. Parameters ---------- + parent : LazyMetaProxy + The parent LazyMetaProxy instance. y : int The row index. - - Returns - ------- - RowProxy - A proxy object for the specified row. """ - class RowProxy: + class _SegmentMetaProxy: """ - A proxy class for a single row of the metadata that fetches metadata on demand. + A proxy class for a single pixel's segment metadata that fetches metadata on demand. Parameters ---------- - parent : LazyMetaProxy - The parent LazyMetaProxy instance. + parent : _RowProxy + The parent _RowProxy instance. y : int The row index. + x : int + The column index. """ - def __init__(self, parent, y): - """ - Initialize RowProxy with parent LazyMetaProxy and row index. - - Parameters - ---------- - parent : LazyMetaProxy - The parent LazyMetaProxy instance. - y : int - The row index. - """ + def __init__(self, parent, y, x): self.parent = parent self.y = y + self.x = x - def __getitem__(self, x): + def __getitem__(self, direction): """ - Fetch metadata for column x in this row. + Fetch metadata for the specified segment direction. Parameters ---------- - x : int - The column index. + direction : int + The direction of the segment ("approach" or "retract"). Returns ------- - dict or SegmentMetaProxy - The metadata for the specified column, or a proxy for segment metadata. + dict + The metadata for the specified segment direction. """ - if self.parent.meta_type == "curve": - return self.parent._fetch_meta(self.y, x) - if self.parent.meta_type == "segment": - - class SegmentMetaProxy: - """A proxy class for a single pixel's segment metadata that fetches metadata on demand. - - Parameters - ---------- - parent : RowProxy - The parent RowProxy instance. - y : int - The row index. - x : int - The column index. - """ - - def __init__(self, parent, y, x): - self.parent = parent - self.y = y - self.x = x - - def __getitem__(self, direction): - """ - Fetch metadata for the specified segment direction. - - Parameters - ---------- - direction : int - The direction of the segment ("approach" or "retract"). - - Returns - ------- - dict - The metadata for the specified segment direction. - """ - return self.parent.parent._fetch_meta(self.y, self.x, direction) - - return SegmentMetaProxy(self, self.y, x) - raise IndexError(f"Unknown metadata type '{self.parent.meta_type}'. Expected 'curve' or 'segment'.") - - return RowProxy(self, y) + return self.parent.parent._fetch_meta(self.y, self.x, direction) + + def __init__(self, parent, y): + """ + Initialize _RowProxy with parent LazyMetaProxy and row index. + + Parameters + ---------- + parent : LazyMetaProxy + The parent LazyMetaProxy instance. + y : int + The row index. + """ + self.parent = parent + self.y = y + + def __getitem__(self, x): + """ + Fetch metadata for column x in this row. + + Parameters + ---------- + x : int + The column index. + + Returns + ------- + dict or _SegmentMetaProxy + The metadata for the specified column, or a proxy for segment metadata. + """ + if self.parent.meta_type == "curve": + return self.parent._fetch_meta(self.y, x) + if self.parent.meta_type == "segment": + return self._SegmentMetaProxy(self, self.y, x) + raise IndexError(f"Unknown metadata type '{self.parent.meta_type}'. Expected 'curve' or 'segment'.") + + def __getitem__(self, y: int): + """ + Return a proxy object for the specified row that can be indexed to fetch metadata for each pixel in that row. + + This allows for lazy loading. + Nested proxy objects are used to allow for fetching segment metadata which requires both x and y indices + as well as the direction of the segment (approach or retract). + + Parameters + ---------- + y : int + The row index. + + Returns + ------- + _RowProxy + A proxy object for the specified row. + """ + return self._RowProxy(self, y) def _fetch_meta(self, y: int, x: int, direction: int | None = None): """ From 0f6dabfee2c86dc25ce2e355d30f226fa266d2ec Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Thu, 28 May 2026 10:01:43 +0100 Subject: [PATCH 46/54] docs: comments and improving variable naming --- AFMReader/general_loader.py | 4 ++-- AFMReader/jpk_qi.py | 21 +++++++++++---------- tests/test_jpk_qi.py | 4 ++-- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index 1877648..9e254d9 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -100,7 +100,7 @@ def load( # noqa: C901 logger.error(f"Loading h5-jpk file returned unexpected number of items: {len(h5_returned)}") elif self.suffix == ".jpk-qi-data": if "jpk_qi_loader" not in self.cached_data: - self.cached_data["jpk_qi_loader"] = jpk_qi.jpk_qi_loader( + self.cached_data["jpk_qi_loader"] = jpk_qi.JPKQILoader( filepath=self.filepath, channel=self.channel, **self.kwargs ) jpk_qi_returned = self.cached_data["jpk_qi_loader"].load(channel=self.channel, **self.kwargs) @@ -162,7 +162,7 @@ def get_available_channels(self): # noqa: C901 available_channels = h5_jpk.get_h5jpk_channels(self.filepath) elif self.suffix == ".jpk-qi-data": if "jpk_qi_loader" not in self.cached_data: - self.cached_data["jpk_qi_loader"] = jpk_qi.jpk_qi_loader(filepath=self.filepath, **self.kwargs) + self.cached_data["jpk_qi_loader"] = jpk_qi.JPKQILoader(filepath=self.filepath, **self.kwargs) available_channels = self.cached_data["jpk_qi_loader"].get_available_channels() elif self.suffix == ".topostats": available_channels = ["image", "image_original"] diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 56d5123..fe65c3d 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -362,7 +362,7 @@ def _get_channel_scaling(props, channel_index): return final_multiplier, final_offset, unit -class jpk_qi_loader: +class JPKQILoader: """ Class for readability and improving modularity in the load jpk qi data function. @@ -413,8 +413,9 @@ def __init__( # Open the ZIP archive once and keep it open for the duration of the loading process self.qi_archive = zipfile.ZipFile(self.filepath, "r") # pylint: disable=consider-using-with logger.info(f"Opened JPK QI archive at {self.filepath}") - self.namelist = self.qi_archive.namelist() - # Set path to the .jpk-qi-image file within the archive for later use + # Store the list of all paths in the archive to avoid having to call namelist() multiple times + self.list_of_all_paths = self.qi_archive.namelist() + # For holding the reference to where the actual .jqk-qi image is (not the metadata). self.path_to_image = None # Chunk size for H5 datasets @@ -466,7 +467,7 @@ def get_available_channels(self): """ # Look for the jpk-qi-image file in the archive if self.path_to_image is None: - for file_name in self.namelist: + for file_name in self.list_of_all_paths: if file_name.endswith(".jpk-qi-image"): self.path_to_image = file_name @@ -599,7 +600,7 @@ def extract_data_to_h5( Whether to include metadata in the loading process, by default True. """ logger.info( - f"Loading all curve data from JPK QI archive with {len(self.namelist)} files " + f"Loading all curve data from JPK QI archive with {len(self.list_of_all_paths)} files " f"{'' if include_metadata else 'not '}including metadata" ) progress_counter = 0 @@ -883,7 +884,7 @@ def get_image( # Search through the namelist to find the .jpk-qi-image file path_to_image = None - for file_name in self.namelist: + for file_name in self.list_of_all_paths: if file_name.endswith(".jpk-qi-image"): path_to_image = file_name if path_to_image is None: @@ -921,7 +922,7 @@ def save_lite_data(self): h5_channels = [self.channel] # Look for the jpk-qi-image file in the archive path_to_image = None - for file_name in self.namelist: + for file_name in self.list_of_all_paths: if file_name.endswith(".jpk-qi-image"): path_to_image = file_name break @@ -1283,7 +1284,7 @@ def parse_dimension_data(self): def extract_global_metadata(self): """Extract global metadata and populate top level metadata dictionary and segment channels list.""" # Load the metadata from the global properties file - if "header.properties" in self.namelist: + if "header.properties" in self.list_of_all_paths: with self.qi_archive.open("header.properties") as archive_meta_file: props = javaproperties.load(archive_meta_file) @@ -1294,7 +1295,7 @@ def extract_global_metadata(self): logger.error(f"File {self.filepath} does not contain essential metadata and cannot be loaded") # Load the metadata from the shared header - if "shared-data/header.properties" in self.namelist: + if "shared-data/header.properties" in self.list_of_all_paths: with self.qi_archive.open("shared-data/header.properties") as shared_data_file: shared_meta = javaproperties.load(shared_data_file) channel_idx = 0 @@ -1337,7 +1338,7 @@ def close(self): self.top_level_meta = {} self.failed_curves = set() self.points_for_channel_segment = {} - self.namelist = [] + self.list_of_all_paths = [] def _make_num_min_characters(num: int, min_chars: int = 3): diff --git a/tests/test_jpk_qi.py b/tests/test_jpk_qi.py index 5dc3f63..c86a067 100644 --- a/tests/test_jpk_qi.py +++ b/tests/test_jpk_qi.py @@ -127,7 +127,7 @@ def test_load_jpk_qi_data( # pylint: disable=too-many-arguments,too-many-positi result_image = np.ndarray result_pixel_to_nm_scaling = float file_path = RESOURCES / file_name - jpk_qi_loader = jpk_qi.jpk_qi_loader(file_path, channel) + jpk_qi_loader = jpk_qi.JPKQILoader(file_path, channel) result_image, result_pixel_to_nm_scaling, result_curve_data = jpk_qi_loader.load() # type: ignore assert result_pixel_to_nm_scaling == pytest.approx(pixel_to_nm_scaling) @@ -150,4 +150,4 @@ def test_load_jpk_qi_data( # pylint: disable=too-many-arguments,too-many-positi def test_load_jpk_data_file_not_found() -> None: """Ensure FileNotFound error is raised.""" with pytest.raises(FileNotFoundError): - jpk_qi.jpk_qi_loader("noexistant_file.jpk-qi-data", "TP") + jpk_qi.JPKQILoader("noexistant_file.jpk-qi-data", "TP") From 612d4e963ec214d6ace5e215943af16ed19e7224 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 1 Jun 2026 10:50:50 +0100 Subject: [PATCH 47/54] refactor: simplifying dataclass structure by making them work closer to a numpy ndarray and grouping into unified CurveDataset structure --- AFMReader/h5_jpk.py | 158 ++++++--------- AFMReader/jpk_qi.py | 332 +++++++++++++----------------- AFMReader/lazy_data_classes.py | 358 ++++++++++++++------------------- 3 files changed, 349 insertions(+), 499 deletions(-) diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index 21b8b8e..bfccbf4 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -13,9 +13,9 @@ from AFMReader.logging import logger from AFMReader.lazy_data_classes import ( - LazyMetaProxy, - LazyMetadata, - LazyQiData, + CurvesDataset, + CurvesMetadata, + CurvesVolume, ) logger.enable(__package__) @@ -294,41 +294,44 @@ def get_h5jpk_channels(file_path: Path | str): return list(_available_channels(f)) -class LazyH5QiData(LazyQiData): +class CurvesH5Volume(CurvesVolume): """ - A proxy class that fetches QI curve data from the HDF5 file on demand. + A CurvesVolume implementation for HDF5 curve data that provides lazy loading of curve data for each pixel. - It behaves like a 2D array of shape (shape_y, shape_x) where each element - is a dictionary containing the QI curve data for that pixel. + Note that the curve data in the HDF5 file is usually copied from another format for fast access. Parameters ---------- - qi_data_group : h5py.Group - The HDF5 group containing the QI curve data. + name : str + The name of the curve volume. shape_x : int The number of columns in the image. shape_y : int The number of rows in the image. + qi_data_group : h5py.Group + The HDF5 group containing the QI curve data. flip_image : bool, optional - Whether to flip the image vertically. Default is ``True``. + Whether to flip the image vertically. Default is True. """ - def __init__(self, qi_data_group: h5py.Group, shape_x: int, shape_y: int, flip_image: bool = True): + def __init__(self, name: str, shape_x: int, shape_y: int, qi_data_group: h5py.Group, flip_image: bool = True): """ - Initialize the LazyH5QiData proxy. + Initialize the CurvesH5Volume instance. Parameters ---------- - qi_data_group : h5py.Group - The HDF5 group containing the QI curve data. + name : str + The name of the curve volume. shape_x : int The number of columns in the image. shape_y : int The number of rows in the image. + qi_data_group : h5py.Group + The HDF5 group containing the QI curve data. flip_image : bool, optional - Whether to flip the image vertically. Default is ``True``. + Whether to flip the image vertically. Default is True. """ - super().__init__(shape_x, shape_y, flip_image) + super().__init__(name, shape_x, shape_y, flip_image) self.qi_data_group = qi_data_group def __iter__(self): # noqa: C901 @@ -369,7 +372,7 @@ def __iter__(self): # noqa: C901 curve_data[channel][segment] = segment_data[start_idx:end_idx] yield curve_data - def _fetch_curve(self, y: int, x: int): + def get_curve(self, y: int, x: int): """ Fetch the QI curve data for a specific pixel (x, y) on demand. @@ -428,111 +431,58 @@ def load_all_curves(self): return all_curves -class LazyH5Metadata(LazyMetadata): +class CurvesH5Metadata(CurvesMetadata): """ - A proxy class that fetches header.properties files on demand. + Metadata class for H5 JPK data that provides access to metadata on demand. Parameters ---------- qi_data_group : h5py.Group The HDF5 group containing the QI curve data. - top_level_meta : dict + toplevel : dict[str, Any] The top-level metadata dictionary. shape_x : int The number of columns in the image. shape_y : int The number of rows in the image. + channel_units : dict[str, str] + A dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is ``True``. """ + # pylint: disable=too-many-positional-arguments def __init__( - self, qi_data_group: h5py.Group, top_level_meta: dict, shape_x: int, shape_y: int, flip_image: bool = True + self, + qi_data_group: h5py.Group, + toplevel: dict[str, Any], + shape_x: int, + shape_y: int, + channel_units: dict[str, str], + flip_image: bool = True, ): """ - Initialize the LazyH5Metadata proxy. + Initialize the CurvesH5Metadata instance. Parameters ---------- qi_data_group : h5py.Group The HDF5 group containing the QI curve data. - top_level_meta : dict + toplevel : dict[str, Any] The top-level metadata dictionary. shape_x : int The number of columns in the image. shape_y : int The number of rows in the image. + channel_units : dict[str, str] + A dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is ``True``. """ - self.qi_data_group = qi_data_group - super().__init__(top_level_meta, shape_x, shape_y, flip_image) - - def __getitem__(self, key): - """ - Fetch metadata based on the key. - - Parameters - ---------- - key : str - The key to fetch metadata for. - - Returns - ------- - object - The fetched metadata, a lazy object that can be further queried. - """ - if key == "top_level": - return self.top_level - if key == "curves": - return LazyH5MetaProxy(self.qi_data_group, "curve", self.shape_x, self.shape_y, self.flip_image) - if key == "segments": - return LazyH5MetaProxy(self.qi_data_group, "segment", self.shape_x, self.shape_y, self.flip_image) - raise KeyError(key) - - -class LazyH5MetaProxy(LazyMetaProxy): - """ - A proxy class that fetches curve or segment metadata from the HDF5 file on demand. - - It behaves like a 2D array of shape (shape_y, shape_x) where each element - is a dictionary containing the requested metadata for that pixel. - - Parameters - ---------- - qi_data_group : h5py.Group - The HDF5 group containing the QI curve data. - meta_type : str - The type of metadata to fetch ("curve" or "segment"). - shape_x : int - The number of columns in the image. - shape_y : int - The number of rows in the image. - flip_image : bool, optional - Whether to flip the image vertically. Default is ``True``. - """ - - def __init__(self, qi_data_group: h5py.Group, meta_type: str, shape_x: int, shape_y: int, flip_image: bool = True): - """ - Initialize the LazyH5MetaProxy. - - Parameters - ---------- - qi_data_group : h5py.Group - The HDF5 group containing the QI curve data. - meta_type : str - The type of metadata to fetch ("curve" or "segment"). - shape_x : int - The number of columns in the image. - shape_y : int - The number of rows in the image. - flip_image : bool, optional - Whether to flip the image vertically. Default is ``True``. - """ - super().__init__(meta_type, shape_x, shape_y, flip_image) + super().__init__(toplevel, shape_x, shape_y, channel_units, flip_image) self.qi_data_group = qi_data_group - def _fetch_meta(self, y: int, x: int, direction: int | None = None): + def get_pixel_metadata(self, y: int, x: int, direction: int | None = None): """ Fetch metadata for a specific pixel (x, y) on demand. @@ -559,7 +509,7 @@ def _fetch_meta(self, y: int, x: int, direction: int | None = None): idx = (idx * 2) + direction meta_dict = {} for key in self.qi_data_group["Curve_Metadata"]: - if key.startswith(f"{self.meta_type}."): + if key.startswith(f"{'segment' if direction is not None else 'curve'}."): new_key = key.split(".", 1)[1] if isinstance(self.qi_data_group["Curve_Metadata"][key], h5py.Dataset): meta_dict[new_key] = ( @@ -574,7 +524,7 @@ def _fetch_meta(self, y: int, x: int, direction: int | None = None): def load_h5jpk( file_path: Path | str, channel: str, flip_image: bool = True, load_curves: bool = True -) -> tuple[np.ndarray, float, dict[str, float]] | tuple[np.ndarray, float, dict[str, float], Any]: +) -> tuple[np.ndarray, float, dict[str, float]] | tuple[np.ndarray, float, dict[str, float], CurvesDataset]: """ Load image from JPK Instruments .h5-jpk files. @@ -597,9 +547,9 @@ def load_h5jpk( Scaling factor converting pixels to nanometers. timestamps : dict[str, float] Dictionary mapping frame labels (e.g., "frame 0") to timestamp values in seconds. - curves_data : tuple(LazyH5QiData, dict, LazyH5Metadata), optional - Tuple containing lazy-loaded QI curve data, channel units, and metadata. - Returned only if load_curves is True and QI curve data is present in the file. + curves_data : CurvesDataset, optional + An instance of CurvesDataset containing lazy-loaded curve data, channel units, and metadata. + Returned only if load_curves is True and curve data is present in the file. Raises ------ @@ -677,10 +627,20 @@ def load_h5jpk( channels_units[key.split(".")[-1]] = value top_level_meta[key] = value - full_metadata = LazyH5Metadata(qi_data_group, top_level_meta, shape_x, shape_y, flip_image) - - all_curve_data = LazyH5QiData(qi_data_group, shape_x, shape_y, flip_image) - - return (image_stack, px2nm, timestamps, (all_curve_data, channels_units, full_metadata)) + curves_volume = CurvesH5Volume( + name="Trace", shape_x=shape_x, shape_y=shape_y, qi_data_group=qi_data_group, flip_image=flip_image + ) + curves_metadata = CurvesH5Metadata( + qi_data_group=qi_data_group, + toplevel=top_level_meta, + shape_x=shape_x, + shape_y=shape_y, + channel_units=channels_units, + flip_image=flip_image, + ) + + curves_data = CurvesDataset(volumes={"Trace": curves_volume}, metadata=curves_metadata) + + return (image_stack, px2nm, timestamps, curves_data) return (image_stack, px2nm, timestamps) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index fe65c3d..7cd9720 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -21,283 +21,232 @@ import h5py import psutil -from AFMReader.lazy_data_classes import LazyMetadata, LazyMetaProxy, LazyQiData +from AFMReader.lazy_data_classes import CurvesMetadata, CurvesVolume, CurvesDataset from AFMReader.logging import logger from AFMReader import jpk -class LazyJpkQiData(LazyQiData): +class CurvesJPKDataset(CurvesDataset): """ - A proxy class that behaves like a 2D list of shape (shape_y, shape_x) but fetches .dat file data on demand. + A dataset class for JPK QI data that holds the raw data as well as metadata. Parameters ---------- - filepath : str - Path to the .jpk file. - shape_x : int - Number of columns in the image. - shape_y : int - Number of rows in the image. - channel_scaling : dict - Dictionary containing scaling information for each channel. + volumes : dict[str, CurvesVolume] + A dictionary mapping curve names to CurvesVolume instances that + provide access to the curve data for each pixel. + metadata : CurvesMetadata + An instance of CurvesMetadata that provides access to the metadata + for each curve. archive : zipfile.ZipFile - The opened ZIP archive containing the .dat files. - flip_image : bool, optional - Whether to flip the image vertically. Default is ``True``. + The ZIP archive containing the JPK data. """ - def __init__(self, filepath, shape_x: int, shape_y: int, channel_scaling, archive, flip_image: bool = True): + def __init__(self, volumes: dict[str, CurvesVolume], metadata: CurvesMetadata, archive: zipfile.ZipFile): """ - Initialize the LazyJpkQiData instance. + Initialise CurvesJPKDataset. Parameters ---------- - filepath : str - Path to the .jpk file. - shape_x : int - Number of columns in the image. - shape_y : int - Number of rows in the image. - channel_scaling : dict - Dictionary containing scaling information for each channel. + volumes : dict[str, CurvesVolume] + A dictionary mapping curve names to CurvesVolume instances that + provide access to the curve data for each pixel. + metadata : CurvesMetadata + An instance of CurvesMetadata that provides access to the metadata + for each curve. archive : zipfile.ZipFile - The opened ZIP archive containing the .dat files. - flip_image : bool, optional - Whether to flip the image vertically. Default is ``True``. + The ZIP archive containing the JPK data. """ - super().__init__(shape_x, shape_y, flip_image) - self.filepath = filepath - self.channel_scaling = channel_scaling + super().__init__(volumes, metadata) self.archive = archive - def __iter__(self): - """Yield the curve data for each pixel in the image, iterating in row-major order (y first, then x).""" - for y in range(self.shape_y): - for x in range(self.shape_x): - yield self._fetch_curve(y, x) - - def _fetch_curve(self, y: int, x: int): - """ - Fetch the curve data for a specific pixel. - - Parameters - ---------- - y : int - Row index of the pixel. - x : int - Column index of the pixel. - - Returns - ------- - dict - Dictionary containing the curve data for the specified pixel. - """ - if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x: - raise IndexError(f"Curve index out of bounds: ({x}, {y})") - if self.flip_image: - y = self.shape_y - 1 - y - curve_num = y * self.shape_x + x - curve_data: dict[str, Any] = {} - - for chan_name, scale in self.channel_scaling.items(): - curve_data[chan_name] = {} - for direction in (0, 1): - dat_path = f"index/{curve_num}/segments/{direction}/channels/{chan_name}.dat" - try: - # Access the file directly without re-parsing the ZIP directory - with self.archive.open(dat_path) as f: - raw_array = np.frombuffer(f.read(), dtype=">i4") - curve_data[chan_name][f"Segment_{direction}"] = (raw_array * scale["multiplier"]) + scale[ - "offset" - ] - except KeyError: - pass # File doesn't exist for this segment - - return curve_data - - def load_all_curves(self): - """ - Eagerly loads all curve data into a 2D list structure matching the image dimensions. - - This can be used if the user wants to have all the curve data available at once, but it is not recommended - for large datasets as it will consume a lot of memory. In this case, it is not notably faster as the zip - structure means each curve is effectively loaded individually anyway - - Returns - ------- - list - A 2D list containing dictionaries with curve data for each pixel. - """ - all_curve_data = [[None for _ in range(self.shape_x)] for _ in range(self.shape_y)] - for y in range(self.shape_y): - for x in range(self.shape_x): - all_curve_data[y][x] = self._fetch_curve(y, x) - # TODO may be good to just return self here as not faster and lots of memory - # return self - return all_curve_data - def close(self): """Close the ZIP archive when done to free up resources.""" self.archive.close() -class LazyQiMetadata(LazyMetadata): +class CurvesJPKMetadata(CurvesMetadata): """ - A proxy class that fetches header.properties files on demand. - - It behaves like a 2D array of shape (shape_y, shape_x) where each element - is a dictionary containing the requested metadata for that pixel. + A metadata class for JPK QI data that provides lazy loading of pixel metadata. Parameters ---------- - filepath : str - Path to the .jpk file. - top_level_meta : dict - Dictionary containing the top-level metadata extracted from the header files. + toplevel : dict + A dictionary containing the top-level metadata for the dataset. archive : zipfile.ZipFile - The opened ZIP archive containing the JPK file contents. + The ZIP archive containing the JPK data. shape_x : int - The number of columns in the image. + Number of columns in the image. shape_y : int - The number of rows in the image. + Number of rows in the image. + channel_units : dict[str, str] + A dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is True. """ - def __init__(self, filepath, top_level_meta, archive, shape_x: int, shape_y: int, flip_image: bool = True): + def __init__( + self, + toplevel: dict, + archive: zipfile.ZipFile, + shape_x: int, + shape_y: int, + channel_units: dict[str, str], + flip_image: bool = True, + ): """ - Initialize the LazyQiMetadata instance. + Initialize the CurvesJPKMetadata instance. Parameters ---------- - filepath : str - Path to the .jpk file. - top_level_meta : dict - Dictionary containing the top-level metadata extracted from the header files. + toplevel : dict + A dictionary containing the top-level metadata for the dataset. archive : zipfile.ZipFile - The opened ZIP archive containing the JPK file contents. + The ZIP archive containing the JPK data. shape_x : int - The number of columns in the image. + Number of columns in the image. shape_y : int - The number of rows in the image. + Number of rows in the image. + channel_units : dict[str, str] + Dictionary mapping channel names to their units. flip_image : bool, optional - Whether to flip the image vertically. Default is True. + Whether to flip the image vertically. Default is ``True``. """ - self.filepath = filepath - # Expose top_level so the frontend can still do `raw_metadata["top_level"]` + super().__init__(toplevel, shape_x, shape_y, channel_units, flip_image) self.archive = archive - super().__init__(top_level_meta, shape_x, shape_y, flip_image) - def __getitem__(self, key): + def get_pixel_metadata(self, y: int, x: int, direction: int | None = None): """ - Fetch requested metadata based on key. - - If the key is 'top_level', it returns the top-level metadata. - If the key is 'curves' or 'segments', it returns a LazyQiMetaProxy - that can be used to fetch curve or segment metadata on demand. + Fetch the metadata for a specific pixel or direction. Parameters ---------- - key : str - The key to fetch metadata for. + y : int + Row index of the pixel. + x : int + Column index of the pixel. + direction : int, optional + The index of the direction to fetch metadata for. If None, returns metadata for the entire pixel. Returns ------- - dict or LazyQiMetaProxy - The requested metadata. + dict + The metadata for the specified pixel (or direction, if provided). """ - if key == "top_level": - return self.top_level - if key == "curves": - return LazyQiMetaProxy(self.filepath, "curve", self.archive, self.shape_x, self.shape_y, self.flip_image) - if key == "segments": - return LazyQiMetaProxy(self.filepath, "segment", self.archive, self.shape_x, self.shape_y, self.flip_image) - raise KeyError(key) + if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x: + raise IndexError(f"Curve index out of bounds: ({x}, {y})") + if self.flip_image: + y = self.shape_y - 1 - y + idx = (y * self.shape_x) + x + if direction is None: + path = f"index/{idx}/header.properties" + else: + path = f"index/{idx}/segments/{direction}/segment-header.properties" + + try: + with self.archive.open(path) as f: + meta_dict = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} + except KeyError: + meta_dict = {} + return meta_dict -class LazyQiMetaProxy(LazyMetaProxy): - """ - A proxy class to represent curve and segment metadata. - It behaves like a 2D list of shape (shape_y, shape_x) but fetches header.properties files on demand for - curves or segments. +class CurvesJPKVolume(CurvesVolume): + """ + A CurvesVolume implementation for JPK QI curve data that provides lazy loading of curve data for each pixel. Parameters ---------- - filepath : str - Path to the .jpk file. - meta_type : str - The type of metadata to fetch ('curve' or 'segment'). - archive : zipfile.ZipFile - The opened ZIP archive containing the JPK file contents. + name : str + The name of the curve volume. shape_x : int The number of columns in the image. shape_y : int The number of rows in the image. + archive : zipfile.ZipFile + The ZIP archive containing the JPK data. + channel_scaling : dict[str, dict[str, float]] + A dictionary mapping channel names to their scaling factors. flip_image : bool, optional Whether to flip the image vertically. Default is True. """ - def __init__(self, filepath, meta_type, archive, shape_x: int, shape_y: int, flip_image: bool = True): + def __init__( + self, + name: str, + shape_x: int, + shape_y: int, + archive: zipfile.ZipFile, + channel_scaling: dict[str, dict[str, float]], + flip_image: bool = True, + ): """ - Initialize the LazyQiMetaProxy instance. + Initialise CurvesJPKVolume. Parameters ---------- - filepath : str - Path to the .jpk file. - meta_type : str - The type of metadata to fetch ('curve' or 'segment'). - archive : zipfile.ZipFile - The opened ZIP archive containing the JPK file contents. + name : str + The name of the curve volume. shape_x : int The number of columns in the image. shape_y : int The number of rows in the image. + archive : zipfile.ZipFile + The ZIP archive containing the JPK data. + channel_scaling : dict[str, dict[str, float]] + A dictionary mapping channel names to their scaling factors. flip_image : bool, optional Whether to flip the image vertically. Default is True. """ - self.filepath = filepath + super().__init__(name, shape_x, shape_y, flip_image) self.archive = archive - super().__init__(meta_type, shape_x, shape_y, flip_image) + self.channel_scaling = channel_scaling - def _fetch_meta(self, y: int, x: int, direction: int | None = None): + def __iter__(self): + """Yield the curve data for each pixel in the image, iterating in row-major order (y first, then x).""" + for y in range(self.shape_y): + for x in range(self.shape_x): + yield self.get_curve(y, x) + + def get_curve(self, y: int, x: int): """ - Fetch the metadata for a specific curve or segment. + Fetch the curve data for a specific pixel. Parameters ---------- y : int - Row index of the curve or segment. + Row index of the pixel. x : int - Column index of the curve or segment. - direction : int, optional - The direction index for segment metadata. Required if meta_type is 'segment'. + Column index of the pixel. Returns ------- dict - The metadata dictionary for the specified curve or segment. + Dictionary containing the curve data for the specified pixel. """ if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x: raise IndexError(f"Curve index out of bounds: ({x}, {y})") if self.flip_image: y = self.shape_y - 1 - y - idx = (y * self.shape_x) + x - if self.meta_type == "curve": - path = f"index/{idx}/header.properties" - else: - if direction is None: - raise ValueError("Direction must be provided for segment metadata") - path = f"index/{idx}/segments/{direction}/segment-header.properties" + curve_num = y * self.shape_x + x + curve_data: dict[str, Any] = {} - try: - with self.archive.open(path) as f: - meta_dict = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()} - except KeyError: - meta_dict = {} + for chan_name, scale in self.channel_scaling.items(): + curve_data[chan_name] = {} + for direction in (0, 1): + dat_path = f"index/{curve_num}/segments/{direction}/channels/{chan_name}.dat" + try: + # Access the file directly without re-parsing the ZIP directory + with self.archive.open(dat_path) as f: + raw_array = np.frombuffer(f.read(), dtype=">i4") + curve_data[chan_name][f"Segment_{direction}"] = (raw_array * scale["multiplier"]) + scale[ + "offset" + ] + except KeyError: + pass # File doesn't exist for this segment - return meta_dict + return curve_data def _get_channel_scaling(props, channel_index): @@ -434,9 +383,9 @@ def __init__( # Just the top level metadata extracted from the header files self.top_level_meta: dict[str, Any] = {} # A lazy reference containing all metadata - self.full_metadata: LazyQiMetadata | None = None + self.full_metadata: CurvesJPKMetadata | None = None # A 2D list of curve data dictionaries - self.curve_data: Any = None + self.curves_volume: CurvesJPKVolume | None = None # A lookup for channel name to unit to be returned self.channels_units: dict[str, str] = {} # The list of channels for the segments with their scaling information extracted from the shared header @@ -528,22 +477,25 @@ def load( # Establish the lazy loading structures for curve data and metadata. Note how lazy structure is used even if # all the data has been accessed and saved to H5 to prevent excessive memory usage - self.full_metadata = LazyQiMetadata( - self.filepath, + self.full_metadata = CurvesJPKMetadata( self.top_level_meta, self.qi_archive, self.shape_x or 0, self.shape_y or 0, + channel_units=self.channels_units, flip_image=bool(self.flip_image), ) - self.curve_data = LazyJpkQiData( - self.filepath, - self.shape_x or 0, - self.shape_y or 0, - self.channel_scaling, - self.qi_archive, + self.curves_volume = CurvesJPKVolume( + name="Trace", + shape_x=self.shape_x or 0, + shape_y=self.shape_y or 0, + archive=self.qi_archive, + channel_scaling=self.channel_scaling, flip_image=bool(self.flip_image), ) + self.curves_dataset = CurvesJPKDataset( + volumes={"Trace": self.curves_volume}, metadata=self.full_metadata, archive=self.qi_archive + ) # Load the image self.image, _ = self.get_image() @@ -552,7 +504,7 @@ def load( if self.save_as_h5: self.save_lite_data() - return (self.image, self.px2nm, (self.curve_data, self.channels_units, self.full_metadata)) + return (self.image, self.px2nm, self.curves_dataset) def output_summary(self): """Output a summary of the loading process, including any failed curve loads and their details.""" diff --git a/AFMReader/lazy_data_classes.py b/AFMReader/lazy_data_classes.py index 8301556..9c70054 100644 --- a/AFMReader/lazy_data_classes.py +++ b/AFMReader/lazy_data_classes.py @@ -1,108 +1,94 @@ -"""Utility classes for lazy loading of QI curve data and metadata from JPK files.""" +""" +Data classes for lazy loading of curve data and metadata from files. + +These classes provide a consistent interface for accessing curve data and metadata in a +lazy manner (i.e. loading data on demand rather than all at once) across different file +formats. This is necessary for handling large datasets with massive memory consumption. +""" # pylint: disable=too-few-public-methods,fixme -class LazyQiData: +class CurvesMetadata: """ - A proxy class that fetches QI curve data on demand. + A class representing the metadata for a dataset of curves, providing lazy loaded access to pixel metadata. - It behaves like a 2D array of shape (shape_y, shape_x) where each element is a dictionary - containing the QI curve data for that pixel. + This is a parent class that should be subclassed for specific file formats to implement + the get_pixel_metadata method, which defines how the metadata is retrieved from the + underlying data source. Parameters ---------- + toplevel : dict + A dictionary containing the top-level metadata for the dataset. shape_x : int The number of columns in the image. shape_y : int The number of rows in the image. + channel_units : dict[str, str] + A dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is True. """ - def __init__(self, shape_x: int, shape_y: int, flip_image: bool = True): + def __init__( + self, toplevel: dict, shape_x: int, shape_y: int, channel_units: dict[str, str], flip_image: bool = True + ): """ - Initialize the LazyQiData instance. + Initialise CurvesMetadata. Parameters ---------- + toplevel : dict + A dictionary containing the top-level metadata for the dataset. shape_x : int The number of columns in the image. shape_y : int The number of rows in the image. + channel_units : dict[str, str] + A dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is True. """ + self.toplevel = toplevel self.shape_x = shape_x self.shape_y = shape_y - self.dims = (shape_y, shape_x) + self.channel_units = channel_units self.flip_image = flip_image - class _RowProxy: - """ - A proxy class for a single row of the QI data that fetches curve data on demand. - - Parameters - ---------- - parent : LazyQiData - The parent LazyQiData instance. - y : int - The row index. - """ - - def __init__(self, parent, y): - """ - Initialize _RowProxy with parent LazyQiData and row index. - - Parameters - ---------- - parent : LazyQiData - The parent LazyQiData instance. - y : int - The row index. - """ - self.parent = parent - self.y = y - - def __getitem__(self, x: int): - """ - Fetch curve data for column x in this row. - - Parameters - ---------- - x : int - The column index. - - Returns - ------- - dict - The QI curve data for the specified pixel. - """ - return self.parent._fetch_curve(self.y, x) - - def __getitem__(self, y: int): + def __getitem__(self, keys): """ - Return a proxy object for the specified row that can be indexed to fetch curve data for each pixel in that row. + Fetch the metadata for a specific pixel or pixel direction. - This allows for lazy loading. + For example, metadata[y, x] would return the metadata for the pixel at row y and column x, while + metadata[y, x, 0] would return the metadata for specifically the first direction of that pixel + (usually the approach). Parameters ---------- - y : int - The row index. + keys : tuple + A tuple of (y, x) or (y, x, direction) representing the indices. Returns ------- - _RowProxy - A proxy object for the specified row. + dict + The metadata for the specified pixel or direction. """ - return self._RowProxy(self, y) - - def _fetch_curve(self, y: int, x: int): + if isinstance(keys, tuple) and len(keys) == 2: + y, x = keys + return self.get_pixel_metadata(y, x) + if isinstance(keys, tuple) and len(keys) == 3: + y, x, direction = keys + return self.get_pixel_metadata(y, x, direction) + raise IndexError( + f"Invalid indexing. Expected (y, x) or (y, x, direction) for pixel metadata indexing. Got {keys}." + ) + + def get_pixel_metadata(self, y: int, x: int, direction: int | None = None): """ - Fetch the QI curve data for a specific pixel. + Fetch the metadata for a specific pixel, optionally for a specific direction. - Should be implemented by subclasses to define how the curve data is retrieved from the underlying data source. + Should be implemented by subclasses to define how the metadata is retrieved from the underlying data source. Parameters ---------- @@ -110,63 +96,27 @@ def _fetch_curve(self, y: int, x: int): Row index of the pixel. x : int Column index of the pixel. + direction : int, optional + The index of the direction to fetch metadata for. If None, returns metadata for the entire pixel. Returns ------- dict - The QI curve data for the specified pixel. - """ - raise NotImplementedError("This method should be implemented by subclasses to fetch curve data on demand.") - - -class LazyMetadata: - """ - A proxy class that fetches metadata on demand. Superclass for metadata proxy classes. - - Parameters - ---------- - top_level_meta : dict - The top-level metadata dictionary. - shape_x : int - The number of columns in the image. - shape_y : int - The number of rows in the image. - flip_image : bool, optional - Whether to flip the image vertically. Default is True. - """ - - def __init__(self, top_level_meta: dict, shape_x: int, shape_y: int, flip_image: bool = True): - """ - Initialize the LazyMetadata instance. - - Parameters - ---------- - top_level_meta : dict - The top-level metadata dictionary. - shape_x : int - The number of columns in the image. - shape_y : int - The number of rows in the image. - flip_image : bool, optional - Whether to flip the image vertically. Default is True. + The metadata for the specified pixel (or direction, if provided). """ - self.top_level = top_level_meta - self.shape_x = shape_x - self.shape_y = shape_y - self.flip_image = flip_image + raise NotImplementedError("This method should be implemented by subclasses to fetch pixel metadata on demand.") -class LazyMetaProxy: +class CurvesVolume: """ - A proxy class that fetches metadata on demand. + A class representing a 2D map or volume of curves, providing lazy loaded access to curve data. - It behaves like a 2D array of shape (shape_y, shape_x) where each element is a dictionary - containing the metadata for that pixel. + An individual curve can be accessed using volume[y, x], which will load the curve data for that pixel on demand. Parameters ---------- - meta_type : str - The type of metadata to fetch ("curve" or "segment"). + name : str + The name of the curve volume. shape_x : int The number of columns in the image. shape_y : int @@ -175,14 +125,14 @@ class LazyMetaProxy: Whether to flip the image vertically. Default is True. """ - def __init__(self, meta_type: str, shape_x: int, shape_y: int, flip_image: bool = True): + def __init__(self, name: str, shape_x: int, shape_y: int, flip_image: bool = True): """ - Initialize the LazyMetaProxy instance. + Initialise CurvesVolume. Parameters ---------- - meta_type : str - The type of metadata to fetch ("curve" or "segment"). + name : str + The name of the curve volume. shape_x : int The number of columns in the image. shape_y : int @@ -190,131 +140,119 @@ def __init__(self, meta_type: str, shape_x: int, shape_y: int, flip_image: bool flip_image : bool, optional Whether to flip the image vertically. Default is True. """ - self.meta_type = meta_type self.shape_x = shape_x self.shape_y = shape_y + self.dims = (shape_y, shape_x) self.flip_image = flip_image + self.name = name - class _RowProxy: + def __getitem__(self, keys): """ - A proxy class for a single row of the metadata that fetches metadata on demand. + Allow numpy style indexing to fetch curve data for a specific pixel. Parameters ---------- - parent : LazyMetaProxy - The parent LazyMetaProxy instance. - y : int - The row index. - """ + keys : tuple + A tuple of (y, x) representing the row and column indices of the pixel. - class _SegmentMetaProxy: - """ - A proxy class for a single pixel's segment metadata that fetches metadata on demand. - - Parameters - ---------- - parent : _RowProxy - The parent _RowProxy instance. - y : int - The row index. - x : int - The column index. - """ - - def __init__(self, parent, y, x): - self.parent = parent - self.y = y - self.x = x - - def __getitem__(self, direction): - """ - Fetch metadata for the specified segment direction. - - Parameters - ---------- - direction : int - The direction of the segment ("approach" or "retract"). - - Returns - ------- - dict - The metadata for the specified segment direction. - """ - return self.parent.parent._fetch_meta(self.y, self.x, direction) - - def __init__(self, parent, y): - """ - Initialize _RowProxy with parent LazyMetaProxy and row index. - - Parameters - ---------- - parent : LazyMetaProxy - The parent LazyMetaProxy instance. - y : int - The row index. - """ - self.parent = parent - self.y = y - - def __getitem__(self, x): - """ - Fetch metadata for column x in this row. - - Parameters - ---------- - x : int - The column index. - - Returns - ------- - dict or _SegmentMetaProxy - The metadata for the specified column, or a proxy for segment metadata. - """ - if self.parent.meta_type == "curve": - return self.parent._fetch_meta(self.y, x) - if self.parent.meta_type == "segment": - return self._SegmentMetaProxy(self, self.y, x) - raise IndexError(f"Unknown metadata type '{self.parent.meta_type}'. Expected 'curve' or 'segment'.") - - def __getitem__(self, y: int): + Returns + ------- + dict + The QI curve data for the specified pixel. + """ + if not isinstance(keys, tuple) or len(keys) != 2: + raise IndexError(f"Invalid indexing. Expected (y, x) for pixel indexing. Got {keys}.") + y, x = keys + if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x: + raise IndexError(f"Pixel index ({y}, {x}) is out of bounds for image of shape {self.dims}.") + return self.get_curve(y, x) + + def get_curve(self, y: int, x: int): """ - Return a proxy object for the specified row that can be indexed to fetch metadata for each pixel in that row. + Fetch the QI curve data for a specific pixel. - This allows for lazy loading. - Nested proxy objects are used to allow for fetching segment metadata which requires both x and y indices - as well as the direction of the segment (approach or retract). + Should be implemented by subclasses to define how the curve data is retrieved from the underlying data source. Parameters ---------- y : int - The row index. + Row index of the pixel. + x : int + Column index of the pixel. Returns ------- - _RowProxy - A proxy object for the specified row. + dict + The QI curve data for the specified pixel. """ - return self._RowProxy(self, y) + raise NotImplementedError("This method should be implemented by subclasses to fetch curve data on demand.") + + +class CurvesDataset: + """ + A dataset containing multiple curve volumes and associated metadata. + + Parameters + ---------- + volumes : dict[str, CurvesVolume] + A dictionary mapping curve names to CurvesVolume instances that + provide access to the curve data for each pixel. + metadata : CurvesMetadata + An instance of CurvesMetadata that provides access to the metadata + for each curve. + default_volume_name : str, optional + The name of the default volume to use when accessing curve data. + If None, the first volume in the dictionary is used. + """ + + def __init__( + self, volumes: dict[str, CurvesVolume], metadata: CurvesMetadata, default_volume_name: str | None = None + ): + """ + Initialise CurvesDataset. - def _fetch_meta(self, y: int, x: int, direction: int | None = None): + Parameters + ---------- + volumes : dict[str, CurvesVolume] + A dictionary mapping curve names to CurvesVolume instances that + provide access to the curve data for each pixel. + metadata : CurvesMetadata + An instance of CurvesMetadata that provides access to the metadata + for each curve. + default_volume_name : str | None, optional + The name of the default volume to use when accessing curve data. + If None, the first volume in the dictionary is used. """ - Fetch the metadata for a specific pixel. + self.volumes = volumes + self.metadata = metadata + self.default_volume_name = default_volume_name or next( + iter(volumes) + ) # Use the first volume as default if not specified - Should be implemented by subclasses to define how the metadata is retrieved - from the underlying data source. + def add_volume(self, name: str, volume: CurvesVolume, default: bool = False): + """ + Add a CurvesVolume to the dataset. Parameters ---------- - y : int - The row index of the pixel. - x : int - The column index of the pixel. - direction : int, optional - The direction of the segment ("approach" or "retract"). Only used for segment metadata. + name : str + The name of the curve to add. + volume : CurvesVolume + The CurvesVolume instance containing the curve data for each pixel. + default : bool, optional + Whether to set this volume as the default volume. Default is False. + """ + self.volumes[name] = volume + if default: + self.default_volume_name = name + + def get_default_volume(self) -> CurvesVolume: + """ + Get the default CurvesVolume for this dataset. Returns ------- - dict - The metadata for the specified pixel. + CurvesVolume + The default CurvesVolume instance for this dataset. """ - raise NotImplementedError("This method should be implemented by subclasses to fetch metadata on demand.") + return self.volumes[self.default_volume_name] From b104b3b7b4602028b3b4f8b9f11ef0dc4b1bb5e2 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 1 Jun 2026 11:14:37 +0100 Subject: [PATCH 48/54] tests: updating curve loading tests (currently being skipped) for refactored dataclasses --- AFMReader/jpk_qi.py | 4 ++-- tests/test_h5jpk.py | 5 ++--- tests/test_jpk_qi.py | 5 ++--- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 7cd9720..962982b 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -433,7 +433,7 @@ def load( config_path: Path | str | None = None, flip_image: bool | None = True, save_as_h5: bool | None = None, - ) -> tuple[np.ndarray, float, Any] | tuple[np.ndarray, float]: + ) -> tuple[np.ndarray, float, CurvesJPKDataset]: """ Load the .jpk-qi-data file. @@ -451,7 +451,7 @@ def load( Returns ------- tuple - A tuple containing image data, scaling factor, and optionally curve data. + A tuple containing image data, scaling factor, and curve data. """ # Update instance attributes based on provided parameters self.channel = channel if channel else self.channel diff --git a/tests/test_h5jpk.py b/tests/test_h5jpk.py index 955f876..3ba82ac 100644 --- a/tests/test_h5jpk.py +++ b/tests/test_h5jpk.py @@ -199,9 +199,8 @@ def test_load_h5jpk_curves( curve_targets : dict[str, tuple[int, float]] A dictionary mapping curve channels to their expected size and sum, used for validating the loaded curve data. """ - _, _, _, result_curve_data = h5_jpk.load_h5jpk(RESOURCES / file_name, channel, flip_image) # type: ignore[misc] - all_curves, _, _ = result_curve_data - curve_at_coords = all_curves[curve_coords[0]][curve_coords[1]] + _, _, _, curve_dataset = h5_jpk.load_h5jpk(RESOURCES / file_name, channel, flip_image) # type: ignore[misc] + curve_at_coords = curve_dataset.get_default_volume()[curve_coords[0], curve_coords[1]] for curve_channel, (expected_size, expected_sum) in curve_targets.items(): curve = curve_at_coords[curve_channel][curve_direction] assert curve.shape == (expected_size,) diff --git a/tests/test_jpk_qi.py b/tests/test_jpk_qi.py index c86a067..f06c918 100644 --- a/tests/test_jpk_qi.py +++ b/tests/test_jpk_qi.py @@ -128,7 +128,7 @@ def test_load_jpk_qi_data( # pylint: disable=too-many-arguments,too-many-positi result_pixel_to_nm_scaling = float file_path = RESOURCES / file_name jpk_qi_loader = jpk_qi.JPKQILoader(file_path, channel) - result_image, result_pixel_to_nm_scaling, result_curve_data = jpk_qi_loader.load() # type: ignore + result_image, result_pixel_to_nm_scaling, curve_dataset = jpk_qi_loader.load() # type: ignore assert result_pixel_to_nm_scaling == pytest.approx(pixel_to_nm_scaling) assert isinstance(result_image, np.ndarray) @@ -137,8 +137,7 @@ def test_load_jpk_qi_data( # pylint: disable=too-many-arguments,too-many-positi assert result_image.sum() == pytest.approx(image_sum) # Test curve data for all targets - all_curves, _, _ = result_curve_data - curve_at_coords = all_curves[curve_coords[0]][curve_coords[1]] + curve_at_coords = curve_dataset.get_default_volume()[curve_coords[0], curve_coords[1]] for curve_channel, (expected_size, expected_sum) in curve_targets.items(): curve = curve_at_coords[curve_channel][curve_direction] assert curve.shape == (expected_size,) From 2394556957344797b182d215edb661f754ded619 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 1 Jun 2026 15:52:16 +0100 Subject: [PATCH 49/54] chore: renaming get_pixel_metadata to get_point_metadata --- AFMReader/h5_jpk.py | 2 +- AFMReader/jpk_qi.py | 2 +- AFMReader/lazy_data_classes.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index bfccbf4..98352a5 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -482,7 +482,7 @@ def __init__( super().__init__(toplevel, shape_x, shape_y, channel_units, flip_image) self.qi_data_group = qi_data_group - def get_pixel_metadata(self, y: int, x: int, direction: int | None = None): + def get_point_metadata(self, y: int, x: int, direction: int | None = None): """ Fetch metadata for a specific pixel (x, y) on demand. diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 962982b..b5e31d8 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -115,7 +115,7 @@ def __init__( super().__init__(toplevel, shape_x, shape_y, channel_units, flip_image) self.archive = archive - def get_pixel_metadata(self, y: int, x: int, direction: int | None = None): + def get_point_metadata(self, y: int, x: int, direction: int | None = None): """ Fetch the metadata for a specific pixel or direction. diff --git a/AFMReader/lazy_data_classes.py b/AFMReader/lazy_data_classes.py index 9c70054..4af6e28 100644 --- a/AFMReader/lazy_data_classes.py +++ b/AFMReader/lazy_data_classes.py @@ -76,17 +76,17 @@ def __getitem__(self, keys): """ if isinstance(keys, tuple) and len(keys) == 2: y, x = keys - return self.get_pixel_metadata(y, x) + return self.get_point_metadata(y, x) if isinstance(keys, tuple) and len(keys) == 3: y, x, direction = keys - return self.get_pixel_metadata(y, x, direction) + return self.get_point_metadata(y, x, direction) raise IndexError( f"Invalid indexing. Expected (y, x) or (y, x, direction) for pixel metadata indexing. Got {keys}." ) - def get_pixel_metadata(self, y: int, x: int, direction: int | None = None): + def get_point_metadata(self, y: int, x: int, direction: int | None = None): """ - Fetch the metadata for a specific pixel, optionally for a specific direction. + Fetch the metadata for a specific pixel/ point, optionally for a specific direction. Should be implemented by subclasses to define how the metadata is retrieved from the underlying data source. From 699c55b65459c24f88b469fe6d2e1bc9f014ec15 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Mon, 1 Jun 2026 17:48:48 +0100 Subject: [PATCH 50/54] feat: moving channel_units to be stored per volume rather than in metadata to align with potential differences in channels between volumes --- AFMReader/h5_jpk.py | 39 ++++++++++++++++++++++++---------- AFMReader/jpk_qi.py | 23 +++++++++++++------- AFMReader/lazy_data_classes.py | 24 ++++++++++----------- 3 files changed, 55 insertions(+), 31 deletions(-) diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index 98352a5..1f8c477 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -20,7 +20,7 @@ logger.enable(__package__) -# pylint: disable=too-few-public-methods,too-many-locals,fixme +# pylint: disable=too-few-public-methods,too-many-locals,fixme,too-many-positional-arguments def _parse_channel_name(channel: str) -> tuple[str, str]: @@ -310,11 +310,21 @@ class CurvesH5Volume(CurvesVolume): The number of rows in the image. qi_data_group : h5py.Group The HDF5 group containing the QI curve data. + channel_units : dict[str, str] + A dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is True. """ - def __init__(self, name: str, shape_x: int, shape_y: int, qi_data_group: h5py.Group, flip_image: bool = True): + def __init__( + self, + name: str, + shape_x: int, + shape_y: int, + qi_data_group: h5py.Group, + channel_units: dict[str, str], + flip_image: bool = True, + ): """ Initialize the CurvesH5Volume instance. @@ -328,10 +338,18 @@ def __init__(self, name: str, shape_x: int, shape_y: int, qi_data_group: h5py.Gr The number of rows in the image. qi_data_group : h5py.Group The HDF5 group containing the QI curve data. + channel_units : dict[str, str] + A dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is True. """ - super().__init__(name, shape_x, shape_y, flip_image) + super().__init__( + name=name, + shape_x=shape_x, + shape_y=shape_y, + channel_units=channel_units, + flip_image=flip_image, + ) self.qi_data_group = qi_data_group def __iter__(self): # noqa: C901 @@ -445,8 +463,6 @@ class CurvesH5Metadata(CurvesMetadata): The number of columns in the image. shape_y : int The number of rows in the image. - channel_units : dict[str, str] - A dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is ``True``. """ @@ -458,7 +474,6 @@ def __init__( toplevel: dict[str, Any], shape_x: int, shape_y: int, - channel_units: dict[str, str], flip_image: bool = True, ): """ @@ -474,12 +489,10 @@ def __init__( The number of columns in the image. shape_y : int The number of rows in the image. - channel_units : dict[str, str] - A dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is ``True``. """ - super().__init__(toplevel, shape_x, shape_y, channel_units, flip_image) + super().__init__(toplevel, shape_x, shape_y, flip_image) self.qi_data_group = qi_data_group def get_point_metadata(self, y: int, x: int, direction: int | None = None): @@ -628,14 +641,18 @@ def load_h5jpk( top_level_meta[key] = value curves_volume = CurvesH5Volume( - name="Trace", shape_x=shape_x, shape_y=shape_y, qi_data_group=qi_data_group, flip_image=flip_image + name="Trace", + shape_x=shape_x, + shape_y=shape_y, + qi_data_group=qi_data_group, + channel_units=channels_units, + flip_image=flip_image, ) curves_metadata = CurvesH5Metadata( qi_data_group=qi_data_group, toplevel=top_level_meta, shape_x=shape_x, shape_y=shape_y, - channel_units=channels_units, flip_image=flip_image, ) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index b5e31d8..ab2b3f7 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -7,6 +7,7 @@ # pylint: disable=too-many-lines,too-many-positional-arguments,too-few-public-methods,too-many-instance-attributes # pylint: disable=too-many-locals,too-many-branches,protected-access,attribute-defined-outside-init,fixme +# pylint: disable=too-many-arguments import os import io @@ -79,8 +80,6 @@ class CurvesJPKMetadata(CurvesMetadata): Number of columns in the image. shape_y : int Number of rows in the image. - channel_units : dict[str, str] - A dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is True. """ @@ -91,7 +90,6 @@ def __init__( archive: zipfile.ZipFile, shape_x: int, shape_y: int, - channel_units: dict[str, str], flip_image: bool = True, ): """ @@ -107,12 +105,10 @@ def __init__( Number of columns in the image. shape_y : int Number of rows in the image. - channel_units : dict[str, str] - Dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is ``True``. """ - super().__init__(toplevel, shape_x, shape_y, channel_units, flip_image) + super().__init__(toplevel, shape_x, shape_y, flip_image) self.archive = archive def get_point_metadata(self, y: int, x: int, direction: int | None = None): @@ -168,6 +164,8 @@ class CurvesJPKVolume(CurvesVolume): The ZIP archive containing the JPK data. channel_scaling : dict[str, dict[str, float]] A dictionary mapping channel names to their scaling factors. + channel_units : dict[str, str] + A dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is True. """ @@ -179,6 +177,7 @@ def __init__( shape_y: int, archive: zipfile.ZipFile, channel_scaling: dict[str, dict[str, float]], + channel_units: dict[str, str], flip_image: bool = True, ): """ @@ -196,10 +195,18 @@ def __init__( The ZIP archive containing the JPK data. channel_scaling : dict[str, dict[str, float]] A dictionary mapping channel names to their scaling factors. + channel_units : dict[str, str] + A dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is True. """ - super().__init__(name, shape_x, shape_y, flip_image) + super().__init__( + name=name, + shape_x=shape_x, + shape_y=shape_y, + channel_units=channel_units, + flip_image=flip_image, + ) self.archive = archive self.channel_scaling = channel_scaling @@ -482,7 +489,6 @@ def load( self.qi_archive, self.shape_x or 0, self.shape_y or 0, - channel_units=self.channels_units, flip_image=bool(self.flip_image), ) self.curves_volume = CurvesJPKVolume( @@ -491,6 +497,7 @@ def load( shape_y=self.shape_y or 0, archive=self.qi_archive, channel_scaling=self.channel_scaling, + channel_units=self.channels_units, flip_image=bool(self.flip_image), ) self.curves_dataset = CurvesJPKDataset( diff --git a/AFMReader/lazy_data_classes.py b/AFMReader/lazy_data_classes.py index 4af6e28..22f01e2 100644 --- a/AFMReader/lazy_data_classes.py +++ b/AFMReader/lazy_data_classes.py @@ -25,15 +25,11 @@ class CurvesMetadata: The number of columns in the image. shape_y : int The number of rows in the image. - channel_units : dict[str, str] - A dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is True. """ - def __init__( - self, toplevel: dict, shape_x: int, shape_y: int, channel_units: dict[str, str], flip_image: bool = True - ): + def __init__(self, toplevel: dict, shape_x: int, shape_y: int, flip_image: bool = True): """ Initialise CurvesMetadata. @@ -45,15 +41,12 @@ def __init__( The number of columns in the image. shape_y : int The number of rows in the image. - channel_units : dict[str, str] - A dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is True. """ self.toplevel = toplevel self.shape_x = shape_x self.shape_y = shape_y - self.channel_units = channel_units self.flip_image = flip_image def __getitem__(self, keys): @@ -81,14 +74,16 @@ def __getitem__(self, keys): y, x, direction = keys return self.get_point_metadata(y, x, direction) raise IndexError( - f"Invalid indexing. Expected (y, x) or (y, x, direction) for pixel metadata indexing. Got {keys}." + f"Invalid indexing. Expected (y, x) or (y, x, direction) for point metadata indexing. Got {keys}." ) + # pylint: disable=unused-argument def get_point_metadata(self, y: int, x: int, direction: int | None = None): """ Fetch the metadata for a specific pixel/ point, optionally for a specific direction. - Should be implemented by subclasses to define how the metadata is retrieved from the underlying data source. + Should be implemented by subclasses if there exists per point metadata to define how the metadata is retrieved + from the underlying data source. If there is no per point metadata, this can simply return an empty dict. Parameters ---------- @@ -104,7 +99,7 @@ def get_point_metadata(self, y: int, x: int, direction: int | None = None): dict The metadata for the specified pixel (or direction, if provided). """ - raise NotImplementedError("This method should be implemented by subclasses to fetch pixel metadata on demand.") + return {} class CurvesVolume: @@ -121,11 +116,13 @@ class CurvesVolume: The number of columns in the image. shape_y : int The number of rows in the image. + channel_units : dict[str, str] + A dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is True. """ - def __init__(self, name: str, shape_x: int, shape_y: int, flip_image: bool = True): + def __init__(self, name: str, shape_x: int, shape_y: int, channel_units: dict[str, str], flip_image: bool = True): """ Initialise CurvesVolume. @@ -137,6 +134,8 @@ def __init__(self, name: str, shape_x: int, shape_y: int, flip_image: bool = Tru The number of columns in the image. shape_y : int The number of rows in the image. + channel_units : dict[str, str] + A dictionary mapping channel names to their units. flip_image : bool, optional Whether to flip the image vertically. Default is True. """ @@ -145,6 +144,7 @@ def __init__(self, name: str, shape_x: int, shape_y: int, flip_image: bool = Tru self.dims = (shape_y, shape_x) self.flip_image = flip_image self.name = name + self.channel_units = channel_units def __getitem__(self, keys): """ From 773daca6b0d77936cc20a80ee256f95262aab878 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Fri, 5 Jun 2026 15:27:59 +0100 Subject: [PATCH 51/54] refactor: using a dataclass to represent all return data consistently --- AFMReader/asd.py | 26 +++---- .../{lazy_data_classes.py => data_classes.py} | 68 +++++++++++++++++++ AFMReader/general_loader.py | 64 ++++++----------- AFMReader/gwy.py | 13 ++-- AFMReader/h5_jpk.py | 30 ++++---- AFMReader/ibw.py | 13 ++-- AFMReader/jpk.py | 15 ++-- AFMReader/jpk_qi.py | 54 +++++++-------- AFMReader/raw_bin.py | 12 ++-- AFMReader/spm.py | 13 ++-- AFMReader/stp.py | 11 ++- AFMReader/top.py | 11 ++- AFMReader/topostats.py | 29 ++++++-- tests/test_asd.py | 12 ++-- tests/test_general_loader.py | 8 +-- tests/test_gwy.py | 12 ++-- tests/test_h5jpk.py | 28 ++++---- tests/test_ibw.py | 15 ++-- tests/test_jpk.py | 14 ++-- tests/test_jpk_qi.py | 18 ++--- tests/test_spm.py | 15 ++-- tests/test_stp.py | 12 ++-- tests/test_top.py | 12 ++-- tests/test_topostats.py | 20 +++--- 24 files changed, 287 insertions(+), 238 deletions(-) rename AFMReader/{lazy_data_classes.py => data_classes.py} (82%) diff --git a/AFMReader/asd.py b/AFMReader/asd.py index 4187e12..c0da406 100644 --- a/AFMReader/asd.py +++ b/AFMReader/asd.py @@ -10,6 +10,7 @@ import numpy.typing as npt from matplotlib import animation +from AFMReader.data_classes import AFMLoad from AFMReader.io import ( read_ascii, read_bool, @@ -182,7 +183,7 @@ def calculate_scaling_factor( raise ValueError(f"channel {channel} not known for .asd file type.") -def load_asd(file_path: str | Path, channel: str): +def load_asd(file_path: str | Path, channel: str) -> AFMLoad: """ Load a .asd file. @@ -196,17 +197,16 @@ def load_asd(file_path: str | Path, channel: str): Returns ------- - npt.NDArray - The .asd file frames data as a numpy 3D array N x W x H - (Number of frames x Width of each frame x height of each frame). - float - The number of nanometres per pixel for the .asd file. (AKA the resolution). - Enables converting between pixels and nanometres when working with the data, in order to use real-world length - scales. - dict - Metadata for the .asd file. The number of entries is too long to list here, and changes based on the file - version please either look into the `read_header_file_version_x` functions or print the keys too see what - metadata is available. + AFMLoad + An AFMLoad object containing: + - image : npt.NDArray + Shape (Number of frames x Width of each frame x height of each frame). + - px2nm : float + The number of nanometres per pixel for the .asd file. + - metadata : dict + Metadata for the .asd file. The number of entries is too long to list here, and changes based on the file + version please either look into the `read_header_file_version_x` functions or print the keys too see what + metadata is available. """ # Ensure the file path is a Path object file_path = Path(file_path) @@ -285,7 +285,7 @@ def load_asd(file_path: str | Path, channel: str): frames = np.array(frames) logger.info(f"[{filename}] : Extracted image.") - return frames, pixel_to_nanometre_scaling_factor, header_dict + return AFMLoad(image=frames, px2nm=pixel_to_nanometre_scaling_factor, metadata=header_dict) def get_asd_channels(file_path: Path): diff --git a/AFMReader/lazy_data_classes.py b/AFMReader/data_classes.py similarity index 82% rename from AFMReader/lazy_data_classes.py rename to AFMReader/data_classes.py index 22f01e2..7e4f2d5 100644 --- a/AFMReader/lazy_data_classes.py +++ b/AFMReader/data_classes.py @@ -6,6 +6,8 @@ formats. This is necessary for handling large datasets with massive memory consumption. """ +import numpy as np + # pylint: disable=too-few-public-methods,fixme @@ -146,6 +148,17 @@ def __init__(self, name: str, shape_x: int, shape_y: int, channel_units: dict[st self.name = name self.channel_units = channel_units + def __len__(self): + """ + Return the total number of pixels in the image. + + Returns + ------- + int + The total number of pixels in the image. + """ + return self.shape_x * self.shape_y + def __getitem__(self, keys): """ Allow numpy style indexing to fetch curve data for a specific pixel. @@ -256,3 +269,58 @@ def get_default_volume(self) -> CurvesVolume: The default CurvesVolume instance for this dataset. """ return self.volumes[self.default_volume_name] + + +class AFMLoad: + """ + A class representing the loaded AFM data, including the image and scaling factors. + + Parameters + ---------- + image : np.ndarray + The image data. + px2nm : float + The pixel to nanometer scaling factor. + timestamps : dict | None, optional + Timestamps associated with the data. Default is None. + metadata : dict | None, optional + Metadata associated with the data. Default is None. + curves_dataset : CurvesDataset | None, optional + Curves dataset associated with the data. Default is None. + """ + + image: np.ndarray + px2nm: float + timestamps: dict | None = None + metadata: dict | None = None + curves_dataset: CurvesDataset | None = None + + def __init__( + self, + image: np.ndarray, + px2nm: float, + timestamps: dict | None = None, + metadata: dict | None = None, + curves_dataset: CurvesDataset | None = None, + ): + """ + Initialise AFMLoad. + + Parameters + ---------- + image : np.ndarray + The image data. + px2nm : float + The pixel to nanometer scaling factor. + timestamps : dict | None, optional + Timestamps associated with the data. Default is None. + metadata : dict | None, optional + Metadata associated with the data. Default is None. + curves_dataset : CurvesDataset | None, optional + Curves dataset associated with the data. Default is None. + """ + self.image = image + self.px2nm = px2nm + self.timestamps = timestamps + self.metadata = metadata + self.curves_dataset = curves_dataset diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index 9e254d9..9a26039 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -3,9 +3,9 @@ from pathlib import Path from typing import Any -import numpy.typing as npt from AFMReader import asd, gwy, h5_jpk, ibw, jpk, raw_bin, spm, stp, top, topostats, jpk_qi +from AFMReader.data_classes import AFMLoad from AFMReader.logging import logger logger.enable(__package__) @@ -48,9 +48,7 @@ def __init__(self, filepath: str | Path, channel: str, kwargs: dict | None = Non # Store heavy loaded data in a dict to avoid having to reload it self.cached_data: dict[str, Any] = {} - def load( # noqa: C901 - self, channel: str | None = None, kwargs: dict | None = None - ) -> tuple[npt.NDArray | str, float | None] | tuple[npt.NDArray | str, float | None, Any]: + def load(self, channel: str | None = None, kwargs: dict | None = None) -> AFMLoad: # noqa: C901 """ Generally loads a file type that can be handled by AFMReader. @@ -63,15 +61,13 @@ def load( # noqa: C901 Returns ------- - tuple - The image data (stack if ''.asd'' or ''.h5-jpk'') and the pixel to nanometre scaling ratio. - If curve data is found, also return the curve data (a large dict of all the curves). + AFMLoad + An AFMLoad object containing the loaded AFM image data and metadata. Raises ------ ValueError - Where the channel is not found, returned as a tuple of "error message" and "None" so that this can be - propagated to Napari without outright failing. + Where the channel is not found. """ if channel: self.channel = channel @@ -79,61 +75,39 @@ def load( # noqa: C901 self.kwargs = kwargs try: if self.suffix == ".asd": - image, pixel_to_nanometre_scaling_factor, _ = asd.load_asd(self.filepath, self.channel) + afm_load = asd.load_asd(self.filepath, self.channel) elif self.suffix == ".gwy": - image, pixel_to_nanometre_scaling_factor = gwy.load_gwy(self.filepath, self.channel) + afm_load = gwy.load_gwy(self.filepath, self.channel) elif self.suffix == ".ibw": - image, pixel_to_nanometre_scaling_factor = ibw.load_ibw(self.filepath, self.channel) + afm_load = ibw.load_ibw(self.filepath, self.channel) elif self.suffix in [".jpk", ".jpk-qi-image"]: - image, pixel_to_nanometre_scaling_factor = jpk.load_jpk(self.filepath, self.channel) + afm_load = jpk.load_jpk(self.filepath, self.channel) elif self.suffix == ".spm": - image, pixel_to_nanometre_scaling_factor = spm.load_spm(self.filepath, self.channel) + afm_load = spm.load_spm(self.filepath, self.channel) elif self.suffix == ".h5-jpk": - h5_returned = h5_jpk.load_h5jpk(self.filepath, self.channel, load_curves=not self.loaded_curves) - if len(h5_returned) == 3: - image, pixel_to_nanometre_scaling_factor, _ = h5_returned # type: ignore[misc] - elif len(h5_returned) == 4: - image, pixel_to_nanometre_scaling_factor, _, curve_data = h5_returned # type: ignore[misc] + afm_load = h5_jpk.load_h5jpk(self.filepath, self.channel, load_curves=not self.loaded_curves) + if afm_load.curves_dataset is not None: self.loaded_curves = True - return image, pixel_to_nanometre_scaling_factor, curve_data - else: - logger.error(f"Loading h5-jpk file returned unexpected number of items: {len(h5_returned)}") elif self.suffix == ".jpk-qi-data": if "jpk_qi_loader" not in self.cached_data: self.cached_data["jpk_qi_loader"] = jpk_qi.JPKQILoader( filepath=self.filepath, channel=self.channel, **self.kwargs ) - jpk_qi_returned = self.cached_data["jpk_qi_loader"].load(channel=self.channel, **self.kwargs) - if len(jpk_qi_returned) == 2: - image, pixel_to_nanometre_scaling_factor = jpk_qi_returned - elif len(jpk_qi_returned) == 3: - image, pixel_to_nanometre_scaling_factor, curve_data = jpk_qi_returned + afm_load = self.cached_data["jpk_qi_loader"].load(channel=self.channel, **self.kwargs) + if afm_load.curves_dataset is not None: self.loaded_curves = True - return image, pixel_to_nanometre_scaling_factor, curve_data - else: - logger.error(f"Loading h5-jpk file returned unexpected number of items: {len(jpk_qi_returned)}") elif self.suffix == ".stp": - image, pixel_to_nanometre_scaling_factor = stp.load_stp(self.filepath) + afm_load = stp.load_stp(self.filepath) elif self.suffix == ".top": - image, pixel_to_nanometre_scaling_factor = top.load_top(self.filepath) + afm_load = top.load_top(self.filepath) elif self.suffix == ".topostats": - ts_dict = topostats.load_topostats(self.filepath) - try: - image = ts_dict[self.channel] - pixel_to_nanometre_scaling_factor = ts_dict["pixel_to_nm_scaling"] - except KeyError as exc: - image_keys = ["image", "image_original"] - topostats_keys = list(ts_dict.keys()) - raise ValueError( - f"'{self.channel}' not in available image keys: " - f"{[im for im in image_keys if im in topostats_keys]}" - ) from exc + afm_load = topostats.load_topostats(self.filepath, self.channel) elif self.suffix == ".bin": - image, pixel_to_nanometre_scaling_factor = raw_bin.load_bin(self.filepath, **self.kwargs) + afm_load = raw_bin.load_bin(self.filepath, **self.kwargs) else: raise ValueError(f"File type '{self.suffix}' is not currently handled by AFMReader.") - return image, pixel_to_nanometre_scaling_factor + return afm_load except ValueError as e: logger.error(f"{e}") diff --git a/AFMReader/gwy.py b/AFMReader/gwy.py index e36f50b..653cfd9 100644 --- a/AFMReader/gwy.py +++ b/AFMReader/gwy.py @@ -7,6 +7,7 @@ import numpy as np from loguru import logger +from AFMReader.data_classes import AFMLoad from AFMReader.io import read_char, read_double, read_null_terminated_string, read_uint32 @@ -36,7 +37,7 @@ def get_gwy_channels(file_path): return list(channel_ids) -def load_gwy(file_path: Path | str, channel: str) -> tuple[np.ndarray[Any, np.float64], float]: +def load_gwy(file_path: Path | str, channel: str) -> AFMLoad: """ Extract image and pixel to nm scaling from the .gwy file. @@ -49,8 +50,8 @@ def load_gwy(file_path: Path | str, channel: str) -> tuple[np.ndarray[Any, np.fl Returns ------- - tuple(np.ndarray, float) - A tuple containing the image and its pixel to nanometre scaling value. + AFMLoad + An AFMLoad object containing the image and its pixel to nanometre scaling value. Raises ------ @@ -65,7 +66,9 @@ def load_gwy(file_path: Path | str, channel: str) -> tuple[np.ndarray[Any, np.fl Sensor'. >>> from AFMReader.gwy import load_gwy - >>> image, pixel_to_nm = load_gwy(file_path="path/to/file.gwy", channel="Height") + >>> afm_load = load_gwy(file_path="path/to/file.gwy", channel="Height") + >>> image = afm_load.image + >>> px2nm = afm_load.px2nm ``` """ logger.info(f"Loading image from : {file_path}") @@ -113,7 +116,7 @@ def load_gwy(file_path: Path | str, channel: str) -> tuple[np.ndarray[Any, np.fl raise ValueError(f"'{channel}' not found in {file_path.suffix} channel list: {channel_ids}") from e logger.info(f"[{filename}] : Extracted image.") - return (image, px_to_nm) + return AFMLoad(image=image, px2nm=px_to_nm) def gwy_read_object(open_file: BinaryIO, data_dict: dict) -> None: diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py index 1f8c477..facff01 100644 --- a/AFMReader/h5_jpk.py +++ b/AFMReader/h5_jpk.py @@ -12,7 +12,8 @@ import numpy as np from AFMReader.logging import logger -from AFMReader.lazy_data_classes import ( +from AFMReader.data_classes import ( + AFMLoad, CurvesDataset, CurvesMetadata, CurvesVolume, @@ -535,9 +536,7 @@ def get_point_metadata(self, y: int, x: int, direction: int | None = None): return meta_dict -def load_h5jpk( - file_path: Path | str, channel: str, flip_image: bool = True, load_curves: bool = True -) -> tuple[np.ndarray, float, dict[str, float]] | tuple[np.ndarray, float, dict[str, float], CurvesDataset]: +def load_h5jpk(file_path: Path | str, channel: str, flip_image: bool = True, load_curves: bool = True) -> AFMLoad: """ Load image from JPK Instruments .h5-jpk files. @@ -554,15 +553,10 @@ def load_h5jpk( Returns ------- - image : np.ndarray - 3D array of shape (frames, height, width) with image data. - pixel_to_nm_scaling : float - Scaling factor converting pixels to nanometers. - timestamps : dict[str, float] - Dictionary mapping frame labels (e.g., "frame 0") to timestamp values in seconds. - curves_data : CurvesDataset, optional - An instance of CurvesDataset containing lazy-loaded curve data, channel units, and metadata. - Returned only if load_curves is True and curve data is present in the file. + AFMLoad + An AFMLoad object containing the image, its pixel to nanometre scaling value, timestamps, and + optionally the curves dataset. Curves dataset only if load_curves is True and curve data is + present in the file. Raises ------ @@ -576,9 +570,9 @@ def load_h5jpk( Load height trace channel from the .jpk file. 'height_trace' is the default channel name. >>> from AFMReader.jpk import load_h5jpk - >>> frames, pixel_to_nanometre_scaling_factor, timestamps = load_h5jpk(file_path="./my_jpk_file.jpk", - >>> channel="height_trace", - >>> flip_image=True) + >>> afm_load = load_h5jpk(file_path="./my_jpk_file.jpk", channel="height_trace", flip_image=True) + >>> image = afm_load.image + >>> pixel_to_nm_scaling = afm_load.px2nm """ logger.info(f"Loading H5-JPK file from : {file_path}") file_path = Path(file_path) @@ -658,6 +652,6 @@ def load_h5jpk( curves_data = CurvesDataset(volumes={"Trace": curves_volume}, metadata=curves_metadata) - return (image_stack, px2nm, timestamps, curves_data) + return AFMLoad(image=image_stack, px2nm=px2nm, timestamps=timestamps, curves_dataset=curves_data) - return (image_stack, px2nm, timestamps) + return AFMLoad(image=image_stack, px2nm=px2nm, timestamps=timestamps) diff --git a/AFMReader/ibw.py b/AFMReader/ibw.py index 624d2f6..442d459 100644 --- a/AFMReader/ibw.py +++ b/AFMReader/ibw.py @@ -7,6 +7,7 @@ import numpy as np from igor2 import binarywave +from AFMReader.data_classes import AFMLoad from AFMReader.logging import logger logger.enable(__package__) @@ -65,7 +66,7 @@ def get_ibw_channels(file_path: Path | str): return labels -def load_ibw(file_path: Path | str, channel: str) -> tuple[np.ndarray, float]: +def load_ibw(file_path: Path | str, channel: str) -> AFMLoad: """ Load image from Asylum Research (Igor) .ibw files. @@ -78,8 +79,8 @@ def load_ibw(file_path: Path | str, channel: str) -> tuple[np.ndarray, float]: Returns ------- - tuple[np.ndarray, float] - A tuple containing the image and its pixel to nanometre scaling value. + AFMLoad + An AFMLoad object containing the image and its pixel to nanometre scaling value. Raises ------ @@ -94,7 +95,9 @@ def load_ibw(file_path: Path | str, channel: str) -> tuple[np.ndarray, float]: not a typo!). >>> from AFMReader.ibw import load_ibw - >>> image, pixel_to_nanometre_scaling_factor = load_ibw(file_path="./my_ibw_file.ibw", channel="HeightTracee") + >>> afm_load = load_ibw(file_path="./my_ibw_file.ibw", channel="HeightTracee") + >>> image = afm_load.image + >>> pixel_to_nanometre_scaling_factor = afm_load.px2nm """ logger.info(f"Loading image from : {file_path}") file_path = Path(file_path) @@ -126,4 +129,4 @@ def load_ibw(file_path: Path | str, channel: str) -> tuple[np.ndarray, float]: raise e logger.info(f"[{filename}] : Extracted image.") - return (image, _ibw_pixel_to_nm_scaling(scan)) + return AFMLoad(image=image, px2nm=_ibw_pixel_to_nm_scaling(scan)) diff --git a/AFMReader/jpk.py b/AFMReader/jpk.py index 1bf18b8..9f1d185 100644 --- a/AFMReader/jpk.py +++ b/AFMReader/jpk.py @@ -9,6 +9,7 @@ from AFMReader.io import read_yaml from AFMReader.logging import logger +from AFMReader.data_classes import AFMLoad logger.enable(__package__) @@ -237,7 +238,7 @@ def get_jpk_channels(file_path: Path | str, config_path: Path | str | None = Non def load_jpk( file_path: Path | str, channel: str, config_path: Path | str | None = None, flip_image: bool = True -) -> tuple[np.ndarray, float]: +) -> AFMLoad: """ Load image from JPK Instruments .jpk files. @@ -255,8 +256,8 @@ def load_jpk( Returns ------- - tuple[npt.NDArray, float] - A tuple containing the image and its pixel to nanometre scaling value. + AFMLoad + An AFMLoad object containing the image and its pixel to nanometre scaling value. Raises ------ @@ -270,9 +271,9 @@ def load_jpk( Load height trace channel from the .jpk file. 'height_trace' is the default channel name. >>> from AFMReader.jpk import load_jpk - >>> image, pixel_to_nanometre_scaling_factor = load_jpk(file_path="./my_jpk_file.jpk", - >>> channel="height_trace", - >>> flip_image=True) + >>> afm_load = load_jpk(file_path="./my_jpk_file.jpk", channel="height_trace", flip_image=True) + >>> image = afm_load.image + >>> pixel_to_nanometre_scaling_factor = afm_load.px2nm """ logger.info(f"Loading image from : {file_path}") file_path = Path(file_path) @@ -285,7 +286,7 @@ def load_jpk( config_path=config_path, flip_image=flip_image, ) - return (image, px2nm) + return AFMLoad(image=image, px2nm=px2nm) def _load_jpk( diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index ab2b3f7..5345558 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -22,7 +22,7 @@ import h5py import psutil -from AFMReader.lazy_data_classes import CurvesMetadata, CurvesVolume, CurvesDataset +from AFMReader.data_classes import AFMLoad, CurvesMetadata, CurvesVolume, CurvesDataset from AFMReader.logging import logger from AFMReader import jpk @@ -318,6 +318,28 @@ def _get_channel_scaling(props, channel_index): return final_multiplier, final_offset, unit +def _make_num_min_characters(num: int, min_chars: int = 3): + """ + Zero-pad an integer to a minimum number of characters. + + Parameters + ---------- + num : int + The integer to pad. + min_chars : int + The minimum number of characters the resulting string should have. Default is 3. + + Returns + ------- + str + The zero-padded string. + """ + string_num = str(num) + if len(string_num) >= min_chars: + return string_num + return "0" * (min_chars - len(string_num)) + string_num + + class JPKQILoader: """ Class for readability and improving modularity in the load jpk qi data function. @@ -440,7 +462,7 @@ def load( config_path: Path | str | None = None, flip_image: bool | None = True, save_as_h5: bool | None = None, - ) -> tuple[np.ndarray, float, CurvesJPKDataset]: + ) -> AFMLoad: """ Load the .jpk-qi-data file. @@ -457,8 +479,8 @@ def load( Returns ------- - tuple - A tuple containing image data, scaling factor, and curve data. + AFMLoad + An AFMLoad object containing the image, its pixel to nanometre scaling value, and curves dataset. """ # Update instance attributes based on provided parameters self.channel = channel if channel else self.channel @@ -511,7 +533,7 @@ def load( if self.save_as_h5: self.save_lite_data() - return (self.image, self.px2nm, self.curves_dataset) + return AFMLoad(image=self.image, px2nm=self.px2nm, curves_dataset=self.curves_dataset) def output_summary(self): """Output a summary of the loading process, including any failed curve loads and their details.""" @@ -1298,25 +1320,3 @@ def close(self): self.failed_curves = set() self.points_for_channel_segment = {} self.list_of_all_paths = [] - - -def _make_num_min_characters(num: int, min_chars: int = 3): - """ - Zero-pad an integer to a minimum number of characters. - - Parameters - ---------- - num : int - The integer to pad. - min_chars : int - The minimum number of characters the resulting string should have. Default is 3. - - Returns - ------- - str - The zero-padded string. - """ - string_num = str(num) - if len(string_num) >= min_chars: - return string_num - return "0" * (min_chars - len(string_num)) + string_num diff --git a/AFMReader/raw_bin.py b/AFMReader/raw_bin.py index 269b62c..40bcd24 100644 --- a/AFMReader/raw_bin.py +++ b/AFMReader/raw_bin.py @@ -5,6 +5,8 @@ import numpy as np +from AFMReader.data_classes import AFMLoad + from .logging import logger # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,fixme @@ -35,7 +37,7 @@ def load_bin( shape_x: int | None = None, shape_y: int | None = None, z_scaling: float = 1.0, -): +) -> AFMLoad: """ Load image from binary file. Parameters to interpret the binary file must be provided. @@ -60,10 +62,8 @@ def load_bin( Returns ------- - image : np.ndarray - 2D array of shape (height, width) with image data. - px2nm : float - Scaling factor converting pixels to nanometers. + AFMLoad + An AFMLoad object containing the image and its pixel to nanometre scaling value. """ filepath = Path(filepath) dt_key = str(data_type).strip() @@ -90,7 +90,7 @@ def load_bin( pixel_to_nm_scaling_factor_x = size_x / shape_x if shape_x > 0 else 1.0 pixel_to_nm_scaling_factor_y = size_y / shape_y if shape_y > 0 else 1.0 px2nm = (pixel_to_nm_scaling_factor_x + pixel_to_nm_scaling_factor_y) / 2 - return image, px2nm + return AFMLoad(image=image, px2nm=px2nm) def get_bin_channels(): diff --git a/AFMReader/spm.py b/AFMReader/spm.py index da29f9a..8bc4355 100644 --- a/AFMReader/spm.py +++ b/AFMReader/spm.py @@ -5,6 +5,7 @@ import numpy as np import pySPM +from AFMReader.data_classes import AFMLoad from AFMReader.logging import logger logger.enable(__package__) @@ -54,7 +55,7 @@ def spm_pixel_to_nm_scaling(filename: str, channel_data: pySPM.SPM.SPM_image) -> return pixel_to_nm_scaling -def load_spm(file_path: Path | str, channel: str) -> tuple: +def load_spm(file_path: Path | str, channel: str) -> AFMLoad: """ Extract image and pixel to nm scaling from the Bruker .spm file. @@ -67,8 +68,8 @@ def load_spm(file_path: Path | str, channel: str) -> tuple: Returns ------- - tuple(np.ndarray, float) - A tuple containing the image and its pixel to nanometre scaling value. + AFMLoad + An AFMLoad object containing the image and its pixel to nanometre scaling value. Raises ------ @@ -83,7 +84,9 @@ def load_spm(file_path: Path | str, channel: str) -> tuple: Sensor'. >>> from AFMReader.spm import load_spm - >>> image, pixel_to_nm = load_spm(file_path="path/to/file.spm", channel="Height") + >>> afm_load = load_spm(file_path="path/to/file.spm", channel="Height") + >>> image = afm_load.image + >>> pixel_to_nm = afm_load.px2nm ``` """ logger.info(f"Loading image from : {file_path}") @@ -109,7 +112,7 @@ def load_spm(file_path: Path | str, channel: str) -> tuple: raise ValueError(f"'{channel}' not in {file_path.suffix} channel list: {labels}") from e raise e - return (image, spm_pixel_to_nm_scaling(filename, channel_data)) + return AFMLoad(image=image, px2nm=spm_pixel_to_nm_scaling(filename, channel_data)) def get_spm_channels(file_path: Path | str) -> list: diff --git a/AFMReader/stp.py b/AFMReader/stp.py index ffda53b..2eeb5c5 100644 --- a/AFMReader/stp.py +++ b/AFMReader/stp.py @@ -5,6 +5,7 @@ import numpy as np +from AFMReader.data_classes import AFMLoad from AFMReader.io import read_double from AFMReader.logging import logger @@ -13,9 +14,7 @@ # pylint: disable=too-many-locals # pylint: disable=too-many-statements -def load_stp( # noqa: C901 (ignore too complex) - file_path: Path | str, header_encoding: str = "latin-1" -) -> tuple[np.ndarray, float]: +def load_stp(file_path: Path | str, header_encoding: str = "latin-1") -> AFMLoad: # noqa: C901 (ignore too complex) """ Load image from STP files. @@ -28,8 +27,8 @@ def load_stp( # noqa: C901 (ignore too complex) Returns ------- - tuple[np.ndarray, float] - A tuple containing the image and its pixel to nanometre scaling value. + AFMLoad + An AFMLoad object containing the image and its pixel to nanometre scaling value. Raises ------ @@ -108,4 +107,4 @@ def load_stp( # noqa: C901 (ignore too complex) raise e logger.info(f"[{filename}] : Extracted image.") - return (image, pixel_to_nm_scaling) + return AFMLoad(image=image, px2nm=pixel_to_nm_scaling) diff --git a/AFMReader/top.py b/AFMReader/top.py index 6219a68..c555e23 100644 --- a/AFMReader/top.py +++ b/AFMReader/top.py @@ -5,6 +5,7 @@ import numpy as np +from AFMReader.data_classes import AFMLoad from AFMReader.io import read_int16 from AFMReader.logging import logger @@ -14,9 +15,7 @@ # pylint: disable=too-many-locals # pylint: disable=too-many-statements # pylint: disable=too-many-branches -def load_top( # noqa: C901 (ignore too complex) - file_path: Path | str, header_encoding: str = "latin-1" -) -> tuple[np.ndarray, float]: +def load_top(file_path: Path | str, header_encoding: str = "latin-1") -> AFMLoad: # noqa: C901 (ignore too complex) """ Load image from TOP files. @@ -29,8 +28,8 @@ def load_top( # noqa: C901 (ignore too complex) Returns ------- - tuple[np.ndarray, float] - A tuple containing the image and its pixel to nanometre scaling value. + AFMLoad + An AFMLoad object containing the image and its pixel to nanometre scaling value. Raises ------ @@ -117,4 +116,4 @@ def load_top( # noqa: C901 (ignore too complex) raise e logger.info(f"[{filename}] : Extracted image.") - return (image, pixel_to_nm_scaling) + return AFMLoad(image=image, px2nm=pixel_to_nm_scaling) diff --git a/AFMReader/topostats.py b/AFMReader/topostats.py index 9b3ac44..1a26c17 100644 --- a/AFMReader/topostats.py +++ b/AFMReader/topostats.py @@ -1,30 +1,32 @@ """For decoding and loading .topostats (HDF5 format) AFM file format into Python Nympy arrays.""" from pathlib import Path -from typing import Any import h5py from packaging.version import parse as parse_version +from AFMReader.data_classes import AFMLoad from AFMReader.io import unpack_hdf5 from AFMReader.logging import logger logger.enable(__package__) -def load_topostats(file_path: Path | str) -> dict[str, Any]: +def load_topostats(file_path: Path | str, channel: str) -> AFMLoad: """ Extract image and pixel to nm scaling from the .topostats (HDF5 format) file. Parameters ---------- - file_path : Path or str + file_path : Path | str Path to the .topostats file. + channel : str + The channel to load. Returns ------- - dict[str, Any] - A dictionary containing the image, its pixel to nm scaling factor and nested Numpy arrays representing the + AFMLoad + An AFMLoad object containing the image, its pixel to nm scaling factor, and nested Numpy arrays representing the analyses performed on the data. Raises @@ -34,7 +36,9 @@ def load_topostats(file_path: Path | str) -> dict[str, Any]: Examples -------- - >>> image, pixel_to_nm_scaling = load_topostats("path/to/topostats_file.topostats") + >>> afm_load = load_topostats("path/to/topostats_file.topostats", channel="image") + >>> image = afm_load.image + >>> pixel_to_nm_scaling = afm_load.px2nm """ logger.info(f"Loading image from : {file_path}") file_path = Path(file_path) @@ -58,4 +62,15 @@ def load_topostats(file_path: Path | str) -> dict[str, Any]: raise e logger.info(f"[{filename}] : Extracted .topostats dictionary.") - return data + try: + image = data.pop(channel) + pixel_to_nanometre_scaling_factor = data.pop("pixel_to_nm_scaling") + except KeyError as exc: + image_keys = ["image", "image_original"] + topostats_keys = list(data.keys()) + raise ValueError( + f"'{channel}' not in available image keys: " f"{[im for im in image_keys if im in topostats_keys]}" + ) from exc + + # Analyses are stored to metadata - this might be a bit clunky and potentially should be stored to their own attr + return AFMLoad(image=image, px2nm=pixel_to_nanometre_scaling_factor, metadata=data) diff --git a/tests/test_asd.py b/tests/test_asd.py index 8af1b6e..a85bbf0 100644 --- a/tests/test_asd.py +++ b/tests/test_asd.py @@ -19,16 +19,12 @@ ) def test_load_asd(file_name: str, channel: str, number_of_frames: int, pixel_to_nm_scaling: float) -> None: """Test the normal operation of loading a .asd file.""" - result_frames = list - result_pixel_to_nm_scaling = float - result_metadata = dict - file_path = RESOURCES / file_name - result_frames, result_pixel_to_nm_scaling, result_metadata = asd.load_asd(file_path, channel) + afm_load = asd.load_asd(file_path, channel) - assert len(result_frames) == number_of_frames # type: ignore - assert result_pixel_to_nm_scaling == pixel_to_nm_scaling - assert isinstance(result_metadata, dict) + assert len(afm_load.image) == number_of_frames # type: ignore + assert afm_load.px2nm == pixel_to_nm_scaling + assert isinstance(afm_load.metadata, dict) def test_load_asd_file_not_found() -> None: diff --git a/tests/test_general_loader.py b/tests/test_general_loader.py index 66bb7a6..ccc0ac3 100644 --- a/tests/test_general_loader.py +++ b/tests/test_general_loader.py @@ -135,9 +135,9 @@ def test_load(capsys: pytest.CaptureFixture, filepath: Path, channel: str, error with pytest.raises(ValueError, match=re.escape(message)): loader.load() else: - image, px2nm = loader.load() # type: ignore[misc] - assert isinstance(image, np.ndarray) - assert isinstance(px2nm, float) + afm_load = loader.load() + assert isinstance(afm_load.image, np.ndarray) + assert isinstance(afm_load.px2nm, float) # check output logs captured = capsys.readouterr() assert message in captured.err @@ -157,7 +157,7 @@ def test_load_filenotfounderror(filepath: Path) -> None: loader = general_loader.LoadFile(filepath, "channel") with pytest.raises(FileNotFoundError) as execinfo: # noqa: PT012 - _, _ = loader.load() # type: ignore[misc] + loader.load() assert "[not_a_real_file] FileNotFoundError" in execinfo.value diff --git a/tests/test_gwy.py b/tests/test_gwy.py index 5d3aa1a..eb8e942 100644 --- a/tests/test_gwy.py +++ b/tests/test_gwy.py @@ -16,12 +16,12 @@ def test_load_gwy() -> None: """Test the normal operation of loading a .gwy file.""" channel = "ZSensor" file_path = RESOURCES / "sample_0.gwy" - result_image, result_pixel_to_nm_scaling = gwy.load_gwy(file_path, channel=channel) - assert isinstance(result_image, np.ndarray) - assert result_image.shape == (512, 512) - assert result_image.sum() == pytest.approx(33836850.232917726) - assert isinstance(result_pixel_to_nm_scaling, float) - assert result_pixel_to_nm_scaling == pytest.approx(0.8468632812499975) + afm_load = gwy.load_gwy(file_path, channel=channel) + assert isinstance(afm_load.image, np.ndarray) + assert afm_load.image.shape == (512, 512) + assert afm_load.image.sum() == pytest.approx(33836850.232917726) + assert isinstance(afm_load.px2nm, float) + assert afm_load.px2nm == pytest.approx(0.8468632812499975) def test_gwy_read_object() -> None: diff --git a/tests/test_h5jpk.py b/tests/test_h5jpk.py index 3ba82ac..a5dde28 100644 --- a/tests/test_h5jpk.py +++ b/tests/test_h5jpk.py @@ -129,20 +129,18 @@ def test_load_h5jpk( image_sum: float, ) -> None: """Test the normal operation of loading a .h5-jpk file.""" - result_image, result_pixel_to_nm_scaling, results_timestamps = h5_jpk.load_h5jpk( # type: ignore[misc] - RESOURCES / file_name, channel, flip_image - ) - - assert result_pixel_to_nm_scaling == pytest.approx(pixel_to_nm_scaling) - assert isinstance(result_image, np.ndarray) - assert result_image.shape == image_shape - assert result_image.dtype == np.dtype(image_dtype) - assert isinstance(results_timestamps, timestamps_dtype) - assert result_image.sum() == pytest.approx(image_sum) - assert len(results_timestamps) == result_image.shape[0] + afm_load = h5_jpk.load_h5jpk(RESOURCES / file_name, channel, flip_image) + + assert afm_load.px2nm == pytest.approx(pixel_to_nm_scaling) + assert isinstance(afm_load.image, np.ndarray) + assert afm_load.image.shape == image_shape + assert afm_load.image.dtype == np.dtype(image_dtype) + assert isinstance(afm_load.timestamps, timestamps_dtype) + assert afm_load.image.sum() == pytest.approx(image_sum) + assert len(afm_load.timestamps) == afm_load.image.shape[0] assert all( - results_timestamps[f"frame {i}"] < results_timestamps[f"frame {i + 1}"] - for i in range(len(results_timestamps) - 1) + afm_load.timestamps[f"frame {i}"] < afm_load.timestamps[f"frame {i + 1}"] + for i in range(len(afm_load.timestamps) - 1) ) @@ -199,7 +197,9 @@ def test_load_h5jpk_curves( curve_targets : dict[str, tuple[int, float]] A dictionary mapping curve channels to their expected size and sum, used for validating the loaded curve data. """ - _, _, _, curve_dataset = h5_jpk.load_h5jpk(RESOURCES / file_name, channel, flip_image) # type: ignore[misc] + afm_load = h5_jpk.load_h5jpk(RESOURCES / file_name, channel, flip_image) + curve_dataset = afm_load.curves_dataset + assert curve_dataset is not None curve_at_coords = curve_dataset.get_default_volume()[curve_coords[0], curve_coords[1]] for curve_channel, (expected_size, expected_sum) in curve_targets.items(): curve = curve_at_coords[curve_channel][curve_direction] diff --git a/tests/test_ibw.py b/tests/test_ibw.py index 34dfdc2..2fa745a 100644 --- a/tests/test_ibw.py +++ b/tests/test_ibw.py @@ -24,17 +24,14 @@ def test_load_ibw( image_sum: float, ) -> None: """Test the normal operation of loading an .ibw file.""" - result_image = np.ndarray - result_pixel_to_nm_scaling = float - file_path = RESOURCES / file_name - result_image, result_pixel_to_nm_scaling = ibw.load_ibw(file_path, channel) # type: ignore + afm_load = ibw.load_ibw(file_path, channel) - assert result_pixel_to_nm_scaling == pytest.approx(pixel_to_nm_scaling) - assert isinstance(result_image, np.ndarray) - assert result_image.shape == image_shape - assert result_image.dtype == image_dtype - assert result_image.sum() == pytest.approx(image_sum) + assert afm_load.px2nm == pytest.approx(pixel_to_nm_scaling) + assert isinstance(afm_load.image, np.ndarray) + assert afm_load.image.shape == image_shape + assert afm_load.image.dtype == image_dtype + assert afm_load.image.sum() == pytest.approx(image_sum) def test_load_ibw_file_not_found() -> None: diff --git a/tests/test_jpk.py b/tests/test_jpk.py index b84b05e..e9fea19 100644 --- a/tests/test_jpk.py +++ b/tests/test_jpk.py @@ -73,16 +73,14 @@ def test_load_jpk( image_sum: float, ) -> None: """Test the normal operation of loading a .jpk file.""" - result_image = np.ndarray - result_pixel_to_nm_scaling = float file_path = RESOURCES / file_name - result_image, result_pixel_to_nm_scaling = jpk.load_jpk(file_path, channel) # type: ignore + afm_load = jpk.load_jpk(file_path, channel) - assert result_pixel_to_nm_scaling == pytest.approx(pixel_to_nm_scaling) - assert isinstance(result_image, np.ndarray) - assert result_image.shape == image_shape - assert result_image.dtype == image_dtype - assert result_image.sum() == pytest.approx(image_sum) + assert afm_load.px2nm == pytest.approx(pixel_to_nm_scaling) + assert isinstance(afm_load.image, np.ndarray) + assert afm_load.image.shape == image_shape + assert afm_load.image.dtype == image_dtype + assert afm_load.image.sum() == pytest.approx(image_sum) def test_load_jpk_file_not_found() -> None: diff --git a/tests/test_jpk_qi.py b/tests/test_jpk_qi.py index f06c918..febbd73 100644 --- a/tests/test_jpk_qi.py +++ b/tests/test_jpk_qi.py @@ -124,26 +124,26 @@ def test_load_jpk_qi_data( # pylint: disable=too-many-arguments,too-many-positi curve_targets: dict[str, tuple[int, float]], ) -> None: """Test the normal operation of loading a .jpk-qi-data file.""" - result_image = np.ndarray - result_pixel_to_nm_scaling = float file_path = RESOURCES / file_name jpk_qi_loader = jpk_qi.JPKQILoader(file_path, channel) - result_image, result_pixel_to_nm_scaling, curve_dataset = jpk_qi_loader.load() # type: ignore + afm_load = jpk_qi_loader.load() - assert result_pixel_to_nm_scaling == pytest.approx(pixel_to_nm_scaling) - assert isinstance(result_image, np.ndarray) - assert result_image.shape == image_shape - assert result_image.dtype == image_dtype - assert result_image.sum() == pytest.approx(image_sum) + assert afm_load.px2nm == pytest.approx(pixel_to_nm_scaling) + assert isinstance(afm_load.image, np.ndarray) + assert afm_load.image.shape == image_shape + assert afm_load.image.dtype == image_dtype + assert afm_load.image.sum() == pytest.approx(image_sum) # Test curve data for all targets + curve_dataset = afm_load.curves_dataset + assert curve_dataset is not None, "Curves data not found/ is None" curve_at_coords = curve_dataset.get_default_volume()[curve_coords[0], curve_coords[1]] for curve_channel, (expected_size, expected_sum) in curve_targets.items(): curve = curve_at_coords[curve_channel][curve_direction] assert curve.shape == (expected_size,) assert curve.sum() == pytest.approx(expected_sum) - jpk_qi_loader.close() # type: ignore + jpk_qi_loader.close() def test_load_jpk_data_file_not_found() -> None: diff --git a/tests/test_spm.py b/tests/test_spm.py index f24b276..7866911 100644 --- a/tests/test_spm.py +++ b/tests/test_spm.py @@ -33,17 +33,14 @@ def test_load_spm( image_sum: float, ) -> None: """Test the normal operation of loading a .spm file.""" - result_image = np.ndarray - result_pixel_to_nm_scaling = float - file_path = RESOURCES / file_name - result_image, result_pixel_to_nm_scaling = spm.load_spm(file_path, channel=channel) + afm_load = spm.load_spm(file_path, channel=channel) - assert result_pixel_to_nm_scaling == pytest.approx(pixel_to_nm_scaling) - assert isinstance(result_image, np.ndarray) - assert result_image.shape == image_shape - assert result_image.dtype == image_dtype - assert result_image.sum() == pytest.approx(image_sum) + assert afm_load.px2nm == pytest.approx(pixel_to_nm_scaling) + assert isinstance(afm_load.image, np.ndarray) + assert afm_load.image.shape == image_shape + assert afm_load.image.dtype == image_dtype + assert afm_load.image.sum() == pytest.approx(image_sum) @patch("pySPM.SPM.SPM_image") diff --git a/tests/test_stp.py b/tests/test_stp.py index 271a85a..ba928b8 100644 --- a/tests/test_stp.py +++ b/tests/test_stp.py @@ -33,10 +33,10 @@ def test_load_stp( ) -> None: """Test the normal operation of loading a .stp file.""" file_path = RESOURCES / file_name - result_image, result_pixel_to_nm_scaling = load_stp(file_path=file_path) + afm_load = load_stp(file_path=file_path) - assert result_pixel_to_nm_scaling == pytest.approx(expected_pixel_to_nm_scaling) - assert isinstance(result_image, np.ndarray) - assert result_image.shape == expected_image_shape - assert result_image.dtype == expected_image_dtype - assert result_image.sum() == pytest.approx(expected_image_sum) + assert afm_load.px2nm == pytest.approx(expected_pixel_to_nm_scaling) + assert isinstance(afm_load.image, np.ndarray) + assert afm_load.image.shape == expected_image_shape + assert afm_load.image.dtype == expected_image_dtype + assert afm_load.image.sum() == pytest.approx(expected_image_sum) diff --git a/tests/test_top.py b/tests/test_top.py index e44b5ce..ae20b36 100644 --- a/tests/test_top.py +++ b/tests/test_top.py @@ -33,10 +33,10 @@ def test_load_top( ) -> None: """Test the normal operation of loading a .top file.""" file_path = RESOURCES / file_name - result_image, result_pixel_to_nm_scaling = load_top(file_path=file_path) + afm_load = load_top(file_path=file_path) - assert result_pixel_to_nm_scaling == pytest.approx(expected_pixel_to_nm_scaling) - assert isinstance(result_image, np.ndarray) - assert result_image.shape == expected_image_shape - assert result_image.dtype == expected_image_dtype - assert result_image.sum() == pytest.approx(expected_image_sum) + assert afm_load.px2nm == pytest.approx(expected_pixel_to_nm_scaling) + assert isinstance(afm_load.image, np.ndarray) + assert afm_load.image.shape == expected_image_shape + assert afm_load.image.dtype == expected_image_dtype + assert afm_load.image.sum() == pytest.approx(expected_image_sum) diff --git a/tests/test_topostats.py b/tests/test_topostats.py index 66a7989..c2c2a18 100644 --- a/tests/test_topostats.py +++ b/tests/test_topostats.py @@ -89,21 +89,23 @@ def test_load_topostats( ) -> None: """Test the normal operation of loading a .topostats (HDF5 format) file.""" file_path = RESOURCES / file_name - topostats_data = topostats.load_topostats(file_path) + afm_load = topostats.load_topostats(file_path, channel="image") - assert set(topostats_data.keys()) == data_keys # type: ignore + expected_metadata_keys = data_keys - {"image", "pixel_to_nm_scaling"} + assert afm_load.metadata is not None + assert set(afm_load.metadata.keys()) == expected_metadata_keys if version_key == "topostats_file_version": - assert topostats_data[version_key] == float(version) + assert afm_load.metadata[version_key] == float(version) else: - assert topostats_data[version_key] == version - assert topostats_data["pixel_to_nm_scaling"] == pytest.approx(pixel_to_nm_scaling) - assert topostats_data["image"].shape == image_shape - assert topostats_data["image"].sum() == pytest.approx(image_sum) + assert afm_load.metadata[version_key] == version + assert afm_load.px2nm == pytest.approx(pixel_to_nm_scaling) + assert afm_load.image.shape == image_shape + assert afm_load.image.sum() == pytest.approx(image_sum) if version > "0.2": - assert isinstance(topostats_data["img_path"], Path) + assert isinstance(afm_load.metadata["img_path"], Path) def test_load_topostats_file_not_found() -> None: """Ensure FileNotFound error is raised.""" with pytest.raises(FileNotFoundError): - topostats.load_topostats("nonexistant_file.topostats") + topostats.load_topostats("nonexistant_file.topostats", channel="image") From 0b29bbbb6c82d9295fca7f4c5942d00a34374eaf Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Fri, 5 Jun 2026 16:42:01 +0100 Subject: [PATCH 52/54] refactor: moving jpk-qi-data logic in general_loader into jpk_qi --- AFMReader/general_loader.py | 18 ++++-------- AFMReader/jpk_qi.py | 56 +++++++++++++++++++++++++++++++++---- 2 files changed, 56 insertions(+), 18 deletions(-) diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index 9a26039..bc34a4f 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -85,17 +85,11 @@ def load(self, channel: str | None = None, kwargs: dict | None = None) -> AFMLoa elif self.suffix == ".spm": afm_load = spm.load_spm(self.filepath, self.channel) elif self.suffix == ".h5-jpk": - afm_load = h5_jpk.load_h5jpk(self.filepath, self.channel, load_curves=not self.loaded_curves) - if afm_load.curves_dataset is not None: - self.loaded_curves = True + afm_load = h5_jpk.load_h5jpk(self.filepath, self.channel) elif self.suffix == ".jpk-qi-data": - if "jpk_qi_loader" not in self.cached_data: - self.cached_data["jpk_qi_loader"] = jpk_qi.JPKQILoader( - filepath=self.filepath, channel=self.channel, **self.kwargs - ) - afm_load = self.cached_data["jpk_qi_loader"].load(channel=self.channel, **self.kwargs) - if afm_load.curves_dataset is not None: - self.loaded_curves = True + afm_load = jpk_qi.load_jpk_data( + filepath=self.filepath, channel=self.channel, cached_data=self.cached_data, **self.kwargs + ) elif self.suffix == ".stp": afm_load = stp.load_stp(self.filepath) elif self.suffix == ".top": @@ -135,9 +129,7 @@ def get_available_channels(self): # noqa: C901 elif self.suffix == ".h5-jpk": available_channels = h5_jpk.get_h5jpk_channels(self.filepath) elif self.suffix == ".jpk-qi-data": - if "jpk_qi_loader" not in self.cached_data: - self.cached_data["jpk_qi_loader"] = jpk_qi.JPKQILoader(filepath=self.filepath, **self.kwargs) - available_channels = self.cached_data["jpk_qi_loader"].get_available_channels() + available_channels = jpk_qi.get_jpk_data_channels(filepath=self.filepath, cached_data=self.cached_data) elif self.suffix == ".topostats": available_channels = ["image", "image_original"] elif self.suffix == ".bin": diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index 5345558..f7d8234 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -431,6 +431,7 @@ def __init__( # Instantiate containers for data to be saved (so an exception is not caused if not saving) self.curve_groups = None + self.saved_to_h5 = False def get_available_channels(self): """ @@ -501,7 +502,7 @@ def load( self.parse_dimension_data() # Setup H5 Data structures if needed - if self.save_as_h5: + if self.save_as_h5 and not self.saved_to_h5: self.save_to_h5() # Establish the lazy loading structures for curve data and metadata. Note how lazy structure is used even if @@ -529,10 +530,6 @@ def load( # Load the image self.image, _ = self.get_image() - # Save a lite form of the images (precalculated) if saving to a file - if self.save_as_h5: - self.save_lite_data() - return AFMLoad(image=self.image, px2nm=self.px2nm, curves_dataset=self.curves_dataset) def output_summary(self): @@ -692,6 +689,9 @@ def save_to_h5( global_meta_group.attrs[key] = str(value).encode("utf-8") logger.info(f"QI data copied to h5 data {file.filename}") + # Save a lite form of the images (precalculated) if saving to a file + self.save_lite_data() + self.saved_to_h5 = True def get_curves_sample(self): """ @@ -1320,3 +1320,49 @@ def close(self): self.failed_curves = set() self.points_for_channel_segment = {} self.list_of_all_paths = [] + + +def load_jpk_data(filepath: str | Path, channel: str, cached_data: dict, save_as_h5: bool = False) -> AFMLoad: + """ + Load the JPK QI data using the JPKQILoader. + + Parameters + ---------- + filepath : str | Path + Path to the JPK QI file. + channel : str + The channel to load from the file. + cached_data : dict + Cached data to avoid reloading heavy data. + save_as_h5 : bool, optional + Whether to save the loaded data as an h5 file for faster future loading. Default is False. + + Returns + ------- + AFMLoad + The loaded JPK QI data. + """ + if "jpk_qi_loader" not in cached_data: + cached_data["jpk_qi_loader"] = JPKQILoader(filepath=filepath, channel=channel, save_as_h5=save_as_h5) + return cached_data["jpk_qi_loader"].load(channel=channel, save_as_h5=save_as_h5) + + +def get_jpk_data_channels(filepath: str | Path, cached_data: dict) -> list[str]: + """ + Get the available channels in the JPK QI data. + + Parameters + ---------- + filepath : str | Path + Path to the JPK QI file. + cached_data : dict + Cached data to avoid reloading heavy data. + + Returns + ------- + list[str] + A list of available channels in the JPK QI data. + """ + if "jpk_qi_loader" not in cached_data: + cached_data["jpk_qi_loader"] = JPKQILoader(filepath=filepath) + return cached_data["jpk_qi_loader"].get_available_channels() From 23523fbce14a8f16685144fa4de74ed3fd906055 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Fri, 5 Jun 2026 16:48:47 +0100 Subject: [PATCH 53/54] chore: moving channels logic out of general_loader and into topostats --- AFMReader/general_loader.py | 2 +- AFMReader/topostats.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py index bc34a4f..2e201a6 100644 --- a/AFMReader/general_loader.py +++ b/AFMReader/general_loader.py @@ -131,7 +131,7 @@ def get_available_channels(self): # noqa: C901 elif self.suffix == ".jpk-qi-data": available_channels = jpk_qi.get_jpk_data_channels(filepath=self.filepath, cached_data=self.cached_data) elif self.suffix == ".topostats": - available_channels = ["image", "image_original"] + available_channels = topostats.get_topostats_channels() elif self.suffix == ".bin": available_channels = raw_bin.get_bin_channels() elif self.suffix in [".stp", ".top"]: diff --git a/AFMReader/topostats.py b/AFMReader/topostats.py index 1a26c17..46ec2d2 100644 --- a/AFMReader/topostats.py +++ b/AFMReader/topostats.py @@ -74,3 +74,15 @@ def load_topostats(file_path: Path | str, channel: str) -> AFMLoad: # Analyses are stored to metadata - this might be a bit clunky and potentially should be stored to their own attr return AFMLoad(image=image, px2nm=pixel_to_nanometre_scaling_factor, metadata=data) + + +def get_topostats_channels() -> list[str]: + """ + Get the available channels for a .topostats file. + + Returns + ------- + list[str] + A list of available channels in the .topostats file. + """ + return ["image", "image_original"] From 724460fbcce1f38c965aa89bd35b8fa39654c119 Mon Sep 17 00:00:00 2001 From: ahobbs7 Date: Sun, 7 Jun 2026 22:35:14 +0100 Subject: [PATCH 54/54] fix: resizing of datasets while saving not including current buffered data --- AFMReader/jpk_qi.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py index f7d8234..2338ff4 100644 --- a/AFMReader/jpk_qi.py +++ b/AFMReader/jpk_qi.py @@ -991,16 +991,16 @@ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, dire buf["Data"].append(segment_array) if len(buf["Data"]) >= self.BUFFER_SIZE or curve_num == self.num_of_curves - 1: - if self.points_for_channel_segment[direction][chan_name] > data_size: - # Fetch and resize the existing dataset for this channel and segment to fit the new data - data_set.resize((self.points_for_channel_segment[direction][chan_name],)) - buffered_data = np.concatenate(buf["Data"]) + required_size = filled_size + len(buffered_data) + if required_size > data_size: + # Fetch and resize the existing dataset for this channel and segment to fit the new data + data_set.resize((required_size,)) # Add the buffer to the dataset - data_set[filled_size : filled_size + len(buffered_data)] = buffered_data + data_set[filled_size:required_size] = buffered_data # Update the filled size for this channel and segment - self.points_for_channel_segment[direction][chan_name] += len(buffered_data) + self.points_for_channel_segment[direction][chan_name] = required_size # Clear the buffer buf["Data"].clear()