Dileep7729 commited on
Commit
610954a
·
verified ·
1 Parent(s): a16e363

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -12,13 +12,13 @@ print("Model loaded successfully.")
12
  # Step 2: Define the Inference Function
13
  def classify_image(image):
14
  """
15
- Classify an image as 'safe' or 'unsafe' with the corresponding percentage.
16
 
17
  Args:
18
  image (PIL.Image.Image): The input image.
19
 
20
  Returns:
21
- dict: A dictionary containing the main category (safe/unsafe) and its percentage.
22
  """
23
  # Define the main categories
24
  main_categories = ["safe", "unsafe"]
@@ -29,23 +29,23 @@ def classify_image(image):
29
  logits_per_image = outputs.logits_per_image # Image-text similarity scores
30
  probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
31
 
32
- # Extract the category with the highest probability
33
  safe_probability = probs[0][0].item() * 100 # Safe percentage
34
  unsafe_probability = probs[0][1].item() * 100 # Unsafe percentage
35
 
36
- # Determine the main category
37
- if safe_probability > unsafe_probability:
38
- return {"Category": "safe", "Probability": f"{safe_probability:.2f}%"}
39
- else:
40
- return {"Category": "unsafe", "Probability": f"{unsafe_probability:.2f}%"}
41
 
42
  # Step 3: Set Up Gradio Interface
43
  iface = gr.Interface(
44
  fn=classify_image,
45
  inputs=gr.Image(type="pil"),
46
- outputs="json",
47
  title="Content Safety Classification",
48
- description="Classify images as 'safe' or 'unsafe' with their respective percentage.",
49
  )
50
 
51
  # Step 4: Launch Gradio Interface
@@ -67,3 +67,4 @@ if __name__ == "__main__":
67
 
68
 
69
 
 
 
12
  # Step 2: Define the Inference Function
13
  def classify_image(image):
14
  """
15
+ Classify an image as 'safe' or 'unsafe' with probabilities and display as a progress bar.
16
 
17
  Args:
18
  image (PIL.Image.Image): The input image.
19
 
20
  Returns:
21
+ dict: A dictionary containing probabilities for 'safe' and 'unsafe'.
22
  """
23
  # Define the main categories
24
  main_categories = ["safe", "unsafe"]
 
29
  logits_per_image = outputs.logits_per_image # Image-text similarity scores
30
  probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
31
 
32
+ # Extract the probabilities
33
  safe_probability = probs[0][0].item() * 100 # Safe percentage
34
  unsafe_probability = probs[0][1].item() * 100 # Unsafe percentage
35
 
36
+ # Return probabilities as a dictionary for display in Gradio's Label component
37
+ return {
38
+ "safe": f"{safe_probability:.2f}%",
39
+ "unsafe": f"{unsafe_probability:.2f}%"
40
+ }
41
 
42
  # Step 3: Set Up Gradio Interface
43
  iface = gr.Interface(
44
  fn=classify_image,
45
  inputs=gr.Image(type="pil"),
46
+ outputs=gr.Label(label="Output"),
47
  title="Content Safety Classification",
48
+ description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.",
49
  )
50
 
51
  # Step 4: Launch Gradio Interface
 
67
 
68
 
69
 
70
+