Spaces:
Build error
Build error
import streamlit as st | |
import pandas as pd | |
import json | |
import os | |
from datetime import datetime | |
from utils import ( | |
load_model, | |
get_hf_token, | |
simulate_training, | |
plot_training_metrics, | |
load_finetuned_model, | |
save_model | |
) | |
st.title("π₯ Fine-tune the Gemma Model") | |
# ------------------------------- | |
# Finetuning Option Selection | |
# ------------------------------- | |
finetune_option = st.radio("Select Finetuning Option", ["Fine-tune from scratch", "Refinetune existing model"]) | |
# ------------------------------- | |
# Model Selection Logic | |
# ------------------------------- | |
selected_model = None | |
saved_model_path = None | |
if finetune_option == "Fine-tune from scratch": | |
# Display Hugging Face model list | |
model_list = [ | |
"google/gemma-3-1b-pt", | |
"google/gemma-3-1b-it", | |
"google/gemma-3-4b-pt", | |
"google/gemma-3-4b-it", | |
"google/gemma-3-12b-pt", | |
"google/gemma-3-12b-it", | |
"google/gemma-3-27b-pt", | |
"google/gemma-3-27b-it" | |
] | |
selected_model = st.selectbox("π οΈ Select Gemma Model to Fine-tune", model_list) | |
elif finetune_option == "Refinetune existing model": | |
# Dynamically list all saved models from the /models folder | |
model_dir = "models" | |
if os.path.exists(model_dir): | |
saved_models = [f for f in os.listdir(model_dir) if f.endswith(".pt")] | |
else: | |
saved_models = [] | |
if saved_models: | |
saved_model_path = st.selectbox("Select a saved model to re-finetune", saved_models) | |
saved_model_path = os.path.join(model_dir, saved_model_path) | |
st.success(f"β Selected model for refinement: `{saved_model_path}`") | |
else: | |
st.warning("β οΈ No saved models found! Switching to fine-tuning from scratch.") | |
finetune_option = "Fine-tune from scratch" | |
# ------------------------------- | |
# Dataset Selection | |
# ------------------------------- | |
st.subheader("π Dataset Selection") | |
dataset_option = st.radio("Choose dataset:", ["Upload New Dataset", "Use Existing Dataset (`train_data.csv`)"]) | |
dataset_path = "datasets/train_data.csv" | |
if dataset_option == "Upload New Dataset": | |
uploaded_file = st.file_uploader("π€ Upload Dataset (CSV or JSON)", type=["csv", "json"]) | |
if uploaded_file is not None: | |
if uploaded_file.name.endswith(".csv"): | |
new_data = pd.read_csv(uploaded_file) | |
elif uploaded_file.name.endswith(".json"): | |
json_data = json.load(uploaded_file) | |
new_data = pd.json_normalize(json_data) | |
else: | |
st.error("β Unsupported file format. Please upload CSV or JSON.") | |
st.stop() | |
if os.path.exists(dataset_path): | |
new_data.to_csv(dataset_path, mode='a', index=False, header=False) | |
st.success(f"β Data appended to `{dataset_path}`!") | |
else: | |
new_data.to_csv(dataset_path, index=False) | |
st.success(f"β Dataset saved as `{dataset_path}`!") | |
elif dataset_option == "Use Existing Dataset (`train_data.csv`)": | |
if os.path.exists(dataset_path): | |
st.success("β Using existing `train_data.csv` for fine-tuning.") | |
else: | |
st.error("β `train_data.csv` not found! Please upload a new dataset.") | |
st.stop() | |
# ------------------------------- | |
# Hyperparameters Configuration | |
# ------------------------------- | |
st.subheader("π§ Hyperparameter Configuration") | |
learning_rate = st.number_input("π Learning Rate", value=1e-4, format="%.5f") | |
batch_size = st.number_input("π οΈ Batch Size", value=16, step=1) | |
epochs = st.number_input("β±οΈ Epochs", value=3, step=1) | |
# ------------------------------- | |
# Fine-tuning Execution with Real-Time Visualization | |
# ------------------------------- | |
if st.button("π Start Fine-tuning"): | |
st.info("Fine-tuning process initiated...") | |
hf_token = get_hf_token() | |
# Model loading logic | |
if finetune_option == "Refinetune existing model" and saved_model_path: | |
tokenizer, model = load_model("google/gemma-3-1b-it", hf_token) | |
model = load_finetuned_model(model, saved_model_path) | |
if model: | |
st.success(f"β Loaded saved model: `{saved_model_path}` for refinement!") | |
else: | |
st.error("β Failed to load the saved model. Aborting.") | |
st.stop() | |
else: | |
if not selected_model: | |
st.error("β Please select a model to fine-tune.") | |
st.stop() | |
tokenizer, model = load_model(selected_model, hf_token) | |
if model: | |
st.success(f"β Base model loaded: `{selected_model}`") | |
else: | |
st.error("β Failed to load the base model. Aborting.") | |
st.stop() | |
# Create placeholders for training progress | |
loss_chart = st.line_chart() # Loss curve | |
acc_chart = st.line_chart() # Accuracy curve | |
progress_text = st.empty() | |
# Simulate training loop with real-time visualization | |
losses_over_epochs = [] | |
accuracies_over_epochs = [] | |
for epoch, losses, accs in simulate_training(epochs, learning_rate, batch_size): | |
# Update training text | |
progress_text.text(f"Epoch {epoch}/{epochs} in progress...") | |
# Assume simulate_training returns overall average loss and accuracy per epoch | |
losses_over_epochs.append(losses) # e.g., average loss of the epoch | |
accuracies_over_epochs.append(accs) # e.g., average accuracy of the epoch | |
# Update real-time charts | |
loss_chart.add_rows(pd.DataFrame({"Loss": [losses]})) | |
acc_chart.add_rows(pd.DataFrame({"Accuracy": [accs]})) | |
progress_text.text("Fine-tuning completed!") | |
# Save fine-tuned model with timestamp | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
model_identifier = selected_model if selected_model else os.path.basename(saved_model_path) | |
new_model_name = f"models/fine_tuned_model_{model_identifier.replace('/', '_')}_{timestamp}.pt" | |
saved_model_path = save_model(model, new_model_name) | |
if saved_model_path: | |
st.success(f"β Fine-tuning completed! Model saved as `{saved_model_path}`") | |
model = load_finetuned_model(model, saved_model_path) | |
if model: | |
st.success("π οΈ Fine-tuned model loaded and ready for inference!") | |
else: | |
st.error("β Failed to load the fine-tuned model for inference.") | |
else: | |
st.error("β Failed to save the fine-tuned model.") | |