awacke1 commited on
Commit
a509ab0
Β·
verified Β·
1 Parent(s): c1b0332

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -0
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import openai
4
+ import pandas as pd
5
+ from typing import List, Tuple
6
+ from uuid import uuid4
7
+ import time
8
+
9
+ # πŸ”‘ Set the OpenAI API key from an environment variable
10
+ openai.api_key = os.getenv("OPENAI_API_KEY")
11
+
12
+ # πŸ†” Function to generate a unique session ID for caching
13
+ def get_session_id():
14
+ if 'session_id' not in st.session_state:
15
+ st.session_state.session_id = str(uuid4())
16
+ return st.session_state.session_id
17
+
18
+ # 🧠 STaR Algorithm Implementation
19
+ class SelfTaughtReasoner:
20
+ def __init__(self, model_engine="text-davinci-003"):
21
+ self.model_engine = model_engine
22
+ self.prompt_examples = [] # Initialize with an empty list
23
+ self.iterations = 0
24
+ self.generated_data = pd.DataFrame(columns=['Problem', 'Rationale', 'Answer', 'Is_Correct'])
25
+ self.rationalized_data = pd.DataFrame(columns=['Problem', 'Rationale', 'Answer', 'Is_Correct'])
26
+ self.fine_tuned_model = None # πŸ—οΈ Placeholder for fine-tuned model
27
+
28
+ def add_prompt_example(self, problem: str, rationale: str, answer: str):
29
+ """
30
+ βž• Adds a prompt example to the few-shot examples.
31
+ """
32
+ self.prompt_examples.append({
33
+ 'Problem': problem,
34
+ 'Rationale': rationale,
35
+ 'Answer': answer
36
+ })
37
+
38
+ def construct_prompt(self, problem: str, include_answer: bool = False, answer: str = "") -> str:
39
+ """
40
+ πŸ“ Constructs the prompt for the OpenAI API call.
41
+ """
42
+ prompt = ""
43
+ for example in self.prompt_examples:
44
+ prompt += f"Problem: {example['Problem']}\n"
45
+ prompt += f"Rationale: {example['Rationale']}\n"
46
+ prompt += f"Answer: {example['Answer']}\n\n"
47
+
48
+ prompt += f"Problem: {problem}\n"
49
+ if include_answer:
50
+ prompt += f"Answer (as hint): {answer}\n"
51
+ prompt += "Rationale:"
52
+ return prompt
53
+
54
+ def generate_rationale_and_answer(self, problem: str) -> Tuple[str, str]:
55
+ """
56
+ πŸ€” Generates a rationale and answer for a given problem.
57
+ """
58
+ prompt = self.construct_prompt(problem)
59
+ try:
60
+ response = openai.Completion.create(
61
+ engine=self.model_engine,
62
+ prompt=prompt,
63
+ max_tokens=150,
64
+ temperature=0.7,
65
+ top_p=1,
66
+ frequency_penalty=0,
67
+ presence_penalty=0,
68
+ stop=["\n\n", "Problem:", "Answer:"]
69
+ )
70
+ rationale = response.choices[0].text.strip()
71
+ # πŸ“ Now generate the answer using the rationale
72
+ prompt += f" {rationale}\nAnswer:"
73
+ answer_response = openai.Completion.create(
74
+ engine=self.model_engine,
75
+ prompt=prompt,
76
+ max_tokens=10,
77
+ temperature=0,
78
+ top_p=1,
79
+ frequency_penalty=0,
80
+ presence_penalty=0,
81
+ stop=["\n", "\n\n", "Problem:"]
82
+ )
83
+ answer = answer_response.choices[0].text.strip()
84
+ return rationale, answer
85
+ except Exception as e:
86
+ st.error(f"❌ Error generating rationale and answer: {e}")
87
+ return "", ""
88
+
89
+ def fine_tune_model(self):
90
+ """
91
+ πŸ› οΈ Fine-tunes the model on the generated rationales.
92
+ """
93
+ time.sleep(1) # ⏳ Simulate time taken for fine-tuning
94
+ self.fine_tuned_model = f"{self.model_engine}-fine-tuned-{get_session_id()}"
95
+ st.success(f"βœ… Model fine-tuned: {self.fine_tuned_model}")
96
+
97
+ def run_iteration(self, dataset: pd.DataFrame):
98
+ """
99
+ πŸ”„ Runs one iteration of the STaR process.
100
+ """
101
+ st.write(f"### Iteration {self.iterations + 1}")
102
+ progress_bar = st.progress(0)
103
+ total = len(dataset)
104
+ for idx, row in dataset.iterrows():
105
+ problem = row['Problem']
106
+ correct_answer = row['Answer']
107
+ # πŸ€– Generate rationale and answer
108
+ rationale, answer = self.generate_rationale_and_answer(problem)
109
+ is_correct = (answer.lower() == correct_answer.lower())
110
+ # πŸ“ Record the generated data
111
+ self.generated_data = self.generated_data.append({
112
+ 'Problem': problem,
113
+ 'Rationale': rationale,
114
+ 'Answer': answer,
115
+ 'Is_Correct': is_correct
116
+ }, ignore_index=True)
117
+ # ❌ If incorrect, perform rationalization
118
+ if not is_correct:
119
+ rationale, answer = self.rationalize(problem, correct_answer)
120
+ is_correct = (answer.lower() == correct_answer.lower())
121
+ if is_correct:
122
+ self.rationalized_data = self.rationalized_data.append({
123
+ 'Problem': problem,
124
+ 'Rationale': rationale,
125
+ 'Answer': answer,
126
+ 'Is_Correct': is_correct
127
+ }, ignore_index=True)
128
+ progress_bar.progress((idx + 1) / total)
129
+ # πŸ”§ Fine-tune the model on correct rationales
130
+ st.write("πŸ”„ Fine-tuning the model on correct rationales...")
131
+ self.fine_tune_model()
132
+ self.iterations += 1
133
+
134
+ # Predefined problem and answer list for dataset
135
+ EXAMPLE_PROBLEM_ANSWERS = [
136
+ {"Problem": "What is deductive reasoning?", "Answer": "It is a logical process that draws specific conclusions from general principles."},
137
+ {"Problem": "What is inductive reasoning?", "Answer": "It is reasoning that forms general principles from specific examples."},
138
+ {"Problem": "Explain abductive reasoning.", "Answer": "It involves finding the best explanation for incomplete observations."},
139
+ {"Problem": "What is the capital of France?", "Answer": "Paris."},
140
+ {"Problem": "Who wrote Hamlet?", "Answer": "William Shakespeare."}
141
+ ]
142
+
143
+ # Additional problem set for testing fine-tuned model
144
+ TEST_PROBLEM_SET = [
145
+ "What is the Pythagorean theorem?",
146
+ "Who developed the theory of relativity?",
147
+ "What is the main ingredient in bread?",
148
+ "Who is the author of 1984?",
149
+ "What is the boiling point of water?"
150
+ ]
151
+
152
+ # Convert the example list into 'Problem | Answer' format
153
+ def format_examples_for_text_area(examples):
154
+ return '\n'.join([f"{example['Problem']} | {example['Answer']}" for example in examples])
155
+
156
+ # πŸ–₯️ Streamlit App
157
+ def main():
158
+ st.title("πŸ€– Self-Taught Reasoner (STaR) Demonstration")
159
+
160
+ # 🧩 Initialize the Self-Taught Reasoner
161
+ if 'star' not in st.session_state:
162
+ st.session_state.star = SelfTaughtReasoner()
163
+
164
+ star = st.session_state.star
165
+
166
+ # Step 1: Few-Shot Prompt Examples
167
+ st.header("Step 1: Add Few-Shot Prompt Examples")
168
+ st.write("Choose an example from the dropdown or input your own.")
169
+
170
+ selected_example = st.selectbox(
171
+ "Select a predefined example",
172
+ [f"Example {i + 1}: {ex['Problem']}" for i, ex in enumerate(EXAMPLE_PROBLEM_ANSWERS)]
173
+ )
174
+
175
+ # Prefill with selected example
176
+ example_idx = int(selected_example.split(" ")[1].replace(":", "")) - 1
177
+ example_problem = EXAMPLE_PROBLEM_ANSWERS[example_idx]['Problem']
178
+ example_answer = EXAMPLE_PROBLEM_ANSWERS[example_idx]['Answer']
179
+
180
+ st.text_area("Problem", value=example_problem, height=50, key="example_problem")
181
+ st.text_input("Answer", value=example_answer, key="example_answer")
182
+
183
+ if st.button("Add Example"):
184
+ star.add_prompt_example(st.session_state.example_problem, "Rationale placeholder", st.session_state.example_answer)
185
+ st.success("Example added successfully!")
186
+
187
+ # Step 2: Input Dataset (Problem | Answer format)
188
+ st.header("Step 2: Input Dataset")
189
+
190
+ # Provide examples in the format 'Problem | Answer' as a default
191
+ prefilled_data = format_examples_for_text_area(EXAMPLE_PROBLEM_ANSWERS)
192
+ dataset_problems = st.text_area(
193
+ "Enter problems and answers in the format 'Problem | Answer', one per line.",
194
+ value=prefilled_data,
195
+ height=200
196
+ )
197
+
198
+ if st.button("Submit Dataset"):
199
+ dataset = []
200
+ lines = dataset_problems.strip().split('\n')
201
+ for line in lines:
202
+ if '|' in line:
203
+ problem, answer = line.split('|', 1)
204
+ dataset.append({'Problem': problem.strip(), 'Answer': answer.strip()})
205
+ st.session_state.dataset = pd.DataFrame(dataset)
206
+ st.success("Dataset loaded.")
207
+
208
+ if 'dataset' in st.session_state:
209
+ st.subheader("Current Dataset:")
210
+ st.dataframe(st.session_state.dataset.head())
211
+
212
+ # Step 3: Test the Fine-Tuned Model (renamed from Step 4)
213
+ st.header("Step 3: Test the Fine-Tuned Model")
214
+
215
+ # Add dropdown for selecting a test problem
216
+ test_problem = st.selectbox(
217
+ "Select a problem to test the fine-tuned model",
218
+ TEST_PROBLEM_SET
219
+ )
220
+
221
+ if st.button("Solve Problem"):
222
+ if not test_problem:
223
+ st.warning("Please enter or select a problem to solve.")
224
+ else:
225
+ rationale, answer = star.generate_rationale_and_answer(test_problem)
226
+ st.subheader("Rationale:")
227
+ st.write(rationale)
228
+ st.subheader("Answer:")
229
+ st.write(answer)
230
+
231
+ # Footer
232
+ st.write("---")
233
+ st.write("Developed as a demonstration of the STaR method.")
234
+
235
+ if __name__ == "__main__":
236
+ main()