Source code for eegunity.modules.llm_booster.eeg_llm_des_parser

import csv
import os
import pandas as pd
from scipy.io import loadmat
import json
import re


def _read_files(directory):
    files = []
    print(f"Traversing the directory:{directory}")
    for root, dirs, files_in_dir in os.walk(directory):
        print(f"Current directory:{root}")
        for file in files_in_dir:
            file_path = os.path.join(root, file)

            print(f"Find file:{file_path}")
            try:
                size = os.path.getsize(file_path)
                print(f"File size:{size} byte")
                if size < 3 * 1024 * 1024:
                    if file.endswith('.txt') or file.endswith('.md'):
                        try:
                            with open(file_path, 'r', encoding='utf-8') as f:
                                file_content = f.read()
                        except UnicodeDecodeError:
                            with open(file_path, 'r', encoding='latin-1') as f:
                                file_content = f.read()
                        files.append({"file_path": file_path, "content": file_content})
                        print(f"read {file} complete, length:{len(file_content)}")
                    elif file.endswith('.docx'):
                        import docx
                        doc = docx.Document(file_path)
                        file_content = "\n".join([para.text for para in doc.paragraphs])
                        files.append({"file_path": file_path, "content": file_content})
                        print(f"read {file} complete")
                    elif file.endswith('.pdf'):
                        import pdfplumber
                        with pdfplumber.open(file_path) as pdf:
                            file_content = "\n".join([page.extract_text() or "" for page in pdf.pages])
                        files.append({"file_path": file_path, "content": file_content})
                        print(f"read {file} complete")
                    elif file.endswith('.csv'):
                        with open(file_path, 'r', encoding='utf-8') as f:
                            reader = csv.reader(f)
                            file_content = "\n".join([",".join(row) for row in reader])
                        files.append({"file_path": file_path, "content": file_content})
                        print(f"read {file} complete")
                    elif file.endswith('.xls') or file.endswith('.xlsx'):
                        df = pd.read_excel(file_path)
                        file_content = df.to_csv(index=False)
                        files.append({"file_path": file_path, "content": file_content})
                        print(f"read {file} complete")
                    elif file.endswith('.mat'):
                        try:
                            mat_data = loadmat(file_path)
                            file_content = str(mat_data)
                            files.append({"file_path": file_path, "content": file_content})
                            print(f"read {file} complete")
                        except Exception as e:
                            print(f"Error processing.mat file {file_path} :{e}")
                    elif 'readme' in file.lower() or 'annotation' in file.lower() or 'record' in file.lower():
                        try:
                            with open(file_path, 'r', encoding='utf-8') as f:
                                file_content = f.read()
                        except UnicodeDecodeError:
                            with open(file_path, 'r', encoding='latin-1') as f:
                                file_content = f.read()
                        files.append({"file_path": file_path, "content": file_content})
                        print(f"read {file} complete, length:{len(file_content)}")

            except Exception as e:
                print(f"Error processing.mat file {file_path}{e}")
    return files

