Spaces:
Sleeping
Sleeping
File size: 4,219 Bytes
f228a1c 99351b6 338f4c1 4fa87d4 f228a1c 338f4c1 f228a1c 14c8502 f228a1c 14c8502 99351b6 5c19b8d 4fa87d4 5c19b8d 4fa87d4 5c19b8d 4fa87d4 5c19b8d 99351b6 338f4c1 99351b6 338f4c1 99351b6 f228a1c bc21776 338f4c1 f228a1c 338f4c1 14c8502 f228a1c 5c19b8d 4fa87d4 5c19b8d 4fa87d4 5c19b8d f228a1c 5c19b8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch
from detoxify import Detoxify
import asyncio
from fastapi.concurrency import run_in_threadpool
from typing import List
class Guardrail:
def __init__(self):
tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
self.classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
truncation=True,
max_length=512,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
async def guard(self, prompt):
return await run_in_threadpool(self.classifier, prompt)
def determine_level(self, label, score):
if label == "SAFE":
return 0, "safe"
else:
if score > 0.9:
return 4, "high"
elif score > 0.75:
return 3, "medium"
elif score > 0.5:
return 2, "low"
else:
return 1, "very low"
class TextPrompt(BaseModel):
prompt: str
class ClassificationResult(BaseModel):
label: str
score: float
level: int
severity_label: str
class ToxicityResult(BaseModel):
toxicity: float
severe_toxicity: float
obscene: float
threat: float
insult: float
identity_attack: float
class TopicBannerClassifier:
def __init__(self):
self.classifier = pipeline(
"zero-shot-classification",
model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
self.hypothesis_template = "This text is about {}"
async def classify(self, text, labels):
return await run_in_threadpool(
self.classifier,
text,
labels,
hypothesis_template=self.hypothesis_template,
multi_label=False
)
class TopicBannerRequest(BaseModel):
prompt: str
labels: List[str]
class TopicBannerResult(BaseModel):
sequence: str
labels: list
scores: list
app = FastAPI()
guardrail = Guardrail()
toxicity_classifier = Detoxify('original')
topic_banner_classifier = TopicBannerClassifier()
@app.post("/api/models/toxicity/classify", response_model=ToxicityResult)
async def classify_toxicity(text_prompt: TextPrompt):
try:
result = await run_in_threadpool(toxicity_classifier.predict, text_prompt.prompt)
return {
"toxicity": result['toxicity'],
"severe_toxicity": result['severe_toxicity'],
"obscene": result['obscene'],
"threat": result['threat'],
"insult": result['insult'],
"identity_attack": result['identity_attack']
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/models/PromptInjection/classify", response_model=ClassificationResult)
async def classify_text(text_prompt: TextPrompt):
try:
result = await guardrail.guard(text_prompt.prompt)
label = result[0]['label']
score = result[0]['score']
level, severity_label = guardrail.determine_level(label, score)
return {"label": label, "score": score, "level": level, "severity_label": severity_label}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
async def classify_topic_banner(request: TopicBannerRequest):
try:
result = await topic_banner_classifier.classify(request.prompt, request.labels)
return {
"sequence": result["sequence"],
"labels": result["labels"],
"scores": result["scores"]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |