File size: 13,166 Bytes
e6f0893
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import streamlit as st
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
import time
import json
import re
import os
import asyncio


# -------------------------------
# Utility Functions
# -------------------------------

token = "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"
os.environ['CURL_CA_BUNDLE'] = ''

@st.cache_resource
def load_model(model_id: str, token: str):
    """
    Loads and caches the Gemma model and tokenizer with authentication token.
    """
    try:
        # Create and run an event loop explicitly
        asyncio.run(async_load(model_id, token))
    
        # Ensure torch classes path is valid (optional)
        if not hasattr(torch, "classes") or not torch.classes:
            torch.classes = torch._C._get_python_module("torch.classes")

        tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
        model = AutoModelForCausalLM.from_pretrained(model_id, token=token)

        return tokenizer, model
    
    except Exception as e:
        print(f"An error occurred: {e}")
        st.error(f"Model loading failed: {e}")
        return None, None

async def async_load(model_id, token):
    """
    Dummy async function to initialize the event loop.
    """
    await asyncio.sleep(0.1)  # Dummy async operation

def preprocess_data(uploaded_file, file_extension):
    """
    Reads the uploaded file and returns a processed version.
    Supports CSV, JSONL, and TXT.
    """
    data = None
    try:
        if file_extension == "csv":
            data = pd.read_csv(uploaded_file)
        elif file_extension == "jsonl":
            # Each line is a JSON object.
            data = [json.loads(line) for line in uploaded_file.readlines()]
            try:
                data = pd.DataFrame(data)
            except Exception:
                st.warning("Unable to convert JSONL to a table. Previewing raw JSON objects.")
        elif file_extension == "txt":
            text_data = uploaded_file.read().decode("utf-8")
            data = text_data.splitlines()
    except Exception as e:
        st.error(f"Error processing file: {e}")
    return data

def clean_text(text, lowercase=True, remove_punctuation=True):
    """
    Cleans text data by applying basic normalization.
    """
    if lowercase:
        text = text.lower()
    if remove_punctuation:
        text = re.sub(r'[^\w\s]', '', text)
    return text

def plot_training_metrics(epochs, loss_values, accuracy_values):
    """
    Returns a matplotlib figure plotting training loss and accuracy.
    """
    fig, ax = plt.subplots(1, 2, figsize=(12, 4))
    ax[0].plot(range(1, epochs+1), loss_values, marker='o', color='red')
    ax[0].set_title("Training Loss")
    ax[0].set_xlabel("Epoch")
    ax[0].set_ylabel("Loss")
    
    ax[1].plot(range(1, epochs+1), accuracy_values, marker='o', color='green')
    ax[1].set_title("Training Accuracy")
    ax[1].set_xlabel("Epoch")
    ax[1].set_ylabel("Accuracy")
    
    return fig

def simulate_training(num_epochs):
    """
    Simulates a training loop for demonstration.
    Yields current epoch, loss values, and accuracy values.
    Replace this with your actual fine-tuning loop.
    """
    loss_values = []
    accuracy_values = []
    for epoch in range(1, num_epochs + 1):
        loss = np.exp(-epoch) + np.random.random() * 0.1
        acc = 0.5 + (epoch / num_epochs) * 0.5 + np.random.random() * 0.05
        loss_values.append(loss)
        accuracy_values.append(acc)
        yield epoch, loss_values, accuracy_values
        time.sleep(1)  # Simulate computation time

def quantize_model(model):
    """
    Applies dynamic quantization for demonstration.
    In practice, adjust this based on your model and target hardware.
    """
    quantized_model = torch.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    )
    return quantized_model

def convert_to_torchscript(model):
    """
    Converts the model to TorchScript format.
    """
    example_input = torch.randint(0, 100, (1, 10))
    traced_model = torch.jit.trace(model, example_input)
    return traced_model

def convert_to_onnx(model, output_path="model.onnx"):
    """
    Converts the model to ONNX format.
    """
    dummy_input = torch.randint(0, 100, (1, 10))
    torch.onnx.export(model, dummy_input, output_path, input_names=["input"], output_names=["output"])
    return output_path

