File size: 3,437 Bytes
8f885c1
c83027c
8f885c1
 
c83027c
8f885c1
 
 
 
c83027c
 
 
 
 
 
 
8f885c1
 
 
 
c83027c
8f885c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94e805a
8f885c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import io
import os
import torch
from PIL import Image
from dotenv import load_dotenv
import matplotlib.pyplot as plt
from torch.nn.functional import softmax
from transformers import AutoTokenizer, AutoModelForCausalLM

load_dotenv()

HF_TOKEN = os.environ.get("HF_TOKEN")
print(HF_TOKEN)

tokenizer = AutoTokenizer.from_pretrained("google/shieldgemma-2b", token = HF_TOKEN)


model = AutoModelForCausalLM.from_pretrained(
    "google/shieldgemma-2b",
    device_map="auto",
    token = HF_TOKEN
)

def get_safety_score(content, safety_policy):

  max_length = 1024
  content = content[:max_length]

  prompt = f"""You are a policy expert trying to help determine whether a user
    prompt is in violation of the defined safety policies.

    <start_of_turn>
    Content: {content.strip()}
    <end_of_turn>

    Our safety principle is defined in the below:

    {safety_policy.strip()}

    Does the content violate the above principle? Your answer must start
    with 'Yes' or 'No'. And then walk through step by step to be sure we answer
    correctly.
  """

  inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
  with torch.no_grad():
    logits = model(**inputs).logits

  vocab = tokenizer.get_vocab()
  selected_logits = logits[0, -1, [vocab['Yes'], vocab['No']]]

  probabilities = softmax(selected_logits, dim=0)

  score = probabilities[0].item()
  return score

def saftey_scorer(text):
  safety_policies = [
      """ "No Harassment in Content": The content retrieved from the provided URL or document shall not contain or seek to promote malicious, intimidating, bullying, or abusive content targeting an individual or group (e.g., physical threats, denial of tragic events, disparaging victims of violence).""",
      """ "No Violence in Content": The content retrieved from the provided URL or document shall not contain or promote any form of violence, harm, or threats against individuals or groups. This includes explicit depictions of violence, threats of physical harm, or any content that encourages violent acts.""",
      """ "No Misinformation in Content": The content retrieved from the provided URL or document shall not contain or spread false or misleading information. This includes intentional misrepresentation of facts, events, or data. This policy also covers the prevention of fake news, conspiracy theories, and any content that intentionally manipulates or distorts the truth.""",
      """ "No Adult Content in Content": The content retrieved from the provided URL or document shall not contain sexually explicit, pornographic, or any material intended to arouse sexual interest. This includes explicit descriptions, images, videos, or discussions of nudity or sexual acts that violate acceptable content guidelines."""
  ]

  safety_scores = []
  keys = ["Harassment", "Violence", "Misinformation", "Adult Content"]

  for safety_policy in safety_policies:
    score = get_safety_score(text, safety_policy)
    safety_scores.append(score)

  safety_scores = [round(x * 100, 2) for x in safety_scores]

  plt.bar(keys, safety_scores)
  plt.ylabel('Score')
  plt.title('Safety Scores')

  # Add the scores on top of the bars
  for i, score in enumerate(safety_scores):
    plt.text(i, score + 0.01, str(score) + '%', ha='center')

  # Save the plot to a temporary file
  buf = io.BytesIO()
  plt.savefig(buf, format='png')
  buf.seek(0)
  plt.close()

  image = Image.open(buf)
  return image