mohitkumarrajbadi's picture
Add application file
e6f0893
raw
history blame
13.2 kB
import streamlit as st
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
import time
import json
import re
import os
import asyncio
# -------------------------------
# Utility Functions
# -------------------------------
token = "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"
os.environ['CURL_CA_BUNDLE'] = ''
@st.cache_resource
def load_model(model_id: str, token: str):
"""
Loads and caches the Gemma model and tokenizer with authentication token.
"""
try:
# Create and run an event loop explicitly
asyncio.run(async_load(model_id, token))
# Ensure torch classes path is valid (optional)
if not hasattr(torch, "classes") or not torch.classes:
torch.classes = torch._C._get_python_module("torch.classes")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
model = AutoModelForCausalLM.from_pretrained(model_id, token=token)
return tokenizer, model
except Exception as e:
print(f"An error occurred: {e}")
st.error(f"Model loading failed: {e}")
return None, None
async def async_load(model_id, token):
"""
Dummy async function to initialize the event loop.
"""
await asyncio.sleep(0.1) # Dummy async operation
def preprocess_data(uploaded_file, file_extension):
"""
Reads the uploaded file and returns a processed version.
Supports CSV, JSONL, and TXT.
"""
data = None
try:
if file_extension == "csv":
data = pd.read_csv(uploaded_file)
elif file_extension == "jsonl":
# Each line is a JSON object.
data = [json.loads(line) for line in uploaded_file.readlines()]
try:
data = pd.DataFrame(data)
except Exception:
st.warning("Unable to convert JSONL to a table. Previewing raw JSON objects.")
elif file_extension == "txt":
text_data = uploaded_file.read().decode("utf-8")
data = text_data.splitlines()
except Exception as e:
st.error(f"Error processing file: {e}")
return data
def clean_text(text, lowercase=True, remove_punctuation=True):
"""
Cleans text data by applying basic normalization.
"""
if lowercase:
text = text.lower()
if remove_punctuation:
text = re.sub(r'[^\w\s]', '', text)
return text
def plot_training_metrics(epochs, loss_values, accuracy_values):
"""
Returns a matplotlib figure plotting training loss and accuracy.
"""
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(range(1, epochs+1), loss_values, marker='o', color='red')
ax[0].set_title("Training Loss")
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("Loss")
ax[1].plot(range(1, epochs+1), accuracy_values, marker='o', color='green')
ax[1].set_title("Training Accuracy")
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel("Accuracy")
return fig
def simulate_training(num_epochs):
"""
Simulates a training loop for demonstration.
Yields current epoch, loss values, and accuracy values.
Replace this with your actual fine-tuning loop.
"""
loss_values = []
accuracy_values = []
for epoch in range(1, num_epochs + 1):
loss = np.exp(-epoch) + np.random.random() * 0.1
acc = 0.5 + (epoch / num_epochs) * 0.5 + np.random.random() * 0.05
loss_values.append(loss)
accuracy_values.append(acc)
yield epoch, loss_values, accuracy_values
time.sleep(1) # Simulate computation time
def quantize_model(model):
"""
Applies dynamic quantization for demonstration.
In practice, adjust this based on your model and target hardware.
"""
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
return quantized_model
def convert_to_torchscript(model):
"""
Converts the model to TorchScript format.
"""
example_input = torch.randint(0, 100, (1, 10))
traced_model = torch.jit.trace(model, example_input)
return traced_model
def convert_to_onnx(model, output_path="model.onnx"):
"""
Converts the model to ONNX format.
"""
dummy_input = torch.randint(0, 100, (1, 10))
torch.onnx.export(model, dummy_input, output_path, input_names=["input"], output_names=["output"])
return output_path
def load_finetuned_model(model, checkpoint_path="fine_tuned_model.pt"):
"""
Loads the fine-tuned model from the checkpoint.
"""
if os.path.exists(checkpoint_path):
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
model.eval()
st.success("Fine-tuned model loaded successfully!")
else:
st.error(f"Checkpoint not found: {checkpoint_path}")
return model
def generate_response(prompt, model, tokenizer, max_length=200):
"""
Generates a response using the fine-tuned model.
"""
# Tokenize the prompt
inputs = tokenizer(prompt, return_tensors="pt").input_ids
# Generate text
with torch.no_grad():
outputs = model.generate(inputs, max_length=max_length, num_return_sequences=1, temperature=0.7)
# Decode the output
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# -------------------------------
# Application Layout
# -------------------------------
st.title("One-Stop Gemma Model Fine-tuning, Quantization & Conversion UI")
st.markdown("""
This application is designed for beginners in generative AI.
It allows you to fine-tune, quantize, and convert Gemma models with an intuitive UI.
You can upload your dataset, clean and preview your data, configure training parameters, and export your model in different formats.
""")
# Sidebar: Model selection and data upload
st.sidebar.header("Configuration")
# Model Selection
selected_model = st.sidebar.selectbox("Select Gemma Model", options=["Gemma-Small", "Gemma-Medium", "Gemma-Large"])
if selected_model == "google/gemma-3-1b-it":
model_id = "google/gemma-3-1b-it"
elif selected_model == "google/gemma-3-4b-it":
model_id = "google/gemma-3-4b-it"
else:
model_id = "google/gemma-3-1b-it"
loading_placeholder = st.sidebar.empty()
loading_placeholder.info("Loading model...")
tokenizer, model = load_model(model_id, token)
loading_placeholder.success("Model loaded.")
# Dataset Upload
uploaded_file = st.sidebar.file_uploader("Upload Dataset (CSV, JSONL, TXT)", type=["csv", "jsonl", "txt"])
data = None
if uploaded_file is not None:
file_ext = uploaded_file.name.split('.')[-1].lower()
data = preprocess_data(uploaded_file, file_ext)
st.sidebar.subheader("Dataset Preview:")
if isinstance(data, pd.DataFrame):
st.sidebar.dataframe(data.head())
elif isinstance(data, list):
st.sidebar.write(data[:5])
else:
st.sidebar.write(data)
else:
st.sidebar.info("Awaiting dataset upload.")
# Data Cleaning Options (for TXT files)
if uploaded_file is not None and file_ext == "txt":
st.sidebar.subheader("Data Cleaning Options")
lowercase_option = st.sidebar.checkbox("Convert to lowercase", value=True)
remove_punct = st.sidebar.checkbox("Remove punctuation", value=True)
cleaned_data = [clean_text(line, lowercase=lowercase_option, remove_punctuation=remove_punct) for line in data]
st.sidebar.text_area("Cleaned Data Preview", value="\n".join(cleaned_data[:5]), height=150)
# Main Tabs for Different Operations
tabs = st.tabs(["Fine-tuning", "Quantization", "Model Conversion"])
# -------------------------------
# Fine-tuning Tab
# -------------------------------
with tabs[0]:
st.header("Fine-tuning")
st.markdown("Configure hyperparameters and start fine-tuning your Gemma model.")
col1, col2, col3 = st.columns(3)
with col1:
learning_rate = st.number_input("Learning Rate", value=1e-4, format="%.5f")
with col2:
batch_size = st.number_input("Batch Size", value=16, step=1)
with col3:
epochs = st.number_input("Epochs", value=3, step=1)
if st.button("Start Fine-tuning"):
if data is None:
st.error("Please upload a dataset first!")
else:
st.info("Starting fine-tuning...")
progress_bar = st.progress(0)
training_placeholder = st.empty()
loss_values = []
accuracy_values = []
# Simulate training loop (replace with your actual training code)
for epoch, losses, accs in simulate_training(epochs):
fig = plot_training_metrics(epoch, losses, accs)
training_placeholder.pyplot(fig)
progress_bar.progress(epoch/epochs)
st.success("Fine-tuning completed!")
# Save the fine-tuned model (for demonstration, saving state_dict)
if model:
torch.save(model.state_dict(), "fine_tuned_model.pt")
with open("fine_tuned_model.pt", "rb") as f:
st.download_button("Download Fine-tuned Model", data=f, file_name="fine_tuned_model.pt", mime="application/octet-stream")
else:
st.error("Model not loaded. Cannot save.")
# -------------------------------
# Quantization Tab
# -------------------------------
with tabs[1]:
st.header("Model Quantization")
st.markdown("Quantize your model to optimize for inference performance.")
quantize_choice = st.radio("Select Quantization Type", options=["Dynamic Quantization"], index=0)
if st.button("Apply Quantization"):
with st.spinner("Applying quantization..."):
quantized_model = quantize_model(model)
st.success("Model quantized successfully!")
torch.save(quantized_model.state_dict(), "quantized_model.pt")
with open("quantized_model.pt", "rb") as f:
st.download_button("Download Quantized Model", data=f, file_name="quantized_model.pt", mime="application/octet-stream")
# -------------------------------
# Model Conversion Tab
# -------------------------------
with tabs[2]:
st.header("Model Conversion")
st.markdown("Convert your model to a different format for deployment or optimization.")
conversion_option = st.selectbox("Select Conversion Format", options=["TorchScript", "ONNX"])
if st.button("Convert Model"):
if conversion_option == "TorchScript":
with st.spinner("Converting to TorchScript..."):
ts_model = convert_to_torchscript(model)
ts_model.save("model_ts.pt")
st.success("Converted to TorchScript!")
with open("model_ts.pt", "rb") as f:
st.download_button("Download TorchScript Model", data=f, file_name="model_ts.pt", mime="application/octet-stream")
elif conversion_option == "ONNX":
with st.spinner("Converting to ONNX..."):
onnx_path = convert_to_onnx(model, "model.onnx")
st.success("Converted to ONNX!")
with open(onnx_path, "rb") as f:
st.download_button("Download ONNX Model", data=f, file_name="model.onnx", mime="application/octet-stream")
# -------------------------------
# Response Generation Section
# -------------------------------
st.header("Generate Responses with Fine-Tuned Model")
st.markdown("Use the fine-tuned model to generate text responses based on your prompts.")
# Check if the fine-tuned model exists
if os.path.exists("fine_tuned_model.pt"):
# Load the fine-tuned model
model = load_finetuned_model(model, "fine_tuned_model.pt")
# Input prompt for generating responses
prompt = st.text_area("Enter a prompt:", "Once upon a time...")
# Max length slider
max_length = st.slider("Max Response Length", min_value=50, max_value=500, value=200, step=10)
if st.button("Generate Response"):
with st.spinner("Generating response..."):
response = generate_response(prompt, model, tokenizer, max_length)
st.success("Generated Response:")
st.write(response)
else:
st.warning("Fine-tuned model not found. Please fine-tune the model first.")
# -------------------------------
# Optional: Cloud Integration Snippet
# -------------------------------
st.header("Cloud Integration")
st.markdown("""
For large-scale training or model storage, consider integrating with Google Cloud Storage or Vertex AI.
Below is an example snippet for uploading your model to GCS:
""")
st.code("""
from google.cloud import storage
def upload_to_gcs(bucket_name, source_file_name, destination_blob_name):
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(destination_blob_name)
blob.upload_from_filename(source_file_name)
print(f"Uploaded {source_file_name} to {destination_blob_name}")
# Example usage:
# upload_to_gcs("your-bucket-name", "fine_tuned_model.pt", "models/fine_tuned_model.pt")
""", language="python")