def load_finetuned_model(model, checkpoint_path="fine_tuned_model.pt"):
    """
    Loads the fine-tuned model from the checkpoint.
    """
    if os.path.exists(checkpoint_path):
        model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
        model.eval()
        st.success("Fine-tuned model loaded successfully!")
    else:
        st.error(f"Checkpoint not found: {checkpoint_path}")
    return model


def generate_response(prompt, model, tokenizer, max_length=200):
    """
    Generates a response using the fine-tuned model.
    """
    # Tokenize the prompt
    inputs = tokenizer(prompt, return_tensors="pt").input_ids

    # Generate text
    with torch.no_grad():
        outputs = model.generate(inputs, max_length=max_length, num_return_sequences=1, temperature=0.7)

    # Decode the output
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response


# -------------------------------
# Application Layout
# -------------------------------

st.title("One-Stop Gemma Model Fine-tuning, Quantization & Conversion UI")
st.markdown("""
This application is designed for beginners in generative AI. 
It allows you to fine-tune, quantize, and convert Gemma models with an intuitive UI. 
You can upload your dataset, clean and preview your data, configure training parameters, and export your model in different formats.
""")

# Sidebar: Model selection and data upload
st.sidebar.header("Configuration")

# Model Selection
selected_model = st.sidebar.selectbox("Select Gemma Model", options=["Gemma-Small", "Gemma-Medium", "Gemma-Large"])
if selected_model == "google/gemma-3-1b-it":
    model_id = "google/gemma-3-1b-it"
elif selected_model == "google/gemma-3-4b-it":
    model_id = "google/gemma-3-4b-it"
else:
    model_id = "google/gemma-3-1b-it"

loading_placeholder = st.sidebar.empty()
loading_placeholder.info("Loading model...")
tokenizer, model = load_model(model_id, token)
loading_placeholder.success("Model loaded.")


# Dataset Upload
uploaded_file = st.sidebar.file_uploader("Upload Dataset (CSV, JSONL, TXT)", type=["csv", "jsonl", "txt"])
data = None
if uploaded_file is not None:
    file_ext = uploaded_file.name.split('.')[-1].lower()
    data = preprocess_data(uploaded_file, file_ext)
    st.sidebar.subheader("Dataset Preview:")
    if isinstance(data, pd.DataFrame):
        st.sidebar.dataframe(data.head())
    elif isinstance(data, list):
        st.sidebar.write(data[:5])
    else:
        st.sidebar.write(data)
else:
    st.sidebar.info("Awaiting dataset upload.")

# Data Cleaning Options (for TXT files)
if uploaded_file is not None and file_ext == "txt":
    st.sidebar.subheader("Data Cleaning Options")
    lowercase_option = st.sidebar.checkbox("Convert to lowercase", value=True)
    remove_punct = st.sidebar.checkbox("Remove punctuation", value=True)
    cleaned_data = [clean_text(line, lowercase=lowercase_option, remove_punctuation=remove_punct) for line in data]
    st.sidebar.text_area("Cleaned Data Preview", value="\n".join(cleaned_data[:5]), height=150)

# Main Tabs for Different Operations
tabs = st.tabs(["Fine-tuning", "Quantization", "Model Conversion"])

# -------------------------------
# Fine-tuning Tab
# -------------------------------
with tabs[0]:
    st.header("Fine-tuning")
    st.markdown("Configure hyperparameters and start fine-tuning your Gemma model.")
    
    col1, col2, col3 = st.columns(3)
    with col1:
        learning_rate = st.number_input("Learning Rate", value=1e-4, format="%.5f")
    with col2:
        batch_size = st.number_input("Batch Size", value=16, step=1)
    with col3:
        epochs = st.number_input("Epochs", value=3, step=1)
    
    if st.button("Start Fine-tuning"):
        if data is None:
            st.error("Please upload a dataset first!")
        else:
            st.info("Starting fine-tuning...")
            progress_bar = st.progress(0)
            training_placeholder = st.empty()
            loss_values = []
            accuracy_values = []
            
            # Simulate training loop (replace with your actual training code)
            for epoch, losses, accs in simulate_training(epochs):
                fig = plot_training_metrics(epoch, losses, accs)
                training_placeholder.pyplot(fig)
                progress_bar.progress(epoch/epochs)
            st.success("Fine-tuning completed!")
            
            # Save the fine-tuned model (for demonstration, saving state_dict)
            if model:
                torch.save(model.state_dict(), "fine_tuned_model.pt")
                with open("fine_tuned_model.pt", "rb") as f:
                    st.download_button("Download Fine-tuned Model", data=f, file_name="fine_tuned_model.pt", mime="application/octet-stream")
            else:
                st.error("Model not loaded. Cannot save.")


