Spaces:
Build error
Build error
File size: 6,451 Bytes
2bdd84f e37cfd0 2bdd84f e37cfd0 2bdd84f e37cfd0 2bdd84f e37cfd0 2bdd84f e37cfd0 2bdd84f e37cfd0 2bdd84f e37cfd0 2bdd84f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import streamlit as st
import pandas as pd
import json
import os
from datetime import datetime
from utils import (
load_model,
get_hf_token,
simulate_training,
plot_training_metrics,
load_finetuned_model,
save_model
)
st.title("π₯ Fine-tune the Gemma Model")
# -------------------------------
# Finetuning Option Selection
# -------------------------------
finetune_option = st.radio("Select Finetuning Option", ["Fine-tune from scratch", "Refinetune existing model"])
# -------------------------------
# Model Selection Logic
# -------------------------------
selected_model = None
saved_model_path = None
if finetune_option == "Fine-tune from scratch":
# Display Hugging Face model list
model_list = [
"google/gemma-3-1b-pt",
"google/gemma-3-1b-it",
"google/gemma-3-4b-pt",
"google/gemma-3-4b-it",
"google/gemma-3-12b-pt",
"google/gemma-3-12b-it",
"google/gemma-3-27b-pt",
"google/gemma-3-27b-it"
]
selected_model = st.selectbox("π οΈ Select Gemma Model to Fine-tune", model_list)
elif finetune_option == "Refinetune existing model":
# Dynamically list all saved models from the /models folder
model_dir = "models"
if os.path.exists(model_dir):
saved_models = [f for f in os.listdir(model_dir) if f.endswith(".pt")]
else:
saved_models = []
if saved_models:
saved_model_path = st.selectbox("Select a saved model to re-finetune", saved_models)
saved_model_path = os.path.join(model_dir, saved_model_path)
st.success(f"β
Selected model for refinement: `{saved_model_path}`")
else:
st.warning("β οΈ No saved models found! Switching to fine-tuning from scratch.")
finetune_option = "Fine-tune from scratch"
# -------------------------------
# Dataset Selection
# -------------------------------
st.subheader("π Dataset Selection")
dataset_option = st.radio("Choose dataset:", ["Upload New Dataset", "Use Existing Dataset (`train_data.csv`)"])
dataset_path = "datasets/train_data.csv"
if dataset_option == "Upload New Dataset":
uploaded_file = st.file_uploader("π€ Upload Dataset (CSV or JSON)", type=["csv", "json"])
if uploaded_file is not None:
if uploaded_file.name.endswith(".csv"):
new_data = pd.read_csv(uploaded_file)
elif uploaded_file.name.endswith(".json"):
json_data = json.load(uploaded_file)
new_data = pd.json_normalize(json_data)
else:
st.error("β Unsupported file format. Please upload CSV or JSON.")
st.stop()
if os.path.exists(dataset_path):
new_data.to_csv(dataset_path, mode='a', index=False, header=False)
st.success(f"β
Data appended to `{dataset_path}`!")
else:
new_data.to_csv(dataset_path, index=False)
st.success(f"β
Dataset saved as `{dataset_path}`!")
elif dataset_option == "Use Existing Dataset (`train_data.csv`)":
if os.path.exists(dataset_path):
st.success("β
Using existing `train_data.csv` for fine-tuning.")
else:
st.error("β `train_data.csv` not found! Please upload a new dataset.")
st.stop()
# -------------------------------
# Hyperparameters Configuration
# -------------------------------
st.subheader("π§ Hyperparameter Configuration")
learning_rate = st.number_input("π Learning Rate", value=1e-4, format="%.5f")
batch_size = st.number_input("π οΈ Batch Size", value=16, step=1)
epochs = st.number_input("β±οΈ Epochs", value=3, step=1)
# -------------------------------
# Fine-tuning Execution with Real-Time Visualization
# -------------------------------
if st.button("π Start Fine-tuning"):
st.info("Fine-tuning process initiated...")
hf_token = get_hf_token()
# Model loading logic
if finetune_option == "Refinetune existing model" and saved_model_path:
tokenizer, model = load_model("google/gemma-3-1b-it", hf_token)
model = load_finetuned_model(model, saved_model_path)
if model:
st.success(f"β
Loaded saved model: `{saved_model_path}` for refinement!")
else:
st.error("β Failed to load the saved model. Aborting.")
st.stop()
else:
if not selected_model:
st.error("β Please select a model to fine-tune.")
st.stop()
tokenizer, model = load_model(selected_model, hf_token)
if model:
st.success(f"β
Base model loaded: `{selected_model}`")
else:
st.error("β Failed to load the base model. Aborting.")
st.stop()
# Create placeholders for training progress
loss_chart = st.line_chart() # Loss curve
acc_chart = st.line_chart() # Accuracy curve
progress_text = st.empty()
# Simulate training loop with real-time visualization
losses_over_epochs = []
accuracies_over_epochs = []
for epoch, losses, accs in simulate_training(epochs, learning_rate, batch_size):
# Update training text
progress_text.text(f"Epoch {epoch}/{epochs} in progress...")
# Assume simulate_training returns overall average loss and accuracy per epoch
losses_over_epochs.append(losses) # e.g., average loss of the epoch
accuracies_over_epochs.append(accs) # e.g., average accuracy of the epoch
# Update real-time charts
loss_chart.add_rows(pd.DataFrame({"Loss": [losses]}))
acc_chart.add_rows(pd.DataFrame({"Accuracy": [accs]}))
progress_text.text("Fine-tuning completed!")
# Save fine-tuned model with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_identifier = selected_model if selected_model else os.path.basename(saved_model_path)
new_model_name = f"models/fine_tuned_model_{model_identifier.replace('/', '_')}_{timestamp}.pt"
saved_model_path = save_model(model, new_model_name)
if saved_model_path:
st.success(f"β
Fine-tuning completed! Model saved as `{saved_model_path}`")
model = load_finetuned_model(model, saved_model_path)
if model:
st.success("π οΈ Fine-tuned model loaded and ready for inference!")
else:
st.error("β Failed to load the fine-tuned model for inference.")
else:
st.error("β Failed to save the fine-tuned model.")
|