Spaces:
Sleeping
Sleeping
File size: 4,202 Bytes
f228a1c 99351b6 338f4c1 f228a1c 338f4c1 f228a1c 14c8502 f228a1c 14c8502 99351b6 5c19b8d 99351b6 338f4c1 99351b6 338f4c1 99351b6 f228a1c bc21776 338f4c1 f228a1c 338f4c1 14c8502 f228a1c 5c19b8d f228a1c 5c19b8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch
from detoxify import Detoxify
import asyncio
from fastapi.concurrency import run_in_threadpool
class Guardrail:
def __init__(self):
tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
self.classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
truncation=True,
max_length=512,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
async def guard(self, prompt):
return await run_in_threadpool(self.classifier, prompt)
def determine_level(self, label, score):
if label == "SAFE":
return 0, "safe"
else:
if score > 0.9:
return 4, "high"
elif score > 0.75:
return 3, "medium"
elif score > 0.5:
return 2, "low"
else:
return 1, "very low"
class TextPrompt(BaseModel):
prompt: str
class ClassificationResult(BaseModel):
label: str
score: float
level: int
severity_label: str
class ToxicityResult(BaseModel):
toxicity: float
severe_toxicity: float
obscene: float
threat: float
insult: float
identity_attack: float
class TopicBannerClassifier:
def __init__(self):
self.classifier = pipeline(
"zero-shot-classification",
model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
self.hypothesis_template = "This text is about {}"
self.classes_verbalized = ["politics", "economy", "entertainment", "environment"]
async def classify(self, text):
return await run_in_threadpool(
self.classifier,
text,
self.classes_verbalized,
hypothesis_template=self.hypothesis_template,
multi_label=False
)
class TopicBannerResult(BaseModel):
sequence: str
labels: list
scores: list
app = FastAPI()
guardrail = Guardrail()
toxicity_classifier = Detoxify('original')
topic_banner_classifier = TopicBannerClassifier()
@app.post("/api/models/toxicity/classify", response_model=ToxicityResult)
async def classify_toxicity(text_prompt: TextPrompt):
try:
result = await run_in_threadpool(toxicity_classifier.predict, text_prompt.prompt)
return {
"toxicity": result['toxicity'],
"severe_toxicity": result['severe_toxicity'],
"obscene": result['obscene'],
"threat": result['threat'],
"insult": result['insult'],
"identity_attack": result['identity_attack']
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/models/PromptInjection/classify", response_model=ClassificationResult)
async def classify_text(text_prompt: TextPrompt):
try:
result = await guardrail.guard(text_prompt.prompt)
label = result[0]['label']
score = result[0]['score']
level, severity_label = guardrail.determine_level(label, score)
return {"label": label, "score": score, "level": level, "severity_label": severity_label}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
async def classify_topic_banner(text_prompt: TextPrompt):
try:
result = await topic_banner_classifier.classify(text_prompt.prompt)
return {
"sequence": result["sequence"],
"labels": result["labels"],
"scores": result["scores"]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |