File size: 4,202 Bytes
f228a1c
 
 
 
99351b6
338f4c1
 
f228a1c
 
 
 
 
 
 
 
 
 
 
 
 
 
338f4c1
 
f228a1c
14c8502
 
 
 
 
 
 
 
 
 
 
 
 
f228a1c
 
 
14c8502
 
 
 
 
 
99351b6
 
 
 
 
 
 
 
5c19b8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99351b6
338f4c1
99351b6
338f4c1
99351b6
 
 
 
 
 
 
 
 
 
f228a1c
bc21776
338f4c1
f228a1c
338f4c1
14c8502
 
 
 
f228a1c
 
 
5c19b8d
 
 
 
 
 
 
 
 
 
 
 
f228a1c
 
5c19b8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch
from detoxify import Detoxify
import asyncio
from fastapi.concurrency import run_in_threadpool

class Guardrail:
    def __init__(self):
        tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
        model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
        self.classifier = pipeline(
            "text-classification",
            model=model,
            tokenizer=tokenizer,
            truncation=True,
            max_length=512,
            device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )

    async def guard(self, prompt):
        return await run_in_threadpool(self.classifier, prompt)

    def determine_level(self, label, score):
        if label == "SAFE":
            return 0, "safe"
        else:
            if score > 0.9:
                return 4, "high"
            elif score > 0.75:
                return 3, "medium"
            elif score > 0.5:
                return 2, "low"
            else:
                return 1, "very low"

class TextPrompt(BaseModel):
    prompt: str

class ClassificationResult(BaseModel):
    label: str
    score: float
    level: int
    severity_label: str

class ToxicityResult(BaseModel):
    toxicity: float
    severe_toxicity: float
    obscene: float
    threat: float
    insult: float
    identity_attack: float

class TopicBannerClassifier:
    def __init__(self):
        self.classifier = pipeline(
            "zero-shot-classification",
            model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
            device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )
        self.hypothesis_template = "This text is about {}"
        self.classes_verbalized = ["politics", "economy", "entertainment", "environment"]

    async def classify(self, text):
        return await run_in_threadpool(
            self.classifier,
            text,
            self.classes_verbalized,
            hypothesis_template=self.hypothesis_template,
            multi_label=False
        )

class TopicBannerResult(BaseModel):
    sequence: str
    labels: list
    scores: list

app = FastAPI()
guardrail = Guardrail()
toxicity_classifier = Detoxify('original')
topic_banner_classifier = TopicBannerClassifier()

@app.post("/api/models/toxicity/classify", response_model=ToxicityResult)
async def classify_toxicity(text_prompt: TextPrompt):
    try:
        result = await run_in_threadpool(toxicity_classifier.predict, text_prompt.prompt)
        return {
            "toxicity": result['toxicity'],
            "severe_toxicity": result['severe_toxicity'],
            "obscene": result['obscene'],
            "threat": result['threat'],
            "insult": result['insult'],
            "identity_attack": result['identity_attack']
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/models/PromptInjection/classify", response_model=ClassificationResult)
async def classify_text(text_prompt: TextPrompt):
    try:
        result = await guardrail.guard(text_prompt.prompt)
        label = result[0]['label']
        score = result[0]['score']
        level, severity_label = guardrail.determine_level(label, score)
        return {"label": label, "score": score, "level": level, "severity_label": severity_label}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
async def classify_topic_banner(text_prompt: TextPrompt):
    try:
        result = await topic_banner_classifier.classify(text_prompt.prompt)
        return {
            "sequence": result["sequence"],
            "labels": result["labels"],
            "scores": result["scores"]
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)