Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
"""Yet another copy of MCQ, Toxic, Bias.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1_4-bS633DBVMc5-jBLCmyUaXzAi5RL6f | |
#MCQ Generation Using T5 | |
""" | |
# mcq_generator.py (corrected) | |
import nltk | |
import random | |
import re | |
import tempfile | |
import torch | |
import spacy | |
import pandas as pd | |
import numpy as np | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer, AutoModelForQuestionAnswering, AutoTokenizer | |
from nltk.corpus import wordnet as wn | |
from nltk.corpus import stopwords | |
from nltk import pos_tag, word_tokenize | |
from sentence_transformers import SentenceTransformer, util | |
from rouge import Rouge | |
# ❌ DO NOT include: | |
# import matplotlib.pyplot as plt | |
# from IPython.display import display | |
# Download required NLTK packages | |
nltk.download('punkt') | |
nltk.download('averaged_perceptron_tagger_eng') | |
nltk.download('wordnet') | |
nltk.download('stopwords') | |
nltk.download('punkt_tab') | |
# Load Safety Models | |
toxicity_model = pipeline("text-classification", model="unitary/toxic-bert") | |
bias_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
# Enhanced Safety check function with comprehensive bias detection | |
def is_suitable_for_students(text): | |
"""Comprehensive content check for appropriateness in educational settings""" | |
text = text.strip() | |
if not text: | |
print("⚠️ Empty paragraph provided.") | |
return False | |
# Check for text length | |
if len(text.split()) < 20: | |
print("⚠️ Text too short for meaningful MCQ generation.") | |
return False | |
# Check Toxicity | |
toxicity = toxicity_model(text[:512])[0] | |
tox_label, tox_score = toxicity['label'].lower(), toxicity['score'] | |
# COMPREHENSIVE BIAS DETECTION | |
# 1. Check for gender bias | |
gender_bias_keywords = [ | |
"women are", "men are", "boys are", "girls are", | |
"females are", "males are", "better at", "worse at", | |
"naturally better", "suited for", "belong in", | |
"should be", "can't do", "always", "never" | |
] | |
# 2. Check for racial bias | |
racial_bias_keywords = [ | |
"race", "racial", "racist", "ethnicity", "ethnic", | |
"black people", "white people", "asian people", "latinos", | |
"minorities", "majority", "immigrants", "foreigners" | |
] | |
# 3. Check for political bias | |
political_bias_keywords = [ | |
"liberal", "conservative", "democrat", "republican", | |
"left-wing", "right-wing", "socialism", "capitalism", | |
"government", "politician", "corrupt", "freedom", "rights", | |
"policy", "policies", "taxes", "taxation" | |
] | |
# 4. Check for religious bias | |
religious_bias_keywords = [ | |
"christian", "muslim", "jewish", "hindu", "buddhist", | |
"atheist", "religion", "religious", "faith", "belief", | |
"worship", "sacred", "holy" | |
] | |
# 5. Check for socioeconomic bias | |
socioeconomic_bias_keywords = [ | |
"poor", "rich", "wealthy", "poverty", "privileged", | |
"underprivileged", "class", "elite", "welfare", "lazy", | |
"hardworking", "deserve", "entitled" | |
] | |
# Combined bias keywords | |
all_bias_keywords = (gender_bias_keywords + racial_bias_keywords + | |
political_bias_keywords + religious_bias_keywords + | |
socioeconomic_bias_keywords) | |
# Additional problematic generalizations | |
problematic_phrases = [ | |
"more aggressive", "less educated", "less intelligent", "more violent", | |
"inferior", "superior", "better", "smarter", "worse", "dumber", | |
"tend to be more", "tend to be less", "are naturally", "by nature", | |
"all people", "those people", "these people", "that group", | |
"always", "never", "inherently", "genetically" | |
] | |
# Check if any bias keywords are present | |
contains_bias_keywords = any(keyword in text.lower() for keyword in all_bias_keywords) | |
contains_problematic_phrases = any(phrase in text.lower() for phrase in problematic_phrases) | |
# Advanced bias detection using BART model | |
# Use both general and specific bias detection sets | |
general_bias_labels = ["neutral", "biased", "discriminatory", "prejudiced", "stereotyping"] | |
gender_bias_labels = ["gender neutral", "gender biased", "sexist"] | |
racial_bias_labels = ["racially neutral", "racially biased", "racist"] | |
political_bias_labels = ["politically neutral", "politically biased", "partisan"] | |
# Run general bias detection first | |
bias_result = bias_model(text[:512], candidate_labels=general_bias_labels) | |
bias_label = bias_result['labels'][0].lower() | |
bias_score = bias_result['scores'][0] | |
# If general check is uncertain, run more specific checks | |
if bias_score < 0.7 and contains_bias_keywords: | |
# Determine which specific bias check to run | |
if any(keyword in text.lower() for keyword in gender_bias_keywords): | |
specific_result = bias_model(text[:512], candidate_labels=gender_bias_labels) | |
if specific_result['labels'][0] != gender_bias_labels[0] and specific_result['scores'][0] > 0.6: | |
bias_label = "gender biased" | |
bias_score = specific_result['scores'][0] | |
if any(keyword in text.lower() for keyword in racial_bias_keywords): | |
specific_result = bias_model(text[:512], candidate_labels=racial_bias_labels) | |
if specific_result['labels'][0] != racial_bias_labels[0] and specific_result['scores'][0] > 0.6: | |
bias_label = "racially biased" | |
bias_score = specific_result['scores'][0] | |
if any(keyword in text.lower() for keyword in political_bias_keywords): | |
specific_result = bias_model(text[:512], candidate_labels=political_bias_labels) | |
if specific_result['labels'][0] != political_bias_labels[0] and specific_result['scores'][0] > 0.6: | |
bias_label = "politically biased" | |
bias_score = specific_result['scores'][0] | |
# Set appropriate thresholds | |
bias_threshold = 0.55 # Lower to catch more subtle bias | |
toxicity_threshold = 0.60 | |
# Decision logic with detailed reporting | |
if tox_label == "toxic" and tox_score > toxicity_threshold: | |
print(f"⚠️ Toxicity Detected ({tox_score:.2f}) — ❌ Not Suitable for Students") | |
return False | |
elif bias_label in ["biased", "discriminatory", "prejudiced", "stereotyping", | |
"gender biased", "racially biased", "politically biased"] and bias_score > bias_threshold: | |
print(f"⚠️ {bias_label.title()} Content Detected ({bias_score:.2f}) — ❌ Not Suitable for Students") | |
return False | |
elif contains_problematic_phrases: | |
print(f"⚠️ Problematic Generalizations Detected — ❌ Not Suitable for Students") | |
return False | |
else: | |
print(f"✅ Passed Safety Check — 🟢 Proceeding to Generate MCQs") | |
return True | |
class ImprovedMCQGenerator: | |
def __init__(self): | |
# Initialize QG-specific model for better question generation | |
self.qg_model_name = "lmqg/t5-base-squad-qg" # Specialized question generation model | |
try: | |
self.qg_tokenizer = AutoTokenizer.from_pretrained(self.qg_model_name) | |
self.qg_model = AutoModelForSeq2SeqLM.from_pretrained(self.qg_model_name) | |
self.has_qg_model = True | |
except: | |
# Fall back to T5 if specialized model fails to load | |
self.has_qg_model = False | |
print("Could not load specialized QG model, falling back to T5") | |
# Initialize T5 model for distractors and fallback question generation | |
self.t5_model_name = "google/flan-t5-base" # Using base model for better quality | |
self.t5_tokenizer = T5Tokenizer.from_pretrained(self.t5_model_name) | |
self.t5_model = T5ForConditionalGeneration.from_pretrained(self.t5_model_name) | |
# Configuration | |
self.max_length = 128 | |
self.stop_words = set(stopwords.words('english')) | |
def clean_text(self, text): | |
"""Clean and normalize text""" | |
text = re.sub(r'\s+', ' ', text) # Remove extra whitespace | |
text = text.strip() | |
return text | |
def generate_question(self, context, answer): | |
"""Generate a question given a context and answer using specialized QG model""" | |
# Find the sentence containing the answer for better context | |
sentences = sent_tokenize(context) | |
relevant_sentences = [] | |
for sentence in sentences: | |
if answer.lower() in sentence.lower(): | |
relevant_sentences.append(sentence) | |
if not relevant_sentences: | |
# If answer not found in any sentence, use a random sentence | |
if sentences: | |
relevant_sentences = [random.choice(sentences)] | |
else: | |
relevant_sentences = [context] | |
# Use up to 3 sentences for context (the sentence with answer + neighbors) | |
if len(relevant_sentences) == 1 and len(sentences) > 1: | |
# Find the index of the relevant sentence | |
idx = sentences.index(relevant_sentences[0]) | |
if idx > 0: | |
relevant_sentences.append(sentences[idx-1]) | |
if idx < len(sentences) - 1: | |
relevant_sentences.append(sentences[idx+1]) | |
# Join the relevant sentences | |
focused_context = ' '.join(relevant_sentences) | |
if self.has_qg_model: | |
# Use specialized QG model | |
input_text = f"answer: {answer} context: {focused_context}" | |
inputs = self.qg_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
outputs = self.qg_model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_length=self.max_length, | |
num_beams=5, | |
top_k=120, | |
top_p=0.95, | |
temperature=1.0, | |
do_sample=True, | |
num_return_sequences=3, | |
no_repeat_ngram_size=2 | |
) | |
# Get multiple questions and pick the best one | |
questions = [self.qg_tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
valid_questions = [q for q in questions if q.endswith('?') and answer.lower() not in q.lower()] | |
if valid_questions: | |
return self.clean_text(valid_questions[0]) | |
# Fallback to T5 model if specialized model fails or isn't available | |
input_text = f"generate question for answer: {answer} from context: {focused_context}" | |
inputs = self.t5_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
outputs = self.t5_model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_length=self.max_length, | |
num_beams=5, | |
top_k=120, | |
top_p=0.95, | |
temperature=1.0, | |
do_sample=True, | |
num_return_sequences=3, | |
no_repeat_ngram_size=2 | |
) | |
questions = [self.t5_tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
# Clean and validate questions | |
valid_questions = [] | |
for q in questions: | |
# Format the question properly | |
q = self.clean_text(q) | |
if not q.endswith('?'): | |
q += '?' | |
# Avoid questions that contain the answer directly | |
if answer.lower() not in q.lower(): | |
valid_questions.append(q) | |
if valid_questions: | |
return valid_questions[0] | |
# If all else fails, create a simple question | |
return f"Which of the following best describes {answer}?" | |
def extract_key_entities(self, text, n=8): | |
"""Extract key entities from text that would make good answers""" | |
# Tokenize and get POS tags | |
sentences = sent_tokenize(text) | |
# Get noun phrases and named entities | |
key_entities = [] | |
for sentence in sentences: | |
words = word_tokenize(sentence) | |
pos_tags = pos_tag(words) | |
# Extract noun phrases (consecutive nouns and adjectives) | |
i = 0 | |
while i < len(pos_tags): | |
if pos_tags[i][1].startswith('NN') or pos_tags[i][1].startswith('JJ'): | |
phrase = pos_tags[i][0] | |
j = i + 1 | |
while j < len(pos_tags) and (pos_tags[j][1].startswith('NN') or pos_tags[j][1] == 'JJ'): | |
phrase += ' ' + pos_tags[j][0] | |
j += 1 | |
if len(phrase.split()) >= 1 and not all(w.lower() in self.stop_words for w in phrase.split()): | |
key_entities.append(phrase) | |
i = j | |
else: | |
i += 1 | |
# Extract important terms based on POS tags | |
important_terms = [] | |
for sentence in sentences: | |
words = word_tokenize(sentence) | |
pos_tags = pos_tag(words) | |
# Get nouns, verbs, and adjectives | |
terms = [word for word, pos in pos_tags if | |
(pos.startswith('NN') or pos.startswith('VB') or pos.startswith('JJ')) | |
and word.lower() not in self.stop_words | |
and len(word) > 2] | |
important_terms.extend(terms) | |
# Combine and remove duplicates | |
all_candidates = key_entities + important_terms | |
unique_candidates = [] | |
for candidate in all_candidates: | |
# Clean candidate | |
candidate = candidate.strip() | |
candidate = re.sub(r'[^\w\s]', '', candidate) | |
# Skip if empty or just stopwords | |
if not candidate or all(w.lower() in self.stop_words for w in candidate.split()): | |
continue | |
# Check for duplicates | |
if candidate.lower() not in [c.lower() for c in unique_candidates]: | |
unique_candidates.append(candidate) | |
# Use TF-IDF to rank entities by importance | |
if len(unique_candidates) > n: | |
try: | |
vectorizer = TfidfVectorizer() | |
tfidf_matrix = vectorizer.fit_transform([text] + unique_candidates) | |
document_vector = tfidf_matrix[0:1] | |
entity_vectors = tfidf_matrix[1:] | |
# Calculate similarity to document | |
similarities = cosine_similarity(document_vector, entity_vectors).flatten() | |
# Get top n entities | |
ranked_entities = [entity for _, entity in sorted(zip(similarities, unique_candidates), reverse=True)] | |
return ranked_entities[:n] | |
except: | |
# Fallback if TF-IDF fails | |
return random.sample(unique_candidates, min(n, len(unique_candidates))) | |
return unique_candidates[:n] | |
def generate_distractors(self, answer, context, n=3): | |
"""Generate plausible distractors for a given answer""" | |
# Extract potential distractors from context | |
potential_distractors = self.extract_key_entities(context, n=15) | |
# Remove the correct answer and similar options | |
filtered_distractors = [] | |
answer_lower = answer.lower() | |
for distractor in potential_distractors: | |
distractor_lower = distractor.lower() | |
# Skip if it's the answer or too similar to the answer | |
if distractor_lower == answer_lower: | |
continue | |
if answer_lower in distractor_lower or distractor_lower in answer_lower: | |
continue | |
if len(set(distractor_lower.split()) & set(answer_lower.split())) > len(answer_lower.split()) / 2: | |
continue | |
filtered_distractors.append(distractor) | |
# If we need more distractors, generate them with T5 | |
if len(filtered_distractors) < n: | |
input_text = f"generate alternatives for: {answer} context: {context}" | |
inputs = self.t5_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
outputs = self.t5_model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_length=64, | |
num_beams=5, | |
top_k=50, | |
top_p=0.95, | |
temperature=1.2, | |
do_sample=True, | |
num_return_sequences=5 | |
) | |
model_distractors = [self.t5_tokenizer.decode(out, skip_special_tokens=True) for out in outputs] | |
# Clean and validate model distractors | |
for distractor in model_distractors: | |
distractor = self.clean_text(distractor) | |
# Skip if it's the answer or too similar | |
if distractor.lower() == answer.lower(): | |
continue | |
if answer.lower() in distractor.lower() or distractor.lower() in answer.lower(): | |
continue | |
filtered_distractors.append(distractor) | |
# Ensure uniqueness | |
unique_distractors = [] | |
for d in filtered_distractors: | |
if d.lower() not in [x.lower() for x in unique_distractors]: | |
unique_distractors.append(d) | |
# If we still don't have enough, create semantic variations | |
while len(unique_distractors) < n: | |
if not unique_distractors and not potential_distractors: | |
# No existing distractors to work with, create something different | |
unique_distractors.append(f"None of the above") | |
unique_distractors.append(f"All of the above") | |
unique_distractors.append(f"Not mentioned in the text") | |
else: | |
base = answer if not unique_distractors else random.choice(unique_distractors) | |
words = base.split() | |
if len(words) > 1: | |
# Modify a multi-word distractor | |
modified = words.copy() | |
pos_to_change = random.randint(0, len(words)-1) | |
# Make sure the new distractor is different | |
modification = f"alternative_{modified[pos_to_change]}" | |
while modification in [x.lower() for x in unique_distractors]: | |
modification += "_variant" | |
modified[pos_to_change] = modification | |
unique_distractors.append(" ".join(modified)) | |
else: | |
# Modify a single word | |
modification = f"alternative_{base}" | |
while modification in [x.lower() for x in unique_distractors]: | |
modification += "_variant" | |
unique_distractors.append(modification) | |
# Return the required number of distractors | |
return unique_distractors[:n] | |
def validate_mcq(self, mcq, context): | |
"""Validate if an MCQ meets quality standards""" | |
# Check if question ends with question mark | |
if not mcq['question'].endswith('?'): | |
return False | |
# Check if the question is too short | |
if len(mcq['question'].split()) < 5: | |
return False | |
# Check if question contains the answer (too obvious) | |
if mcq['answer'].lower() in mcq['question'].lower(): | |
return False | |
# Check if options are sufficiently different | |
if len(set([o.lower() for o in mcq['options']])) < len(mcq['options']): | |
return False | |
# Check if answer is in the context | |
if mcq['answer'].lower() not in context.lower(): | |
return False | |
return True | |
def generate_mcqs(self, paragraph, num_questions=5): | |
"""Generate multiple-choice questions from a paragraph""" | |
paragraph = self.clean_text(paragraph) | |
mcqs = [] | |
# Extract potential answers | |
potential_answers = self.extract_key_entities(paragraph, n=num_questions*3) | |
# Shuffle potential answers | |
random.shuffle(potential_answers) | |
# Try to generate MCQs for each potential answer | |
attempts = 0 | |
max_attempts = num_questions * 3 # Try more potential answers than needed | |
while len(mcqs) < num_questions and attempts < max_attempts and potential_answers: | |
answer = potential_answers.pop(0) | |
attempts += 1 | |
# Generate question | |
question = self.generate_question(paragraph, answer) | |
# Generate distractors | |
distractors = self.generate_distractors(answer, paragraph) | |
# Create MCQ | |
mcq = { | |
'question': question, | |
'options': [answer] + distractors, | |
'answer': answer | |
} | |
# Validate MCQ | |
if self.validate_mcq(mcq, paragraph): | |
# Shuffle options | |
shuffled_options = mcq['options'].copy() | |
random.shuffle(shuffled_options) | |
# Find the index of the correct answer | |
correct_index = shuffled_options.index(answer) | |
# Update MCQ with shuffled options | |
mcq['options'] = shuffled_options | |
mcq['answer_index'] = correct_index | |
mcqs.append(mcq) | |
return mcqs[:num_questions] | |
# Helper functions | |
def format_mcq(mcq, index): | |
"""Format MCQ for display""" | |
question = f"Q{index+1}: {mcq['question']}" | |
options = [f" {chr(65+i)}. {option}" for i, option in enumerate(mcq['options'])] | |
answer = f"Answer: {chr(65+mcq['answer_index'])}" | |
return "\n".join([question] + options + [answer, ""]) | |
def generate_mcqs_from_paragraph(paragraph, num_questions=5): | |
"""Generate and format MCQs from a paragraph""" | |
generator = ImprovedMCQGenerator() | |
mcqs = generator.generate_mcqs(paragraph, num_questions) | |
formatted_mcqs = [] | |
for i, mcq in enumerate(mcqs): | |
formatted_mcqs.append(format_mcq(mcq, i)) | |
return "\n".join(formatted_mcqs) | |
# Example paragraphs | |
example_paragraphs = [ | |
""" | |
The cell is the basic structural and functional unit of all living organisms. Cells can be classified into two main types: prokaryotic and eukaryotic. | |
Prokaryotic cells, found in bacteria and archaea, lack a defined nucleus and membrane-bound organelles. In contrast, eukaryotic cells, which make up plants, | |
animals, fungi, and protists, contain a nucleus that houses the cell’s DNA, as well as various organelles like mitochondria and the endoplasmic reticulum. | |
The cell membrane regulates the movement of substances in and out of the cell, while the cytoplasm supports the internal structures. | |
""", | |
""" | |
The Industrial Revolution was a major historical transformation that began in Great Britain in the late 18th century. It marked the shift from manual labor and | |
hand-made goods to machine-based manufacturing and mass production. This shift significantly increased productivity and efficiency. The textile industry was the | |
first to implement modern industrial methods, including the use of spinning machines and mechanized looms. A key innovation during this period was the development | |
of steam power, notably improved by Scottish engineer James Watt. Steam engines enabled factories to operate away from rivers, which had previously been the main | |
power source. Additional advancements included the invention of machine tools and the emergence of large-scale factory systems. These changes revolutionized industrial | |
labor and contributed to the rise of new social classes, including the industrial working class and the capitalist class. The Industrial Revolution also led to rapid | |
urbanization, a sharp rise in population, and eventually, improvements in living standards and economic growth. | |
""" | |
] | |
# Main execution | |
if __name__ == "__main__": | |
print("MCQ Generator - Testing with Example Paragraphs") | |
print("=" * 80) | |
for i, paragraph in enumerate(example_paragraphs): | |
print(f"\nExample {i + 1}:") | |
print("-" * 40) | |
if is_suitable_for_students(paragraph): | |
print(generate_mcqs_from_paragraph(paragraph)) | |
else: | |
print("❌ Content not suitable for MCQ generation. Please provide different content.") | |
print("=" * 80) | |
# Interactive mode | |
print("\n--- MCQ Generator ---") | |
print("Enter a paragraph to generate MCQs (or type 'exit' to quit):") | |
while True: | |
user_input = input("> ") | |
if user_input.lower() == 'exit': | |
break | |
if is_suitable_for_students(user_input): | |
print(generate_mcqs_from_paragraph(user_input)) | |
else: | |
print("❌ Content not suitable for MCQ generation. Please provide different content.") | |
"""#Performance Metrics | |
""" | |
import time | |
import psutil | |
import numpy as np | |
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction | |
from rouge import Rouge | |
import matplotlib.pyplot as plt | |
try: | |
from IPython.display import display | |
except ImportError: | |
# Create a dummy display function for non-notebook environments | |
def display(obj): | |
pass | |
import pandas as pd | |
from nltk.tokenize import sent_tokenize | |
import tracemalloc | |
import gc | |
import re | |
import random | |
import warnings | |
from sklearn.metrics.pairwise import cosine_similarity | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
class MCQPerformanceMetrics: | |
def __init__(self, mcq_generator): | |
"""Initialize the performance metrics class with the MCQ generator""" | |
self.mcq_generator = mcq_generator | |
self.rouge = Rouge() | |
# Initialize NLTK smoothing function to handle zero counts | |
self.smoothing = SmoothingFunction().method1 | |
# For semantic similarity | |
self.tfidf_vectorizer = TfidfVectorizer(stop_words='english') | |
def measure_execution_time(self, paragraphs, num_questions=5, repetitions=3): | |
"""Measure execution time for generating MCQs""" | |
execution_times = [] | |
questions_per_second = [] | |
for paragraph in paragraphs: | |
paragraph_times = [] | |
for _ in range(repetitions): | |
start_time = time.time() | |
mcqs = self.mcq_generator.generate_mcqs(paragraph, num_questions) | |
end_time = time.time() | |
execution_time = end_time - start_time | |
paragraph_times.append(execution_time) | |
# Calculate questions per second | |
if len(mcqs) > 0: | |
qps = len(mcqs) / execution_time | |
questions_per_second.append(qps) | |
execution_times.append(np.mean(paragraph_times)) | |
return { | |
'avg_execution_time': np.mean(execution_times), | |
'min_execution_time': np.min(execution_times), | |
'max_execution_time': np.max(execution_times), | |
'avg_questions_per_second': np.mean(questions_per_second) if questions_per_second else 0 | |
} | |
def measure_memory_usage(self, paragraph, num_questions=5): | |
"""Measure peak memory usage during MCQ generation""" | |
# Clear memory before test | |
gc.collect() | |
# Start memory tracking | |
tracemalloc.start() | |
# Generate MCQs | |
self.mcq_generator.generate_mcqs(paragraph, num_questions) | |
# Get peak memory usage | |
current, peak = tracemalloc.get_traced_memory() | |
# Stop tracking | |
tracemalloc.stop() | |
return { | |
'current_memory_MB': current / (1024 * 1024), | |
'peak_memory_MB': peak / (1024 * 1024) | |
} | |
def compute_semantic_similarity(self, text1, text2): | |
"""Compute semantic similarity between two texts using TF-IDF and cosine similarity""" | |
try: | |
# Handle empty strings | |
if not text1.strip() or not text2.strip(): | |
return 0 | |
# Fit and transform the texts | |
tfidf_matrix = self.tfidf_vectorizer.fit_transform([text1, text2]) | |
# Compute cosine similarity | |
similarity = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0] | |
return similarity | |
except Exception as e: | |
print(f"Error computing semantic similarity: {e}") | |
return 0 | |
def evaluate_question_quality(self, mcqs, reference_questions=None): | |
"""Evaluate the quality of generated questions with improved reference handling""" | |
if not mcqs: | |
return {'avg_question_length': 0, 'has_question_mark': 0} | |
# Basic metrics | |
question_lengths = [len(mcq['question'].split()) for mcq in mcqs] | |
has_question_mark = [int(mcq['question'].endswith('?')) for mcq in mcqs] | |
# Option distinctiveness - average cosine distance between options | |
option_distinctiveness = [] | |
for mcq in mcqs: | |
options = mcq['options'] | |
if len(options) < 2: | |
continue | |
# Enhanced distinctiveness calculation using TF-IDF and cosine similarity | |
distinctiveness_scores = [] | |
for i in range(len(options)): | |
for j in range(i+1, len(options)): | |
if not options[i].strip() or not options[j].strip(): | |
continue | |
# Calculate semantic similarity between options | |
similarity = self.compute_semantic_similarity(options[i], options[j]) | |
distinctiveness_scores.append(1 - similarity) # Higher is better (more distinct) | |
if distinctiveness_scores: | |
option_distinctiveness.append(np.mean(distinctiveness_scores)) | |
# Compare with reference questions if provided | |
bleu_scores = [] | |
modified_bleu_scores = [] # Using smoothing function | |
rouge_scores = {'rouge-1': [], 'rouge-2': [], 'rouge-l': []} | |
semantic_similarities = [] # New metric for semantic similarity | |
if reference_questions and len(reference_questions) > 0: | |
# Print debug info | |
print(f"Number of MCQs: {len(mcqs)}") | |
print(f"Number of reference questions: {len(reference_questions)}") | |
# Align MCQs with reference questions based on semantic similarity | |
aligned_pairs = [] | |
if len(mcqs) <= len(reference_questions): | |
# If we have enough reference questions, find the best match for each MCQ | |
for mcq in mcqs: | |
best_match_idx = -1 | |
best_similarity = -1 | |
for i, ref in enumerate(reference_questions): | |
if i in [pair[1] for pair in aligned_pairs]: | |
continue # Skip already matched references | |
similarity = self.compute_semantic_similarity( | |
mcq['question'], | |
ref if isinstance(ref, str) else "" | |
) | |
if similarity > best_similarity: | |
best_similarity = similarity | |
best_match_idx = i | |
if best_match_idx >= 0: | |
aligned_pairs.append((mcq, best_match_idx)) | |
else: | |
# If no match found, use the first available reference | |
for i, ref in enumerate(reference_questions): | |
if i not in [pair[1] for pair in aligned_pairs]: | |
aligned_pairs.append((mcq, i)) | |
break | |
else: | |
# If we have more MCQs than references, match each reference to its best MCQ | |
used_mcqs = set() | |
for i, ref in enumerate(reference_questions): | |
best_match_idx = -1 | |
best_similarity = -1 | |
for j, mcq in enumerate(mcqs): | |
if j in used_mcqs: | |
continue # Skip already matched MCQs | |
similarity = self.compute_semantic_similarity( | |
mcq['question'], | |
ref if isinstance(ref, str) else "" | |
) | |
if similarity > best_similarity: | |
best_similarity = similarity | |
best_match_idx = j | |
if best_match_idx >= 0: | |
aligned_pairs.append((mcqs[best_match_idx], i)) | |
used_mcqs.add(best_match_idx) | |
# Add remaining MCQs with cycling through references | |
for i, mcq in enumerate(mcqs): | |
if i not in used_mcqs: | |
ref_idx = i % len(reference_questions) | |
aligned_pairs.append((mcq, ref_idx)) | |
# Calculate metrics for aligned pairs | |
for mcq, ref_idx in aligned_pairs: | |
reference = reference_questions[ref_idx] if isinstance(reference_questions[ref_idx], str) else "" | |
if not reference: | |
continue | |
ref_tokens = reference.split() | |
hyp_tokens = mcq['question'].split() | |
# Debug output | |
print(f"\nReference ({ref_idx}): {reference}") | |
print(f"Generated: {mcq['question']}") | |
# Calculate semantic similarity | |
sem_sim = self.compute_semantic_similarity(mcq['question'], reference) | |
semantic_similarities.append(sem_sim) | |
print(f"Semantic similarity: {sem_sim:.4f}") | |
try: | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
# Standard BLEU | |
bleu_score = sentence_bleu([ref_tokens], hyp_tokens, weights=(0.25, 0.25, 0.25, 0.25)) | |
bleu_scores.append(bleu_score) | |
# BLEU with smoothing to handle zero counts | |
modified_bleu = sentence_bleu( | |
[ref_tokens], | |
hyp_tokens, | |
weights=(0.25, 0.25, 0.25, 0.25), | |
smoothing_function=self.smoothing | |
) | |
modified_bleu_scores.append(modified_bleu) | |
print(f"Smoothed BLEU: {modified_bleu:.4f}") | |
except Exception as e: | |
print(f"BLEU score calculation error: {e}") | |
# ROUGE scores | |
try: | |
if len(reference) > 0 and len(mcq['question']) > 0: | |
rouge_result = self.rouge.get_scores(mcq['question'], reference)[0] | |
rouge_scores['rouge-1'].append(rouge_result['rouge-1']['f']) | |
rouge_scores['rouge-2'].append(rouge_result['rouge-2']['f']) | |
rouge_scores['rouge-l'].append(rouge_result['rouge-l']['f']) | |
print(f"ROUGE-1: {rouge_result['rouge-1']['f']:.4f}, ROUGE-L: {rouge_result['rouge-l']['f']:.4f}") | |
except Exception as e: | |
print(f"ROUGE score calculation error: {e}") | |
results = { | |
'avg_question_length': np.mean(question_lengths), | |
'has_question_mark': np.mean(has_question_mark) * 100, # as percentage | |
'option_distinctiveness': np.mean(option_distinctiveness) if option_distinctiveness else 0 | |
} | |
if modified_bleu_scores: | |
results['avg_smoothed_bleu_score'] = np.mean(modified_bleu_scores) | |
if semantic_similarities: | |
results['avg_semantic_similarity'] = np.mean(semantic_similarities) | |
for rouge_type, scores in rouge_scores.items(): | |
if scores: | |
results[f'avg_{rouge_type}'] = np.mean(scores) | |
return results | |
def analyze_distractor_quality(self, mcqs, context): | |
"""Analyze the quality of distractors with improved semantic analysis""" | |
if not mcqs: | |
return {} | |
# Check if distractor is in context | |
context_presence = [] | |
semantic_relevance = [] # New metric for semantic relevance to context | |
for mcq in mcqs: | |
try: | |
correct_answer = mcq['options'][mcq['answer_index']] | |
distractors = [opt for i, opt in enumerate(mcq['options']) if i != mcq['answer_index']] | |
distractor_in_context = [] | |
distractor_semantic_relevance = [] | |
for distractor in distractors: | |
# Check semantic relevance to context | |
semantic_sim = self.compute_semantic_similarity(distractor, context) | |
distractor_semantic_relevance.append(semantic_sim) | |
# Traditional word overlap check | |
distractor_words = set(distractor.lower().split()) | |
context_words = set(context.lower().split()) | |
if distractor_words: | |
overlap_ratio = len(distractor_words.intersection(context_words)) / len(distractor_words) | |
distractor_in_context.append(overlap_ratio >= 0.5) # At least 50% of words in context | |
if distractor_in_context: | |
context_presence.append(sum(distractor_in_context) / len(distractor_in_context)) | |
if distractor_semantic_relevance: | |
semantic_relevance.append(np.mean(distractor_semantic_relevance)) | |
except Exception as e: | |
print(f"Error in distractor context analysis: {e}") | |
# Calculate semantic similarity between distractors and correct answer | |
distractor_answer_similarity = [] | |
distractor_plausibility = [] # New metric for plausibility | |
for mcq in mcqs: | |
try: | |
correct_answer = mcq['options'][mcq['answer_index']] | |
distractors = [opt for i, opt in enumerate(mcq['options']) if i != mcq['answer_index']] | |
similarities = [] | |
plausibility_scores = [] | |
for distractor in distractors: | |
# Semantic similarity | |
similarity = self.compute_semantic_similarity(correct_answer, distractor) | |
similarities.append(similarity) | |
# Plausibility - should be somewhat similar to correct answer but not too similar | |
# Sweet spot is around 0.3-0.7 similarity | |
plausibility = 1.0 - abs(0.5 - similarity) # 1.0 at 0.5 similarity, decreasing on both sides | |
plausibility_scores.append(plausibility) | |
if similarities: | |
distractor_answer_similarity.append(np.mean(similarities)) | |
if plausibility_scores: | |
distractor_plausibility.append(np.mean(plausibility_scores)) | |
except Exception as e: | |
print(f"Error in distractor similarity analysis: {e}") | |
results = { | |
'context_presence': np.mean(context_presence) * 100 if context_presence else 0, # as percentage | |
'distractor_answer_similarity': np.mean(distractor_answer_similarity) * 100 if distractor_answer_similarity else 0 # as percentage | |
} | |
# Add new metrics | |
if semantic_relevance: | |
results['distractor_semantic_relevance'] = np.mean(semantic_relevance) | |
if distractor_plausibility: | |
results['distractor_plausibility'] = np.mean(distractor_plausibility) | |
return results | |
def calculate_readability_scores(self, mcqs): | |
"""Calculate readability scores for questions""" | |
try: | |
import textstat | |
has_textstat = True | |
except ImportError: | |
has_textstat = False | |
print("textstat package not found - readability metrics will be skipped") | |
return {} | |
if not has_textstat or not mcqs: | |
return {} | |
readability_scores = { | |
'flesch_reading_ease': [], | |
'flesch_kincaid_grade': [], | |
'automated_readability_index': [], | |
'smog_index': [], # Added SMOG Index | |
'coleman_liau_index': [] # Added Coleman-Liau Index | |
} | |
for mcq in mcqs: | |
question_text = mcq['question'] | |
# Add options to create full MCQ text for readability analysis | |
full_mcq_text = question_text + "\n" | |
for i, option in enumerate(mcq['options']): | |
full_mcq_text += f"{chr(65+i)}. {option}\n" | |
try: | |
readability_scores['flesch_reading_ease'].append(textstat.flesch_reading_ease(full_mcq_text)) | |
readability_scores['flesch_kincaid_grade'].append(textstat.flesch_kincaid_grade(full_mcq_text)) | |
readability_scores['automated_readability_index'].append(textstat.automated_readability_index(full_mcq_text)) | |
readability_scores['smog_index'].append(textstat.smog_index(full_mcq_text)) | |
readability_scores['coleman_liau_index'].append(textstat.coleman_liau_index(full_mcq_text)) | |
except Exception as e: | |
print(f"Error calculating readability: {e}") | |
result = {} | |
for metric, scores in readability_scores.items(): | |
if scores: | |
result[f'avg_{metric}'] = np.mean(scores) | |
return result | |
def evaluate_question_diversity(self, mcqs): | |
"""Evaluate the diversity of questions generated""" | |
if not mcqs or len(mcqs) < 2: | |
return {'question_diversity': 0} | |
# Calculate pairwise similarity between questions | |
similarities = [] | |
for i in range(len(mcqs)): | |
for j in range(i+1, len(mcqs)): | |
similarity = self.compute_semantic_similarity(mcqs[i]['question'], mcqs[j]['question']) | |
similarities.append(similarity) | |
# Diversity is inverse of average similarity | |
avg_similarity = np.mean(similarities) if similarities else 0 | |
diversity = 1 - avg_similarity | |
return {'question_diversity': diversity} | |
def evaluate_contextual_relevance(self, mcqs, context): | |
"""Evaluate how relevant questions are to the context""" | |
if not mcqs: | |
return {'contextual_relevance': 0} | |
relevance_scores = [] | |
for mcq in mcqs: | |
# Calculate similarity between question and context | |
similarity = self.compute_semantic_similarity(mcq['question'], context) | |
relevance_scores.append(similarity) | |
return {'contextual_relevance': np.mean(relevance_scores) if relevance_scores else 0} | |
def evaluate(self, paragraphs, num_questions=5, reference_questions=None): | |
"""Run a comprehensive evaluation of the MCQ generator""" | |
try: | |
# Get one set of MCQs for quality evaluation | |
sample_paragraph = paragraphs[0] if isinstance(paragraphs, list) else paragraphs | |
sample_mcqs = self.mcq_generator.generate_mcqs(sample_paragraph, num_questions) | |
print(f"Generated {len(sample_mcqs)} MCQs for evaluation") | |
# Execution time | |
timing_metrics = self.measure_execution_time( | |
paragraphs if isinstance(paragraphs, list) else [paragraphs], | |
num_questions | |
) | |
# Memory usage | |
memory_metrics = self.measure_memory_usage(sample_paragraph, num_questions) | |
# Question quality | |
quality_metrics = self.evaluate_question_quality(sample_mcqs, reference_questions) | |
# Distractor quality | |
distractor_metrics = self.analyze_distractor_quality(sample_mcqs, sample_paragraph) | |
# Readability metrics | |
readability_metrics = self.calculate_readability_scores(sample_mcqs) | |
# New metrics | |
diversity_metrics = self.evaluate_question_diversity(sample_mcqs) | |
relevance_metrics = self.evaluate_contextual_relevance(sample_mcqs, sample_paragraph) | |
# Combine all metrics | |
all_metrics = { | |
**timing_metrics, | |
**memory_metrics, | |
**quality_metrics, | |
**distractor_metrics, | |
**readability_metrics, | |
**diversity_metrics, | |
**relevance_metrics | |
} | |
return all_metrics | |
except Exception as e: | |
print(f"Error during evaluation: {e}") | |
import traceback | |
traceback.print_exc() | |
return {"error": str(e)} | |
def visualize_results(self, metrics): | |
"""Visualize the evaluation results with enhanced charts""" | |
try: | |
# Create a dataframe for better display | |
metrics_df = pd.DataFrame({k: [v] for k, v in metrics.items()}) | |
# Format the numbers | |
for col in metrics_df.columns: | |
if 'time' in col: | |
metrics_df[col] = metrics_df[col].round(2).astype(str) + ' sec' | |
elif 'memory' in col: | |
metrics_df[col] = metrics_df[col].round(2).astype(str) + ' MB' | |
elif col in ['has_question_mark', 'context_presence', 'distractor_answer_similarity']: | |
metrics_df[col] = metrics_df[col].round(1).astype(str) + '%' | |
else: | |
metrics_df[col] = metrics_df[col].round(3) | |
display(metrics_df.T.rename(columns={0: 'Value'})) | |
# Create enhanced visualizations | |
fig = plt.figure(figsize=(16, 14)) | |
# Create 3 rows, 2 columns for more organized charts | |
gs = fig.add_gridspec(3, 2) | |
# Filter out metrics that shouldn't be plotted | |
plottable_metrics = {k: v for k, v in metrics.items() if isinstance(v, (int, float))} | |
# 1. Performance Metrics | |
ax1 = fig.add_subplot(gs[0, 0]) | |
performance_keys = ['avg_execution_time', 'avg_questions_per_second'] | |
performance_metrics = [plottable_metrics.get(k, 0) for k in performance_keys] | |
bars = ax1.bar(performance_keys, performance_metrics, color=['#3498db', '#2ecc71']) | |
ax1.set_title('Performance Metrics', fontsize=14, fontweight='bold') | |
ax1.set_xticklabels(performance_keys, rotation=45, ha='right') | |
# Add value labels on bars | |
for bar in bars: | |
height = bar.get_height() | |
ax1.text(bar.get_x() + bar.get_width()/2., height + 0.1, | |
f'{height:.2f}', ha='center', va='bottom') | |
# 2. Memory Usage | |
ax2 = fig.add_subplot(gs[0, 1]) | |
memory_keys = ['current_memory_MB', 'peak_memory_MB'] | |
memory_metrics = [plottable_metrics.get(k, 0) for k in memory_keys] | |
bars = ax2.bar(memory_keys, memory_metrics, color=['#9b59b6', '#34495e']) | |
ax2.set_title('Memory Usage (MB)', fontsize=14, fontweight='bold') | |
# Add value labels | |
for bar in bars: | |
height = bar.get_height() | |
ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01, | |
f'{height:.2f}', ha='center', va='bottom') | |
# 3. Question Quality | |
ax3 = fig.add_subplot(gs[1, 0]) | |
quality_keys = ['avg_question_length', 'has_question_mark', 'option_distinctiveness', | |
'question_diversity', 'contextual_relevance'] | |
quality_metrics = [ | |
plottable_metrics.get('avg_question_length', 0), | |
plottable_metrics.get('has_question_mark', 0) / 100, # Convert from percentage | |
plottable_metrics.get('option_distinctiveness', 0), | |
plottable_metrics.get('question_diversity', 0), | |
plottable_metrics.get('contextual_relevance', 0) | |
] | |
bars = ax3.bar(['Avg Length', 'Question Mark', 'Option Distinct.', 'Diversity', 'Relevance'], | |
quality_metrics, color=['#f39c12', '#d35400', '#c0392b', '#16a085', '#27ae60']) | |
ax3.set_title('Question Quality Metrics', fontsize=14, fontweight='bold') | |
ax3.set_xticklabels(['Avg Length', 'Question Mark', 'Option Distinct.', 'Diversity', 'Relevance'], | |
rotation=45, ha='right') | |
# Add value labels | |
for bar in bars: | |
height = bar.get_height() | |
ax3.text(bar.get_x() + bar.get_width()/2., height + 0.01, | |
f'{height:.2f}', ha='center', va='bottom') | |
# 4. Distractor Quality | |
ax4 = fig.add_subplot(gs[1, 1]) | |
distractor_keys = ['context_presence', 'distractor_answer_similarity', | |
'distractor_semantic_relevance', 'distractor_plausibility'] | |
distractor_metrics = [ | |
plottable_metrics.get('context_presence', 0) / 100, # Convert from percentage | |
plottable_metrics.get('distractor_answer_similarity', 0) / 100, # Convert from percentage | |
plottable_metrics.get('distractor_semantic_relevance', 0), | |
plottable_metrics.get('distractor_plausibility', 0) | |
] | |
bars = ax4.bar(['Context', 'Answer Sim.', 'Semantic Rel.', 'Plausibility'], | |
distractor_metrics, color=['#1abc9c', '#e74c3c', '#3498db', '#f1c40f']) | |
ax4.set_title('Distractor Quality Metrics', fontsize=14, fontweight='bold') | |
ax4.set_xticklabels(['Context', 'Answer Sim.', 'Semantic Rel.', 'Plausibility'], | |
rotation=45, ha='right') | |
# Add value labels | |
for bar in bars: | |
height = bar.get_height() | |
ax4.text(bar.get_x() + bar.get_width()/2., height + 0.01, | |
f'{height:.2f}', ha='center', va='bottom') | |
# 5. NLP Metrics | |
ax5 = fig.add_subplot(gs[2, 0]) | |
nlp_keys = ['avg_smoothed_bleu_score', 'avg_semantic_similarity', | |
'avg_rouge-1', 'avg_rouge-2', 'avg_rouge-l'] | |
nlp_metrics = [ | |
plottable_metrics.get('avg_smoothed_bleu_score', 0), | |
plottable_metrics.get('avg_semantic_similarity', 0), | |
plottable_metrics.get('avg_rouge-1', 0), | |
plottable_metrics.get('avg_rouge-2', 0), | |
plottable_metrics.get('avg_rouge-l', 0) | |
] | |
bars = ax5.bar(['Smooth BLEU', 'Semantic', 'ROUGE-1', 'ROUGE-2', 'ROUGE-L'], | |
nlp_metrics, color=['#3498db', '#2980b9', '#9b59b6', '#e74c3c', '#c0392b', '#d35400']) | |
ax5.set_title('NLP Evaluation Metrics', fontsize=14, fontweight='bold') | |
ax5.set_xticklabels(['Smooth BLEU', 'Semantic', 'ROUGE-1', 'ROUGE-2', 'ROUGE-L'], | |
rotation=45, ha='right') | |
# Add value labels | |
for bar in bars: | |
height = bar.get_height() | |
ax5.text(bar.get_x() + bar.get_width()/2., height + 0.01, | |
f'{height:.3f}', ha='center', va='bottom') | |
# 6. Readability Metrics | |
ax6 = fig.add_subplot(gs[2, 1]) | |
readability_keys = ['avg_flesch_reading_ease', 'avg_flesch_kincaid_grade', | |
'avg_automated_readability_index', 'avg_smog_index', 'avg_coleman_liau_index'] | |
readability_metrics = [ | |
plottable_metrics.get('avg_flesch_reading_ease', 0), | |
plottable_metrics.get('avg_flesch_kincaid_grade', 0), | |
plottable_metrics.get('avg_automated_readability_index', 0), | |
plottable_metrics.get('avg_smog_index', 0), | |
plottable_metrics.get('avg_coleman_liau_index', 0) | |
] | |
bars = ax6.bar(['Flesch Ease', 'Kincaid', 'ARI', 'SMOG', 'Coleman-Liau'], | |
readability_metrics, color=['#27ae60', '#2ecc71', '#16a085', '#1abc9c', '#2980b9']) | |
ax6.set_title('Readability Metrics', fontsize=14, fontweight='bold') | |
ax6.set_xticklabels(['Flesch Ease', 'Kincaid', 'ARI', 'SMOG', 'Coleman-Liau'], | |
rotation=45, ha='right') | |
# Add value labels | |
for bar in bars: | |
height = bar.get_height() | |
ax6.text(bar.get_x() + bar.get_width()/2., height + 0.1, | |
f'{height:.2f}', ha='center', va='bottom') | |
plt.tight_layout() | |
plt.show() | |
return fig | |
except Exception as e: | |
print(f"Error in visualization: {e}") | |
import traceback | |
traceback.print_exc() | |
# Example usage function with improved error handling | |
def run_performance_evaluation(): | |
# Import the MCQ generator | |
try: | |
# First try to import from the module | |
from improved_mcq_generator import ImprovedMCQGenerator | |
except ImportError: | |
# If that fails, try to load the class from current namespace | |
try: | |
# This assumes the class is defined in the current session | |
ImprovedMCQGenerator = globals().get('ImprovedMCQGenerator') | |
if ImprovedMCQGenerator is None: | |
raise ImportError("ImprovedMCQGenerator class not found") | |
except Exception as e: | |
print(f"Error importing ImprovedMCQGenerator: {e}") | |
return | |
# Test paragraphs - use a variety for better assessment | |
test_paragraphs = [ | |
"""The cell is the basic structural and functional unit of all living organisms. Cells can be classified into two main types: prokaryotic and eukaryotic. | |
Prokaryotic cells, found in bacteria and archaea, lack a defined nucleus and membrane-bound organelles. In contrast, eukaryotic cells, which make up plants, | |
animals, fungi, and protists, contain a nucleus that houses the cell’s DNA, as well as various organelles like mitochondria and the endoplasmic reticulum. | |
The cell membrane regulates the movement of substances in and out of the cell, while the cytoplasm supports the internal structures.""" | |
] | |
# Reference questions for comparison (optional) | |
reference_questions = [ | |
"What do prokaryotic cells lack?", | |
"Which cell structures are missing in prokaryotic cells compared to eukaryotic cells?", | |
"What type of cells are found in bacteria and archaea?", | |
"What is the basic structural and functional unit of all living organisms?", | |
"What controls the movement of substances in and out of a cell?" | |
] | |
try: | |
# Initialize the MCQ generator | |
mcq_generator = ImprovedMCQGenerator() | |
# Initialize performance metrics | |
metrics_evaluator = MCQPerformanceMetrics(mcq_generator) | |
# Run evaluation | |
print("Running performance evaluation...") | |
results = metrics_evaluator.evaluate(test_paragraphs, num_questions=5, reference_questions=reference_questions) | |
# Visualize results | |
metrics_evaluator.visualize_results(results) | |
# Print detailed results | |
print("\nDetailed Performance Metrics:") | |
for metric, value in results.items(): | |
# Format the value based on metric type | |
if isinstance(value, (int, float)): | |
if 'time' in metric: | |
print(f"{metric}: {value:.2f} seconds") | |
elif 'memory' in metric: | |
print(f"{metric}: {value:.2f} MB") | |
elif metric in ['has_question_mark', 'context_presence', 'distractor_answer_similarity']: | |
print(f"{metric}: {value:.1f}%") | |
else: | |
print(f"{metric}: {value:.3f}") | |
else: | |
print(f"{metric}: {value}") | |
except Exception as e: | |
print(f"Error in performance evaluation: {e}") | |
import traceback | |
traceback.print_exc() | |
if __name__ == "__main__": | |
run_performance_evaluation() |