Spaces:
Build error
Build error
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import matplotlib.pyplot as plt | |
import time | |
import json | |
import re | |
import os | |
import asyncio | |
# ------------------------------- | |
# Utility Functions | |
# ------------------------------- | |
token = "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" | |
os.environ['CURL_CA_BUNDLE'] = '' | |
def load_model(model_id: str, token: str): | |
""" | |
Loads and caches the Gemma model and tokenizer with authentication token. | |
""" | |
try: | |
# Create and run an event loop explicitly | |
asyncio.run(async_load(model_id, token)) | |
# Ensure torch classes path is valid (optional) | |
if not hasattr(torch, "classes") or not torch.classes: | |
torch.classes = torch._C._get_python_module("torch.classes") | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) | |
model = AutoModelForCausalLM.from_pretrained(model_id, token=token) | |
return tokenizer, model | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
st.error(f"Model loading failed: {e}") | |
return None, None | |
async def async_load(model_id, token): | |
""" | |
Dummy async function to initialize the event loop. | |
""" | |
await asyncio.sleep(0.1) # Dummy async operation | |
def preprocess_data(uploaded_file, file_extension): | |
""" | |
Reads the uploaded file and returns a processed version. | |
Supports CSV, JSONL, and TXT. | |
""" | |
data = None | |
try: | |
if file_extension == "csv": | |
data = pd.read_csv(uploaded_file) | |
elif file_extension == "jsonl": | |
# Each line is a JSON object. | |
data = [json.loads(line) for line in uploaded_file.readlines()] | |
try: | |
data = pd.DataFrame(data) | |
except Exception: | |
st.warning("Unable to convert JSONL to a table. Previewing raw JSON objects.") | |
elif file_extension == "txt": | |
text_data = uploaded_file.read().decode("utf-8") | |
data = text_data.splitlines() | |
except Exception as e: | |
st.error(f"Error processing file: {e}") | |
return data | |
def clean_text(text, lowercase=True, remove_punctuation=True): | |
""" | |
Cleans text data by applying basic normalization. | |
""" | |
if lowercase: | |
text = text.lower() | |
if remove_punctuation: | |
text = re.sub(r'[^\w\s]', '', text) | |
return text | |
def plot_training_metrics(epochs, loss_values, accuracy_values): | |
""" | |
Returns a matplotlib figure plotting training loss and accuracy. | |
""" | |
fig, ax = plt.subplots(1, 2, figsize=(12, 4)) | |
ax[0].plot(range(1, epochs+1), loss_values, marker='o', color='red') | |
ax[0].set_title("Training Loss") | |
ax[0].set_xlabel("Epoch") | |
ax[0].set_ylabel("Loss") | |
ax[1].plot(range(1, epochs+1), accuracy_values, marker='o', color='green') | |
ax[1].set_title("Training Accuracy") | |
ax[1].set_xlabel("Epoch") | |
ax[1].set_ylabel("Accuracy") | |
return fig | |
def simulate_training(num_epochs): | |
""" | |
Simulates a training loop for demonstration. | |
Yields current epoch, loss values, and accuracy values. | |
Replace this with your actual fine-tuning loop. | |
""" | |
loss_values = [] | |
accuracy_values = [] | |
for epoch in range(1, num_epochs + 1): | |
loss = np.exp(-epoch) + np.random.random() * 0.1 | |
acc = 0.5 + (epoch / num_epochs) * 0.5 + np.random.random() * 0.05 | |
loss_values.append(loss) | |
accuracy_values.append(acc) | |
yield epoch, loss_values, accuracy_values | |
time.sleep(1) # Simulate computation time | |
def quantize_model(model): | |
""" | |
Applies dynamic quantization for demonstration. | |
In practice, adjust this based on your model and target hardware. | |
""" | |
quantized_model = torch.quantization.quantize_dynamic( | |
model, {torch.nn.Linear}, dtype=torch.qint8 | |
) | |
return quantized_model | |
def convert_to_torchscript(model): | |
""" | |
Converts the model to TorchScript format. | |
""" | |
example_input = torch.randint(0, 100, (1, 10)) | |
traced_model = torch.jit.trace(model, example_input) | |
return traced_model | |
def convert_to_onnx(model, output_path="model.onnx"): | |
""" | |
Converts the model to ONNX format. | |
""" | |
dummy_input = torch.randint(0, 100, (1, 10)) | |
torch.onnx.export(model, dummy_input, output_path, input_names=["input"], output_names=["output"]) | |
return output_path | |
def load_finetuned_model(model, checkpoint_path="fine_tuned_model.pt"): | |
""" | |
Loads the fine-tuned model from the checkpoint. | |
""" | |
if os.path.exists(checkpoint_path): | |
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu'))) | |
model.eval() | |
st.success("Fine-tuned model loaded successfully!") | |
else: | |
st.error(f"Checkpoint not found: {checkpoint_path}") | |
return model | |
def generate_response(prompt, model, tokenizer, max_length=200): | |
""" | |
Generates a response using the fine-tuned model. | |
""" | |
# Tokenize the prompt | |
inputs = tokenizer(prompt, return_tensors="pt").input_ids | |
# Generate text | |
with torch.no_grad(): | |
outputs = model.generate(inputs, max_length=max_length, num_return_sequences=1, temperature=0.7) | |
# Decode the output | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
# ------------------------------- | |
# Application Layout | |
# ------------------------------- | |
st.title("One-Stop Gemma Model Fine-tuning, Quantization & Conversion UI") | |
st.markdown(""" | |
This application is designed for beginners in generative AI. | |
It allows you to fine-tune, quantize, and convert Gemma models with an intuitive UI. | |
You can upload your dataset, clean and preview your data, configure training parameters, and export your model in different formats. | |
""") | |
# Sidebar: Model selection and data upload | |
st.sidebar.header("Configuration") | |
# Model Selection | |
selected_model = st.sidebar.selectbox("Select Gemma Model", options=["Gemma-Small", "Gemma-Medium", "Gemma-Large"]) | |
if selected_model == "google/gemma-3-1b-it": | |
model_id = "google/gemma-3-1b-it" | |
elif selected_model == "google/gemma-3-4b-it": | |
model_id = "google/gemma-3-4b-it" | |
else: | |
model_id = "google/gemma-3-1b-it" | |
loading_placeholder = st.sidebar.empty() | |
loading_placeholder.info("Loading model...") | |
tokenizer, model = load_model(model_id, token) | |
loading_placeholder.success("Model loaded.") | |
# Dataset Upload | |
uploaded_file = st.sidebar.file_uploader("Upload Dataset (CSV, JSONL, TXT)", type=["csv", "jsonl", "txt"]) | |
data = None | |
if uploaded_file is not None: | |
file_ext = uploaded_file.name.split('.')[-1].lower() | |
data = preprocess_data(uploaded_file, file_ext) | |
st.sidebar.subheader("Dataset Preview:") | |
if isinstance(data, pd.DataFrame): | |
st.sidebar.dataframe(data.head()) | |
elif isinstance(data, list): | |
st.sidebar.write(data[:5]) | |
else: | |
st.sidebar.write(data) | |
else: | |
st.sidebar.info("Awaiting dataset upload.") | |
# Data Cleaning Options (for TXT files) | |
if uploaded_file is not None and file_ext == "txt": | |
st.sidebar.subheader("Data Cleaning Options") | |
lowercase_option = st.sidebar.checkbox("Convert to lowercase", value=True) | |
remove_punct = st.sidebar.checkbox("Remove punctuation", value=True) | |
cleaned_data = [clean_text(line, lowercase=lowercase_option, remove_punctuation=remove_punct) for line in data] | |
st.sidebar.text_area("Cleaned Data Preview", value="\n".join(cleaned_data[:5]), height=150) | |
# Main Tabs for Different Operations | |
tabs = st.tabs(["Fine-tuning", "Quantization", "Model Conversion"]) | |
# ------------------------------- | |
# Fine-tuning Tab | |
# ------------------------------- | |
with tabs[0]: | |
st.header("Fine-tuning") | |
st.markdown("Configure hyperparameters and start fine-tuning your Gemma model.") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
learning_rate = st.number_input("Learning Rate", value=1e-4, format="%.5f") | |
with col2: | |
batch_size = st.number_input("Batch Size", value=16, step=1) | |
with col3: | |
epochs = st.number_input("Epochs", value=3, step=1) | |
if st.button("Start Fine-tuning"): | |
if data is None: | |
st.error("Please upload a dataset first!") | |
else: | |
st.info("Starting fine-tuning...") | |
progress_bar = st.progress(0) | |
training_placeholder = st.empty() | |
loss_values = [] | |
accuracy_values = [] | |
# Simulate training loop (replace with your actual training code) | |
for epoch, losses, accs in simulate_training(epochs): | |
fig = plot_training_metrics(epoch, losses, accs) | |
training_placeholder.pyplot(fig) | |
progress_bar.progress(epoch/epochs) | |
st.success("Fine-tuning completed!") | |
# Save the fine-tuned model (for demonstration, saving state_dict) | |
if model: | |
torch.save(model.state_dict(), "fine_tuned_model.pt") | |
with open("fine_tuned_model.pt", "rb") as f: | |
st.download_button("Download Fine-tuned Model", data=f, file_name="fine_tuned_model.pt", mime="application/octet-stream") | |
else: | |
st.error("Model not loaded. Cannot save.") | |
# ------------------------------- | |
# Quantization Tab | |
# ------------------------------- | |
with tabs[1]: | |
st.header("Model Quantization") | |
st.markdown("Quantize your model to optimize for inference performance.") | |
quantize_choice = st.radio("Select Quantization Type", options=["Dynamic Quantization"], index=0) | |
if st.button("Apply Quantization"): | |
with st.spinner("Applying quantization..."): | |
quantized_model = quantize_model(model) | |
st.success("Model quantized successfully!") | |
torch.save(quantized_model.state_dict(), "quantized_model.pt") | |
with open("quantized_model.pt", "rb") as f: | |
st.download_button("Download Quantized Model", data=f, file_name="quantized_model.pt", mime="application/octet-stream") | |
# ------------------------------- | |
# Model Conversion Tab | |
# ------------------------------- | |
with tabs[2]: | |
st.header("Model Conversion") | |
st.markdown("Convert your model to a different format for deployment or optimization.") | |
conversion_option = st.selectbox("Select Conversion Format", options=["TorchScript", "ONNX"]) | |
if st.button("Convert Model"): | |
if conversion_option == "TorchScript": | |
with st.spinner("Converting to TorchScript..."): | |
ts_model = convert_to_torchscript(model) | |
ts_model.save("model_ts.pt") | |
st.success("Converted to TorchScript!") | |
with open("model_ts.pt", "rb") as f: | |
st.download_button("Download TorchScript Model", data=f, file_name="model_ts.pt", mime="application/octet-stream") | |
elif conversion_option == "ONNX": | |
with st.spinner("Converting to ONNX..."): | |
onnx_path = convert_to_onnx(model, "model.onnx") | |
st.success("Converted to ONNX!") | |
with open(onnx_path, "rb") as f: | |
st.download_button("Download ONNX Model", data=f, file_name="model.onnx", mime="application/octet-stream") | |
# ------------------------------- | |
# Response Generation Section | |
# ------------------------------- | |
st.header("Generate Responses with Fine-Tuned Model") | |
st.markdown("Use the fine-tuned model to generate text responses based on your prompts.") | |
# Check if the fine-tuned model exists | |
if os.path.exists("fine_tuned_model.pt"): | |
# Load the fine-tuned model | |
model = load_finetuned_model(model, "fine_tuned_model.pt") | |
# Input prompt for generating responses | |
prompt = st.text_area("Enter a prompt:", "Once upon a time...") | |
# Max length slider | |
max_length = st.slider("Max Response Length", min_value=50, max_value=500, value=200, step=10) | |
if st.button("Generate Response"): | |
with st.spinner("Generating response..."): | |
response = generate_response(prompt, model, tokenizer, max_length) | |
st.success("Generated Response:") | |
st.write(response) | |
else: | |
st.warning("Fine-tuned model not found. Please fine-tune the model first.") | |
# ------------------------------- | |
# Optional: Cloud Integration Snippet | |
# ------------------------------- | |
st.header("Cloud Integration") | |
st.markdown(""" | |
For large-scale training or model storage, consider integrating with Google Cloud Storage or Vertex AI. | |
Below is an example snippet for uploading your model to GCS: | |
""") | |
st.code(""" | |
from google.cloud import storage | |
def upload_to_gcs(bucket_name, source_file_name, destination_blob_name): | |
storage_client = storage.Client() | |
bucket = storage_client.bucket(bucket_name) | |
blob = bucket.blob(destination_blob_name) | |
blob.upload_from_filename(source_file_name) | |
print(f"Uploaded {source_file_name} to {destination_blob_name}") | |
# Example usage: | |
# upload_to_gcs("your-bucket-name", "fine_tuned_model.pt", "models/fine_tuned_model.pt") | |
""", language="python") | |