Jintonic92's picture
Update src/SecondModule/module2.py
02cfaa4 verified
import pandas as pd
import requests
from typing import Tuple, Optional
from dataclasses import dataclass
import logging
from dotenv import load_dotenv
import os
import time
import re
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# .env ํŒŒ์ผ ๋กœ๋“œ
load_dotenv()
# Hugging Face API ์ •๋ณด
API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
API_KEY = os.getenv("HUGGINGFACE_API_KEY")
base_path = os.path.dirname(os.path.abspath(__file__))
misconception_csv_path = os.path.join(base_path, 'misconception_mapping.csv')
if not API_KEY:
raise ValueError("API_KEY๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. .env ํŒŒ์ผ์„ ํ™•์ธํ•˜์„ธ์š”.")
#์œ ์‚ฌ ๋ฌธ์ œ ์ƒ์„ฑ๊ธฐ ํด๋ž˜์Šค
@dataclass
class GeneratedQuestion:
question: str
choices: dict
correct_answer: str
explanation: str
class SimilarQuestionGenerator:
def __init__(self, misconception_csv_path: str = 'misconception_mapping.csv'):
"""
Initialize the generator by loading the misconception mapping and the language model.
"""
self._load_data(misconception_csv_path)
def _load_data(self, misconception_csv_path: str):
logger.info("Loading misconception mapping...")
self.misconception_df = pd.read_csv(misconception_csv_path)
def get_misconception_text(self, misconception_id: float) -> Optional[str]:
# MisconceptionId๋ฅผ ๋ฐ›์•„ ํ•ด๋‹น ID์— ๋งค์นญ๋˜๋Š” ์˜ค๊ฐœ๋… ์„ค๋ช… ํ…์ŠคํŠธ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค
"""Retrieve the misconception text based on the misconception ID."""
if pd.isna(misconception_id): # NaN ์ฒดํฌ
logger.warning("Received NaN for misconception_id.")
return "No misconception provided."
try:
row = self.misconception_df[self.misconception_df['MisconceptionId'] == int(misconception_id)]
if not row.empty:
return row.iloc[0]['MisconceptionName']
except ValueError as e:
logger.error(f"Error processing misconception_id: {e}")
logger.warning(f"No misconception found for ID: {misconception_id}")
return "Misconception not found."
def generate_prompt(self, construct_name: str, subject_name: str, question_text: str, correct_answer_text: str, wrong_answer_text: str, misconception_text: str) -> str:
"""Create a prompt for the language model."""
#๋ฌธ์ œ ์ƒ์„ฑ์„ ์œ„ํ•œ ํ”„๋กฌํ”„ํŠธ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑ
logger.info("Generating prompt...")
misconception_clause = (f"that targets the following misconception: \"{misconception_text}\"." if misconception_text != "There is no misconception" else "")
prompt = f"""
<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
You are an educational assistant designed to generate multiple-choice questions {misconception_clause}
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
You need to create a similar multiple-choice question based on the following details:
Construct Name: {construct_name}
Subject Name: {subject_name}
Question Text: {question_text}
Correct Answer: {correct_answer_text}
Wrong Answer: {wrong_answer_text}
Please follow this output format:
---
Question: <Your Question Text>
A) <Choice A>
B) <Choice B>
C) <Choice C>
D) <Choice D>
Correct Answer: <Correct Choice (e.g., A)>
Explanation: <Brief explanation for the correct answer>
---
Ensure that the question is conceptually similar but not identical to the original. Ensure clarity and educational value.
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
""".strip()
logger.debug(f"Generated prompt: {prompt}")
return prompt
def call_model_api(self, prompt: str) -> str:
"""Hugging Face API ํ˜ธ์ถœ"""
logger.info("Calling Hugging Face API...")
headers = {"Authorization": f"Bearer {API_KEY}"}
try:
response = requests.post(API_URL, headers=headers, json={"inputs": prompt})
response.raise_for_status()
response_data = response.json()
logger.debug(f"Raw API response: {response_data}")
# API ์‘๋‹ต์ด ๋ฆฌ์ŠคํŠธ์ธ ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ
if isinstance(response_data, list):
if response_data and isinstance(response_data[0], dict):
generated_text = response_data[0].get('generated_text', '')
else:
generated_text = response_data[0] if response_data else ''
# API ์‘๋‹ต์ด ๋”•์…”๋„ˆ๋ฆฌ์ธ ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ
elif isinstance(response_data, dict):
generated_text = response_data.get('generated_text', '')
else:
generated_text = str(response_data)
logger.info(f"Generated text: {generated_text}")
return generated_text
except requests.exceptions.RequestException as e:
logger.error(f"API request failed: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error in call_model_api: {e}")
raise
# --- module2.py ์ค‘ ์ผ๋ถ€ ---
def parse_model_output(self, output: str) -> GeneratedQuestion:
"""Parse the model output with improved extraction of the question components."""
if not isinstance(output, str):
logger.error(f"Invalid output format: {type(output)}. Expected string.")
raise ValueError("Model output is not a string")
logger.info("Parsing model output...")
# 1) ์ „์ฒด ํ…์ŠคํŠธ๋ฅผ ์ค„ ๋‹จ์œ„๋กœ ๋‚˜๋ˆ”
lines = output.splitlines()
# 2) ๋งˆ์ง€๋ง‰์œผ๋กœ ๋“ฑ์žฅํ•˜๋Š” Question~Explanation ๋ธ”๋ก์„ ์ฐพ๊ธฐ ์œ„ํ•œ ์ž„์‹œ ๋ณ€์ˆ˜
question = ""
choices = {}
correct_answer = ""
explanation = ""
# ์ด ๋ธ”๋ก์„ ์—ฌ๋Ÿฌ ๋ฒˆ ๋งŒ๋‚  ์ˆ˜ ์žˆ์œผ๋‹ˆ, ์ผ๋‹จ ๋ฐœ๊ฒฌํ•  ๋•Œ๋งˆ๋‹ค ์ €์žฅํ•ด๋‘๊ณ  ๋ฎ์–ด์”Œ์šฐ๋Š” ๋ฐฉ์‹.
# ์ตœ์ข…์ ์œผ๋กœ "๋งˆ์ง€๋ง‰์— ๋ฐœ๊ฒฌ๋œ" Question ๋ธ”๋ก์ด ์•„๋ž˜ ๋ณ€์ˆ˜๋ฅผ ๋ฎ์–ด์“ฐ๊ฒŒ ๋จ
temp_question = ""
temp_choices = {}
temp_correct = ""
temp_explanation = ""
for line in lines:
line = line.strip()
if not line:
continue
# Question:
if line.lower().startswith("question:"):
# ์ง€๊ธˆ๊นŒ์ง€ ์ €์žฅํ•ด๋‘” ์ด์ „ ๋ธ”๋ก๋“ค์„ ์ตœ์ข… ์ €์žฅ ์˜์—ญ์— ๋ฎ์–ด์”Œ์šด๋‹ค
if temp_question:
question = temp_question
choices = temp_choices
correct_answer = temp_correct
explanation = temp_explanation
# ์ƒˆ ๋ธ”๋ก์„ ์‹œ์ž‘
temp_question = line.split(":", 1)[1].strip()
temp_choices = {}
temp_correct = ""
temp_explanation = ""
# A) / B) / C) / D)
elif re.match(r"^[ABCD]\)", line):
# "A) ์„ ํƒ์ง€ ๋‚ด์šฉ"
letter = line[0] # A, B, C, D
choice_text = line[2:].strip()
temp_choices[letter] = choice_text
# Correct Answer:
elif line.lower().startswith("correct answer:"):
# "Correct Answer: A)" ํ˜•ํƒœ์—์„œ A๋งŒ ์ถ”์ถœ
ans_part = line.split(":", 1)[1].strip()
temp_correct = ans_part[0].upper() if ans_part else ""
# Explanation:
elif line.lower().startswith("explanation:"):
temp_explanation = line.split(":", 1)[1].strip()
# ๋ฃจํ”„๊ฐ€ ๋๋‚œ ๋’ค, ํ•œ ๋ฒˆ ๋” ์ตœ์‹  ๋ธ”๋ก์„ ์ตœ์ข… ๋ณ€์ˆ˜์— ๋ฐ˜์˜
if temp_question:
question = temp_question
choices = temp_choices
correct_answer = temp_correct
explanation = temp_explanation
# ์ด์ œ question, choices, correct_answer, explanation์ด ์ตœ์ข… ํŒŒ์‹ฑ ๊ฒฐ๊ณผ
logger.debug(f"Parsed components - Question: {question}, Choices: {choices}, "
f"Correct Answer: {correct_answer}, Explanation: {explanation}")
return GeneratedQuestion(question, choices, correct_answer, explanation)
def validate_generated_question(self, question: GeneratedQuestion) -> bool:
"""Validate if all components of the generated question are present and valid."""
logger.info("Validating generated question...")
try:
# Check if question text exists and is not too short
if not question.question or len(question.question.strip()) < 10:
logger.warning("Question text is missing or too short")
return False
# Check if all four choices exist and are not empty
required_choices = set(['A', 'B', 'C', 'D'])
if set(question.choices.keys()) != required_choices:
logger.warning(f"Missing choices. Found: {set(question.choices.keys())}")
return False
if not all(choice.strip() for choice in question.choices.values()):
logger.warning("Empty choice text found")
return False
# Check if correct answer is valid (should be just A, B, C, or D)
if not question.correct_answer or question.correct_answer not in required_choices:
logger.warning(f"Invalid correct answer: {question.correct_answer}")
return False
# Check if explanation exists and is not too short
if not question.explanation or len(question.explanation.strip()) < 20:
logger.warning("Explanation is missing or too short")
return False
logger.info("Question validation passed")
return True
except Exception as e:
logger.error(f"Error during validation: {e}")
return False
def generate_similar_question_with_text(self, construct_name: str, subject_name: str,
question_text: str, correct_answer_text: str,
wrong_answer_text: str, misconception_id: float,
max_retries: int = 3) -> Tuple[Optional[GeneratedQuestion], Optional[str]]:
"""Generate a similar question with validation and retry mechanism."""
logger.info("generate_similar_question_with_text initiated")
# Get misconception text
try:
misconception_text = self.get_misconception_text(misconception_id)
logger.info(f"Misconception text retrieved: {misconception_text}")
if not misconception_text:
logger.info("Skipping question generation due to lack of misconception.")
return None, None
except Exception as e:
logger.error(f"Error retrieving misconception text: {e}")
return None, None
# Generate prompt once since it doesn't change between retries
prompt = self.generate_prompt(construct_name, subject_name, question_text,
correct_answer_text, wrong_answer_text, misconception_text)
# Attempt generation with retries
for attempt in range(max_retries):
try:
logger.info(f"Attempt {attempt + 1} of {max_retries}")
# Call API
generated_text = self.call_model_api(prompt)
logger.info(f"Generated text from API: {generated_text}")
# Parse output
generated_question = self.parse_model_output(generated_text)
# Validate the generated question
if self.validate_generated_question(generated_question):
logger.info("Successfully generated valid question")
return generated_question, generated_text
else:
logger.warning(f"Generated question failed validation on attempt {attempt + 1}")
# If this was the last attempt, return None
if attempt == max_retries - 1:
logger.error("Max retries reached without generating valid question")
return None, generated_text
# Add delay between retries to avoid rate limiting
time.sleep(2) # 2 second delay between retries
except Exception as e:
logger.error(f"Error during question generation attempt {attempt + 1}: {e}")
if attempt == max_retries - 1:
return None, None
time.sleep(2) # Add delay before retry
return None, None