|
import os |
|
import zipfile |
|
import torch |
|
from torch import nn, optim |
|
from torch.utils.data import DataLoader, Dataset |
|
from torchvision import transforms |
|
from PIL import Image |
|
from transformers import CLIPModel, CLIPProcessor |
|
import gradio as gr |
|
|
|
|
|
try: |
|
import torch |
|
except ModuleNotFoundError: |
|
print("PyTorch is not installed. Installing now...") |
|
os.system("pip install torch torchvision torchaudio") |
|
import torch |
|
|
|
|
|
if not os.path.exists("data"): |
|
os.makedirs("data") |
|
|
|
print("Extracting Data.zip...") |
|
with zipfile.ZipFile("Data.zip", 'r') as zip_ref: |
|
zip_ref.extractall("data") |
|
print("Extraction complete.") |
|
|
|
|
|
def find_dataset_path(root_dir): |
|
for root, dirs, files in os.walk(root_dir): |
|
if 'safe' in dirs and 'unsafe' in dirs: |
|
return root |
|
return None |
|
|
|
|
|
dataset_path = find_dataset_path("data/Data") |
|
if dataset_path is None: |
|
print("Debugging extracted structure:") |
|
for root, dirs, files in os.walk("data"): |
|
print(f"Root: {root}") |
|
print(f"Directories: {dirs}") |
|
print(f"Files: {files}") |
|
raise FileNotFoundError("Expected 'safe' and 'unsafe' folders not found inside 'data/Data'. Please check the Data.zip structure.") |
|
print(f"Dataset path found: {dataset_path}") |
|
|
|
|
|
class CustomImageDataset(Dataset): |
|
def __init__(self, root_dir, transform=None): |
|
self.root_dir = root_dir |
|
self.transform = transform |
|
self.image_paths = [] |
|
self.labels = [] |
|
|
|
for label, folder in enumerate(["safe", "unsafe"]): |
|
folder_path = os.path.join(root_dir, folder) |
|
if not os.path.exists(folder_path): |
|
raise FileNotFoundError(f"Folder '{folder}' not found in '{root_dir}'") |
|
for filename in os.listdir(folder_path): |
|
if filename.endswith((".jpg", ".jpeg", ".png")): |
|
self.image_paths.append(os.path.join(folder_path, filename)) |
|
self.labels.append(label) |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def __getitem__(self, idx): |
|
image_path = self.image_paths[idx] |
|
image = Image.open(image_path).convert("RGB") |
|
label = self.labels[idx] |
|
if self.transform: |
|
image = self.transform(image) |
|
return image, label |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5,), (0.5,)), |
|
]) |
|
|
|
|
|
train_dataset = CustomImageDataset(dataset_path, transform=transform) |
|
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) |
|
|
|
|
|
print(f"Number of samples in the dataset: {len(train_dataset)}") |
|
if len(train_dataset) == 0: |
|
raise ValueError("The dataset is empty. Please check if 'Data.zip' is correctly unzipped and contains 'safe' and 'unsafe' folders.") |
|
|
|
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
|
|
model.classifier = nn.Linear(model.visual_projection.out_features, 2) |
|
|
|
|
|
optimizer = optim.Adam(model.classifier.parameters(), lr=1e-4) |
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
model.train() |
|
for epoch in range(3): |
|
total_loss = 0 |
|
for images, labels in train_loader: |
|
optimizer.zero_grad() |
|
images = torch.stack([img.to(torch.float32) for img in images]) |
|
outputs = model.get_image_features(pixel_values=images) |
|
logits = model.classifier(outputs) |
|
loss = criterion(logits, labels) |
|
loss.backward() |
|
optimizer.step() |
|
total_loss += loss.item() |
|
print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}") |
|
|
|
|
|
model.save_pretrained("fine-tuned-model") |
|
processor.save_pretrained("fine-tuned-model") |
|
print("Model fine-tuned and saved successfully.") |
|
|
|
|
|
def classify_image(image, class_names): |
|
|
|
model = CLIPModel.from_pretrained("fine-tuned-model") |
|
processor = CLIPProcessor.from_pretrained("fine-tuned-model") |
|
|
|
|
|
labels = [label.strip() for label in class_names.split(",") if label.strip()] |
|
if not labels: |
|
return {"Error": "Please enter at least one valid class name."} |
|
|
|
|
|
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) |
|
outputs = model(**inputs) |
|
logits_per_image = outputs.logits_per_image |
|
probs = logits_per_image.softmax(dim=1) |
|
|
|
|
|
result = {label: probs[0][i].item() for i, label in enumerate(labels)} |
|
return dict(sorted(result.items(), key=lambda item: item[1], reverse=True)) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=classify_image, |
|
inputs=[ |
|
gr.Image(type="pil"), |
|
gr.Textbox(label="Possible class names (comma-separated)", placeholder="e.g., safe, unsafe") |
|
], |
|
outputs=gr.Label(num_top_classes=2), |
|
title="Content Safety Classification", |
|
description="Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model.", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|