danielnashed commited on
Commit
c7476e8
Β·
verified Β·
1 Parent(s): d11b2ea

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +329 -0
app.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, DataCollatorWithPadding, DefaultDataCollator
3
+ from openai import OpenAI
4
+ from huggingface_hub import login
5
+ import datasets
6
+ from datasets import Dataset
7
+ import json
8
+ import pandas as pd
9
+ import torch
10
+ import wandb
11
+ import os
12
+ import sys
13
+ from peft import LoraConfig, TaskType, get_peft_model, AutoPeftModelForCausalLM
14
+ from sklearn.model_selection import train_test_split
15
+
16
+ IS_COLAB = False
17
+ if "google.colab" in sys.modules or "google.colab" in os.environ:
18
+ IS_COLAB = True
19
+
20
+ # Load env secrets
21
+ if IS_COLAB:
22
+ from google.colab import userdata
23
+ OPENAI_API_KEY=userdata.get('OPENAI_API_KEY')
24
+ WANDB_API_KEY=userdata.get('WANDB_API_KEY')
25
+ else:
26
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
27
+ WANDB_API_KEY = os.environ.get("WANDB_API_KEY")
28
+
29
+ # Authenticate Weights and Biases
30
+ wandb.login(key=WANDB_API_KEY)
31
+
32
+ # Custom callback to capture logs
33
+ class LoggingCallback(TrainerCallback):
34
+ def __init__(self):
35
+ self.logs = [] # Store logs
36
+
37
+ def on_log(self, args, state, control, logs=None, **kwargs):
38
+ if logs:
39
+ self.logs.append(logs) # Append logs to list
40
+
41
+
42
+ class LLMTrainingApp:
43
+ def __init__(self):
44
+ # self.metric = datasets.load_metric('sacrebleu')
45
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ self.finetuning_dataset = []
47
+ self.prompt_template = """### Question: {question} ### Answer: """
48
+ self.training_output = "/content/peft-model" if IS_COLAB else "./peft-model"
49
+ self.localpath = "/content/finetuned-model" if IS_COLAB else "./finetuned-model"
50
+ self.tokenizer = None
51
+ self.model = None
52
+ self.model_name = None
53
+ self.fine_tuned_model = None
54
+ self.teacher_model = OpenAI(api_key=OPENAI_API_KEY)
55
+ self.base_models = {
56
+ "SmolLM": {"hf_name":"HuggingFaceTB/SmolLM2-135M",
57
+ "model_size": "135M",
58
+ "training_size": "2T",
59
+ "context_window": "8192"},
60
+ "GPT2": {"hf_name":"openai-community/gpt2",
61
+ "model_size": "137M",
62
+ "training_size": "2T",
63
+ "context_window": "1024"}
64
+ }
65
+ self.peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
66
+ self.logging_callback = LoggingCallback()
67
+
68
+ def login_into_hf(self, token):
69
+ if not token:
70
+ return "❌ Please enter a valid token."
71
+ try:
72
+ login(token)
73
+ return f"βœ… Logged in successfully!"
74
+ except Exception as e:
75
+ return f"❌ Login failed: {str(e)}"
76
+
77
+ def select_model(self, model_name):
78
+ self.model_name = model_name
79
+ model_hf_name = self.base_models[model_name]["hf_name"]
80
+ try:
81
+ self.tokenizer = AutoTokenizer.from_pretrained(model_hf_name)
82
+ self.tokenizer.pad_token = self.tokenizer.eos_token
83
+ base_model = AutoModelForCausalLM.from_pretrained(
84
+ model_hf_name,
85
+ torch_dtype="auto",
86
+ device_map="auto"
87
+ )
88
+ self.model = get_peft_model(base_model, self.peft_config)
89
+ params = self.model.get_nb_trainable_parameters()
90
+ percent_trainable = round(100 * (params[0] / params[1]), 2)
91
+ return f"βœ… Loaded model into memory! Base Model card: {json.dumps(self.base_models[model_name])} - % of trainable parameters for PEFT model: {percent_trainable}"
92
+ except Exception as e:
93
+ return f"❌ Failed to load model and/or tokenizer: {str(e)}"
94
+
95
+ def create_golden_dataset(self, dataset):
96
+ try:
97
+ dataset = pd.DataFrame(dataset)
98
+ for i, row in dataset.iterrows():
99
+ self.finetuning_dataset.append({"question": self.prompt_template.format(question=row["Question"]), "answer": row["Answer"]})
100
+ return "βœ… Golden dataset created!"
101
+ except Exception as e:
102
+ return f"❌ Failed to create dataset: {str(e)}"
103
+
104
+ def extend_dataset(self):
105
+ try:
106
+ completion = self.teacher_model.chat.completions.create(
107
+ model="gpt-4o",
108
+ messages=[
109
+ {
110
+ "role": "user",
111
+ "content": """Given the following question-answer pairs, generate 10 similar pairs in the following json format below. Do not respond with anything other than the json.
112
+ ```json
113
+ [
114
+ {
115
+ "question": "question 1",
116
+ "answer": "answer 1"
117
+ },
118
+ {
119
+ "question": "question 2",
120
+ "answer": "answer 2"
121
+ }
122
+ ]
123
+ """
124
+ }
125
+ ]
126
+ )
127
+ response = completion.choices[0].message.content
128
+ print(f"raw response: {response}")
129
+ clean_response = response.replace("```json", "").replace("```", "").strip()
130
+ print(f"clean response: {clean_response}")
131
+ new_data = json.loads(clean_response)
132
+ for i, row in enumerate(new_data):
133
+ self.finetuning_dataset.append({"question": self.prompt_template.format(question=row["question"]), "answer": row["answer"]})
134
+ # create df to display
135
+ df = pd.DataFrame(new_data)
136
+ return "βœ… Synthetic dataset generated!", df
137
+ except Exception as e:
138
+ return f"❌ Failed to generate synthetic dataset: {str(e)}", pd.DataFrame()
139
+
140
+ def tokenize_function(self, examples):
141
+ try:
142
+ # Tokenize the question and answer as input and target (labels) for causal LM
143
+ encoding = self.tokenizer(examples['question'], examples['answer'], padding=True)
144
+ # Set the labels as the input_ids
145
+ encoding['labels'] = encoding['input_ids'].copy()
146
+ return encoding
147
+ except Exception as e:
148
+ return f"❌ Failed to tokenize input: {str(e)}"
149
+
150
+
151
+ def prepare_data_for_training(self):
152
+ try:
153
+ dataset = Dataset.from_dict({
154
+ "question": [entry["question"] for entry in self.finetuning_dataset],
155
+ "answer": [entry["answer"] for entry in self.finetuning_dataset],
156
+ })
157
+ dataset = dataset.map(self.tokenize_function, batched=True)
158
+ train_dataset, test_dataset = dataset.train_test_split(test_size=0.2).values()
159
+ return {"train": train_dataset, "test": test_dataset}
160
+ except Exception as e:
161
+ return f"❌ Failed to prepare data for training: {str(e)}"
162
+
163
+
164
+ def compute_bleu(self, eval_pred):
165
+ predictions, labels = eval_pred
166
+ # # Flatten predictions and labels if they are in nested lists
167
+ # predictions = predictions.flatten()
168
+ # labels = labels.flatten()
169
+ # # Ensure that predictions and labels are integers
170
+ # predictions = predictions.astype(int) # Convert to integer
171
+ # labels = labels.astype(int) # Convert to integer
172
+ # # Decode the predicted tokens
173
+ # decoded_preds = self.tokenizer.decode(predictions, skip_special_tokens=True)
174
+ # decoded_labels = self.tokenizer.decode(labels, skip_special_tokens=True)
175
+ # result = self.metric.compute(predictions=[decoded_preds], references=[[decoded_labels]])
176
+ result = {"bleu": 1}
177
+ return result
178
+
179
+ def log_generator(self):
180
+ """ Continuously send logs to frontend during training """
181
+ for log in self.logging_callback.logs:
182
+ yield str(log)
183
+
184
+ def train_model(self):
185
+ try:
186
+ tokenized_datasets = self.prepare_data_for_training()
187
+
188
+ # Create training arguments
189
+ training_args = TrainingArguments(
190
+ output_dir=self.training_output,
191
+ learning_rate=1e-3,
192
+ per_device_train_batch_size=32,
193
+ per_device_eval_batch_size=32,
194
+ num_train_epochs=2,
195
+ weight_decay=0.01,
196
+ eval_strategy="epoch",
197
+ save_strategy="epoch",
198
+ load_best_model_at_end=True,
199
+ )
200
+
201
+ # Create trainer & attach logging callback
202
+ trainer = Trainer(
203
+ model=self.model,
204
+ args=training_args,
205
+ train_dataset=tokenized_datasets["train"],
206
+ eval_dataset=tokenized_datasets["test"],
207
+ tokenizer=self.tokenizer,
208
+ data_collator=DefaultDataCollator(),
209
+ compute_metrics=self.compute_bleu,
210
+ callbacks=[self.logging_callback],
211
+ )
212
+
213
+ # Start training and yield logs in real-time
214
+ trainer.train()
215
+
216
+ # for log in logging_callback.logs:
217
+ # yield str(log)
218
+
219
+ # Save trained model to HF
220
+ self.model.save_pretrained(self.localpath) # save to local
221
+ self.model.push_to_hub(f"{self.model_name}-lora")
222
+
223
+ return "βœ… Training complete!"
224
+ except Exception as e:
225
+ return f"❌ Training failed: {str(e)}"
226
+
227
+ def run_inference(self, prompt):
228
+ try:
229
+ # Load fine-tuned memory into memory and set mode to eval
230
+ self.fine_tuned_model = AutoPeftModelForCausalLM.from_pretrained(self.localpath)
231
+ self.fine_tuned_model = self.fine_tuned_model.to(self.device)
232
+ self.fine_tuned_model.eval()
233
+
234
+ # Tokenize input with padding and attention mask
235
+ inputs = self.tokenizer(prompt, return_tensors="pt", padding=True).to(self.device)
236
+
237
+ # Generate response
238
+ output = self.fine_tuned_model.generate(
239
+ **inputs,
240
+ max_length=50, # Limit response length
241
+ num_return_sequences=1, # Single response
242
+ temperature=0.7, # Sampling randomness
243
+ top_p=0.9 # Nucleus sampling
244
+ )
245
+
246
+ response = self.tokenizer.batch_decode(output.detach().cpu().numpy(), skip_special_tokens=True)[0]
247
+ return response
248
+ except Exception as e:
249
+ return f"❌ Inference failed: {str(e)}"
250
+
251
+ def build_ui(self):
252
+ with gr.Blocks() as demo:
253
+ gr.Markdown("# LLM Fine-tuning")
254
+
255
+ # Model Selection
256
+ with gr.Group():
257
+ gr.Markdown("### 1. Login into Hugging Face")
258
+ with gr.Column():
259
+ token = gr.Textbox(label="Enter Hugging Face Access Token (w/ write permissions)", type="password")
260
+ inference_btn = gr.Button("Login", variant="primary")
261
+ status = gr.Textbox(label="Status")
262
+ inference_btn.click(self.login_into_hf, inputs=token, outputs=status)
263
+
264
+ # Model Selection
265
+ with gr.Group():
266
+ gr.Markdown("### 2. Select Model")
267
+ with gr.Column():
268
+ model_dropdown = gr.Dropdown([key for key in self.base_models.keys()], label="Small Models")
269
+ select_model_btn = gr.Button("Select", variant="primary")
270
+ selected_model_text = gr.Textbox(label="Model Status")
271
+ select_model_btn.click(self.select_model, inputs=model_dropdown, outputs=[selected_model_text])
272
+
273
+ # Create Golden Dataset
274
+ with gr.Group():
275
+ gr.Markdown("### 3. Create Golden Dataset")
276
+ with gr.Column():
277
+ dataset_table = gr.Dataframe(
278
+ headers=["Question", "Answer"],
279
+ value=[["", ""] for _ in range(3)],
280
+ label="Golden Dataset"
281
+ )
282
+ create_data_btn = gr.Button("Create Dataset", variant="primary")
283
+ dataset_status = gr.Textbox(label="Dataset Status")
284
+ create_data_btn.click(self.create_golden_dataset, inputs=dataset_table, outputs=[dataset_status])
285
+
286
+ # Generate Full Dataset
287
+ with gr.Group():
288
+ gr.Markdown("### 4. Extend Dataset with Synthetic Data")
289
+ with gr.Column():
290
+ dataset_table = gr.Dataframe(
291
+ headers=["Question", "Answer"],
292
+ label="Golden + Synthetic Dataset"
293
+ )
294
+ generate_status = gr.Textbox(label="Dataset Generation Status")
295
+ generate_data_btn = gr.Button("Generate Dataset", variant="primary")
296
+ generate_data_btn.click(self.extend_dataset, outputs=[generate_status, dataset_table])
297
+
298
+ # Train Model & Visualize Loss
299
+ with gr.Group():
300
+ gr.Markdown("### 5. Start Logging")
301
+ with gr.Column():
302
+ train_status = gr.Textbox(label="Training Status", lines=10)
303
+ train_btn = gr.Button("Train", variant="primary")
304
+ train_btn.click(self.log_generator, outputs=[train_status])
305
+
306
+ # Train Model & Visualize Loss
307
+ with gr.Group():
308
+ gr.Markdown("### 6. Train Model")
309
+ with gr.Column():
310
+ train_status = gr.Textbox(label="Training Status")
311
+ train_btn = gr.Button("Train", variant="primary")
312
+ train_btn.click(self.train_model, outputs=[train_status])
313
+
314
+ # Run Inference
315
+ with gr.Group():
316
+ gr.Markdown("### 7. Run Inference")
317
+ with gr.Column():
318
+ user_prompt = gr.Textbox(label="Enter Prompt")
319
+ inference_btn = gr.Button("Run Inference", variant="primary")
320
+ inference_output = gr.Textbox(label="Inference Output")
321
+ inference_btn.click(self.run_inference, inputs=user_prompt, outputs=inference_output)
322
+
323
+ return demo
324
+
325
+ # Create an instance of the app
326
+ app = LLMTrainingApp()
327
+
328
+ # Launch the Gradio app using the class method
329
+ app.build_ui().launch()