File size: 1,193 Bytes
0c99d9c
fb54360
0c99d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1674c44
a94cbc1
0c99d9c
b889bb5
0c99d9c
 
a94cbc1
0c99d9c
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import json
import gradio as gr

from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline


tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")

classifier = pipeline(
  "text-classification",
  model=model,
  tokenizer=tokenizer,
  truncation=True,
  max_length=512,
  device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)

def predict(user_input: str):

    prediction = classifier(user_input)[0]
    
    return f"Label: {prediction['label']}\nProbability: {round(prediction['score'], 3)}"


textbox = gr.Textbox(placeholder="Enter user input for injection attack classification", lines=12)

interface = gr.Interface(
    inputs=textbox, fn=predict, outputs="text",
    title="Injection Attack Classifier",
    description="This web API flags if the text presented as input to an LLM qualifies to be an injection attack",
    allow_flagging="manual", flagging_options=["Useful", "Not Useful"]
)

with gr.Blocks() as demo:
    interface.launch()

demo.queue(concurrency_count=4)
demo.launch()