mohitkumarrajbadi's picture
New Improvement in Pages
e37cfd0
import streamlit as st
import pandas as pd
import json
import os
from datetime import datetime
from utils import (
load_model,
get_hf_token,
simulate_training,
plot_training_metrics,
load_finetuned_model,
save_model
)
st.title("πŸ”₯ Fine-tune the Gemma Model")
# -------------------------------
# Finetuning Option Selection
# -------------------------------
finetune_option = st.radio("Select Finetuning Option", ["Fine-tune from scratch", "Refinetune existing model"])
# -------------------------------
# Model Selection Logic
# -------------------------------
selected_model = None
saved_model_path = None
if finetune_option == "Fine-tune from scratch":
# Display Hugging Face model list
model_list = [
"google/gemma-3-1b-pt",
"google/gemma-3-1b-it",
"google/gemma-3-4b-pt",
"google/gemma-3-4b-it",
"google/gemma-3-12b-pt",
"google/gemma-3-12b-it",
"google/gemma-3-27b-pt",
"google/gemma-3-27b-it"
]
selected_model = st.selectbox("πŸ› οΈ Select Gemma Model to Fine-tune", model_list)
elif finetune_option == "Refinetune existing model":
# Dynamically list all saved models from the /models folder
model_dir = "models"
if os.path.exists(model_dir):
saved_models = [f for f in os.listdir(model_dir) if f.endswith(".pt")]
else:
saved_models = []
if saved_models:
saved_model_path = st.selectbox("Select a saved model to re-finetune", saved_models)
saved_model_path = os.path.join(model_dir, saved_model_path)
st.success(f"βœ… Selected model for refinement: `{saved_model_path}`")
else:
st.warning("⚠️ No saved models found! Switching to fine-tuning from scratch.")
finetune_option = "Fine-tune from scratch"
# -------------------------------
# Dataset Selection
# -------------------------------
st.subheader("πŸ“š Dataset Selection")
dataset_option = st.radio("Choose dataset:", ["Upload New Dataset", "Use Existing Dataset (`train_data.csv`)"])
dataset_path = "datasets/train_data.csv"
if dataset_option == "Upload New Dataset":
uploaded_file = st.file_uploader("πŸ“€ Upload Dataset (CSV or JSON)", type=["csv", "json"])
if uploaded_file is not None:
if uploaded_file.name.endswith(".csv"):
new_data = pd.read_csv(uploaded_file)
elif uploaded_file.name.endswith(".json"):
json_data = json.load(uploaded_file)
new_data = pd.json_normalize(json_data)
else:
st.error("❌ Unsupported file format. Please upload CSV or JSON.")
st.stop()
if os.path.exists(dataset_path):
new_data.to_csv(dataset_path, mode='a', index=False, header=False)
st.success(f"βœ… Data appended to `{dataset_path}`!")
else:
new_data.to_csv(dataset_path, index=False)
st.success(f"βœ… Dataset saved as `{dataset_path}`!")
elif dataset_option == "Use Existing Dataset (`train_data.csv`)":
if os.path.exists(dataset_path):
st.success("βœ… Using existing `train_data.csv` for fine-tuning.")
else:
st.error("❌ `train_data.csv` not found! Please upload a new dataset.")
st.stop()
# -------------------------------
# Hyperparameters Configuration
# -------------------------------
st.subheader("πŸ”§ Hyperparameter Configuration")
learning_rate = st.number_input("πŸ“Š Learning Rate", value=1e-4, format="%.5f")
batch_size = st.number_input("πŸ› οΈ Batch Size", value=16, step=1)
epochs = st.number_input("⏱️ Epochs", value=3, step=1)
# -------------------------------
# Fine-tuning Execution with Real-Time Visualization
# -------------------------------
if st.button("πŸš€ Start Fine-tuning"):
st.info("Fine-tuning process initiated...")
hf_token = get_hf_token()
# Model loading logic
if finetune_option == "Refinetune existing model" and saved_model_path:
tokenizer, model = load_model("google/gemma-3-1b-it", hf_token)
model = load_finetuned_model(model, saved_model_path)
if model:
st.success(f"βœ… Loaded saved model: `{saved_model_path}` for refinement!")
else:
st.error("❌ Failed to load the saved model. Aborting.")
st.stop()
else:
if not selected_model:
st.error("❌ Please select a model to fine-tune.")
st.stop()
tokenizer, model = load_model(selected_model, hf_token)
if model:
st.success(f"βœ… Base model loaded: `{selected_model}`")
else:
st.error("❌ Failed to load the base model. Aborting.")
st.stop()
# Create placeholders for training progress
loss_chart = st.line_chart() # Loss curve
acc_chart = st.line_chart() # Accuracy curve
progress_text = st.empty()
# Simulate training loop with real-time visualization
losses_over_epochs = []
accuracies_over_epochs = []
for epoch, losses, accs in simulate_training(epochs, learning_rate, batch_size):
# Update training text
progress_text.text(f"Epoch {epoch}/{epochs} in progress...")
# Assume simulate_training returns overall average loss and accuracy per epoch
losses_over_epochs.append(losses) # e.g., average loss of the epoch
accuracies_over_epochs.append(accs) # e.g., average accuracy of the epoch
# Update real-time charts
loss_chart.add_rows(pd.DataFrame({"Loss": [losses]}))
acc_chart.add_rows(pd.DataFrame({"Accuracy": [accs]}))
progress_text.text("Fine-tuning completed!")
# Save fine-tuned model with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_identifier = selected_model if selected_model else os.path.basename(saved_model_path)
new_model_name = f"models/fine_tuned_model_{model_identifier.replace('/', '_')}_{timestamp}.pt"
saved_model_path = save_model(model, new_model_name)
if saved_model_path:
st.success(f"βœ… Fine-tuning completed! Model saved as `{saved_model_path}`")
model = load_finetuned_model(model, saved_model_path)
if model:
st.success("πŸ› οΈ Fine-tuned model loaded and ready for inference!")
else:
st.error("❌ Failed to load the fine-tuned model for inference.")
else:
st.error("❌ Failed to save the fine-tuned model.")