import csv import logging import os import random import nltk import numpy as np import pandas as pd from config import ( # LOG_FILE, API_ERROR, IGNORE_BY_API_ERROR, SEED, ) from datasets import load_dataset def print_and_log(message: str): # TODO: redefine logging """ Log message. Args: message (str): The message to be printed and logged. """ logging.info(message) def write_to_file(filename: str, content: str): """ Writes the given content to a specified file. Args: filename (str): The path to the file to write content. content (str): The content to be written. """ print(content) with open(filename, "a+", encoding="utf-8") as file: file.write(content) def write_new_data( output_file: str, current_data: dict, column_names: list, ) -> None: """ Writes a new row of data to a CSV file. Args: output_file (str): The path to the output CSV file. current_data (dict): A dictionary containing the data to be written. column_names (list): A list of column names in the desired order. Returns: None """ # Extract data in the specified order based on column names data_row = [current_data[column] for column in column_names] # Write the data row to the CSV file write_to_csv(output_file, data_row) def write_to_csv(filename: str, row_data: list) -> None: """ Appends a row of data to a CSV file. Args: filename (str): The name of the CSV file. row_data: A list of values to be written as a row. Returns: None """ # Open the CSV file in append mode, creating it if it doesn't exist with open(filename, "a+", encoding="UTF8", newline="") as file: writer = csv.writer(file) writer.writerow(row_data) def count_csv_lines(filename: str) -> int: """Counts the number of lines in a CSV file, excluding the header row. Args: filename (str): The path to the CSV file. Returns: int: The number of lines in the CSV file, excluding the header row. """ file_data = pd.read_csv(filename, sep=",").values return len(file_data) def read_csv_data(input_file: str) -> np.ndarray: """ Reads data from a specified CSV file. Args: file_path (str): The path to the CSV file. Returns: numpy.ndarray: The data from the CSV file. """ file_data = pd.read_csv( input_file, dtype="string", keep_default_na=False, sep=",", ).values return file_data def get_column(input_file: str, column_name: str) -> np.ndarray: """ Retrieves a specific column from a CSV file as a NumPy array. Args: input_file (str): The path to the CSV file. column_name (str): The name of the column to extract. Returns: np.ndarray: Values from the specified column. """ # Read CSV, preserving string data types and handling missing values df = pd.read_csv( input_file, dtype="string", keep_default_na=False, sep=",", ) # Extract the specified column as a NumPy array column_data = df[column_name].values return column_data def generate_column_names(categories: list) -> list: """ Generates column names for a pairwise comparison matrix. Args: categories (list): A list of categories. Returns: list: A list of column names, including a 'human' column and pairwise combinations. """ column_names = ["human"] # Add individual category names as column names column_names.extend(categories) # Add pairwise combinations of categories as column names for i in categories: for j in categories: column_names.append(f"{i}_{j}") # TODO: improve? # for i in range(len(categories)): # for j in range(i + 1, len(categories)): # column_names.append(f"{categories[i]}_{categories[j]}") return column_names def normalize_text(input_text: str) -> str: """ Normalizes the given text by removing unnecessary characters and formatting it for better readability. Args: input_text (str): The input text to be normalized. Returns: The normalized text. This function performs the following transformations: 1. Strips leading and trailing whitespace 2. Removes double asterisks (`**`) 3. Replaces newlines with spaces 4. Removes extra spaces """ processed_text = input_text.strip() processed_text = processed_text.replace("**", "") processed_text = processed_text.replace("\n", " ") processed_text = processed_text.replace(" ", " ") # Remove extra spaces # TODO: what if 3 or more spaces return processed_text def refine_candidate_text(input_text: str, candidate_text: str) -> str: # TODO: how different with processing text """ Removes specific surrounding marks from the candidate text if they are present in the input text with an excess of exactly two occurrences. Args: input_text (str): The original text. candidate (str): The candidate text to be refined. Returns: str: The refined candidate text. """ # Create a copy of the candidate string and strip whitespace refined_candidate = candidate_text.strip() # Iterate through each mark for mark in ["```", "'", '"']: # Count occurrences of the mark in input_text and refined_candidate count_input_text = input_text.count(mark) count_refined_candidate = refined_candidate.count(mark) # Check if the mark should be stripped if ( count_refined_candidate == count_input_text + 2 and refined_candidate.startswith(mark) and refined_candidate.endswith(mark) ): # Strip the mark from both ends of the refined_candidate refined_candidate = refined_candidate.strip(mark) return refined_candidate def generate_file_name( existing_data_file: str, existing_kinds: list, new_kinds: list, ) -> str: """ Generates a new file name based on the path of an existing data file and a combination of existing and new kinds. Args: existing_data_file (str): The path to the existing data file. existing_kinds (list): A list of existing kinds. new_kinds (list): A list of new kinds. Returns: str: The generated file name with the full path. """ # Combine existing and new kinds into a single list combined_kinds = existing_kinds + new_kinds # Get the directory path of the existing data file directory_path = os.path.dirname(existing_data_file) # Create a new file name by joining the kinds with underscores and adding # a suffix # TODO: move to config file new_file_name = "_".join(combined_kinds) + "_with_best_similarity.csv" # Combine the directory path with the new file name to get the full output # file path output_file_path = os.path.join(directory_path, new_file_name) return output_file_path def shuffle(data: list[list], seed: int) -> None: """ Shuffles the elements within each sublist of the given data structure. Args: data (list of lists): The array containing sublists to shuffle. seed (int): The seed value for the random number generator. Returns: None """ for sublist in data: random.Random(seed).shuffle(sublist) def generate_human_with_shuffle( dataset_name: str, column_name: str, num_samples: int, output_file: str, ) -> None: """ Generates a shuffled list of sentences from the dataset and writes them to a CSV file. Args: dataset_name (str): The name of the dataset to load. column_name (str): The column name to extract sentences from. num_samples (int): The number of samples to process. output_file (str): The path to the output CSV file. Returns: None """ # Load the dataset dataset = load_dataset(dataset_name) data = dataset["train"] lines = [] # Tokenize sentences and add to the lines list for sample in data: nltk_tokens = nltk.sent_tokenize(sample[column_name]) lines.extend(nltk_tokens) # Filter out empty lines filtered_lines = [line for line in lines if line != ""] lines = filtered_lines # Shuffle the lines shuffle([lines], seed=SEED) # Ensure the output file exists and write the header if it doesn't if not os.path.exists(output_file): header = ["human"] write_to_csv(output_file, header) # Get the number of lines already processed in the output file number_of_processed_lines = count_csv_lines(output_file) # Print the initial lines to be processed print(f"Lines before processing: {lines[:num_samples]}") # Slice the lines list to get the unprocessed lines lines = lines[number_of_processed_lines:num_samples] # Print the lines after slicing print(f"Lines after slicing: {lines}") # Process each line and write to the output file for index, human in enumerate(lines): normalized_text = normalize_text(human) output_data = [normalized_text] write_to_csv(output_file, output_data) print( f"Processed {index + 1} / {len(lines)};\ Total processed:\ {number_of_processed_lines + index + 1} / {num_samples}", ) def split_data(data: list, train_ratio: float) -> list[list, list]: """ Splits a dataset into training and testing sets. Args: data (list): The input dataset. train_ratio (float): The proportion of data to use for training. Returns: The training and testing sets. """ # Calculate the number of samples for training train_size = int(len(data) * train_ratio) # Split the data into training and testing sets train_data = data[:train_size] test_data = data[train_size:] return train_data, test_data def combine_text_with_BERT_format(text_list: list[str]) -> str: """ Formats a list of texts into a single string suitable for BERT input. Args: text_list (list[str]): A list of text strings. Returns: str: A single string formatted with BERT's special tokens. """ # TODO: simplify this function # combined_text = f"{text_list[0]}" # for i in range(1, len(text_list)): # combined_text += f"{text_list[i]}" # return combined_text formatted_text = "" + "".join(text_list) + "" return formatted_text def check_api_error(data: list): """ Checks if the given data contains an API error or an indication to ignore an API error. Args: data (list): A list of items to check. Returns: bool: True if an API error or ignore indication is found, False otherwise. """ for item in data: # Check for API error indicators if item in (API_ERROR, IGNORE_BY_API_ERROR): return True # Return True if at least an error indicator is found return False # Return False if no error indicators are found def calculate_required_models(num_columns: int) -> int: """ Calculates the minimum number of models required to generate the specified number of columns. Args: num_columns (int): The total number of columns to generate. Returns: int: The minimum number of models required. Raises: ValueError: If the number of columns cannot be achieved with the current model configuration. """ num_models = 0 count_human = 1 # Initial count representing human input # TODO: simplify this function while True: count_single = num_models # Single model count count_pair = num_models * num_models # Pair model count total_count = count_human + count_single + count_pair if total_count == num_columns: return num_models elif total_count > num_columns: raise Exception( "Cannot calculate the number of models to match the number of columns", # noqa: E501 ) num_models += 1 def parse_multimodal_data(multimodel_csv_file: list) -> list: """ Parses multimodal data from a CSV file into a structured format. Args: multimodel_csv_file (str): Path to the CSV file. Returns: list: A list of dictionaries, each containing 'human', 'single', and 'pair' keys. Raises: Exception: If there is an error in reading the CSV file or processing the data. """ # TODO: simplify this function # Read CSV data into a list of lists input_data = read_csv_data(multimodel_csv_file) # Initialize the result list structured_data = [] # Calculate the number of models based on the number of columns in the first row # noqa: E501 num_models = calculate_required_models(len(input_data[0])) # Process each row in the input data for row in input_data: row_data = {} index = 0 # Extract human data row_data["human"] = row[index] index += 1 # Extract single model data single_model_data = [] for _ in range(num_models): single_model_data.append(row[index]) index += 1 row_data["single"] = single_model_data # Extract pair model data pair_model_data = [] for _ in range(num_models): sub_pair_data = [] for _ in range(num_models): sub_pair_data.append(row[index]) index += 1 pair_model_data.append(sub_pair_data) row_data["pair"] = pair_model_data # Append the structured row data to the result list structured_data.append(row_data) return structured_data def check_error(data_item: dict) -> bool: """ Checks if the given data item contains any API errors. An API error is indicated by a specific error message or code within the text. Args: data_item (dict): A dictionary containing 'human', 'single', and 'pair' fields. Returns: bool: True if an API error is found, otherwise False. """ # Check for API error in the 'human' field if check_api_error(data_item["human"]): return True # Check for API error in the 'single' model data for single_text in data_item["single"]: if check_api_error(single_text): return True # Get the number of models from the 'single' model data num_models = len(data_item["single"]) # Check for API error in the 'pair' model data for i in range(num_models): for j in range(num_models): if check_api_error(data_item["pair"][i][j]): return True # No errors found return False