Dileep7729 commited on
Commit
d5de525
·
verified ·
1 Parent(s): ab3b271

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -33
app.py CHANGED
@@ -20,63 +20,44 @@ except Exception as e:
20
  def classify_image(image):
21
  """
22
  Classify an image as 'safe' or 'unsafe' and return probabilities.
23
-
24
- Args:
25
- image (PIL.Image.Image): Uploaded image.
26
-
27
- Returns:
28
- dict: Classification results or an error message.
29
  """
30
  try:
31
- print("Starting image classification...")
32
-
33
- # Validate input
34
  if image is None:
35
  raise ValueError("No image provided. Please upload a valid image.")
36
 
37
- # Validate image format
38
- if not hasattr(image, "convert"):
39
- raise ValueError("Invalid image format. Please upload a valid image (JPEG, PNG, etc.).")
40
-
41
  # Define categories
42
  categories = ["safe", "unsafe"]
43
 
44
- # Process the image with the processor
45
- print("Processing the image...")
46
  inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
47
- print(f"Processed inputs: {inputs}")
48
 
49
- # Run inference with the model
50
- print("Running model inference...")
51
  outputs = model(**inputs)
52
- logits_per_image = outputs.logits_per_image # Image-text similarity scores
53
- print(f"Logits per image: {logits_per_image}")
54
 
55
- # Apply softmax to convert logits to probabilities
56
- probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
57
- print(f"Softmax probabilities: {probs}")
58
 
59
  # Extract probabilities for each category
60
- safe_prob = probs[0][0].item()
61
- unsafe_prob = probs[0][1].item()
62
 
63
  # Normalize probabilities to ensure they sum to 100%
64
- total_prob = safe_prob + unsafe_prob
65
- safe_percentage = (safe_prob / total_prob) * 100
66
- unsafe_percentage = (unsafe_prob / total_prob) * 100
67
- print(f"Normalized percentages: safe={safe_percentage}, unsafe={unsafe_percentage}")
68
 
69
- # Return results
70
  return {
71
- "safe": safe_percentage,
72
- "unsafe": unsafe_percentage
73
  }
74
 
75
  except Exception as e:
76
- print(f"Error during classification: {e}")
77
  return {"Error": str(e)}
78
 
79
 
 
80
  # Step 3: Set Up Gradio Interface
81
  iface = gr.Interface(
82
  fn=classify_image,
 
20
  def classify_image(image):
21
  """
22
  Classify an image as 'safe' or 'unsafe' and return probabilities.
 
 
 
 
 
 
23
  """
24
  try:
 
 
 
25
  if image is None:
26
  raise ValueError("No image provided. Please upload a valid image.")
27
 
 
 
 
 
28
  # Define categories
29
  categories = ["safe", "unsafe"]
30
 
31
+ # Process the image
 
32
  inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
 
33
 
34
+ # Run inference
 
35
  outputs = model(**inputs)
 
 
36
 
37
+ # Extract logits and apply softmax
38
+ logits_per_image = outputs.logits_per_image # Image-text similarity scores
39
+ probs = logits_per_image.softmax(dim=1).detach().numpy() # Convert logits to probabilities
40
 
41
  # Extract probabilities for each category
42
+ safe_prob = probs[0][0] # Safe probability
43
+ unsafe_prob = probs[0][1] # Unsafe probability
44
 
45
  # Normalize probabilities to ensure they sum to 100%
46
+ total = safe_prob + unsafe_prob
47
+ safe_percentage = (safe_prob / total) * 100
48
+ unsafe_percentage = (unsafe_prob / total) * 100
 
49
 
50
+ # Return results as percentages
51
  return {
52
+ "safe": round(safe_percentage, 2), # Rounded to 2 decimal places
53
+ "unsafe": round(unsafe_percentage, 2)
54
  }
55
 
56
  except Exception as e:
 
57
  return {"Error": str(e)}
58
 
59
 
60
+
61
  # Step 3: Set Up Gradio Interface
62
  iface = gr.Interface(
63
  fn=classify_image,