Skip to content
86 changes: 51 additions & 35 deletions neurokit2/data/read_xdf.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,37 @@
# -*- coding: utf-8 -*-
import io
import urllib

from typing import TypedDict

import numpy as np
import pandas as pd
import requests


class ReadXDFInfo(TypedDict):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a new class?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's to document the structure of info in the return type. It doesn't do anything at runtime. Couldn't find a more elegant solution than using a class unfortunately

sampling_rates_original: list[float]
sampling_rates_effective: list[float]
sampling_rate: int
datetime: str
data: list[pd.DataFrame]


def read_xdf(filename, upsample=2, fillmissing=None):
def read_xdf(
filename: str, upsample: float = 2.0, fillmissing: float | None = None
) -> tuple[pd.DataFrame, ReadXDFInfo]:
"""**Read and tidy an XDF file**

Reads and tidies an XDF file with multiple streams into a Pandas DataFrame.
The function outputs both the dataframe and the information (such as the sampling rate).

Note that, as XDF can store streams with different sampling rates and different time stamps,
**the function will resample all streams to 2 times (default) the highest sampling rate** (to
minimize aliasing). The final sampling rate can be found in the ``info`` dictionary.
minimize aliasing) and then interpolate based on an evenly spaced index. While this is generally safe, it
may produce unexpected results, particularly if the original stream has large gaps in its time series.
For more discussion, see `here <https://github.com/xdf-modules/pyxdf/pull/1>`_.

The final upsampled sampling rate can be found in the ``info`` dictionary.

.. note::

Expand All @@ -21,11 +41,12 @@ def read_xdf(filename, upsample=2, fillmissing=None):
Parameters
----------
filename : str
Path (with the extension) of an XDF file (e.g., ``"data.xdf"``).
Path (with the extension) or URL pointing to an XDF file (e.g., ``"data.xdf"``).
upsample : float
Factor by which to upsample the data. Default is 2, which means that the data will be
Factor by which to upsample the data. Default is 2.0, which means that the data will be
resampled to 2 times the highest sampling rate. You can increase that to further reduce
edge-distortion, especially for high frequency signals like EEG.
edge-distortion, especially for high frequency signals like EEG. ``1.0`` disables upsampling
(but not interpolation).
fillmissing : float
The maximum duration in seconds of missing data to fill. ``None`` (default) will
interpolate all missing values and prevent issues with NaNs. However, it might be important
Expand All @@ -35,9 +56,9 @@ def read_xdf(filename, upsample=2, fillmissing=None):

Returns
----------
df : DataFrame, dict
The BITalino file as a pandas dataframe if one device was read, or a dictionary
of pandas dataframes (one dataframe per device) if multiple devices are read.
df : DataFrame
The XDF data as a pandas dataframe. If multiple streams are read,
they will be merged into one dataframe.
info : dict
The metadata information containing the sampling rate(s).

Expand All @@ -53,21 +74,32 @@ def read_xdf(filename, upsample=2, fillmissing=None):

# data, info = nk.read_xdf("data.xdf")
# sampling_rate = info["sampling_rate"]

"""
try:
import pyxdf
except ImportError:
except ImportError as e:
raise ImportError(
"The 'pyxdf' module is required for this function to run. ",
"Please install it first (`pip install pyxdf`).",
)
) from e

# Load file
# TODO: would be nice to be able to stream a file from URL
# if filename is a URL, stream bytes from file
if urllib.parse.urlparse(filename).scheme != "":
try:
req = requests.get(filename, stream=True, timeout=10)
req.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)

req.raw.decode_content = True
filename = io.BytesIO(req.content)
except requests.exceptions.RequestException as e:
raise IOError(f"Failed to read XDF file from URL: {filename}") from e

streams, header = pyxdf.load_xdf(filename)

# Get smaller time stamp to later use as offset (zero point)
min_ts = min([min(s["time_stamps"]) for s in streams])
min_ts = min(min(s["time_stamps"]) for s in streams)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not s ure this will work does it 🤔
the goal is to create a vector of mins and then get its min

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked and it produces identical results. It's just a little micro-optimisation that reduces memory usage a bit, but I mainly did it to appease pylint


# Loop through all the streams and convert to dataframes
dfs = []
Expand All @@ -83,22 +115,16 @@ def read_xdf(filename, upsample=2, fillmissing=None):
if stream["info"]["type"][0] == "GYRO":
dat = dat.rename(columns={"X": "GYRO_X", "Y": "GYRO_Y", "Z": "GYRO_Z"})
# Compute movement
dat["GYRO"] = np.sqrt(
dat["GYRO_X"] ** 2 + dat["GYRO_Y"] ** 2 + dat["GYRO_Z"] ** 2
)
dat["GYRO"] = np.sqrt(dat["GYRO_X"] ** 2 + dat["GYRO_Y"] ** 2 + dat["GYRO_Z"] ** 2)

if stream["info"]["type"][0] == "ACC":
dat = dat.rename(columns={"X": "ACC_X", "Y": "ACC_Y", "Z": "ACC_Z"})
# Compute acceleration
dat["ACC"] = np.sqrt(
dat["ACC_X"] ** 2 + dat["ACC_Y"] ** 2 + dat["ACC_Z"] ** 2
)
dat["ACC"] = np.sqrt(dat["ACC_X"] ** 2 + dat["ACC_Y"] ** 2 + dat["ACC_Z"] ** 2)

# Muse - PPG data has three channels: ambient, infrared, red
if stream["info"]["type"][0] == "PPG":
dat = dat.rename(
columns={"PPG1": "LUX", "PPG2": "PPG", "PPG3": "RED", "IR": "PPG"}
)
dat = dat.rename(columns={"PPG1": "LUX", "PPG2": "PPG", "PPG3": "RED", "IR": "PPG"})
# Zeros suggest interruptions, better to replace with NaNs (I think?)
dat["PPG"] = dat["PPG"].replace(0, value=np.nan)
dat["LUX"] = dat["LUX"].replace(0, value=np.nan)
Expand All @@ -111,12 +137,8 @@ def read_xdf(filename, upsample=2, fillmissing=None):

# Store metadata
info = {
"sampling_rates_original": [
float(s["info"]["nominal_srate"][0]) for s in streams
],
"sampling_rates_effective": [
float(s["info"]["effective_srate"]) for s in streams
],
"sampling_rates_original": [float(s["info"]["nominal_srate"][0]) for s in streams],
"sampling_rates_effective": [float(s["info"]["effective_srate"]) for s in streams],
"datetime": header["info"]["datetime"][0],
"data": dfs,
}
Expand All @@ -137,14 +159,8 @@ def read_xdf(filename, upsample=2, fillmissing=None):
fillmissing = int(info["sampling_rate"] * fillmissing)

# Create new index with evenly spaced timestamps
idx = pd.date_range(
df.index.min(), df.index.max(), freq=str(1000 / info["sampling_rate"]) + "ms"
)
idx = pd.date_range(df.index.min(), df.index.max(), freq=str(1000 / info["sampling_rate"]) + "ms")
# https://stackoverflow.com/questions/47148446/pandas-resample-interpolate-is-producing-nans
df = (
df.reindex(df.index.union(idx))
.interpolate(method="index", limit=fillmissing)
.reindex(idx)
)
df = df.reindex(df.index.union(idx)).interpolate(method="index", limit=fillmissing).reindex(idx)

return df, info
Loading