Dileep7729 commited on
Commit
bbfef86
·
verified ·
1 Parent(s): ec20dc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -120
app.py CHANGED
@@ -1,129 +1,26 @@
1
  import os
2
- import zipfile
3
  import torch
4
- from torch import nn, optim
5
- from torch.utils.data import DataLoader, Dataset
6
  from torchvision import transforms
7
  from PIL import Image
8
  from transformers import CLIPModel, CLIPProcessor
9
  import gradio as gr
10
 
11
- # Ensure PyTorch is installed
12
- try:
13
- import torch
14
- except ModuleNotFoundError:
15
- print("PyTorch is not installed. Installing now...")
16
- os.system("pip install torch torchvision torchaudio")
17
- import torch
18
-
19
- # Step 1: Unzip the dataset
20
- if not os.path.exists("data"):
21
- os.makedirs("data")
22
-
23
- print("Extracting Data.zip...")
24
- with zipfile.ZipFile("Data.zip", 'r') as zip_ref:
25
- zip_ref.extractall("data")
26
- print("Extraction complete.")
27
-
28
- # Step 2: Dynamically find the 'safe' and 'unsafe' folders
29
- def find_dataset_path(root_dir):
30
- for root, dirs, files in os.walk(root_dir):
31
- if 'safe' in dirs and 'unsafe' in dirs:
32
- return root
33
- return None
34
-
35
- # Look for 'safe' and 'unsafe' inside 'data/Data'
36
- dataset_path = find_dataset_path("data/Data")
37
- if dataset_path is None:
38
- print("Debugging extracted structure:")
39
- for root, dirs, files in os.walk("data"):
40
- print(f"Root: {root}")
41
- print(f"Directories: {dirs}")
42
- print(f"Files: {files}")
43
- raise FileNotFoundError("Expected 'safe' and 'unsafe' folders not found inside 'data/Data'. Please check the Data.zip structure.")
44
- print(f"Dataset path found: {dataset_path}")
45
-
46
- # Step 3: Define Custom Dataset Class
47
- class CustomImageDataset(Dataset):
48
- def __init__(self, root_dir, transform=None):
49
- self.root_dir = root_dir
50
- self.transform = transform
51
- self.image_paths = []
52
- self.labels = []
53
-
54
- for label, folder in enumerate(["safe", "unsafe"]): # 0 = safe, 1 = unsafe
55
- folder_path = os.path.join(root_dir, folder)
56
- if not os.path.exists(folder_path):
57
- raise FileNotFoundError(f"Folder '{folder}' not found in '{root_dir}'")
58
- for filename in os.listdir(folder_path):
59
- if filename.endswith((".jpg", ".jpeg", ".png")): # Only load image files
60
- self.image_paths.append(os.path.join(folder_path, filename))
61
- self.labels.append(label)
62
-
63
- def __len__(self):
64
- return len(self.image_paths)
65
-
66
- def __getitem__(self, idx):
67
- image_path = self.image_paths[idx]
68
- image = Image.open(image_path).convert("RGB")
69
- label = self.labels[idx]
70
- if self.transform:
71
- image = self.transform(image)
72
- return image, label
73
-
74
- # Step 4: Data Transformations
75
- transform = transforms.Compose([
76
- transforms.Resize((224, 224)), # Resize to 224x224 pixels
77
- transforms.ToTensor(), # Convert to tensor
78
- transforms.Normalize((0.5,), (0.5,)), # Normalize image values
79
- ])
80
-
81
- # Step 5: Load the Dataset
82
- train_dataset = CustomImageDataset(dataset_path, transform=transform)
83
- train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
84
-
85
- # Debugging: Check the dataset
86
- print(f"Number of samples in the dataset: {len(train_dataset)}")
87
- if len(train_dataset) == 0:
88
- raise ValueError("The dataset is empty. Please check if 'Data.zip' is correctly unzipped and contains 'safe' and 'unsafe' folders.")
89
-
90
- # Step 6: Load Pretrained CLIP Model
91
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
92
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
93
-
94
- # Add a Classification Layer
95
- model.classifier = nn.Linear(model.visual_projection.out_features, 2) # 2 classes: safe, unsafe
96
-
97
- # Define Optimizer and Loss Function
98
- optimizer = optim.Adam(model.classifier.parameters(), lr=1e-4)
99
- criterion = nn.CrossEntropyLoss()
100
-
101
- # Step 7: Fine-Tune the Model
102
- model.train()
103
- for epoch in range(3): # Number of epochs
104
- total_loss = 0
105
- for images, labels in train_loader:
106
- optimizer.zero_grad()
107
- images = torch.stack([img.to(torch.float32) for img in images]) # Batch of images
108
- outputs = model.get_image_features(pixel_values=images)
109
- logits = model.classifier(outputs)
110
- loss = criterion(logits, labels)
111
- loss.backward()
112
- optimizer.step()
113
- total_loss += loss.item()
114
- print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")
115
-
116
- # Save the Fine-Tuned Model
117
- model.save_pretrained("fine-tuned-model")
118
- processor.save_pretrained("fine-tuned-model")
119
- print("Model fine-tuned and saved successfully.")
120
-
121
- # Step 8: Define Gradio Inference Function
122
  def classify_image(image, class_names):
