Spaces:
Build error
Build error
File size: 13,166 Bytes
e6f0893 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 |
import streamlit as st
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
import time
import json
import re
import os
import asyncio
# -------------------------------
# Utility Functions
# -------------------------------
token = "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"
os.environ['CURL_CA_BUNDLE'] = ''
@st.cache_resource
def load_model(model_id: str, token: str):
"""
Loads and caches the Gemma model and tokenizer with authentication token.
"""
try:
# Create and run an event loop explicitly
asyncio.run(async_load(model_id, token))
# Ensure torch classes path is valid (optional)
if not hasattr(torch, "classes") or not torch.classes:
torch.classes = torch._C._get_python_module("torch.classes")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
model = AutoModelForCausalLM.from_pretrained(model_id, token=token)
return tokenizer, model
except Exception as e:
print(f"An error occurred: {e}")
st.error(f"Model loading failed: {e}")
return None, None
async def async_load(model_id, token):
"""
Dummy async function to initialize the event loop.
"""
await asyncio.sleep(0.1) # Dummy async operation
def preprocess_data(uploaded_file, file_extension):
"""
Reads the uploaded file and returns a processed version.
Supports CSV, JSONL, and TXT.
"""
data = None
try:
if file_extension == "csv":
data = pd.read_csv(uploaded_file)
elif file_extension == "jsonl":
# Each line is a JSON object.
data = [json.loads(line) for line in uploaded_file.readlines()]
try:
data = pd.DataFrame(data)
except Exception:
st.warning("Unable to convert JSONL to a table. Previewing raw JSON objects.")
elif file_extension == "txt":
text_data = uploaded_file.read().decode("utf-8")
data = text_data.splitlines()
except Exception as e:
st.error(f"Error processing file: {e}")
return data
def clean_text(text, lowercase=True, remove_punctuation=True):
"""
Cleans text data by applying basic normalization.
"""
if lowercase:
text = text.lower()
if remove_punctuation:
text = re.sub(r'[^\w\s]', '', text)
return text
def plot_training_metrics(epochs, loss_values, accuracy_values):
"""
Returns a matplotlib figure plotting training loss and accuracy.
"""
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(range(1, epochs+1), loss_values, marker='o', color='red')
ax[0].set_title("Training Loss")
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("Loss")
ax[1].plot(range(1, epochs+1), accuracy_values, marker='o', color='green')
ax[1].set_title("Training Accuracy")
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel("Accuracy")
return fig
def simulate_training(num_epochs):
"""
Simulates a training loop for demonstration.
Yields current epoch, loss values, and accuracy values.
Replace this with your actual fine-tuning loop.
"""
loss_values = []
accuracy_values = []
for epoch in range(1, num_epochs + 1):
loss = np.exp(-epoch) + np.random.random() * 0.1
acc = 0.5 + (epoch / num_epochs) * 0.5 + np.random.random() * 0.05
loss_values.append(loss)
accuracy_values.append(acc)
yield epoch, loss_values, accuracy_values
time.sleep(1) # Simulate computation time
def quantize_model(model):
"""
Applies dynamic quantization for demonstration.
In practice, adjust this based on your model and target hardware.
"""
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
return quantized_model
def convert_to_torchscript(model):
"""
Converts the model to TorchScript format.
"""
example_input = torch.randint(0, 100, (1, 10))
traced_model = torch.jit.trace(model, example_input)
return traced_model
def convert_to_onnx(model, output_path="model.onnx"):
"""
Converts the model to ONNX format.
"""
dummy_input = torch.randint(0, 100, (1, 10))
torch.onnx.export(model, dummy_input, output_path, input_names=["input"], output_names=["output"])
return output_path
def load_finetuned_model(model, checkpoint_path="fine_tuned_model.pt"):
"""
Loads the fine-tuned model from the checkpoint.
"""
if os.path.exists(checkpoint_path):
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
model.eval()
st.success("Fine-tuned model loaded successfully!")
else:
st.error(f"Checkpoint not found: {checkpoint_path}")
return model
def generate_response(prompt, model, tokenizer, max_length=200):
"""
Generates a response using the fine-tuned model.
"""
# Tokenize the prompt
inputs = tokenizer(prompt, return_tensors="pt").input_ids
# Generate text
with torch.no_grad():
outputs = model.generate(inputs, max_length=max_length, num_return_sequences=1, temperature=0.7)
# Decode the output
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# -------------------------------
# Application Layout
# -------------------------------
st.title("One-Stop Gemma Model Fine-tuning, Quantization & Conversion UI")
st.markdown("""
This application is designed for beginners in generative AI.
It allows you to fine-tune, quantize, and convert Gemma models with an intuitive UI.
You can upload your dataset, clean and preview your data, configure training parameters, and export your model in different formats.
""")
# Sidebar: Model selection and data upload
st.sidebar.header("Configuration")
# Model Selection
selected_model = st.sidebar.selectbox("Select Gemma Model", options=["Gemma-Small", "Gemma-Medium", "Gemma-Large"])
if selected_model == "google/gemma-3-1b-it":
model_id = "google/gemma-3-1b-it"
elif selected_model == "google/gemma-3-4b-it":
model_id = "google/gemma-3-4b-it"
else:
model_id = "google/gemma-3-1b-it"
loading_placeholder = st.sidebar.empty()
loading_placeholder.info("Loading model...")
tokenizer, model = load_model(model_id, token)
loading_placeholder.success("Model loaded.")
# Dataset Upload
uploaded_file = st.sidebar.file_uploader("Upload Dataset (CSV, JSONL, TXT)", type=["csv", "jsonl", "txt"])
data = None
if uploaded_file is not None:
file_ext = uploaded_file.name.split('.')[-1].lower()
data = preprocess_data(uploaded_file, file_ext)
st.sidebar.subheader("Dataset Preview:")
if isinstance(data, pd.DataFrame):
st.sidebar.dataframe(data.head())
elif isinstance(data, list):
st.sidebar.write(data[:5])
else:
st.sidebar.write(data)
else:
st.sidebar.info("Awaiting dataset upload.")
# Data Cleaning Options (for TXT files)
if uploaded_file is not None and file_ext == "txt":
st.sidebar.subheader("Data Cleaning Options")
lowercase_option = st.sidebar.checkbox("Convert to lowercase", value=True)
remove_punct = st.sidebar.checkbox("Remove punctuation", value=True)
cleaned_data = [clean_text(line, lowercase=lowercase_option, remove_punctuation=remove_punct) for line in data]
st.sidebar.text_area("Cleaned Data Preview", value="\n".join(cleaned_data[:5]), height=150)
# Main Tabs for Different Operations
tabs = st.tabs(["Fine-tuning", "Quantization", "Model Conversion"])
# -------------------------------
# Fine-tuning Tab
# -------------------------------
with tabs[0]:
st.header("Fine-tuning")
st.markdown("Configure hyperparameters and start fine-tuning your Gemma model.")
col1, col2, col3 = st.columns(3)
with col1:
learning_rate = st.number_input("Learning Rate", value=1e-4, format="%.5f")
with col2:
batch_size = st.number_input("Batch Size", value=16, step=1)
with col3:
epochs = st.number_input("Epochs", value=3, step=1)
if st.button("Start Fine-tuning"):
if data is None:
st.error("Please upload a dataset first!")
else:
st.info("Starting fine-tuning...")
progress_bar = st.progress(0)
training_placeholder = st.empty()
loss_values = []
accuracy_values = []
# Simulate training loop (replace with your actual training code)
for epoch, losses, accs in simulate_training(epochs):
fig = plot_training_metrics(epoch, losses, accs)
training_placeholder.pyplot(fig)
progress_bar.progress(epoch/epochs)
st.success("Fine-tuning completed!")
# Save the fine-tuned model (for demonstration, saving state_dict)
if model:
torch.save(model.state_dict(), "fine_tuned_model.pt")
with open("fine_tuned_model.pt", "rb") as f:
st.download_button("Download Fine-tuned Model", data=f, file_name="fine_tuned_model.pt", mime="application/octet-stream")
else:
st.error("Model not loaded. Cannot save.")
# -------------------------------
# Quantization Tab
# -------------------------------
with tabs[1]:
st.header("Model Quantization")
st.markdown("Quantize your model to optimize for inference performance.")
quantize_choice = st.radio("Select Quantization Type", options=["Dynamic Quantization"], index=0)
if st.button("Apply Quantization"):
with st.spinner("Applying quantization..."):
quantized_model = quantize_model(model)
st.success("Model quantized successfully!")
torch.save(quantized_model.state_dict(), "quantized_model.pt")
with open("quantized_model.pt", "rb") as f:
st.download_button("Download Quantized Model", data=f, file_name="quantized_model.pt", mime="application/octet-stream")
# -------------------------------
# Model Conversion Tab
# -------------------------------
with tabs[2]:
st.header("Model Conversion")
st.markdown("Convert your model to a different format for deployment or optimization.")
conversion_option = st.selectbox("Select Conversion Format", options=["TorchScript", "ONNX"])
if st.button("Convert Model"):
if conversion_option == "TorchScript":
with st.spinner("Converting to TorchScript..."):
ts_model = convert_to_torchscript(model)
ts_model.save("model_ts.pt")
st.success("Converted to TorchScript!")
with open("model_ts.pt", "rb") as f:
st.download_button("Download TorchScript Model", data=f, file_name="model_ts.pt", mime="application/octet-stream")
elif conversion_option == "ONNX":
with st.spinner("Converting to ONNX..."):
onnx_path = convert_to_onnx(model, "model.onnx")
st.success("Converted to ONNX!")
with open(onnx_path, "rb") as f:
st.download_button("Download ONNX Model", data=f, file_name="model.onnx", mime="application/octet-stream")
# -------------------------------
# Response Generation Section
# -------------------------------
st.header("Generate Responses with Fine-Tuned Model")
st.markdown("Use the fine-tuned model to generate text responses based on your prompts.")
# Check if the fine-tuned model exists
if os.path.exists("fine_tuned_model.pt"):
# Load the fine-tuned model
model = load_finetuned_model(model, "fine_tuned_model.pt")
# Input prompt for generating responses
prompt = st.text_area("Enter a prompt:", "Once upon a time...")
# Max length slider
max_length = st.slider("Max Response Length", min_value=50, max_value=500, value=200, step=10)
if st.button("Generate Response"):
with st.spinner("Generating response..."):
response = generate_response(prompt, model, tokenizer, max_length)
st.success("Generated Response:")
st.write(response)
else:
st.warning("Fine-tuned model not found. Please fine-tune the model first.")
# -------------------------------
# Optional: Cloud Integration Snippet
# -------------------------------
st.header("Cloud Integration")
st.markdown("""
For large-scale training or model storage, consider integrating with Google Cloud Storage or Vertex AI.
Below is an example snippet for uploading your model to GCS:
""")
st.code("""
from google.cloud import storage
def upload_to_gcs(bucket_name, source_file_name, destination_blob_name):
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(destination_blob_name)
blob.upload_from_filename(source_file_name)
print(f"Uploaded {source_file_name} to {destination_blob_name}")
# Example usage:
# upload_to_gcs("your-bucket-name", "fine_tuned_model.pt", "models/fine_tuned_model.pt")
""", language="python")
|