Spaces:
Sleeping
Sleeping
import nltk | |
import numpy as np | |
from config import metric | |
from utils import refine_candidate_text | |
from texts.bart_score import ( | |
bart_score, | |
check_bart_score, | |
) | |
def compute_metrics(evaluation_predictions): | |
""" | |
Function to compute evaluation metrics for model predictions. | |
Parameters: | |
evaluation_predictions (tuple): A tuple containing two elements: | |
- predictions (array-like): The raw prediction scores from the model. | |
- labels (array-like): The true labels for the evaluation data. | |
Returns: | |
dict: A dictionary containing the computed evaluation metrics. | |
""" | |
# Unpack predictions and labels from the input tuple | |
raw_predictions, true_labels = evaluation_predictions | |
# Convert raw prediction scores to predicted class labels | |
predicted_labels = np.argmax(raw_predictions, axis=1) | |
# Compute and return the evaluation metrics | |
return metric.compute( | |
prediction_scores=predicted_labels, | |
references=true_labels, | |
average="macro", | |
) | |
def extract_by_best_similarity(input_text, raw_text): | |
""" | |
Extracts the best candidate string from the raw text based on the highest | |
similarity score compared to the input text. The similarity score is | |
calculated using the BART score. | |
Args: | |
input_text (str): The original text. | |
raw_text (str): The raw text containing multiple candidate strings. | |
Returns: | |
str: The best candidate string with the highest similarity score. | |
Returns the input text if no suitable candidate is found. | |
""" | |
# Refine the raw text | |
refined_raw_text = refine_candidate_text(input_text, raw_text) | |
# Tokenize the refined raw text into sentences | |
raw_candidates = nltk.sent_tokenize(refined_raw_text) | |
# Split sentences further by newlines to get individual candidates | |
candidate_list = [] | |
for sentence in raw_candidates: | |
candidate_list.extend(sentence.split("\n")) | |
# Initialize variables to track the best similarity score | |
# and the best candidate | |
best_similarity = -9999 | |
best_candidate = "" | |
# Iterate over each candidate to find the best one based on the BART score | |
for candidate in candidate_list: | |
refined_candidate = refine_candidate_text(input_text, candidate) | |
if check_bart_score(input_text, refined_candidate): | |
score = bart_score(input_text, refined_candidate)[0] | |
if score > best_similarity: | |
best_similarity = score | |
best_candidate = refined_candidate | |
# Print the best candidate found | |
print(f"best_candidate = {best_candidate}") | |
# Return the best candidate if found, otherwise return the input text | |
if best_candidate == "": | |
return input_text | |
return best_candidate | |