Jintonic92's picture
Update src/ThirdModule/module3.py
a0a2f13 verified
# # module3.py
import re
import requests
from typing import Optional, Tuple
import logging
from dotenv import load_dotenv
import os
from collections import Counter
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load .env file
load_dotenv()
# Hugging Face API information
API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
API_KEY = os.getenv("HUGGINGFACE_API_KEY")
if not API_KEY:
raise ValueError("API_KEYκ°€ μ„€μ •λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€. .env νŒŒμΌμ„ ν™•μΈν•˜μ„Έμš”.")
class AnswerVerifier:
def verify_answer(self, question: str, choices: dict, num_checks: int = 5) -> Tuple[Optional[str], str]:
"""
Self-consistency approachλ₯Ό ν™œμš©ν•œ λ‹΅λ³€ 검증
num_checks: 동일 μ§ˆλ¬Έμ— λŒ€ν•΄ 반볡 검증할 횟수
λ°˜ν™˜κ°’: (κ²€μ¦λœ λ‹΅μ•ˆ, μ„€λͺ…) νŠœν”Œ
"""
try:
answers = []
for i, _ in enumerate(range(num_checks)):
prompt = self._create_prompt(question, choices)
headers = {"Authorization": f"Bearer {API_KEY}"}
response = requests.post(
API_URL,
headers=headers,
json={"inputs": prompt}
)
response.raise_for_status()
response_data = response.json()
logger.debug(f"Raw API response: {response_data}")
# API 응닡 처리
generated_text = self._process_response(response_data)
logger.debug(f"Trial {i+1}:")
logger.debug(f"Generated text: {generated_text}")
answer = self._extract_answer(generated_text)
logger.debug(f"Extracted answer: {answer}")
if answer:
answers.append(answer)
if not answers:
return None, "No valid answers extracted"
# # λ‹€μˆ˜κ²° νˆ¬ν‘œλ‘œ μ΅œμ’… λ‹΅μ•ˆ κ²°μ •
# final_answer, explanation = self._get_majority_vote(answers)
# logger.info(f"Final verified answer: {final_answer} ({explanation})")
# return final_answer, explanation
# Return only the final answer instead of a tuple
final_answer, explanation = self._get_majority_vote(answers)
logger.info(f"Final verified answer: {final_answer} ({explanation})")
return final_answer # κΈ°μ‘΄: return final_answer, explanation
except Exception as e:
logger.error(f"Error in verify_answer: {e}")
return None, f"Error occurred: {str(e)}"
def _create_prompt(self, question: str, choices: dict) -> str:
"""κ°œμ„ λœ ν”„λ‘¬ν”„νŠΈ - 더 λͺ…ν™•ν•œ 응닡 ν˜•μ‹ μš”κ΅¬"""
return f"""
<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
You are an expert mathematics teacher evaluating multiple-choice answers.
Analyze the question and options carefully to select the correct answer.
IMPORTANT: You must respond ONLY with "Answer: X" where X is A, B, C, or D.
Do not include any explanation or additional text.
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
Question: {question}
Options:
A) {choices['A']}
B) {choices['B']}
C) {choices['C']}
D) {choices['D']}
Provide your answer in the format: "Answer: X" (where X is A, B, C, or D)
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
""".strip()
def _process_response(self, response_data) -> str:
"""API 응닡 데이터 처리 - κ°œμ„ λœ 버전"""
generated_text = ""
if isinstance(response_data, list):
if response_data and isinstance(response_data[0], dict):
generated_text = response_data[0].get('generated_text', '')
else:
generated_text = response_data[0] if response_data else ''
elif isinstance(response_data, dict):
generated_text = response_data.get('generated_text', '')
else:
generated_text = str(response_data)
# assistant 응닡 λΆ€λΆ„λ§Œ μΆ”μΆœ
parts = generated_text.split('<|start_header_id|>assistant<|end_header_id|>')
if len(parts) > 1:
return parts[-1].strip()
return generated_text.strip()
def _extract_answer(self, response: str) -> Optional[str]:
"""κ°œμ„ λœ λ‹΅μ•ˆ μΆ”μΆœ 둜직"""
response = response.strip().upper()
# 1. "ANSWER: X" ν˜•μ‹ μ°ΎκΈ°
answer_pattern = r'(?:ANSWER:|CORRECT ANSWER:)\s*([ABCD])'
answer_match = re.search(answer_pattern, response)
if answer_match:
return answer_match.group(1)
# 2. λ‹¨λ…μœΌλ‘œ μžˆλŠ” A, B, C, D μ°ΎκΈ°
standalone_pattern = r'\b([ABCD])\b'
matches = re.findall(standalone_pattern, response)
# λ§ˆμ§€λ§‰μ— μžˆλŠ” λ‹΅μ•ˆ λ°˜ν™˜ (일반적으둜 μ΅œμ’… λ‹΅μ•ˆμ΄ λ§ˆμ§€λ§‰μ— μœ„μΉ˜)
if matches:
return matches[-1]
return None
def _get_majority_vote(self, answers: list) -> Tuple[str, str]:
"""κ°œμ„ λœ λ‹€μˆ˜κ²° νˆ¬ν‘œ μ‹œμŠ€ν…œ"""
if not answers:
return None, "No valid answers extracted"
counter = Counter(answers)
# 동점인 경우 처리
max_count = max(counter.values())
top_answers = [ans for ans, count in counter.items() if count == max_count]
if len(top_answers) > 1:
return None, f"Tie between answers: {top_answers}"
final_answer = counter.most_common(1)[0][0]
total_votes = len(answers)
confidence = (counter[final_answer] / total_votes) * 100
# 신뒰도 μž„κ³„κ°’ μ„€μ •
if confidence < 60:
return None, f"Low confidence ({confidence:.1f}%) for answer {final_answer}"
explanation = (f"Answer '{final_answer}' selected with {confidence:.1f}% confidence "
f"({counter[final_answer]}/{total_votes} votes). "
f"Distribution: {dict(counter)}")
return final_answer, explanation