File size: 15,059 Bytes
22e1b62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
import csv
import logging
import os
import random

import nltk
import numpy as np
import pandas as pd
from config import (  # LOG_FILE,
    API_ERROR,
    IGNORE_BY_API_ERROR,
    SEED,
)
from datasets import load_dataset


def print_and_log(message: str):
    # TODO: redefine logging
    """
    Log message.

    Args:
        message (str): The message to be printed and logged.
    """
    logging.info(message)


def write_to_file(filename: str, content: str):
    """
    Writes the given content to a specified file.

    Args:
        filename (str): The path to the file to write content.
        content (str): The content to be written.
    """
    print(content)
    with open(filename, "a+", encoding="utf-8") as file:
        file.write(content)


def write_new_data(
    output_file: str,
    current_data: dict,
    column_names: list,
) -> None:
    """
    Writes a new row of data to a CSV file.

    Args:
        output_file (str): The path to the output CSV file.
        current_data (dict): A dictionary containing the data to be written.
        column_names (list): A list of column names in the desired order.

    Returns:
        None
    """
    # Extract data in the specified order based on column names
    data_row = [current_data[column] for column in column_names]

    # Write the data row to the CSV file
    write_to_csv(output_file, data_row)


def write_to_csv(filename: str, row_data: list) -> None:
    """
    Appends a row of data to a CSV file.

    Args:
        filename (str): The name of the CSV file.
        row_data: A list of values to be written as a row.

    Returns:
        None
    """
    # Open the CSV file in append mode, creating it if it doesn't exist
    with open(filename, "a+", encoding="UTF8", newline="") as file:
        writer = csv.writer(file)
        writer.writerow(row_data)


def count_csv_lines(filename: str) -> int:
    """Counts the number of lines in a CSV file, excluding the header row.

    Args:
        filename (str): The path to the CSV file.

    Returns:
        int: The number of lines in the CSV file, excluding the header row.
    """
    file_data = pd.read_csv(filename, sep=",").values
    return len(file_data)


def read_csv_data(input_file: str) -> np.ndarray:
    """
    Reads data from a specified CSV file.

    Args:
        file_path (str): The path to the CSV file.

    Returns:
        numpy.ndarray: The data from the CSV file.
    """
    file_data = pd.read_csv(
        input_file,
        dtype="string",
        keep_default_na=False,
        sep=",",
    ).values
    return file_data


def get_column(input_file: str, column_name: str) -> np.ndarray:
    """
    Retrieves a specific column from a CSV file as a NumPy array.

    Args:
        input_file (str): The path to the CSV file.
        column_name (str): The name of the column to extract.

    Returns:
        np.ndarray: Values from the specified column.
    """
    # Read CSV, preserving string data types and handling missing values
    df = pd.read_csv(
        input_file,
        dtype="string",
        keep_default_na=False,
        sep=",",
    )

    # Extract the specified column as a NumPy array
    column_data = df[column_name].values
    return column_data


def generate_column_names(categories: list) -> list:
    """
    Generates column names for a pairwise comparison matrix.

    Args:
        categories (list): A list of categories.

    Returns:
        list: A list of column names,
            including a 'human' column and pairwise combinations.
    """
    column_names = ["human"]

    # Add individual category names as column names
    column_names.extend(categories)

    # Add pairwise combinations of categories as column names
    for i in categories:
        for j in categories:
            column_names.append(f"{i}_{j}")

    # TODO: improve?
    # for i in range(len(categories)):
    #    for j in range(i + 1, len(categories)):
    #        column_names.append(f"{categories[i]}_{categories[j]}")

    return column_names


def normalize_text(input_text: str) -> str:
    """
    Normalizes the given text by removing unnecessary characters and
        formatting it for better readability.

    Args:
        input_text (str): The input text to be normalized.

    Returns:
        The normalized text.

    This function performs the following transformations:
        1. Strips leading and trailing whitespace
        2. Removes double asterisks (`**`)
        3. Replaces newlines with spaces
        4. Removes extra spaces
    """
    processed_text = input_text.strip()
    processed_text = processed_text.replace("**", "")
    processed_text = processed_text.replace("\n", " ")
    processed_text = processed_text.replace("  ", " ")  # Remove extra spaces
    # TODO: what if 3 or more spaces
    return processed_text