def _filter_files_with_gpt(files: str, client_paras: dict, client_type: str, completion_para: dict):
    """
    Analyzes the given files using a GPT model and extracts JSON information containing
    sampling rate and channel names.

    Some large models may return mixed text; therefore, this function includes fault-tolerant
    logic to parse the response correctly.

    Args:
        files (List[Dict[str, Any]]): A list of file information dictionaries, each containing:
            - "file_path" (str): Path to the file.
            - "content" (str): File content.
        client_paras (Dict[str, Any]): Parameters required for initializing the GPT client.
        client_type (str): Specifies the type of GPT client to use. Supported types:
            - "AzureOpenAI"
            - "OpenAI"
        completion_para (Dict[str, Any]): Parameters for the GPT model completion request.

    Returns:
        List[Tuple[Dict[str, Any], Dict[str, Any]]]: A list of tuples where each tuple contains:
            - The original file information.
            - Extracted sampling rate and channel names in JSON format.
    """

    def parse_json_with_fallback(response_text: str):
        """
        Extracts and parses a JSON structure from the given GPT response text.

        The function attempts to retrieve a valid JSON object using the following steps:
        1. Prioritizes JSON code blocks enclosed by ```json ... ``` markers.
        2. If not found or parsing fails, attempts to extract JSON from the first occurrence
           of '{' to the last occurrence of '}'.
        3. If parsing still fails, raises a ValueError.

        Args:
            response_text (str): The text response from the GPT model.

        Returns:
            Dict[str, Any]: Extracted JSON data.

        Raises:
            ValueError: If no valid JSON can be extracted.
        """
        # Step 1: Attempt to extract JSON from ```json ... ``` blocks
        code_blocks = re.findall(r'```json\s*(.*?)\s*```', response_text, re.DOTALL)
        if code_blocks:
            for block in code_blocks:
                block = block.strip()
                try:
                    return json.loads(block)
                except Exception:
                    continue  # Try the next JSON block if parsing fails

        # Step 2: If no JSON block found, attempt to extract from the entire response
        text = response_text.strip()
        start_index = text.find('{')
        end_index = text.rfind('}')
        if start_index != -1 and end_index != -1 and start_index < end_index:
            possible_json = text[start_index:end_index + 1]
            try:
                return json.loads(possible_json)
            except Exception:
                pass

        # Step 3: If all attempts fail, raise an error
        raise ValueError("No valid JSON could be extracted from the GPT response.")
    processed_files = []
    try:
        if client_type == "AzureOpenAI":
            from openai import AzureOpenAI
            client = AzureOpenAI(**client_paras)
        elif client_type == "OpenAI":
            from openai import OpenAI
            client = OpenAI(**client_paras)
        else:
            raise ValueError("Unsupported client_type. Supported types are 'AzureOpenAI' and 'OpenAI'.")

        for file_info in files:
            try:
                file_path = file_info["file_path"]
                content = file_info["content"]
                print(f"Files being processed: {file_path}")

                response = client.chat.completions.create(
                    **completion_para,
                    response_format={"type": "json_object"},
                    messages=[
                        {
                            "role": "system",
                            "content": (
                                    "You are a highly capable assistant. Your task is to analyze the provided files and "
                                    "extract the **sampling rates** and **channel names** in a well-structured JSON format. "
                                    "The output should follow this structure:\n\n"
                                    "{\n"
                                    '    "sampling_rate": <number or null>,\n'
                                    '    "channels": [<channel_name_1>, <channel_name_2>, ...] or null\n'
                                    "}\n\n"
                                    "If either sampling rates or channel names are not available in the file, return `null` for the corresponding field. "
                                    "Ensure the response strictly follows the specified JSON format without additional text or explanations."                        ),
                        },
                        {
                            "role": "user",
                            "content": json.dumps({
                                "file_path": file_path,
                                "content": content,
                                "format": "json"
                            })
                        }
                    ]
                )
                print(f"LLM response: {response}")
                response_text = response.choices[0].message.content
                decision = parse_json_with_fallback(response_text)
                processed_files.append((file_info, decision))
                print(f"Parsing LLM response: {decision}")
            except Exception as e:
                print(f"Error LLM response: {e}")
                continue

        return processed_files

    except Exception as e:
        print(f"Error using GPT to filter files: {e}")
        return processed_files


def _resolve_sampling_rate_conflict(sampling_rate_list):
    if not sampling_rate_list:
        return None, "Sample rate information not found"

    print("Multiple different sample rate information was detected:")
    for i, (file_info, sampling_rate) in enumerate(sampling_rate_list, 1):
        print(f"{i}: From file {file_info['file_path']}")
        print(f"   Sampling rate: {sampling_rate}")
    print(f"{len(sampling_rate_list) + 1}: Using no sample rate information")

    chosen_index = None
    while chosen_index not in range(1, len(sampling_rate_list) + 2):
        try:
            user_input = input(f"Please select a number for the sample rate information (input format is '1'): ")
            chosen_index = int(user_input)
        except ValueError:
            print("Invalid selection, please re-enter.")

    if chosen_index == len(sampling_rate_list) + 1:
        return None, "Sample rate data is discarded"
    else:
        return sampling_rate_list[chosen_index - 1][1], None


