shakii's picture
Update app.py
8990b6e verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # For development - you should restrict this in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model and tokenizer
model_name = "fakespot-ai/roberta-base-ai-text-detection-v1"
#model_name = "SuperAnnotate/ai-detector"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
class TextRequest(BaseModel):
text: str
@app.post("/predict")
async def predict(request: TextRequest):
try:
# Tokenize the input text
inputs = tokenizer(request.text, return_tensors="pt", truncation=True, max_length=512)
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get the probability scores
human_prob = predictions[0][0].item()
ai_prob = predictions[0][1].item()
return {
"text": request.text,
"human_probability": round(human_prob * 100, 2),
"ai_probability": round(ai_prob * 100, 2),
"prediction": "AI-generated" if ai_prob > human_prob else "Human-written"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "AI Text Detection API is running"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)