mohitkumarrajbadi commited on
Commit
e6f0893
·
1 Parent(s): f75f33e

Add application file

Browse files
Files changed (1) hide show
  1. app.py +354 -0
app.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import matplotlib.pyplot as plt
7
+ import time
8
+ import json
9
+ import re
10
+ import os
11
+ import asyncio
12
+
13
+
14
+ # -------------------------------
15
+ # Utility Functions
16
+ # -------------------------------
17
+
18
+ token = "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"
19
+ os.environ['CURL_CA_BUNDLE'] = ''
20
+
21
+ @st.cache_resource
22
+ def load_model(model_id: str, token: str):
23
+ """
24
+ Loads and caches the Gemma model and tokenizer with authentication token.
25
+ """
26
+ try:
27
+ # Create and run an event loop explicitly
28
+ asyncio.run(async_load(model_id, token))
29
+
30
+ # Ensure torch classes path is valid (optional)
31
+ if not hasattr(torch, "classes") or not torch.classes:
32
+ torch.classes = torch._C._get_python_module("torch.classes")
33
+
34
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
35
+ model = AutoModelForCausalLM.from_pretrained(model_id, token=token)
36
+
37
+ return tokenizer, model
38
+
39
+ except Exception as e:
40
+ print(f"An error occurred: {e}")
41
+ st.error(f"Model loading failed: {e}")
42
+ return None, None
43
+
44
+ async def async_load(model_id, token):
45
+ """
46
+ Dummy async function to initialize the event loop.
47
+ """
48
+ await asyncio.sleep(0.1) # Dummy async operation
49
+
50
+ def preprocess_data(uploaded_file, file_extension):
51
+ """
52
+ Reads the uploaded file and returns a processed version.
53
+ Supports CSV, JSONL, and TXT.
54
+ """
55
+ data = None
56
+ try:
57
+ if file_extension == "csv":
58
+ data = pd.read_csv(uploaded_file)
59
+ elif file_extension == "jsonl":
60
+ # Each line is a JSON object.
61
+ data = [json.loads(line) for line in uploaded_file.readlines()]
62
+ try:
63
+ data = pd.DataFrame(data)
64
+ except Exception:
65
+ st.warning("Unable to convert JSONL to a table. Previewing raw JSON objects.")
66
+ elif file_extension == "txt":
67
+ text_data = uploaded_file.read().decode("utf-8")
68
+ data = text_data.splitlines()
69
+ except Exception as e:
70
+ st.error(f"Error processing file: {e}")
71
+ return data
72
+
73
+ def clean_text(text, lowercase=True, remove_punctuation=True):
74
+ """
75
+ Cleans text data by applying basic normalization.
76
+ """
77
+ if lowercase:
78
+ text = text.lower()
79
+ if remove_punctuation:
80
+ text = re.sub(r'[^\w\s]', '', text)
81
+ return text
82
+
83
+ def plot_training_metrics(epochs, loss_values, accuracy_values):
84
+ """
85
+ Returns a matplotlib figure plotting training loss and accuracy.
86
+ """
87
+ fig, ax = plt.subplots(1, 2, figsize=(12, 4))
88
+ ax[0].plot(range(1, epochs+1), loss_values, marker='o', color='red')
89
+ ax[0].set_title("Training Loss")
90
+ ax[0].set_xlabel("Epoch")
91
+ ax[0].set_ylabel("Loss")
92
+
93
+ ax[1].plot(range(1, epochs+1), accuracy_values, marker='o', color='green')
94
+ ax[1].set_title("Training Accuracy")
95
+ ax[1].set_xlabel("Epoch")
96
+ ax[1].set_ylabel("Accuracy")
97
+
98
+ return fig
99
+
100
+ def simulate_training(num_epochs):
101
+ """
102
+ Simulates a training loop for demonstration.
103
+ Yields current epoch, loss values, and accuracy values.
104
+ Replace this with your actual fine-tuning loop.
105
+ """
106
+ loss_values = []
107
+ accuracy_values = []
108
+ for epoch in range(1, num_epochs + 1):
109
+ loss = np.exp(-epoch) + np.random.random() * 0.1
110
+ acc = 0.5 + (epoch / num_epochs) * 0.5 + np.random.random() * 0.05
111
+ loss_values.append(loss)
112
+ accuracy_values.append(acc)
113
+ yield epoch, loss_values, accuracy_values
114
+ time.sleep(1) # Simulate computation time
115
+
116
+ def quantize_model(model):
117
+ """
118
+ Applies dynamic quantization for demonstration.
119
+ In practice, adjust this based on your model and target hardware.
120
+ """
121
+ quantized_model = torch.quantization.quantize_dynamic(
122
+ model, {torch.nn.Linear}, dtype=torch.qint8
123
+ )
124
+ return quantized_model
125
+
126
+ def convert_to_torchscript(model):
127
+ """
128
+ Converts the model to TorchScript format.
129
+ """
130
+ example_input = torch.randint(0, 100, (1, 10))
131
+ traced_model = torch.jit.trace(model, example_input)
132
+ return traced_model
133
+
134
+ def convert_to_onnx(model, output_path="model.onnx"):
135
+ """
136
+ Converts the model to ONNX format.
137
+ """
138
+ dummy_input = torch.randint(0, 100, (1, 10))
139
+ torch.onnx.export(model, dummy_input, output_path, input_names=["input"], output_names=["output"])
140
+ return output_path
141
+
142
+ def load_finetuned_model(model, checkpoint_path="fine_tuned_model.pt"):
143
+ """
144
+ Loads the fine-tuned model from the checkpoint.
145
+ """
146
+ if os.path.exists(checkpoint_path):
147
+ model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
148
+ model.eval()
149
+ st.success("Fine-tuned model loaded successfully!")
150
+ else:
151
+ st.error(f"Checkpoint not found: {checkpoint_path}")
152
+ return model
153
+
154
+
155
+ def generate_response(prompt, model, tokenizer, max_length=200):
156
+ """
157
+ Generates a response using the fine-tuned model.
158
+ """
159
+ # Tokenize the prompt
160
+ inputs = tokenizer(prompt, return_tensors="pt").input_ids
161
+
162
+ # Generate text
163
+ with torch.no_grad():
164
+ outputs = model.generate(inputs, max_length=max_length, num_return_sequences=1, temperature=0.7)
165
+
166
+ # Decode the output
167
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
168
+ return response
169
+
170
+
171
+ # -------------------------------
172
+ # Application Layout
173
+ # -------------------------------
174
+
175
+ st.title("One-Stop Gemma Model Fine-tuning, Quantization & Conversion UI")
176
+ st.markdown("""
177
+ This application is designed for beginners in generative AI.
178
+ It allows you to fine-tune, quantize, and convert Gemma models with an intuitive UI.
179
+ You can upload your dataset, clean and preview your data, configure training parameters, and export your model in different formats.
180
+ """)
181
+
182
+ # Sidebar: Model selection and data upload
183
+ st.sidebar.header("Configuration")
184
+
185
+ # Model Selection
186
+ selected_model = st.sidebar.selectbox("Select Gemma Model", options=["Gemma-Small", "Gemma-Medium", "Gemma-Large"])
187
+ if selected_model == "google/gemma-3-1b-it":
188
+ model_id = "google/gemma-3-1b-it"
189
+ elif selected_model == "google/gemma-3-4b-it":
190
+ model_id = "google/gemma-3-4b-it"
191
+ else:
192
+ model_id = "google/gemma-3-1b-it"
193
+
194
+ loading_placeholder = st.sidebar.empty()
195
+ loading_placeholder.info("Loading model...")
196
+ tokenizer, model = load_model(model_id, token)
197
+ loading_placeholder.success("Model loaded.")
198
+
199
+
200
+ # Dataset Upload
201
+ uploaded_file = st.sidebar.file_uploader("Upload Dataset (CSV, JSONL, TXT)", type=["csv", "jsonl", "txt"])
202
+ data = None
203
+ if uploaded_file is not None:
204
+ file_ext = uploaded_file.name.split('.')[-1].lower()
205
+ data = preprocess_data(uploaded_file, file_ext)
206
+ st.sidebar.subheader("Dataset Preview:")
207
+ if isinstance(data, pd.DataFrame):
208
+ st.sidebar.dataframe(data.head())
209
+ elif isinstance(data, list):
210
+ st.sidebar.write(data[:5])
211
+ else:
212
+ st.sidebar.write(data)
213
+ else:
214
+ st.sidebar.info("Awaiting dataset upload.")
215
+
216
+ # Data Cleaning Options (for TXT files)
217
+ if uploaded_file is not None and file_ext == "txt":
218
+ st.sidebar.subheader("Data Cleaning Options")
219
+ lowercase_option = st.sidebar.checkbox("Convert to lowercase", value=True)
220
+ remove_punct = st.sidebar.checkbox("Remove punctuation", value=True)
221
+ cleaned_data = [clean_text(line, lowercase=lowercase_option, remove_punctuation=remove_punct) for line in data]
222
+ st.sidebar.text_area("Cleaned Data Preview", value="\n".join(cleaned_data[:5]), height=150)
223
+
224
+ # Main Tabs for Different Operations
225
+ tabs = st.tabs(["Fine-tuning", "Quantization", "Model Conversion"])
226
+
227
+ # -------------------------------
228
+ # Fine-tuning Tab
229
+ # -------------------------------
230
+ with tabs[0]:
231
+ st.header("Fine-tuning")
232
+ st.markdown("Configure hyperparameters and start fine-tuning your Gemma model.")
233
+
234
+ col1, col2, col3 = st.columns(3)
235
+ with col1:
236
+ learning_rate = st.number_input("Learning Rate", value=1e-4, format="%.5f")
237
+ with col2:
238
+ batch_size = st.number_input("Batch Size", value=16, step=1)
239
+ with col3:
240
+ epochs = st.number_input("Epochs", value=3, step=1)
241
+
242
+ if st.button("Start Fine-tuning"):
243
+ if data is None:
244
+ st.error("Please upload a dataset first!")
245
+ else:
246
+ st.info("Starting fine-tuning...")
247
+ progress_bar = st.progress(0)
248
+ training_placeholder = st.empty()
249
+ loss_values = []
250
+ accuracy_values = []
251
+
252
+ # Simulate training loop (replace with your actual training code)
253
+ for epoch, losses, accs in simulate_training(epochs):
254
+ fig = plot_training_metrics(epoch, losses, accs)
255
+ training_placeholder.pyplot(fig)
256
+ progress_bar.progress(epoch/epochs)
257
+ st.success("Fine-tuning completed!")
258
+
259
+ # Save the fine-tuned model (for demonstration, saving state_dict)
260
+ if model:
261
+ torch.save(model.state_dict(), "fine_tuned_model.pt")
262
+ with open("fine_tuned_model.pt", "rb") as f:
263
+ st.download_button("Download Fine-tuned Model", data=f, file_name="fine_tuned_model.pt", mime="application/octet-stream")
264
+ else:
265
+ st.error("Model not loaded. Cannot save.")
266
+
267
+
268
+ # -------------------------------
269
+ # Quantization Tab
270
+ # -------------------------------
271
+ with tabs[1]:
272
+ st.header("Model Quantization")
273
+ st.markdown("Quantize your model to optimize for inference performance.")
274
+ quantize_choice = st.radio("Select Quantization Type", options=["Dynamic Quantization"], index=0)
275
+
276
+ if st.button("Apply Quantization"):
277
+ with st.spinner("Applying quantization..."):
278
+ quantized_model = quantize_model(model)
279
+ st.success("Model quantized successfully!")
280
+ torch.save(quantized_model.state_dict(), "quantized_model.pt")
281
+ with open("quantized_model.pt", "rb") as f:
282
+ st.download_button("Download Quantized Model", data=f, file_name="quantized_model.pt", mime="application/octet-stream")
283
+
284
+ # -------------------------------
285
+ # Model Conversion Tab
286
+ # -------------------------------
287
+ with tabs[2]:
288
+ st.header("Model Conversion")
289
+ st.markdown("Convert your model to a different format for deployment or optimization.")
290
+ conversion_option = st.selectbox("Select Conversion Format", options=["TorchScript", "ONNX"])
291
+
292
+ if st.button("Convert Model"):
293
+ if conversion_option == "TorchScript":
294
+ with st.spinner("Converting to TorchScript..."):
295
+ ts_model = convert_to_torchscript(model)
296
+ ts_model.save("model_ts.pt")
297
+ st.success("Converted to TorchScript!")
298
+ with open("model_ts.pt", "rb") as f:
299
+ st.download_button("Download TorchScript Model", data=f, file_name="model_ts.pt", mime="application/octet-stream")
300
+ elif conversion_option == "ONNX":
301
+ with st.spinner("Converting to ONNX..."):
302
+ onnx_path = convert_to_onnx(model, "model.onnx")
303
+ st.success("Converted to ONNX!")
304
+ with open(onnx_path, "rb") as f:
305
+ st.download_button("Download ONNX Model", data=f, file_name="model.onnx", mime="application/octet-stream")
306
+
307
+ # -------------------------------
308
+ # Response Generation Section
309
+ # -------------------------------
310
+ st.header("Generate Responses with Fine-Tuned Model")
311
+ st.markdown("Use the fine-tuned model to generate text responses based on your prompts.")
312
+
313
+ # Check if the fine-tuned model exists
314
+ if os.path.exists("fine_tuned_model.pt"):
315
+ # Load the fine-tuned model
316
+ model = load_finetuned_model(model, "fine_tuned_model.pt")
317
+
318
+ # Input prompt for generating responses
319
+ prompt = st.text_area("Enter a prompt:", "Once upon a time...")
320
+
321
+ # Max length slider
322
+ max_length = st.slider("Max Response Length", min_value=50, max_value=500, value=200, step=10)
323
+
324
+ if st.button("Generate Response"):
325
+ with st.spinner("Generating response..."):
326
+ response = generate_response(prompt, model, tokenizer, max_length)
327
+ st.success("Generated Response:")
328
+ st.write(response)
329
+
330
+ else:
331
+ st.warning("Fine-tuned model not found. Please fine-tune the model first.")
332
+
333
+
334
+ # -------------------------------
335
+ # Optional: Cloud Integration Snippet
336
+ # -------------------------------
337
+ st.header("Cloud Integration")
338
+ st.markdown("""
339
+ For large-scale training or model storage, consider integrating with Google Cloud Storage or Vertex AI.
340
+ Below is an example snippet for uploading your model to GCS:
341
+ """)
342
+ st.code("""
343
+ from google.cloud import storage
344
+
345
+ def upload_to_gcs(bucket_name, source_file_name, destination_blob_name):
346
+ storage_client = storage.Client()
347
+ bucket = storage_client.bucket(bucket_name)
348
+ blob = bucket.blob(destination_blob_name)
349
+ blob.upload_from_filename(source_file_name)
350
+ print(f"Uploaded {source_file_name} to {destination_blob_name}")
351
+
352
+ # Example usage:
353
+ # upload_to_gcs("your-bucket-name", "fine_tuned_model.pt", "models/fine_tuned_model.pt")
354
+ """, language="python")