GemmaGuard / src /model.py
Jay Prajapati
fix: Default CPU Usage
94e805a
import io
import os
import torch
from PIL import Image
from dotenv import load_dotenv
import matplotlib.pyplot as plt
from torch.nn.functional import softmax
from transformers import AutoTokenizer, AutoModelForCausalLM
load_dotenv()
HF_TOKEN = os.environ.get("HF_TOKEN")
print(HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained("google/shieldgemma-2b", token = HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
"google/shieldgemma-2b",
device_map="auto",
token = HF_TOKEN
)
def get_safety_score(content, safety_policy):
max_length = 1024
content = content[:max_length]
prompt = f"""You are a policy expert trying to help determine whether a user
prompt is in violation of the defined safety policies.
<start_of_turn>
Content: {content.strip()}
<end_of_turn>
Our safety principle is defined in the below:
{safety_policy.strip()}
Does the content violate the above principle? Your answer must start
with 'Yes' or 'No'. And then walk through step by step to be sure we answer
correctly.
"""
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
with torch.no_grad():
logits = model(**inputs).logits
vocab = tokenizer.get_vocab()
selected_logits = logits[0, -1, [vocab['Yes'], vocab['No']]]
probabilities = softmax(selected_logits, dim=0)
score = probabilities[0].item()
return score
def saftey_scorer(text):
safety_policies = [
""" "No Harassment in Content": The content retrieved from the provided URL or document shall not contain or seek to promote malicious, intimidating, bullying, or abusive content targeting an individual or group (e.g., physical threats, denial of tragic events, disparaging victims of violence).""",
""" "No Violence in Content": The content retrieved from the provided URL or document shall not contain or promote any form of violence, harm, or threats against individuals or groups. This includes explicit depictions of violence, threats of physical harm, or any content that encourages violent acts.""",
""" "No Misinformation in Content": The content retrieved from the provided URL or document shall not contain or spread false or misleading information. This includes intentional misrepresentation of facts, events, or data. This policy also covers the prevention of fake news, conspiracy theories, and any content that intentionally manipulates or distorts the truth.""",
""" "No Adult Content in Content": The content retrieved from the provided URL or document shall not contain sexually explicit, pornographic, or any material intended to arouse sexual interest. This includes explicit descriptions, images, videos, or discussions of nudity or sexual acts that violate acceptable content guidelines."""
]
safety_scores = []
keys = ["Harassment", "Violence", "Misinformation", "Adult Content"]
for safety_policy in safety_policies:
score = get_safety_score(text, safety_policy)
safety_scores.append(score)
safety_scores = [round(x * 100, 2) for x in safety_scores]
plt.bar(keys, safety_scores)
plt.ylabel('Score')
plt.title('Safety Scores')
# Add the scores on top of the bars
for i, score in enumerate(safety_scores):
plt.text(i, score + 0.01, str(score) + '%', ha='center')
# Save the plot to a temporary file
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
image = Image.open(buf)
return image