pmkhanh7890's picture
1st
22e1b62
raw
history blame
15.1 kB
import csv
import logging
import os
import random
import nltk
import numpy as np
import pandas as pd
from config import ( # LOG_FILE,
API_ERROR,
IGNORE_BY_API_ERROR,
SEED,
)
from datasets import load_dataset
def print_and_log(message: str):
# TODO: redefine logging
"""
Log message.
Args:
message (str): The message to be printed and logged.
"""
logging.info(message)
def write_to_file(filename: str, content: str):
"""
Writes the given content to a specified file.
Args:
filename (str): The path to the file to write content.
content (str): The content to be written.
"""
print(content)
with open(filename, "a+", encoding="utf-8") as file:
file.write(content)
def write_new_data(
output_file: str,
current_data: dict,
column_names: list,
) -> None:
"""
Writes a new row of data to a CSV file.
Args:
output_file (str): The path to the output CSV file.
current_data (dict): A dictionary containing the data to be written.
column_names (list): A list of column names in the desired order.
Returns:
None
"""
# Extract data in the specified order based on column names
data_row = [current_data[column] for column in column_names]
# Write the data row to the CSV file
write_to_csv(output_file, data_row)
def write_to_csv(filename: str, row_data: list) -> None:
"""
Appends a row of data to a CSV file.
Args:
filename (str): The name of the CSV file.
row_data: A list of values to be written as a row.
Returns:
None
"""
# Open the CSV file in append mode, creating it if it doesn't exist
with open(filename, "a+", encoding="UTF8", newline="") as file:
writer = csv.writer(file)
writer.writerow(row_data)
def count_csv_lines(filename: str) -> int:
"""Counts the number of lines in a CSV file, excluding the header row.
Args:
filename (str): The path to the CSV file.
Returns:
int: The number of lines in the CSV file, excluding the header row.
"""
file_data = pd.read_csv(filename, sep=",").values
return len(file_data)
def read_csv_data(input_file: str) -> np.ndarray:
"""
Reads data from a specified CSV file.
Args:
file_path (str): The path to the CSV file.
Returns:
numpy.ndarray: The data from the CSV file.
"""
file_data = pd.read_csv(
input_file,
dtype="string",
keep_default_na=False,
sep=",",
).values
return file_data
def get_column(input_file: str, column_name: str) -> np.ndarray:
"""
Retrieves a specific column from a CSV file as a NumPy array.
Args:
input_file (str): The path to the CSV file.
column_name (str): The name of the column to extract.
Returns:
np.ndarray: Values from the specified column.
"""
# Read CSV, preserving string data types and handling missing values
df = pd.read_csv(
input_file,
dtype="string",
keep_default_na=False,
sep=",",
)
# Extract the specified column as a NumPy array
column_data = df[column_name].values
return column_data
def generate_column_names(categories: list) -> list:
"""
Generates column names for a pairwise comparison matrix.
Args:
categories (list): A list of categories.
Returns:
list: A list of column names,
including a 'human' column and pairwise combinations.
"""
column_names = ["human"]
# Add individual category names as column names
column_names.extend(categories)
# Add pairwise combinations of categories as column names
for i in categories:
for j in categories:
column_names.append(f"{i}_{j}")
# TODO: improve?
# for i in range(len(categories)):
# for j in range(i + 1, len(categories)):
# column_names.append(f"{categories[i]}_{categories[j]}")
return column_names
def normalize_text(input_text: str) -> str:
"""
Normalizes the given text by removing unnecessary characters and
formatting it for better readability.
Args:
input_text (str): The input text to be normalized.
Returns:
The normalized text.
This function performs the following transformations:
1. Strips leading and trailing whitespace
2. Removes double asterisks (`**`)
3. Replaces newlines with spaces
4. Removes extra spaces
"""
processed_text = input_text.strip()
processed_text = processed_text.replace("**", "")
processed_text = processed_text.replace("\n", " ")
processed_text = processed_text.replace(" ", " ") # Remove extra spaces
# TODO: what if 3 or more spaces
return processed_text
def refine_candidate_text(input_text: str, candidate_text: str) -> str:
# TODO: how different with processing text
"""
Removes specific surrounding marks from the candidate text if they are
present in the input text with an excess of exactly two occurrences.
Args:
input_text (str): The original text.
candidate (str): The candidate text to be refined.
Returns:
str: The refined candidate text.
"""
# Create a copy of the candidate string and strip whitespace
refined_candidate = candidate_text.strip()
# Iterate through each mark
for mark in ["```", "'", '"']:
# Count occurrences of the mark in input_text and refined_candidate
count_input_text = input_text.count(mark)
count_refined_candidate = refined_candidate.count(mark)
# Check if the mark should be stripped
if (
count_refined_candidate == count_input_text + 2
and refined_candidate.startswith(mark)
and refined_candidate.endswith(mark)
):
# Strip the mark from both ends of the refined_candidate
refined_candidate = refined_candidate.strip(mark)
return refined_candidate
def generate_file_name(
existing_data_file: str,
existing_kinds: list,
new_kinds: list,
) -> str:
"""
Generates a new file name based on the path of an existing data file and a
combination of existing and new kinds.
Args:
existing_data_file (str): The path to the existing data file.
existing_kinds (list): A list of existing kinds.
new_kinds (list): A list of new kinds.
Returns:
str: The generated file name with the full path.
"""
# Combine existing and new kinds into a single list
combined_kinds = existing_kinds + new_kinds
# Get the directory path of the existing data file
directory_path = os.path.dirname(existing_data_file)
# Create a new file name by joining the kinds with underscores and adding
# a suffix
# TODO: move to config file
new_file_name = "_".join(combined_kinds) + "_with_best_similarity.csv"
# Combine the directory path with the new file name to get the full output
# file path
output_file_path = os.path.join(directory_path, new_file_name)
return output_file_path
def shuffle(data: list[list], seed: int) -> None:
"""
Shuffles the elements within each sublist of the given data structure.
Args:
data (list of lists): The array containing sublists to shuffle.
seed (int): The seed value for the random number generator.
Returns:
None
"""
for sublist in data:
random.Random(seed).shuffle(sublist)
def generate_human_with_shuffle(
dataset_name: str,
column_name: str,
num_samples: int,
output_file: str,
) -> None:
"""
Generates a shuffled list of sentences from the dataset and writes them to
a CSV file.
Args:
dataset_name (str): The name of the dataset to load.
column_name (str): The column name to extract sentences from.
num_samples (int): The number of samples to process.
output_file (str): The path to the output CSV file.
Returns:
None
"""
# Load the dataset
dataset = load_dataset(dataset_name)
data = dataset["train"]
lines = []
# Tokenize sentences and add to the lines list
for sample in data:
nltk_tokens = nltk.sent_tokenize(sample[column_name])
lines.extend(nltk_tokens)
# Filter out empty lines
filtered_lines = [line for line in lines if line != ""]
lines = filtered_lines
# Shuffle the lines
shuffle([lines], seed=SEED)
# Ensure the output file exists and write the header if it doesn't
if not os.path.exists(output_file):
header = ["human"]
write_to_csv(output_file, header)
# Get the number of lines already processed in the output file
number_of_processed_lines = count_csv_lines(output_file)
# Print the initial lines to be processed
print(f"Lines before processing: {lines[:num_samples]}")
# Slice the lines list to get the unprocessed lines
lines = lines[number_of_processed_lines:num_samples]
# Print the lines after slicing
print(f"Lines after slicing: {lines}")
# Process each line and write to the output file
for index, human in enumerate(lines):
normalized_text = normalize_text(human)
output_data = [normalized_text]
write_to_csv(output_file, output_data)
print(
f"Processed {index + 1} / {len(lines)};\
Total processed:\
{number_of_processed_lines + index + 1} / {num_samples}",
)
def split_data(data: list, train_ratio: float) -> list[list, list]:
"""
Splits a dataset into training and testing sets.
Args:
data (list): The input dataset.
train_ratio (float): The proportion of data to use for training.
Returns:
The training and testing sets.
"""
# Calculate the number of samples for training
train_size = int(len(data) * train_ratio)
# Split the data into training and testing sets
train_data = data[:train_size]
test_data = data[train_size:]
return train_data, test_data
def combine_text_with_BERT_format(text_list: list[str]) -> str:
"""
Formats a list of texts into a single string suitable for BERT input.
Args:
text_list (list[str]): A list of text strings.
Returns:
str: A single string formatted with BERT's special tokens.
"""
# TODO: simplify this function
# combined_text = f"<s>{text_list[0]}</s>"
# for i in range(1, len(text_list)):
# combined_text += f"</s>{text_list[i]}</s>"
# return combined_text
formatted_text = "<s>" + "</s><s>".join(text_list) + "</s>"
return formatted_text
def check_api_error(data: list):
"""
Checks if the given data contains an API error or an indication to ignore
an API error.
Args:
data (list): A list of items to check.
Returns:
bool: True if an API error or ignore indication is found,
False otherwise.
"""
for item in data:
# Check for API error indicators
if item in (API_ERROR, IGNORE_BY_API_ERROR):
return True # Return True if at least an error indicator is found
return False # Return False if no error indicators are found
def calculate_required_models(num_columns: int) -> int:
"""
Calculates the minimum number of models required to generate the specified number of columns.
Args:
num_columns (int): The total number of columns to generate.
Returns:
int: The minimum number of models required.
Raises:
ValueError: If the number of columns cannot be achieved with the current model configuration.
"""
num_models = 0
count_human = 1 # Initial count representing human input
# TODO: simplify this function
while True:
count_single = num_models # Single model count
count_pair = num_models * num_models # Pair model count
total_count = count_human + count_single + count_pair
if total_count == num_columns:
return num_models
elif total_count > num_columns:
raise Exception(
"Cannot calculate the number of models to match the number of columns", # noqa: E501
)
num_models += 1
def parse_multimodal_data(multimodel_csv_file: list) -> list:
"""
Parses multimodal data from a CSV file into a structured format.
Args:
multimodel_csv_file (str): Path to the CSV file.
Returns:
list: A list of dictionaries, each containing 'human', 'single', and
'pair' keys.
Raises:
Exception: If there is an error in reading the CSV file or processing
the data.
"""
# TODO: simplify this function
# Read CSV data into a list of lists
input_data = read_csv_data(multimodel_csv_file)
# Initialize the result list
structured_data = []
# Calculate the number of models based on the number of columns in the first row # noqa: E501
num_models = calculate_required_models(len(input_data[0]))
# Process each row in the input data
for row in input_data:
row_data = {}
index = 0
# Extract human data
row_data["human"] = row[index]
index += 1
# Extract single model data
single_model_data = []
for _ in range(num_models):
single_model_data.append(row[index])
index += 1
row_data["single"] = single_model_data
# Extract pair model data
pair_model_data = []
for _ in range(num_models):
sub_pair_data = []
for _ in range(num_models):
sub_pair_data.append(row[index])
index += 1
pair_model_data.append(sub_pair_data)
row_data["pair"] = pair_model_data
# Append the structured row data to the result list
structured_data.append(row_data)
return structured_data
def check_error(data_item: dict) -> bool:
"""
Checks if the given data item contains any API errors.
An API error is indicated by a specific error message
or code within the text.
Args:
data_item (dict): A dictionary containing 'human', 'single',
and 'pair' fields.
Returns:
bool: True if an API error is found, otherwise False.
"""
# Check for API error in the 'human' field
if check_api_error(data_item["human"]):
return True
# Check for API error in the 'single' model data
for single_text in data_item["single"]:
if check_api_error(single_text):
return True
# Get the number of models from the 'single' model data
num_models = len(data_item["single"])
# Check for API error in the 'pair' model data
for i in range(num_models):
for j in range(num_models):
if check_api_error(data_item["pair"][i][j]):
return True
# No errors found
return False