merve HF Staff commited on
Commit
49927a6
·
verified ·
1 Parent(s): f90f608

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, ShieldGemma2ForImageClassification
2
+ from PIL import Image
3
+ import requests
4
+ import torch
5
+ import gradio as gr
6
+ import spaces
7
+ model_id = "google/shieldgemma-2-4b-it"
8
+ model = ShieldGemma2ForImageClassification.from_pretrained(model_id).to("cuda")
9
+ processor = AutoProcessor.from_pretrained(model_id)
10
+
11
+ @spaces.GPU()
12
+ def infer(image, policies, policy_descriptions):
13
+ policies = policies.split(";")
14
+ policy_descriptions = policy_descriptions.split(";")
15
+ custom_policies = dict(zip(policies, policy_descriptions))
16
+ print(custom_policies)
17
+
18
+ inputs = processor(
19
+ images=[image],
20
+ custom_policies=custom_policies,
21
+ policies=policies,
22
+ return_tensors="pt",
23
+ ).to(model.device)
24
+
25
+ with torch.inference_mode():
26
+ output = model(**inputs)
27
+ print(output.probabilities)
28
+
29
+ outs = {}
30
+ for idx, policy in enumerate(output.probabilities):
31
+ yes_prob = policy[0]
32
+ no_prob = policy[1]
33
+
34
+ outs[f"Yes for {policies[idx]}"] = yes_prob
35
+ outs[f"No for {policies[idx]}"] = no_prob
36
+ print(outs)
37
+
38
+ return outs
39
+
40
+
41
+ IMG = """
42
+ <img src='https://storage.googleapis.com/gweb-developer-goog-blog-assets/images/SheildGemma2_WagtailBlog_RD1_V01a.original.png' />
43
+ """
44
+
45
+
46
+ with gr.Blocks() as demo:
47
+ gr.Markdown(IMG)
48
+ gr.Markdown("## ShieldGemma2 for Multimodal Safety")
49
+ gr.Markdown("ShieldGemma2 is a safety moderation model for vision language models. It can detect unsafe images. To use it, simply input an image, and provide policies. A policy is a description of what should be detected, and also provide names of policies.")
50
+ gr.Markdown("You can test it with an image of your choice and example policies provided below.")
51
+
52
+ with gr.Row():
53
+ with gr.Column():
54
+ image = gr.Image(type="pil")
55
+ policies = gr.Textbox(label = "Name of policies separated by semicolon")
56
+ policy_descriptions = gr.Textbox(label = "Description of policies separated by semicolon", lines=10)
57
+ btn = gr.Button("Submit")
58
+ with gr.Column():
59
+ outputs = gr.Label()
60
+
61
+ btn.click(fn=infer, inputs=[image, policies, policy_descriptions], outputs=outputs)
62
+ gr.Examples(examples = [["Sexually Explicit content;Dangerous content;Violent content", "The image shall not contain content that depicts explicit or graphic sexual acts.; The image shall not contain content that facilitates or encourages activities that could cause real-world harm (e.g., building firearms and explosive devices, promotion of terrorism, instructions for suicide).;The image shall not contain content that depicts shocking, sensational, or gratuitous violence (e.g., excessive blood and gore, gratuitous violence against animals, extreme injury or moment of death)."]],
63
+ inputs = [policies, policy_descriptions])
64
+
65
+
66
+ demo.launch(debug=True)
67
+