# -------------------------------
# Quantization Tab
# -------------------------------
with tabs[1]:
    st.header("Model Quantization")
    st.markdown("Quantize your model to optimize for inference performance.")
    quantize_choice = st.radio("Select Quantization Type", options=["Dynamic Quantization"], index=0)
    
    if st.button("Apply Quantization"):
        with st.spinner("Applying quantization..."):
            quantized_model = quantize_model(model)
            st.success("Model quantized successfully!")
            torch.save(quantized_model.state_dict(), "quantized_model.pt")
            with open("quantized_model.pt", "rb") as f:
                st.download_button("Download Quantized Model", data=f, file_name="quantized_model.pt", mime="application/octet-stream")

# -------------------------------
# Model Conversion Tab
# -------------------------------
with tabs[2]:
    st.header("Model Conversion")
    st.markdown("Convert your model to a different format for deployment or optimization.")
    conversion_option = st.selectbox("Select Conversion Format", options=["TorchScript", "ONNX"])
    
    if st.button("Convert Model"):
        if conversion_option == "TorchScript":
            with st.spinner("Converting to TorchScript..."):
                ts_model = convert_to_torchscript(model)
                ts_model.save("model_ts.pt")
                st.success("Converted to TorchScript!")
                with open("model_ts.pt", "rb") as f:
                    st.download_button("Download TorchScript Model", data=f, file_name="model_ts.pt", mime="application/octet-stream")
        elif conversion_option == "ONNX":
            with st.spinner("Converting to ONNX..."):
                onnx_path = convert_to_onnx(model, "model.onnx")
                st.success("Converted to ONNX!")
                with open(onnx_path, "rb") as f:
                    st.download_button("Download ONNX Model", data=f, file_name="model.onnx", mime="application/octet-stream")

# -------------------------------
# Response Generation Section
# -------------------------------
st.header("Generate Responses with Fine-Tuned Model")
st.markdown("Use the fine-tuned model to generate text responses based on your prompts.")

# Check if the fine-tuned model exists
if os.path.exists("fine_tuned_model.pt"):
    # Load the fine-tuned model
    model = load_finetuned_model(model, "fine_tuned_model.pt")

    # Input prompt for generating responses
    prompt = st.text_area("Enter a prompt:", "Once upon a time...")

    # Max length slider
    max_length = st.slider("Max Response Length", min_value=50, max_value=500, value=200, step=10)

    if st.button("Generate Response"):
        with st.spinner("Generating response..."):
            response = generate_response(prompt, model, tokenizer, max_length)
            st.success("Generated Response:")
            st.write(response)

else:
    st.warning("Fine-tuned model not found. Please fine-tune the model first.")


# -------------------------------
# Optional: Cloud Integration Snippet
# -------------------------------
st.header("Cloud Integration")
st.markdown("""
For large-scale training or model storage, consider integrating with Google Cloud Storage or Vertex AI. 
Below is an example snippet for uploading your model to GCS:
""")
st.code("""
from google.cloud import storage

def upload_to_gcs(bucket_name, source_file_name, destination_blob_name):
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(destination_blob_name)
    blob.upload_from_filename(source_file_name)
    print(f"Uploaded {source_file_name} to {destination_blob_name}")

# Example usage:
# upload_to_gcs("your-bucket-name", "fine_tuned_model.pt", "models/fine_tuned_model.pt")
""", language="python")