Source code for eegunity.utils.channel_align_raw

import numpy as np
import mne
from eegunity.modules.parser.eeg_parser import set_montage_any
from eegunity.utils.label_channel import misc_channel_indices, stim_channel_indices


[docs] def channel_align_raw(mne_raw, channel_order, min_matched_channel=1): """ Aligns and orders the channels of an MNE Raw object according to a specified channel order. This function ensures that the channels in the raw MNE object are aligned and ordered according to the specified `channel_order`. If some channels from `channel_order` are missing in the raw data, they will be added with zero values and later interpolated. ``misc`` label channels and ``stim`` trigger channels are temporarily removed before alignment so they do not interfere with EEG-specific operations (montage fitting, bad-channel interpolation). They are re-appended at the end of the channel list after alignment is complete. Parameters ---------- mne_raw : mne.io.Raw The raw EEG/MEG data in an MNE Raw object. channel_order : list of str The desired order of channels. Should contain only channels to align (typically EEG). ``misc`` and ``stim`` channels are handled separately and must not be listed here. min_matched_channel : int, optional The minimum required number of matched channels, by default 1. Returns ------- mne.io.Raw The modified raw object with channels aligned, missing channels interpolated, and preserved ``misc``/``stim`` channels appended at the end. Raises ------ ValueError If the number of matched channels is less than `min_matched_channel`. Notes ----- - The function picks and reorders the matched channels to match `channel_order`. - If some channels from `channel_order` are missing in `mne_raw`, they are added as zero data channels and interpolated. - The missing channels are first marked as 'bad' before interpolation. - ``misc``/``stim`` channels are not interpolated and are not included in the alignment order. Examples -------- >>> import mne >>> raw = mne.io.read_raw_fif('sample_raw.fif', preload=True) >>> desired_order = ['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2'] >>> aligned_raw = channel_align_raw(raw, desired_order, min_matched_channel=5) """ # ----------------------------------------------------------------------- # Step 1: Extract and remove preserved channels before EEG alignment. # They survive alignment unchanged and are re-appended at the end. # ----------------------------------------------------------------------- preserve_idx = sorted(set(misc_channel_indices(mne_raw)) | set(stim_channel_indices(mne_raw))) if preserve_idx: mne_raw.load_data() preserve_names = [mne_raw.ch_names[idx] for idx in preserve_idx] preserve_data = mne_raw._data[preserve_idx, :].copy() preserve_types = [mne.channel_type(mne_raw.info, idx) for idx in preserve_idx] # Drop preserved channels so they don't interfere with EEG alignment. align_names = [ch for i, ch in enumerate(mne_raw.ch_names) if i not in preserve_idx] mne_raw.pick_channels(align_names) # ----------------------------------------------------------------------- # Step 2: Standard EEG channel alignment. # ----------------------------------------------------------------------- # Get existing channels in the raw object existing_channels = mne_raw.ch_names # Find the matched channels between the raw data and the desired channel order matched_channels = [ch for ch in channel_order if ch in existing_channels] # If the number of matched channels is less than the minimum required, raise an error if len(matched_channels) < min_matched_channel: raise ValueError( f"Error: Matched channels ({len(matched_channels)}) are less than the required minimum ({min_matched_channel})") # If there are missing channels in the raw data, handle them if len(matched_channels) < len(channel_order): missing_channels = [ch for ch in channel_order if ch not in existing_channels] mne_raw.load_data() # Create minimal info for the missing channels (only basic info required) missing_info = mne.create_info(missing_channels, sfreq=mne_raw.info['sfreq'], ch_types='eeg') # Create the missing channels with zero data missing_raw = mne.io.RawArray(np.zeros((len(missing_channels), len(mne_raw.times))), missing_info) # Add the missing channels to the raw data mne_raw.add_channels([missing_raw], force_update_info=True) # Use set_montage_any to add the montage (coordinates) for missing channels mne_raw = set_montage_any(mne_raw, verbose='CRITICAL') # Manually mark the missing channels as 'bads' so they can be interpolated mne_raw.info['bads'].extend(missing_channels) # Interpolate missing channels mne_raw.interpolate_bads(reset_bads=True, origin='auto', method=dict(meg="MNE", eeg="MNE", fnirs="nearest")) # Pick the matched channels and ensure the correct order mne_raw.pick_channels(matched_channels) mne_raw.reorder_channels(matched_channels) # Ensures the channels are in the specified order # ----------------------------------------------------------------------- # Step 3: Re-append preserved channels at the end of the channel list. # ----------------------------------------------------------------------- if preserve_idx: preserved_info = mne.create_info( preserve_names, sfreq=mne_raw.info['sfreq'], ch_types=preserve_types, ) preserved_raw = mne.io.RawArray(preserve_data, preserved_info, verbose=False) mne_raw.add_channels([preserved_raw], force_update_info=True) return mne_raw