Source code for eegunity.utils.h5

import json
import datetime
import h5py
import numpy as np
from pathlib import Path


[docs] class h5Dataset: """ Handle HDF5 file operations in a format compatible with h5py. This class is adapted from: https://github.com/935963004/LaBraM/blob/main/dataset_maker/shock/utils/h5.py#L8. Atomic write ------------ Data is written to ``<name>.hdf5.tmp`` and only renamed to ``<name>.hdf5`` when :meth:`save` is called successfully. If the final ``.hdf5`` file already exists, ``FileExistsError`` is raised at construction time so the user can clean up explicitly. """ def __init__(self, path: Path, name: str) -> None: """ Initialize the HDF5 dataset handler. Parameters ---------- path : Path The path to the directory containing the HDF5 file. name : str The name of the HDF5 file (without extension). Raises ------ FileExistsError If ``<path>/<name>.hdf5`` already exists. """ self.__name = name self.__final_path = Path(path) / f'{name}.hdf5' self.__tmp_path = Path(path) / f'{name}.hdf5.tmp' if self.__final_path.exists(): raise FileExistsError( f"HDF5 file already exists: {self.__final_path}. " "Delete it manually before re-exporting." ) if self.__tmp_path.exists(): self.__tmp_path.unlink() self.__f = h5py.File(self.__tmp_path, 'w')
[docs] def addGroup(self, grpName: str): """ Add a new group to the HDF5 file. Parameters ---------- grpName : str The name of the group to create. Returns ------- h5py.Group The created group object. """ return self.__f.create_group(grpName)
[docs] def addDataset(self, grp: h5py.Group, dsName: str, arr: np.array, chunks: tuple = None, **kwargs): """ Add a dataset to a specified group. Parameters ---------- grp : h5py.Group The group to which the dataset will be added. dsName : str The name of the dataset. arr : np.array The data to store in the dataset. chunks : tuple, optional The chunk shape to use when storing the dataset. **kwargs Additional keyword arguments passed to `create_dataset`. Returns ------- h5py.Dataset The created dataset object. """ return grp.create_dataset(dsName, data=arr, chunks=chunks, **kwargs)
[docs] def addAttributes(self, src: 'h5py.Dataset|h5py.Group', attrName: str, attrValue): """ Add an attribute to a dataset or group. Parameters ---------- src : h5py.Dataset or h5py.Group The target object to which the attribute will be added. attrName : str The name of the attribute. attrValue : any The value of the attribute. """ src.attrs[f'{attrName}'] = attrValue
[docs] def save(self): """Close the tmp file and atomically rename it to the final path.""" self.__f.close() self.__tmp_path.rename(self.__final_path)
@property def name(self): """ Get the name of the HDF5 dataset. Returns ------- str The name of the HDF5 file. """ return self.__name
[docs] class h5EpochDatasetV2: """ EEGUnity v2 epoch HDF5 format writer (v2.1 schema). Flat-array layout optimised for PyTorch random access: Structure --------- / (root) attrs: version ("2.1"), sfreq, ch_names (JSON), n_channels, n_times, label_map (JSON: code->event_name), n_epochs_total, info_fields (JSON list[str]; v2.1), created_by, created_at ├── data (N, n_ch, n_times) float32 │ chunk=(1, n_ch, n_times), gzip level-1 ├── epoch_meta/ │ ├── source_group (N,) variable-length UTF-8 string │ └── event_code (N,) int16 ├── misc_meta/ (v2.1, optional) │ ├── {misc_channel_name} (N,) float32 │ └── attrs.names: JSON list[str] └── source_meta/ └── {group_name}/ attrs: file_path, n_epochs_in_source, sfreq, ch_names (JSON), age, gender, amplifier, cap, handedness, <any additional participants.tsv columns> └── info (uint8) pickled mne.Info bytes Backward compat --------------- Readers of v2.0 can still open v2.1 files if they ignore /misc_meta/ and the extra source attrs. The dataset_api in eeg_kernel_agent accepts both "2.0" and "2.1" as valid version strings. Usage ----- writer = h5EpochDatasetV2(output_dir, "MyDataset") writer.add_epochs(group_name, event_name, epoch_array_float32, info_bytes, source_attrs, sfreq, ch_names, misc_values={"accuracy": np.array([1, 0, 1])}) # v2.1 writer.save() Reading in PyTorch ------------------ Use eeg_kernel_agent.dataset_api.{EventDataset, MISCDataset, InfoDataset}. """ def __init__(self, path: Path, name: str) -> None: self._final_path = Path(path) / f'{name}.hdf5' self._tmp_path = Path(path) / f'{name}.hdf5.tmp' if self._final_path.exists(): raise FileExistsError( f"HDF5 file already exists: {self._final_path}. " "Delete it manually before re-exporting." ) if self._tmp_path.exists(): self._tmp_path.unlink() self._f = None self._label_map: dict = {} # event_name -> int code self._next_code: int = 0 self._n_ch: int = None self._n_times: int = None self._source_epoch_counts: dict = {} # group_name -> cumulative count self._info_fields: set = set() # v2.1: non-canonical source attr keys self._misc_names: list = [] # v2.1: preserves insertion order # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _ensure_initialized(self, n_ch: int, n_times: int, sfreq: float, ch_names) -> None: if self._f is not None: return self._n_ch = n_ch self._n_times = n_times self._f = h5py.File(self._tmp_path, 'w') self._f.attrs['version'] = '2.1' self._f.attrs['sfreq'] = float(sfreq) self._f.attrs['ch_names'] = json.dumps(list(ch_names)) self._f.attrs['n_channels'] = int(n_ch) self._f.attrs['n_times'] = int(n_times) self._f.attrs['created_by'] = 'EEGUnity' self._f.attrs['created_at'] = datetime.datetime.now().isoformat() self._f.create_dataset( 'data', shape=(0, n_ch, n_times), maxshape=(None, n_ch, n_times), dtype='float32', chunks=(1, n_ch, n_times), compression='gzip', compression_opts=1, ) em = self._f.create_group('epoch_meta') em.create_dataset( 'source_group', shape=(0,), maxshape=(None,), dtype=h5py.string_dtype(encoding='utf-8'), ) em.create_dataset( 'event_code', shape=(0,), maxshape=(None,), dtype='int16', ) self._f.create_group('source_meta') def _get_or_create_code(self, event_name: str) -> int: if event_name not in self._label_map: self._label_map[event_name] = self._next_code self._next_code += 1 return self._label_map[event_name] def _append_1d(self, ds_name: str, values) -> None: ds = self._f[ds_name] old = ds.shape[0] n = len(values) ds.resize(old + n, axis=0) ds[old:old + n] = values # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ _CANONICAL_SOURCE_ATTRS = ( 'file_path', 'age', 'gender', 'amplifier', 'cap', 'handedness', )
[docs] def add_epochs( self, group_name: str, event_name: str, epoch_data: np.ndarray, info_bytes: bytes, source_attrs: dict, sfreq: float, ch_names, misc_values: dict = None, ) -> None: """ Append epochs for one (source_file, event) pair. Parameters ---------- group_name : str Unique identifier for the source file (e.g. basename without extension). event_name : str Human-readable event / class label. epoch_data : np.ndarray, shape (n_epochs, n_ch, n_times) Epoch array; will be cast to float32. info_bytes : bytes ``pickle.dumps(raw.info)`` from the source file. source_attrs : dict Scalar metadata stored as HDF5 attrs on the source_meta group. Canonical keys (optional): file_path, age, gender, amplifier, cap, handedness. Any additional keys (e.g. participants.tsv columns like ``p_factor``) are stored verbatim and their names recorded in the root ``info_fields`` attr. sfreq : float Sampling frequency (used only during lazy initialisation). ch_names : list[str] Channel names (used only during lazy initialisation). misc_values : dict, optional Mapping ``{misc_channel_name: array-like of length n}`` for v2.1 per-epoch misc values (e.g. reaction_time, accuracy). Stored under ``/misc_meta/<name>`` as float32. """ n, n_ch, n_times = epoch_data.shape if n == 0: return self._ensure_initialized(n_ch, n_times, sfreq, ch_names) if n_ch != self._n_ch or n_times != self._n_times: raise ValueError( f"Epoch shape mismatch for group '{group_name}', event '{event_name}': " f"expected ({self._n_ch}, {self._n_times}), got ({n_ch}, {n_times}). " "Ensure all recordings have the same channel count and epoch length." ) # ---- Register source_meta (once per source file) ---- sm = self._f['source_meta'] if group_name not in sm: grp = sm.create_group(group_name) grp.create_dataset( 'info', data=np.frombuffer(info_bytes, dtype='uint8'), chunks=None, ) grp.attrs['sfreq'] = float(sfreq) grp.attrs['ch_names'] = json.dumps(list(ch_names)) for key in self._CANONICAL_SOURCE_ATTRS: val = source_attrs.get(key, 'unknown') grp.attrs[key] = str(val) if val is not None else 'unknown' # v2.1: persist arbitrary extra attrs (participants.tsv columns etc.) for key, val in source_attrs.items(): if key in self._CANONICAL_SOURCE_ATTRS: continue if val is None: continue grp.attrs[key] = str(val) self._info_fields.add(key) # ---- Append epoch data ---- old_size = self._f['data'].shape[0] self._f['data'].resize(old_size + n, axis=0) self._f['data'][old_size:old_size + n] = epoch_data.astype('float32') # ---- Append epoch_meta ---- code = self._get_or_create_code(event_name) self._append_1d('epoch_meta/source_group', np.array([group_name] * n, dtype=object)) self._append_1d('epoch_meta/event_code', np.full(n, code, dtype='int16')) # ---- Append misc_meta (v2.1) ---- # Always call if *any* misc channel has been registered, so newly-added # channels get NaN padding for all prior epochs. if misc_values or self._misc_names: self._append_misc(misc_values or {}, n, group_name, event_name) # ---- Update per-source epoch count ---- self._source_epoch_counts[group_name] = ( self._source_epoch_counts.get(group_name, 0) + n )
def _append_misc(self, misc_values: dict, n: int, group_name: str, event_name: str) -> None: if 'misc_meta' not in self._f: self._f.create_group('misc_meta') mm = self._f['misc_meta'] current_total = self._f['data'].shape[0] # already resized old_total = current_total - n for name, values in misc_values.items(): arr = np.asarray(values, dtype='float32') if arr.shape != (n,): raise ValueError( f"misc_values['{name}'] length {arr.shape} does not match " f"epoch count {n} for group '{group_name}', event '{event_name}'." ) if name not in mm: mm.create_dataset( name, shape=(old_total,), maxshape=(None,), dtype='float32', fillvalue=float('nan'), ) self._misc_names.append(name) ds = mm[name] # Pad any prior epochs that lacked this misc channel with NaN. if ds.shape[0] < old_total: ds.resize(old_total, axis=0) ds.resize(old_total + n, axis=0) ds[old_total:old_total + n] = arr # Other misc channels not supplied this call: extend with NaN so the # final array length stays aligned with /data. for name in self._misc_names: if name in misc_values: continue ds = mm[name] if ds.shape[0] < old_total + n: ds.resize(old_total + n, axis=0)
[docs] def save(self) -> None: """Finalise the tmp file, then atomically rename it to the final path.""" if self._f is None: return # Reverse map: int code -> event_name string reverse_map = {int(v): k for k, v in self._label_map.items()} self._f.attrs['label_map'] = json.dumps(reverse_map) self._f.attrs['n_epochs_total'] = int(self._f['data'].shape[0]) # v2.1: non-canonical source attr keys (participants.tsv columns etc.) self._f.attrs['info_fields'] = json.dumps(sorted(self._info_fields)) # Persist per-source epoch counts into source_meta attrs sm = self._f['source_meta'] for grp_name, count in self._source_epoch_counts.items(): if grp_name in sm: sm[grp_name].attrs['n_epochs_in_source'] = int(count) # v2.1: persist misc channel order on /misc_meta if 'misc_meta' in self._f: self._f['misc_meta'].attrs['names'] = json.dumps(self._misc_names) self._f.close() self._f = None self._tmp_path.rename(self._final_path)
@property def name(self) -> str: return self._final_path.stem