Dileep7729 commited on
Commit
a16e363
·
verified ·
1 Parent(s): a41b014

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -35
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- import gradio as gr
3
  from transformers import CLIPModel, CLIPProcessor
4
 
5
  # Step 1: Load Fine-Tuned Model from Hugging Face Model Hub
@@ -13,57 +12,40 @@ print("Model loaded successfully.")
13
  # Step 2: Define the Inference Function
14
  def classify_image(image):
15
  """
16
- Classify an image as 'safe' or 'unsafe' with probabilities and subcategories.
17
 
18
  Args:
19
  image (PIL.Image.Image): The input image.
20
 
21
  Returns:
22
- dict: A dictionary containing main categories (safe/unsafe) and their probabilities.
23
  """
24
- # Define the predefined categories
25
  main_categories = ["safe", "unsafe"]
26
- safe_subcategories = ["retail product", "other safe content"]
27
- unsafe_subcategories = ["harmful", "violent", "sexual", "self harm"]
28
 
29
  # Process the image with the main categories
30
- main_inputs = processor(text=main_categories, images=image, return_tensors="pt", padding=True)
31
- main_outputs = model(**main_inputs)
32
- logits_per_image = main_outputs.logits_per_image # Image-text similarity scores
33
- main_probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
 
 
 
 
34
 
35
  # Determine the main category
36
- main_result = {main_categories[i]: main_probs[0][i].item() for i in range(len(main_categories))}
37
- main_category = max(main_result, key=main_result.get) # Either "safe" or "unsafe"
38
-
39
- # Process the image with subcategories based on the main category
40
- subcategories = safe_subcategories if main_category == "safe" else unsafe_subcategories
41
- sub_inputs = processor(text=subcategories, images=image, return_tensors="pt", padding=True)
42
- sub_outputs = model(**sub_inputs)
43
- sub_logits = sub_outputs.logits_per_image
44
- sub_probs = sub_logits.softmax(dim=1) # Convert logits to probabilities
45
-
46
- # Create a structured result
47
- result = {
48
- "Main Category": main_category,
49
- "Main Probabilities": main_result,
50
- "Subcategory Probabilities": {
51
- subcategories[i]: sub_probs[0][i].item() for i in range(len(subcategories))
52
- }
53
- }
54
- return result
55
 
56
  # Step 3: Set Up Gradio Interface
57
  iface = gr.Interface(
58
  fn=classify_image,
59
  inputs=gr.Image(type="pil"),
60
  outputs="json",
61
- title="Enhanced Content Safety Classification",
62
- description=(
63
- "Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model. "
64
- "For 'safe', identify subcategories such as 'retail product'. "
65
- "For 'unsafe', identify subcategories such as 'harmful', 'violent', 'sexual', or 'self harm'."
66
- ),
67
  )
68
 
69
  # Step 4: Launch Gradio Interface
@@ -84,3 +66,4 @@ if __name__ == "__main__":
84
 
85
 
86
 
 
 
1
  import gradio as gr
 
2
  from transformers import CLIPModel, CLIPProcessor
3
 
4
  # Step 1: Load Fine-Tuned Model from Hugging Face Model Hub
 
12
  # Step 2: Define the Inference Function
13
  def classify_image(image):
14
  """
15
+ Classify an image as 'safe' or 'unsafe' with the corresponding percentage.
16
 
17
  Args:
18
  image (PIL.Image.Image): The input image.
19
 
20
  Returns:
21
+ dict: A dictionary containing the main category (safe/unsafe) and its percentage.
22
  """
23
+ # Define the main categories
24
  main_categories = ["safe", "unsafe"]
 
 
25
 
26
  # Process the image with the main categories
27
+ inputs = processor(text=main_categories, images=image, return_tensors="pt", padding=True)
28
+ outputs = model(**inputs)
29
+ logits_per_image = outputs.logits_per_image # Image-text similarity scores
30
+ probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
31
+
32
+ # Extract the category with the highest probability
33
+ safe_probability = probs[0][0].item() * 100 # Safe percentage
34
+ unsafe_probability = probs[0][1].item() * 100 # Unsafe percentage
35
 
36
  # Determine the main category
37
+ if safe_probability > unsafe_probability:
38
+ return {"Category": "safe", "Probability": f"{safe_probability:.2f}%"}
39
+ else:
40
+ return {"Category": "unsafe", "Probability": f"{unsafe_probability:.2f}%"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # Step 3: Set Up Gradio Interface
43
  iface = gr.Interface(
44
  fn=classify_image,
45
  inputs=gr.Image(type="pil"),
46
  outputs="json",
47
+ title="Content Safety Classification",
48
+ description="Classify images as 'safe' or 'unsafe' with their respective percentage.",
 
 
 
 
49
  )
50
 
51
  # Step 4: Launch Gradio Interface
 
66
 
67
 
68
 
69
+