Spaces:
Sleeping
Sleeping
import csv | |
import logging | |
import os | |
import random | |
import nltk | |
import numpy as np | |
import pandas as pd | |
from config import ( # LOG_FILE, | |
API_ERROR, | |
IGNORE_BY_API_ERROR, | |
SEED, | |
) | |
from datasets import load_dataset | |
def print_and_log(message: str): | |
# TODO: redefine logging | |
""" | |
Log message. | |
Args: | |
message (str): The message to be printed and logged. | |
""" | |
logging.info(message) | |
def write_to_file(filename: str, content: str): | |
""" | |
Writes the given content to a specified file. | |
Args: | |
filename (str): The path to the file to write content. | |
content (str): The content to be written. | |
""" | |
print(content) | |
with open(filename, "a+", encoding="utf-8") as file: | |
file.write(content) | |
def write_new_data( | |
output_file: str, | |
current_data: dict, | |
column_names: list, | |
) -> None: | |
""" | |
Writes a new row of data to a CSV file. | |
Args: | |
output_file (str): The path to the output CSV file. | |
current_data (dict): A dictionary containing the data to be written. | |
column_names (list): A list of column names in the desired order. | |
Returns: | |
None | |
""" | |
# Extract data in the specified order based on column names | |
data_row = [current_data[column] for column in column_names] | |
# Write the data row to the CSV file | |
write_to_csv(output_file, data_row) | |
def write_to_csv(filename: str, row_data: list) -> None: | |
""" | |
Appends a row of data to a CSV file. | |
Args: | |
filename (str): The name of the CSV file. | |
row_data: A list of values to be written as a row. | |
Returns: | |
None | |
""" | |
# Open the CSV file in append mode, creating it if it doesn't exist | |
with open(filename, "a+", encoding="UTF8", newline="") as file: | |
writer = csv.writer(file) | |
writer.writerow(row_data) | |
def count_csv_lines(filename: str) -> int: | |
"""Counts the number of lines in a CSV file, excluding the header row. | |
Args: | |
filename (str): The path to the CSV file. | |
Returns: | |
int: The number of lines in the CSV file, excluding the header row. | |
""" | |
file_data = pd.read_csv(filename, sep=",").values | |
return len(file_data) | |
def read_csv_data(input_file: str) -> np.ndarray: | |
""" | |
Reads data from a specified CSV file. | |
Args: | |
file_path (str): The path to the CSV file. | |
Returns: | |
numpy.ndarray: The data from the CSV file. | |
""" | |
file_data = pd.read_csv( | |
input_file, | |
dtype="string", | |
keep_default_na=False, | |
sep=",", | |
).values | |
return file_data | |
def get_column(input_file: str, column_name: str) -> np.ndarray: | |
""" | |
Retrieves a specific column from a CSV file as a NumPy array. | |
Args: | |
input_file (str): The path to the CSV file. | |
column_name (str): The name of the column to extract. | |
Returns: | |
np.ndarray: Values from the specified column. | |
""" | |
# Read CSV, preserving string data types and handling missing values | |
df = pd.read_csv( | |
input_file, | |
dtype="string", | |
keep_default_na=False, | |
sep=",", | |
) | |
# Extract the specified column as a NumPy array | |
column_data = df[column_name].values | |
return column_data | |
def generate_column_names(categories: list) -> list: | |
""" | |
Generates column names for a pairwise comparison matrix. | |
Args: | |
categories (list): A list of categories. | |
Returns: | |
list: A list of column names, | |
including a 'human' column and pairwise combinations. | |
""" | |
column_names = ["human"] | |
# Add individual category names as column names | |
column_names.extend(categories) | |
# Add pairwise combinations of categories as column names | |
for i in categories: | |
for j in categories: | |
column_names.append(f"{i}_{j}") | |
# TODO: improve? | |
# for i in range(len(categories)): | |
# for j in range(i + 1, len(categories)): | |
# column_names.append(f"{categories[i]}_{categories[j]}") | |
return column_names | |
def normalize_text(input_text: str) -> str: | |
""" | |
Normalizes the given text by removing unnecessary characters and | |
formatting it for better readability. | |
Args: | |
input_text (str): The input text to be normalized. | |
Returns: | |
The normalized text. | |
This function performs the following transformations: | |
1. Strips leading and trailing whitespace | |
2. Removes double asterisks (`**`) | |
3. Replaces newlines with spaces | |
4. Removes extra spaces | |
""" | |
processed_text = input_text.strip() | |
processed_text = processed_text.replace("**", "") | |
processed_text = processed_text.replace("\n", " ") | |
processed_text = processed_text.replace(" ", " ") # Remove extra spaces | |
# TODO: what if 3 or more spaces | |
return processed_text | |
def refine_candidate_text(input_text: str, candidate_text: str) -> str: | |
# TODO: how different with processing text | |
""" | |
Removes specific surrounding marks from the candidate text if they are | |
present in the input text with an excess of exactly two occurrences. | |
Args: | |
input_text (str): The original text. | |
candidate (str): The candidate text to be refined. | |
Returns: | |
str: The refined candidate text. | |
""" | |
# Create a copy of the candidate string and strip whitespace | |
refined_candidate = candidate_text.strip() | |
# Iterate through each mark | |
for mark in ["```", "'", '"']: | |
# Count occurrences of the mark in input_text and refined_candidate | |
count_input_text = input_text.count(mark) | |
count_refined_candidate = refined_candidate.count(mark) | |
# Check if the mark should be stripped | |
if ( | |
count_refined_candidate == count_input_text + 2 | |
and refined_candidate.startswith(mark) | |
and refined_candidate.endswith(mark) | |
): | |
# Strip the mark from both ends of the refined_candidate | |
refined_candidate = refined_candidate.strip(mark) | |
return refined_candidate | |
def generate_file_name( | |
existing_data_file: str, | |
existing_kinds: list, | |
new_kinds: list, | |
) -> str: | |
""" | |
Generates a new file name based on the path of an existing data file and a | |
combination of existing and new kinds. | |
Args: | |
existing_data_file (str): The path to the existing data file. | |
existing_kinds (list): A list of existing kinds. | |
new_kinds (list): A list of new kinds. | |
Returns: | |
str: The generated file name with the full path. | |
""" | |
# Combine existing and new kinds into a single list | |
combined_kinds = existing_kinds + new_kinds | |
# Get the directory path of the existing data file | |
directory_path = os.path.dirname(existing_data_file) | |
# Create a new file name by joining the kinds with underscores and adding | |
# a suffix | |
# TODO: move to config file | |
new_file_name = "_".join(combined_kinds) + "_with_best_similarity.csv" | |
# Combine the directory path with the new file name to get the full output | |
# file path | |
output_file_path = os.path.join(directory_path, new_file_name) | |
return output_file_path | |
def shuffle(data: list[list], seed: int) -> None: | |
""" | |
Shuffles the elements within each sublist of the given data structure. | |
Args: | |
data (list of lists): The array containing sublists to shuffle. | |
seed (int): The seed value for the random number generator. | |
Returns: | |
None | |
""" | |
for sublist in data: | |
random.Random(seed).shuffle(sublist) | |
def generate_human_with_shuffle( | |
dataset_name: str, | |
column_name: str, | |
num_samples: int, | |
output_file: str, | |
) -> None: | |
""" | |
Generates a shuffled list of sentences from the dataset and writes them to | |
a CSV file. | |
Args: | |
dataset_name (str): The name of the dataset to load. | |
column_name (str): The column name to extract sentences from. | |
num_samples (int): The number of samples to process. | |
output_file (str): The path to the output CSV file. | |
Returns: | |
None | |
""" | |
# Load the dataset | |
dataset = load_dataset(dataset_name) | |
data = dataset["train"] | |
lines = [] | |
# Tokenize sentences and add to the lines list | |
for sample in data: | |
nltk_tokens = nltk.sent_tokenize(sample[column_name]) | |
lines.extend(nltk_tokens) | |
# Filter out empty lines | |
filtered_lines = [line for line in lines if line != ""] | |
lines = filtered_lines | |
# Shuffle the lines | |
shuffle([lines], seed=SEED) | |
# Ensure the output file exists and write the header if it doesn't | |
if not os.path.exists(output_file): | |
header = ["human"] | |
write_to_csv(output_file, header) | |
# Get the number of lines already processed in the output file | |
number_of_processed_lines = count_csv_lines(output_file) | |
# Print the initial lines to be processed | |
print(f"Lines before processing: {lines[:num_samples]}") | |
# Slice the lines list to get the unprocessed lines | |
lines = lines[number_of_processed_lines:num_samples] | |
# Print the lines after slicing | |
print(f"Lines after slicing: {lines}") | |
# Process each line and write to the output file | |
for index, human in enumerate(lines): | |
normalized_text = normalize_text(human) | |
output_data = [normalized_text] | |
write_to_csv(output_file, output_data) | |
print( | |
f"Processed {index + 1} / {len(lines)};\ | |
Total processed:\ | |
{number_of_processed_lines + index + 1} / {num_samples}", | |
) | |
def split_data(data: list, train_ratio: float) -> list[list, list]: | |
""" | |
Splits a dataset into training and testing sets. | |
Args: | |
data (list): The input dataset. | |
train_ratio (float): The proportion of data to use for training. | |
Returns: | |
The training and testing sets. | |
""" | |
# Calculate the number of samples for training | |
train_size = int(len(data) * train_ratio) | |
# Split the data into training and testing sets | |
train_data = data[:train_size] | |
test_data = data[train_size:] | |
return train_data, test_data | |
def combine_text_with_BERT_format(text_list: list[str]) -> str: | |
""" | |
Formats a list of texts into a single string suitable for BERT input. | |
Args: | |
text_list (list[str]): A list of text strings. | |
Returns: | |
str: A single string formatted with BERT's special tokens. | |
""" | |
# TODO: simplify this function | |
# combined_text = f"<s>{text_list[0]}</s>" | |
# for i in range(1, len(text_list)): | |
# combined_text += f"</s>{text_list[i]}</s>" | |
# return combined_text | |
formatted_text = "<s>" + "</s><s>".join(text_list) + "</s>" | |
return formatted_text | |
def check_api_error(data: list): | |
""" | |
Checks if the given data contains an API error or an indication to ignore | |
an API error. | |
Args: | |
data (list): A list of items to check. | |
Returns: | |
bool: True if an API error or ignore indication is found, | |
False otherwise. | |
""" | |
for item in data: | |
# Check for API error indicators | |
if item in (API_ERROR, IGNORE_BY_API_ERROR): | |
return True # Return True if at least an error indicator is found | |
return False # Return False if no error indicators are found | |
def calculate_required_models(num_columns: int) -> int: | |
""" | |
Calculates the minimum number of models required to generate the specified number of columns. | |
Args: | |
num_columns (int): The total number of columns to generate. | |
Returns: | |
int: The minimum number of models required. | |
Raises: | |
ValueError: If the number of columns cannot be achieved with the current model configuration. | |
""" | |
num_models = 0 | |
count_human = 1 # Initial count representing human input | |
# TODO: simplify this function | |
while True: | |
count_single = num_models # Single model count | |
count_pair = num_models * num_models # Pair model count | |
total_count = count_human + count_single + count_pair | |
if total_count == num_columns: | |
return num_models | |
elif total_count > num_columns: | |
raise Exception( | |
"Cannot calculate the number of models to match the number of columns", # noqa: E501 | |
) | |
num_models += 1 | |
def parse_multimodal_data(multimodel_csv_file: list) -> list: | |
""" | |
Parses multimodal data from a CSV file into a structured format. | |
Args: | |
multimodel_csv_file (str): Path to the CSV file. | |
Returns: | |
list: A list of dictionaries, each containing 'human', 'single', and | |
'pair' keys. | |
Raises: | |
Exception: If there is an error in reading the CSV file or processing | |
the data. | |
""" | |
# TODO: simplify this function | |
# Read CSV data into a list of lists | |
input_data = read_csv_data(multimodel_csv_file) | |
# Initialize the result list | |
structured_data = [] | |
# Calculate the number of models based on the number of columns in the first row # noqa: E501 | |
num_models = calculate_required_models(len(input_data[0])) | |
# Process each row in the input data | |
for row in input_data: | |
row_data = {} | |
index = 0 | |
# Extract human data | |
row_data["human"] = row[index] | |
index += 1 | |
# Extract single model data | |
single_model_data = [] | |
for _ in range(num_models): | |
single_model_data.append(row[index]) | |
index += 1 | |
row_data["single"] = single_model_data | |
# Extract pair model data | |
pair_model_data = [] | |
for _ in range(num_models): | |
sub_pair_data = [] | |
for _ in range(num_models): | |
sub_pair_data.append(row[index]) | |
index += 1 | |
pair_model_data.append(sub_pair_data) | |
row_data["pair"] = pair_model_data | |
# Append the structured row data to the result list | |
structured_data.append(row_data) | |
return structured_data | |
def check_error(data_item: dict) -> bool: | |
""" | |
Checks if the given data item contains any API errors. | |
An API error is indicated by a specific error message | |
or code within the text. | |
Args: | |
data_item (dict): A dictionary containing 'human', 'single', | |
and 'pair' fields. | |
Returns: | |
bool: True if an API error is found, otherwise False. | |
""" | |
# Check for API error in the 'human' field | |
if check_api_error(data_item["human"]): | |
return True | |
# Check for API error in the 'single' model data | |
for single_text in data_item["single"]: | |
if check_api_error(single_text): | |
return True | |
# Get the number of models from the 'single' model data | |
num_models = len(data_item["single"]) | |
# Check for API error in the 'pair' model data | |
for i in range(num_models): | |
for j in range(num_models): | |
if check_api_error(data_item["pair"][i][j]): | |
return True | |
# No errors found | |
return False | |