Spaces:
Running
Running
import clip | |
import clip.model | |
from datasets import Dataset | |
import json | |
import numpy as np | |
import pandas as pd | |
from PIL import Image | |
from sklearn.model_selection import train_test_split | |
import streamlit as st | |
import time | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
import tqdm | |
import os | |
def model_training(): | |
dataset_path = st.session_state.get("selected_dataset", None) | |
if not dataset_path or dataset_path == "": | |
st.error("Please select a dataset to proceed.") | |
return | |
if not os.path.exists(f"annotations/{dataset_path}/annotations.json"): | |
st.error("No annotations found for the selected dataset.") | |
return | |
with open(f"annotations/{dataset_path}/annotations.json", "r") as f: | |
annotations_dict = json.load(f) | |
annotations_df = pd.DataFrame(annotations_dict.items(), columns=['image_path', 'annotation']) | |
annotations_df.columns = ['file_name', 'text'] | |
st.subheader("Data Preview") | |
st.dataframe(annotations_df.head(), use_container_width=True) | |
if len(annotations_df) < 2: | |
st.error("Not enough data to train the model.") | |
return | |
test_size = st.selectbox("Select Test Size", options=[0.1, 0.2, 0.3, 0.4, 0.5], index=1) | |
train_df, val_df = train_test_split(annotations_df, test_size=test_size, random_state=42) | |
if len(train_df) < 2: | |
st.error("Not enough data to train the model.") | |
st.write(f"Train Size: {len(train_df)} | Validation Size: {len(val_df)}") | |
col1, col2 = st.columns(2) | |
with col1: | |
optimizer = st.selectbox("Select Optimizer", options=optim.__all__, index=3) | |
optimizer = getattr(optim, optimizer) | |
with col2: | |
batch_size_options = [2, 4, 8, 16, 32, 64, 128] | |
ideal_batch_size = int(np.sqrt(len(train_df))) | |
if ideal_batch_size in batch_size_options: | |
ideal_batch_size_index = batch_size_options.index(ideal_batch_size) | |
else: | |
for batch_size in batch_size_options: | |
if batch_size > ideal_batch_size: | |
ideal_batch_size_index = batch_size_options.index(batch_size) - 1 | |
if ideal_batch_size_index < 0: | |
ideal_batch_size_index = 0 | |
break | |
batch_size = st.selectbox("Select Batch Size", options=[2, 4, 8, 16, 32, 64, 128], index=ideal_batch_size_index) | |
col1, col2 = st.columns(2) | |
with col1: | |
weight_decay = st.number_input("Weight Decay", value=0.3, format="%.5f") | |
with col2: | |
learning_rate = st.number_input("Learning Rate", value=1e-3, format="%.5f") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if st.button("Train", key="train_button", use_container_width=True, type="primary"): | |
def convert_models_to_fp32(model): | |
for p in model.parameters(): | |
p.data = p.data.float() | |
p.grad.data = p.grad.data.float() | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
with st.spinner("Loading Model..."): | |
model, preprocess = clip.load("ViT-B/32", device=device, jit=False) | |
clip.model.convert_weights(model) | |
loss_img = nn.CrossEntropyLoss() | |
loss_txt = nn.CrossEntropyLoss() | |
optimizer = optimizer(model.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-6, weight_decay=weight_decay) | |
def collate_fn(batch): | |
images = [] | |
texts = [] | |
for entry in batch: | |
img = entry['file_name'] | |
text = entry['text'] | |
images.append(img) | |
texts.append(text) | |
images = [preprocess(Image.open(img_path)) for img_path in images] | |
images = torch.stack(images) | |
return images, list(texts) | |
train_df['file_name'] = train_df['file_name'].str.strip() | |
val_df['file_name'] = val_df['file_name'].str.strip() | |
dataset = Dataset.from_pandas(train_df) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
collate_fn=collate_fn | |
) | |
val_dataset = Dataset.from_pandas(val_df) | |
val_dataloader = DataLoader( | |
val_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
collate_fn=collate_fn | |
) | |
def calculate_val_loss(model): | |
model.eval() | |
total_loss = 0 | |
with torch.no_grad(): | |
for batch_idx, (images, texts) in enumerate(val_dataloader): | |
texts = clip.tokenize(texts).to(device) | |
images = images.to(device) | |
texts = texts.to(device) | |
logits_per_image, logits_per_text = model(images, texts) | |
ground_truth = torch.arange(len(images)).to(device) | |
image_loss = loss_img(logits_per_image, ground_truth) | |
text_loss = loss_txt(logits_per_text, ground_truth) | |
total_loss += (image_loss + text_loss) / 2 | |
model.train() | |
return total_loss / len(val_dataloader) | |
step = 0 | |
progress_bar = st.progress(0, text=f"Model Training in progress... \nStep: {step}/{len(dataloader)} | {0 / len(dataloader)}% Completed | Loss: 0.0") | |
for batch_idx, (images, texts) in enumerate(dataloader): | |
optimizer.zero_grad() | |
texts = clip.tokenize(texts).to(device) | |
images = images.to(device) | |
texts = texts.to(device) | |
logits_per_image, logits_per_text = model(images, texts) | |
ground_truth = torch.arange(len(images)).to(device) | |
image_loss = loss_img(logits_per_image, ground_truth) | |
text_loss = loss_txt(logits_per_text, ground_truth) | |
total_loss = (image_loss + text_loss) / 2 | |
total_loss.backward() | |
if step % 20 == 0: | |
print("\nStep : ", step) | |
print("Total Loss : ", total_loss.item()) | |
val_loss = calculate_val_loss(model) | |
print("\nValidation Loss : ", val_loss.item()) | |
print("\n") | |
convert_models_to_fp32(model) | |
optimizer.step() | |
clip.model.convert_weights(model) | |
step += 1 | |
progress_bar.progress((batch_idx + 1) / len(dataloader), f"Model Training in progress... \nStep: {step}/{len(dataloader)} | {round((batch_idx + 1) / len(dataloader) * 100)}% Completed | Loss: {val_loss.item():.4f}") | |
st.toast("Training Completed!", icon="π") | |
with st.spinner("Saving Model..."): | |
finetuned_model = model.eval() | |
torch.save(finetuned_model.state_dict(), f"annotations/{dataset_path}/finetuned_model.pt") | |
st.success("Model Saved Successfully!") | |