def refine_candidate_text(input_text: str, candidate_text: str) -> str:
    # TODO: how different with processing text
    """
    Removes specific surrounding marks from the candidate text if they are
    present in the input text with an excess of exactly two occurrences.

    Args:
        input_text (str): The original text.
        candidate (str): The candidate text to be refined.

    Returns:
        str: The refined candidate text.
    """

    # Create a copy of the candidate string and strip whitespace
    refined_candidate = candidate_text.strip()

    # Iterate through each mark
    for mark in ["```", "'", '"']:
        # Count occurrences of the mark in input_text and refined_candidate
        count_input_text = input_text.count(mark)
        count_refined_candidate = refined_candidate.count(mark)

        # Check if the mark should be stripped
        if (
            count_refined_candidate == count_input_text + 2
            and refined_candidate.startswith(mark)
            and refined_candidate.endswith(mark)
        ):
            # Strip the mark from both ends of the refined_candidate
            refined_candidate = refined_candidate.strip(mark)

    return refined_candidate


def generate_file_name(
    existing_data_file: str,
    existing_kinds: list,
    new_kinds: list,
) -> str:
    """
    Generates a new file name based on the path of an existing data file and a
    combination of existing and new kinds.

    Args:
        existing_data_file (str): The path to the existing data file.
        existing_kinds (list): A list of existing kinds.
        new_kinds (list): A list of new kinds.

    Returns:
        str: The generated file name with the full path.
    """

    # Combine existing and new kinds into a single list
    combined_kinds = existing_kinds + new_kinds

    # Get the directory path of the existing data file
    directory_path = os.path.dirname(existing_data_file)

    # Create a new file name by joining the kinds with underscores and adding
    # a suffix
    # TODO: move to config file
    new_file_name = "_".join(combined_kinds) + "_with_best_similarity.csv"

    # Combine the directory path with the new file name to get the full output
    # file path
    output_file_path = os.path.join(directory_path, new_file_name)

    return output_file_path


def shuffle(data: list[list], seed: int) -> None:
    """
    Shuffles the elements within each sublist of the given data structure.

    Args:
        data (list of lists): The array containing sublists to shuffle.
        seed (int): The seed value for the random number generator.

    Returns:
        None
    """
    for sublist in data:
        random.Random(seed).shuffle(sublist)


def generate_human_with_shuffle(
    dataset_name: str,
    column_name: str,
    num_samples: int,
    output_file: str,
) -> None:
    """
    Generates a shuffled list of sentences from the dataset and writes them to
    a CSV file.

    Args:
        dataset_name (str): The name of the dataset to load.
        column_name (str): The column name to extract sentences from.
        num_samples (int): The number of samples to process.
        output_file (str): The path to the output CSV file.

    Returns:
        None
    """
    # Load the dataset
    dataset = load_dataset(dataset_name)
    data = dataset["train"]

    lines = []
    # Tokenize sentences and add to the lines list
    for sample in data:
        nltk_tokens = nltk.sent_tokenize(sample[column_name])
        lines.extend(nltk_tokens)

    # Filter out empty lines
    filtered_lines = [line for line in lines if line != ""]
    lines = filtered_lines

    # Shuffle the lines
    shuffle([lines], seed=SEED)

    # Ensure the output file exists and write the header if it doesn't
    if not os.path.exists(output_file):
        header = ["human"]
        write_to_csv(output_file, header)

    # Get the number of lines already processed in the output file
    number_of_processed_lines = count_csv_lines(output_file)

    # Print the initial lines to be processed
    print(f"Lines before processing: {lines[:num_samples]}")

    # Slice the lines list to get the unprocessed lines
    lines = lines[number_of_processed_lines:num_samples]

    # Print the lines after slicing
    print(f"Lines after slicing: {lines}")

    # Process each line and write to the output file
    for index, human in enumerate(lines):
        normalized_text = normalize_text(human)
        output_data = [normalized_text]
        write_to_csv(output_file, output_data)
        print(
            f"Processed {index + 1} / {len(lines)};\
            Total processed:\
            {number_of_processed_lines + index + 1} / {num_samples}",
        )


