fastapi-app / app.py
vaibhaviiii28's picture
Update app.py
b702e41 verified
raw
history blame
2.29 kB
from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
from transformers import pipeline
from PIL import Image
import joblib
import re
import string
import io
import os
import uvicorn
# βœ… Set Hugging Face Cache Directory (Fixes Permission Error)
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
# βœ… Initialize FastAPI
app = FastAPI()
# βœ… Load NSFW Image Classification Model (with custom cache directory)
pipe = pipeline("image-classification", model="LukeJacob2023/nsfw-image-detector", cache_dir="/tmp/hf_cache")
# βœ… Load Toxic Text Classification Model
try:
model = joblib.load("toxic_classifier.pkl")
vectorizer = joblib.load("vectorizer.pkl")
print("βœ… Model & Vectorizer Loaded Successfully!")
except Exception as e:
print(f"❌ Error: {e}")
exit(1)
# πŸ“Œ Text Input Data Model
class TextInput(BaseModel):
text: str
# πŸ”Ή Text Preprocessing Function
def preprocess_text(text):
text = text.lower()
text = re.sub(r'\d+', '', text) # Remove numbers
text = text.translate(str.maketrans('', '', string.punctuation)) # Remove punctuation
return text.strip()
# πŸ“Œ NSFW Image Classification API
@app.post("/classify_image/")
async def classify_image(file: UploadFile = File(...)):
try:
image = Image.open(io.BytesIO(await file.read()))
results = pipe(image)
classification_label = max(results, key=lambda x: x['score'])['label']
nsfw_labels = {"sexy", "porn", "hentai"}
nsfw_status = "NSFW" if classification_label in nsfw_labels else "SFW"
return {"status": nsfw_status, "results": results}
except Exception as e:
return {"error": str(e)}
# πŸ“Œ Toxic Text Classification API
@app.post("/classify_text/")
async def classify_text(data: TextInput):
try:
processed_text = preprocess_text(data.text)
text_vectorized = vectorizer.transform([processed_text])
prediction = model.predict(text_vectorized)
result = "Toxic" if prediction[0] == 1 else "Safe"
return {"prediction": result}
except Exception as e:
return {"error": str(e)}
# βœ… Run FastAPI using Uvicorn (Hugging Face requires port 7860)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)