123
- # Load Fine-Tuned Model
124
- model = CLIPModel.from_pretrained("fine-tuned-model")
125
- processor = CLIPProcessor.from_pretrained("fine-tuned-model")
126
-
127
  # Split class names from comma-separated input
128
  labels = [label.strip() for label in class_names.split(",") if label.strip()]
129
  if not labels:
@@ -139,7 +36,7 @@ def classify_image(image, class_names):
139
  result = {label: probs[0][i].item() for i, label in enumerate(labels)}
140
  return dict(sorted(result.items(), key=lambda item: item[1], reverse=True))
141
 
142
- # Step 9: Set Up Gradio Interface
143
  iface = gr.Interface(
144
  fn=classify_image,
145
  inputs=[
@@ -151,7 +48,7 @@ iface = gr.Interface(
151
  description="Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model.",
152
  )
153
 
154
- # Launch Gradio Interface
155
  if __name__ == "__main__":
156
  iface.launch()
157
 
@@ -166,3 +63,4 @@ if __name__ == "__main__":
166
 
167
 
168
 
 
 
1
  import os
 
2
  import torch
 
 
3
  from torchvision import transforms
4
  from PIL import Image
5
  from transformers import CLIPModel, CLIPProcessor
6
  import gradio as gr
7
 
8
+ # Step 1: Ensure Fine-Tuned Model is Available
9
+ fine_tuned_model_path = "fine-tuned-model"
10
+
11
+ if not os.path.exists(fine_tuned_model_path):
12
+ raise FileNotFoundError(
13
+ f"The fine-tuned model is missing. Ensure that the fine-tuned model files are available in the '{fine_tuned_model_path}' directory."
14
+ )
15
+
16
+ # Step 2: Load Fine-Tuned Model
17
+ print("Loading fine-tuned model...")
18
+ model = CLIPModel.from_pretrained(fine_tuned_model_path)
19
+ processor = CLIPProcessor.from_pretrained(fine_tuned_model_path)
20
+ print("Fine-tuned model loaded successfully.")
21
+
22
+ # Step 3: Define Gradio Inference Function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def classify_image(image, class_names):
 
 
 
 
24
  # Split class names from comma-separated input
25
  labels = [label.strip() for label in class_names.split(",") if label.strip()]
26
  if not labels:
 
36
  result = {label: probs[0][i].item() for i, label in enumerate(labels)}
37
  return dict(sorted(result.items(), key=lambda item: item[1], reverse=True))
38
 
39
+ # Step 4: Set Up Gradio Interface
40
  iface = gr.Interface(
41
  fn=classify_image,
42
  inputs=[
 
48
  description="Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model.",
49
  )
50
 
51
+ # Step 5: Launch Gradio Interface
52
  if __name__ == "__main__":
53
  iface.launch()
54
 
 
63
 
64
 
65
 
66
+