Spaces:
Sleeping
Sleeping
import pandas as pd | |
import requests | |
from typing import Tuple, Optional | |
from dataclasses import dataclass | |
import logging | |
from dotenv import load_dotenv | |
import os | |
import time | |
import re | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# .env ํ์ผ ๋ก๋ | |
load_dotenv() | |
# Hugging Face API ์ ๋ณด | |
API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct" | |
API_KEY = os.getenv("HUGGINGFACE_API_KEY") | |
base_path = os.path.dirname(os.path.abspath(__file__)) | |
misconception_csv_path = os.path.join(base_path, 'misconception_mapping.csv') | |
if not API_KEY: | |
raise ValueError("API_KEY๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค. .env ํ์ผ์ ํ์ธํ์ธ์.") | |
#์ ์ฌ ๋ฌธ์ ์์ฑ๊ธฐ ํด๋์ค | |
class GeneratedQuestion: | |
question: str | |
choices: dict | |
correct_answer: str | |
explanation: str | |
class SimilarQuestionGenerator: | |
def __init__(self, misconception_csv_path: str = 'misconception_mapping.csv'): | |
""" | |
Initialize the generator by loading the misconception mapping and the language model. | |
""" | |
self._load_data(misconception_csv_path) | |
def _load_data(self, misconception_csv_path: str): | |
logger.info("Loading misconception mapping...") | |
self.misconception_df = pd.read_csv(misconception_csv_path) | |
def get_misconception_text(self, misconception_id: float) -> Optional[str]: | |
# MisconceptionId๋ฅผ ๋ฐ์ ํด๋น ID์ ๋งค์นญ๋๋ ์ค๊ฐ๋ ์ค๋ช ํ ์คํธ๋ฅผ ๋ฐํํฉ๋๋ค | |
"""Retrieve the misconception text based on the misconception ID.""" | |
if pd.isna(misconception_id): # NaN ์ฒดํฌ | |
logger.warning("Received NaN for misconception_id.") | |
return "No misconception provided." | |
try: | |
row = self.misconception_df[self.misconception_df['MisconceptionId'] == int(misconception_id)] | |
if not row.empty: | |
return row.iloc[0]['MisconceptionName'] | |
except ValueError as e: | |
logger.error(f"Error processing misconception_id: {e}") | |
logger.warning(f"No misconception found for ID: {misconception_id}") | |
return "Misconception not found." | |
def generate_prompt(self, construct_name: str, subject_name: str, question_text: str, correct_answer_text: str, wrong_answer_text: str, misconception_text: str) -> str: | |
"""Create a prompt for the language model.""" | |
#๋ฌธ์ ์์ฑ์ ์ํ ํ๋กฌํํธ ํ ์คํธ๋ฅผ ์์ฑ | |
logger.info("Generating prompt...") | |
misconception_clause = (f"that targets the following misconception: \"{misconception_text}\"." if misconception_text != "There is no misconception" else "") | |
prompt = f""" | |
<|begin_of_text|> | |
<|start_header_id|>system<|end_header_id|> | |
You are an educational assistant designed to generate multiple-choice questions {misconception_clause} | |
<|eot_id|> | |
<|start_header_id|>user<|end_header_id|> | |
You need to create a similar multiple-choice question based on the following details: | |
Construct Name: {construct_name} | |
Subject Name: {subject_name} | |
Question Text: {question_text} | |
Correct Answer: {correct_answer_text} | |
Wrong Answer: {wrong_answer_text} | |
Please follow this output format: | |
--- | |
Question: <Your Question Text> | |
A) <Choice A> | |
B) <Choice B> | |
C) <Choice C> | |
D) <Choice D> | |
Correct Answer: <Correct Choice (e.g., A)> | |
Explanation: <Brief explanation for the correct answer> | |
--- | |
Ensure that the question is conceptually similar but not identical to the original. Ensure clarity and educational value. | |
<|eot_id|> | |
<|start_header_id|>assistant<|end_header_id|> | |
""".strip() | |
logger.debug(f"Generated prompt: {prompt}") | |
return prompt | |
def call_model_api(self, prompt: str) -> str: | |
"""Hugging Face API ํธ์ถ""" | |
logger.info("Calling Hugging Face API...") | |
headers = {"Authorization": f"Bearer {API_KEY}"} | |
try: | |
response = requests.post(API_URL, headers=headers, json={"inputs": prompt}) | |
response.raise_for_status() | |
response_data = response.json() | |
logger.debug(f"Raw API response: {response_data}") | |
# API ์๋ต์ด ๋ฆฌ์คํธ์ธ ๊ฒฝ์ฐ ์ฒ๋ฆฌ | |
if isinstance(response_data, list): | |
if response_data and isinstance(response_data[0], dict): | |
generated_text = response_data[0].get('generated_text', '') | |
else: | |
generated_text = response_data[0] if response_data else '' | |
# API ์๋ต์ด ๋์ ๋๋ฆฌ์ธ ๊ฒฝ์ฐ ์ฒ๋ฆฌ | |
elif isinstance(response_data, dict): | |
generated_text = response_data.get('generated_text', '') | |
else: | |
generated_text = str(response_data) | |
logger.info(f"Generated text: {generated_text}") | |
return generated_text | |
except requests.exceptions.RequestException as e: | |
logger.error(f"API request failed: {e}") | |
raise | |
except Exception as e: | |
logger.error(f"Unexpected error in call_model_api: {e}") | |
raise | |
# --- module2.py ์ค ์ผ๋ถ --- | |
def parse_model_output(self, output: str) -> GeneratedQuestion: | |
"""Parse the model output with improved extraction of the question components.""" | |
if not isinstance(output, str): | |
logger.error(f"Invalid output format: {type(output)}. Expected string.") | |
raise ValueError("Model output is not a string") | |
logger.info("Parsing model output...") | |
# 1) ์ ์ฒด ํ ์คํธ๋ฅผ ์ค ๋จ์๋ก ๋๋ | |
lines = output.splitlines() | |
# 2) ๋ง์ง๋ง์ผ๋ก ๋ฑ์ฅํ๋ Question~Explanation ๋ธ๋ก์ ์ฐพ๊ธฐ ์ํ ์์ ๋ณ์ | |
question = "" | |
choices = {} | |
correct_answer = "" | |
explanation = "" | |
# ์ด ๋ธ๋ก์ ์ฌ๋ฌ ๋ฒ ๋ง๋ ์ ์์ผ๋, ์ผ๋จ ๋ฐ๊ฒฌํ ๋๋ง๋ค ์ ์ฅํด๋๊ณ ๋ฎ์ด์์ฐ๋ ๋ฐฉ์. | |
# ์ต์ข ์ ์ผ๋ก "๋ง์ง๋ง์ ๋ฐ๊ฒฌ๋" Question ๋ธ๋ก์ด ์๋ ๋ณ์๋ฅผ ๋ฎ์ด์ฐ๊ฒ ๋จ | |
temp_question = "" | |
temp_choices = {} | |
temp_correct = "" | |
temp_explanation = "" | |
for line in lines: | |
line = line.strip() | |
if not line: | |
continue | |
# Question: | |
if line.lower().startswith("question:"): | |
# ์ง๊ธ๊น์ง ์ ์ฅํด๋ ์ด์ ๋ธ๋ก๋ค์ ์ต์ข ์ ์ฅ ์์ญ์ ๋ฎ์ด์์ด๋ค | |
if temp_question: | |
question = temp_question | |
choices = temp_choices | |
correct_answer = temp_correct | |
explanation = temp_explanation | |
# ์ ๋ธ๋ก์ ์์ | |
temp_question = line.split(":", 1)[1].strip() | |
temp_choices = {} | |
temp_correct = "" | |
temp_explanation = "" | |
# A) / B) / C) / D) | |
elif re.match(r"^[ABCD]\)", line): | |
# "A) ์ ํ์ง ๋ด์ฉ" | |
letter = line[0] # A, B, C, D | |
choice_text = line[2:].strip() | |
temp_choices[letter] = choice_text | |
# Correct Answer: | |
elif line.lower().startswith("correct answer:"): | |
# "Correct Answer: A)" ํํ์์ A๋ง ์ถ์ถ | |
ans_part = line.split(":", 1)[1].strip() | |
temp_correct = ans_part[0].upper() if ans_part else "" | |
# Explanation: | |
elif line.lower().startswith("explanation:"): | |
temp_explanation = line.split(":", 1)[1].strip() | |
# ๋ฃจํ๊ฐ ๋๋ ๋ค, ํ ๋ฒ ๋ ์ต์ ๋ธ๋ก์ ์ต์ข ๋ณ์์ ๋ฐ์ | |
if temp_question: | |
question = temp_question | |
choices = temp_choices | |
correct_answer = temp_correct | |
explanation = temp_explanation | |
# ์ด์ question, choices, correct_answer, explanation์ด ์ต์ข ํ์ฑ ๊ฒฐ๊ณผ | |
logger.debug(f"Parsed components - Question: {question}, Choices: {choices}, " | |
f"Correct Answer: {correct_answer}, Explanation: {explanation}") | |
return GeneratedQuestion(question, choices, correct_answer, explanation) | |
def validate_generated_question(self, question: GeneratedQuestion) -> bool: | |
"""Validate if all components of the generated question are present and valid.""" | |
logger.info("Validating generated question...") | |
try: | |
# Check if question text exists and is not too short | |
if not question.question or len(question.question.strip()) < 10: | |
logger.warning("Question text is missing or too short") | |
return False | |
# Check if all four choices exist and are not empty | |
required_choices = set(['A', 'B', 'C', 'D']) | |
if set(question.choices.keys()) != required_choices: | |
logger.warning(f"Missing choices. Found: {set(question.choices.keys())}") | |
return False | |
if not all(choice.strip() for choice in question.choices.values()): | |
logger.warning("Empty choice text found") | |
return False | |
# Check if correct answer is valid (should be just A, B, C, or D) | |
if not question.correct_answer or question.correct_answer not in required_choices: | |
logger.warning(f"Invalid correct answer: {question.correct_answer}") | |
return False | |
# Check if explanation exists and is not too short | |
if not question.explanation or len(question.explanation.strip()) < 20: | |
logger.warning("Explanation is missing or too short") | |
return False | |
logger.info("Question validation passed") | |
return True | |
except Exception as e: | |
logger.error(f"Error during validation: {e}") | |
return False | |
def generate_similar_question_with_text(self, construct_name: str, subject_name: str, | |
question_text: str, correct_answer_text: str, | |
wrong_answer_text: str, misconception_id: float, | |
max_retries: int = 3) -> Tuple[Optional[GeneratedQuestion], Optional[str]]: | |
"""Generate a similar question with validation and retry mechanism.""" | |
logger.info("generate_similar_question_with_text initiated") | |
# Get misconception text | |
try: | |
misconception_text = self.get_misconception_text(misconception_id) | |
logger.info(f"Misconception text retrieved: {misconception_text}") | |
if not misconception_text: | |
logger.info("Skipping question generation due to lack of misconception.") | |
return None, None | |
except Exception as e: | |
logger.error(f"Error retrieving misconception text: {e}") | |
return None, None | |
# Generate prompt once since it doesn't change between retries | |
prompt = self.generate_prompt(construct_name, subject_name, question_text, | |
correct_answer_text, wrong_answer_text, misconception_text) | |
# Attempt generation with retries | |
for attempt in range(max_retries): | |
try: | |
logger.info(f"Attempt {attempt + 1} of {max_retries}") | |
# Call API | |
generated_text = self.call_model_api(prompt) | |
logger.info(f"Generated text from API: {generated_text}") | |
# Parse output | |
generated_question = self.parse_model_output(generated_text) | |
# Validate the generated question | |
if self.validate_generated_question(generated_question): | |
logger.info("Successfully generated valid question") | |
return generated_question, generated_text | |
else: | |
logger.warning(f"Generated question failed validation on attempt {attempt + 1}") | |
# If this was the last attempt, return None | |
if attempt == max_retries - 1: | |
logger.error("Max retries reached without generating valid question") | |
return None, generated_text | |
# Add delay between retries to avoid rate limiting | |
time.sleep(2) # 2 second delay between retries | |
except Exception as e: | |
logger.error(f"Error during question generation attempt {attempt + 1}: {e}") | |
if attempt == max_retries - 1: | |
return None, None | |
time.sleep(2) # Add delay before retry | |
return None, None |