def split_data(data: list, train_ratio: float) -> list[list, list]:
    """
    Splits a dataset into training and testing sets.

    Args:
        data (list): The input dataset.
        train_ratio (float): The proportion of data to use for training.

    Returns:
        The training and testing sets.
    """

    # Calculate the number of samples for training
    train_size = int(len(data) * train_ratio)

    # Split the data into training and testing sets
    train_data = data[:train_size]
    test_data = data[train_size:]

    return train_data, test_data


def combine_text_with_BERT_format(text_list: list[str]) -> str:
    """
    Formats a list of texts into a single string suitable for BERT input.

    Args:
        text_list (list[str]): A list of text strings.

    Returns:
        str: A single string formatted with BERT's special tokens.
    """
    # TODO: simplify this function
    # combined_text = f"<s>{text_list[0]}</s>"
    # for i in range(1, len(text_list)):
    #     combined_text += f"</s>{text_list[i]}</s>"
    # return combined_text

    formatted_text = "<s>" + "</s><s>".join(text_list) + "</s>"
    return formatted_text


def check_api_error(data: list):
    """
    Checks if the given data contains an API error or an indication to ignore
    an API error.

    Args:
        data (list): A list of items to check.

    Returns:
        bool: True if an API error or ignore indication is found,
            False otherwise.
    """
    for item in data:
        # Check for API error indicators
        if item in (API_ERROR, IGNORE_BY_API_ERROR):
            return True  # Return True if at least an error indicator is found
    return False  # Return False if no error indicators are found


def calculate_required_models(num_columns: int) -> int:
    """
    Calculates the minimum number of models required to generate the specified number of columns.

    Args:
        num_columns (int): The total number of columns to generate.

    Returns:
        int: The minimum number of models required.

    Raises:
        ValueError: If the number of columns cannot be achieved with the current model configuration.
    """

    num_models = 0
    count_human = 1  # Initial count representing human input

    # TODO: simplify this function
    while True:
        count_single = num_models  # Single model count
        count_pair = num_models * num_models  # Pair model count

        total_count = count_human + count_single + count_pair

        if total_count == num_columns:
            return num_models
        elif total_count > num_columns:
            raise Exception(
                "Cannot calculate the number of models to match the number of columns",  # noqa: E501
            )

        num_models += 1


def parse_multimodal_data(multimodel_csv_file: list) -> list:
    """
    Parses multimodal data from a CSV file into a structured format.

    Args:
        multimodel_csv_file (str): Path to the CSV file.

    Returns:
        list: A list of dictionaries, each containing 'human', 'single', and
        'pair' keys.

    Raises:
        Exception: If there is an error in reading the CSV file or processing
        the data.
    """
    # TODO: simplify this function

    # Read CSV data into a list of lists
    input_data = read_csv_data(multimodel_csv_file)

    # Initialize the result list
    structured_data = []

    # Calculate the number of models based on the number of columns in the first row  # noqa: E501
    num_models = calculate_required_models(len(input_data[0]))

    # Process each row in the input data
    for row in input_data:
        row_data = {}
        index = 0

        # Extract human data
        row_data["human"] = row[index]
        index += 1

        # Extract single model data
        single_model_data = []
        for _ in range(num_models):
            single_model_data.append(row[index])
            index += 1
        row_data["single"] = single_model_data

        # Extract pair model data
        pair_model_data = []
        for _ in range(num_models):
            sub_pair_data = []
            for _ in range(num_models):
                sub_pair_data.append(row[index])
                index += 1
            pair_model_data.append(sub_pair_data)
        row_data["pair"] = pair_model_data

        # Append the structured row data to the result list
        structured_data.append(row_data)

    return structured_data


def check_error(data_item: dict) -> bool:
    """
    Checks if the given data item contains any API errors.
    An API error is indicated by a specific error message
    or code within the text.

    Args:
        data_item (dict): A dictionary containing 'human', 'single',
            and 'pair' fields.

    Returns:
        bool: True if an API error is found, otherwise False.
    """
    # Check for API error in the 'human' field
    if check_api_error(data_item["human"]):
        return True

    # Check for API error in the 'single' model data
    for single_text in data_item["single"]:
        if check_api_error(single_text):
            return True

    # Get the number of models from the 'single' model data
    num_models = len(data_item["single"])

    # Check for API error in the 'pair' model data
    for i in range(num_models):
        for j in range(num_models):
            if check_api_error(data_item["pair"][i][j]):
                return True

    # No errors found
    return False