def _resolve_channel_names_conflict(channel_info_list):
    if not channel_info_list:
        return [], "No channel name information found"

    print("Several different channel name information was detected:")
    for i, (file_info, channel_names) in enumerate(channel_info_list, 1):
        print(f"{i}: From file {file_info['file_path']}")
        print(f"   Channel names: {channel_names}")
    print(f"{len(channel_info_list) + 1}: Using no channel names information")

    chosen_index = None
    while chosen_index not in range(1, len(channel_info_list) + 2):
        try:
            user_input = input(f"Please select a number for the channel name information (input format is '1'): ")
            chosen_index = int(user_input)
        except ValueError:
            print("Invalid selection, please re-enter.")

    if chosen_index == len(channel_info_list) + 1:
        return [], "The channel name data was discarded."
    else:
        return channel_info_list[chosen_index - 1][1], None


[docs] def llm_description_file_parser(directory: str, client_type: str, client_paras: dict, completion_para: dict): """ Parse files in a specified directory to extract sampling rate and channel information using a Large Language Model (LLM) API. This function traverses a directory to read various file formats. It extracts sampling rates and channel names from the files using an LLM API (e.g., GPT-4), and processes the extracted information based on user inputs to resolve conflicts. Contributor: Jingyi Ding (Jingyi.Ding21@student.xjtlu.edu.cn), on 2024-07-26. EEGUnity Team modified it on 2025-02-23 Parameters ---------- directory : str The directory path where the files are stored for processing. Generally speaking, it will be the root directory of the dataset client_type : str The type of LLM client to use (e.g., "AzureOpenAI", "OpenAI"). client_paras : dict A dictionary containing the parameters needed to initialize the LLM API client. Please refer to OpenAI documentation. Returns ------- dict: A dictionary containing the parsed sampling rate and channel information. Returns an error message if no files are selected or if all data is discarded due to conflicts. Raises ------ ValueError: If no files are selected for further analysis or if there are conflicts in the extracted data. Examples -------- >>> directory = 'path/to/description/directory' >>> client_paras = {"api_key": "your_api_key", "api_version": "2023-03-15-preview"} >>> client_type = "AzureOpenAI" >>> result = llm_description_file_parser(directory, client_paras, client_type) >>> print("The end result:", json.dumps(result, indent=4, ensure_ascii=False)) """ files = _read_files(directory) processed_files = _filter_files_with_gpt(files, client_paras=client_paras, client_type=client_type, completion_para=completion_para) if not processed_files: return {"error": "No files were selected for further analysis"} sampling_rate_list = [] channel_info_list = [] channel_keys = ["channel_names", "channels", "channel name", "names of channel"] sampling_rate_keys = ["sampling_rate", "sampling rates", "sample rate", "samplingrate"] for file_info, decision in processed_files: for key in sampling_rate_keys: if key in decision: sampling_rate_list.append((file_info, decision[key])) break for key in channel_keys: if key in decision: channel_info_list.append((file_info, decision[key])) break selected_info = {"sampling_rate": None, "channels": []} selected_sampling_rate, sampling_rate_msg = _resolve_sampling_rate_conflict(sampling_rate_list) if sampling_rate_msg: print(f" Sampling rate data: {sampling_rate_msg}") selected_info["sampling_rate"] = None else: selected_info["sampling_rate"] = selected_sampling_rate selected_channel_names, channel_names_msg = _resolve_channel_names_conflict(channel_info_list) if channel_names_msg: print(f" Channel name data: {channel_names_msg}") selected_info["channels"] = [] else: selected_info["channels"] = selected_channel_names if selected_info["sampling_rate"] is None and not selected_info["channels"]: return {"error": "All data is discarded"} return {"sampling_rate": selected_info["sampling_rate"], "channels": selected_info["channels"]}