File size: 6,473 Bytes
a0a2f13
 
 
4be3af4
a0a2f13
b2b9de7
4be3af4
 
a0a2f13
b2b9de7
4be3af4
b2b9de7
4be3af4
b2b9de7
a0a2f13
4be3af4
 
a0a2f13
4be3af4
 
 
 
 
 
 
a0a2f13
 
 
 
 
 
4be3af4
a0a2f13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4be3af4
a0a2f13
 
 
 
 
 
4be3af4
a0a2f13
 
4be3af4
a0a2f13
 
 
 
4be3af4
a0a2f13
 
 
 
 
b2b9de7
4be3af4
 
a0a2f13
b2b9de7
 
a0a2f13
4be3af4
b2b9de7
 
a0a2f13
 
 
 
 
b2b9de7
 
 
a0a2f13
 
b2b9de7
 
 
 
a0a2f13
 
b2b9de7
 
4be3af4
 
a0a2f13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4be3af4
a0a2f13
4be3af4
 
a0a2f13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4be3af4
a0a2f13
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# # module3.py

import re
import requests
from typing import Optional, Tuple
import logging
from dotenv import load_dotenv
import os
from collections import Counter

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load .env file
load_dotenv()

# Hugging Face API information
API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
API_KEY = os.getenv("HUGGINGFACE_API_KEY")

if not API_KEY:
    raise ValueError("API_KEYκ°€ μ„€μ •λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€. .env νŒŒμΌμ„ ν™•μΈν•˜μ„Έμš”.")

class AnswerVerifier:
    def verify_answer(self, question: str, choices: dict, num_checks: int = 5) -> Tuple[Optional[str], str]:
        """
        Self-consistency approachλ₯Ό ν™œμš©ν•œ λ‹΅λ³€ 검증
        num_checks: 동일 μ§ˆλ¬Έμ— λŒ€ν•΄ 반볡 검증할 횟수
        λ°˜ν™˜κ°’: (κ²€μ¦λœ λ‹΅μ•ˆ, μ„€λͺ…) νŠœν”Œ
        """
        try:
            answers = []
            for i, _ in enumerate(range(num_checks)):
                prompt = self._create_prompt(question, choices)
                headers = {"Authorization": f"Bearer {API_KEY}"}
                
                response = requests.post(
                    API_URL,
                    headers=headers,
                    json={"inputs": prompt}
                )
                response.raise_for_status()
                
                response_data = response.json()
                logger.debug(f"Raw API response: {response_data}")
                
                # API 응닡 처리
                generated_text = self._process_response(response_data)

                logger.debug(f"Trial {i+1}:")
                logger.debug(f"Generated text: {generated_text}")
            
                answer = self._extract_answer(generated_text)

                logger.debug(f"Extracted answer: {answer}")
                
                if answer:
                    answers.append(answer)
            
            if not answers:
                return None, "No valid answers extracted"
            
            # # λ‹€μˆ˜κ²° νˆ¬ν‘œλ‘œ μ΅œμ’… λ‹΅μ•ˆ κ²°μ •
            # final_answer, explanation = self._get_majority_vote(answers)
            # logger.info(f"Final verified answer: {final_answer} ({explanation})")
            # return final_answer, explanation
            
            # Return only the final answer instead of a tuple
            final_answer, explanation = self._get_majority_vote(answers)
            logger.info(f"Final verified answer: {final_answer} ({explanation})")
            return final_answer  # κΈ°μ‘΄: return final_answer, explanation


        except Exception as e:
            logger.error(f"Error in verify_answer: {e}")
            return None, f"Error occurred: {str(e)}"

    def _create_prompt(self, question: str, choices: dict) -> str:
        """κ°œμ„ λœ ν”„λ‘¬ν”„νŠΈ - 더 λͺ…ν™•ν•œ 응닡 ν˜•μ‹ μš”κ΅¬"""
        return f"""
        <|begin_of_text|>
        <|start_header_id|>system<|end_header_id|>
        You are an expert mathematics teacher evaluating multiple-choice answers.
        Analyze the question and options carefully to select the correct answer.
        
        IMPORTANT: You must respond ONLY with "Answer: X" where X is A, B, C, or D.
        Do not include any explanation or additional text.
        <|eot_id|>
        <|start_header_id|>user<|end_header_id|>
        Question: {question}
        
        Options:
        A) {choices['A']}
        B) {choices['B']}
        C) {choices['C']}
        D) {choices['D']}
        
        Provide your answer in the format: "Answer: X" (where X is A, B, C, or D)
        <|eot_id|>
        <|start_header_id|>assistant<|end_header_id|>
        """.strip()

    def _process_response(self, response_data) -> str:
        """API 응닡 데이터 처리 - κ°œμ„ λœ 버전"""
        generated_text = ""
        
        if isinstance(response_data, list):
            if response_data and isinstance(response_data[0], dict):
                generated_text = response_data[0].get('generated_text', '')
            else:
                generated_text = response_data[0] if response_data else ''
        elif isinstance(response_data, dict):
            generated_text = response_data.get('generated_text', '')
        else:
            generated_text = str(response_data)
        
        # assistant 응닡 λΆ€λΆ„λ§Œ μΆ”μΆœ
        parts = generated_text.split('<|start_header_id|>assistant<|end_header_id|>')
        if len(parts) > 1:
            return parts[-1].strip()
        return generated_text.strip()


    def _extract_answer(self, response: str) -> Optional[str]:
        """κ°œμ„ λœ λ‹΅μ•ˆ μΆ”μΆœ 둜직"""
        response = response.strip().upper()
        
        # 1. "ANSWER: X" ν˜•μ‹ μ°ΎκΈ°
        answer_pattern = r'(?:ANSWER:|CORRECT ANSWER:)\s*([ABCD])'
        answer_match = re.search(answer_pattern, response)
        if answer_match:
            return answer_match.group(1)
        
        # 2. λ‹¨λ…μœΌλ‘œ μžˆλŠ” A, B, C, D μ°ΎκΈ°
        standalone_pattern = r'\b([ABCD])\b'
        matches = re.findall(standalone_pattern, response)
        
        # λ§ˆμ§€λ§‰μ— μžˆλŠ” λ‹΅μ•ˆ λ°˜ν™˜ (일반적으둜 μ΅œμ’… λ‹΅μ•ˆμ΄ λ§ˆμ§€λ§‰μ— μœ„μΉ˜)
        if matches:
            return matches[-1]
        
        return None
        
    def _get_majority_vote(self, answers: list) -> Tuple[str, str]:
        """κ°œμ„ λœ λ‹€μˆ˜κ²° νˆ¬ν‘œ μ‹œμŠ€ν…œ"""
        if not answers:
            return None, "No valid answers extracted"
        
        counter = Counter(answers)
        
        # 동점인 경우 처리
        max_count = max(counter.values())
        top_answers = [ans for ans, count in counter.items() if count == max_count]
        
        if len(top_answers) > 1:
            return None, f"Tie between answers: {top_answers}"
        
        final_answer = counter.most_common(1)[0][0]
        total_votes = len(answers)
        confidence = (counter[final_answer] / total_votes) * 100
        
        # 신뒰도 μž„κ³„κ°’ μ„€μ •
        if confidence < 60:
            return None, f"Low confidence ({confidence:.1f}%) for answer {final_answer}"
        
        explanation = (f"Answer '{final_answer}' selected with {confidence:.1f}% confidence "
                      f"({counter[final_answer]}/{total_votes} votes). "
                      f"Distribution: {dict(counter)}")
        
        return final_answer, explanation