pujanpaudel's picture
Update app.py
2b4ae9b verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import uvicorn
import numpy as np
import io
from PIL import Image
import base64
import torch
import torch.nn.functional as F
from transformers import ViTImageProcessor, SwinForImageClassification
import lightning as L
import uuid
import cv2
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Label mappings
label2id = {'fake': 0, 'real': 1}
id2label = {0: 'fake', 1: 'real'}
# Load model
hyper_params = {
"MODEL_CKPT": "microsoft/swin-small-patch4-window7-224",
"num_labels": 2,
"id2label": id2label,
"label2id": label2id,
}
vit_img_processor = ViTImageProcessor.from_pretrained(hyper_params['MODEL_CKPT'])
class DeepFakeModel(L.LightningModule):
def __init__(self, hyperparams: dict):
super().__init__()
self.model = SwinForImageClassification.from_pretrained(
hyperparams["MODEL_CKPT"],
num_labels=hyperparams["num_labels"],
id2label=hyperparams["id2label"],
label2id=hyperparams["label2id"],
ignore_mismatched_sizes=True
)
self.loss_fn = torch.nn.CrossEntropyLoss()
def forward(self, pixel_values):
output = self.model(pixel_values=pixel_values)
return output.logits
# Load trained model
model = DeepFakeModel(hyper_params)
state_dict = torch.load("deepfake_new_trained.pth", map_location=torch.device(device))
model.load_state_dict(state_dict)
model.to(device)
model.eval()
print("Model loaded successfully")
# Initialize FastAPI app
app = FastAPI(title="DeepFake Detector API", description="API for detecting deepfake images", version="1.0.0")
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Update with frontend server address in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ImageData(BaseModel):
image: str # Base64 encoded image
class AnalysisResult(BaseModel):
id: str
isDeepfake: bool
confidence: float
details: str
def preprocess_image(img):
img = vit_img_processor(img, return_tensors='pt')['pixel_values'].to(device)
return img
# Load the face detector once
face_net = cv2.dnn.readNetFromCaffe(
"deploy.prototxt",
"res10_300x300_ssd_iter_140000.caffemodel"
)
def detect_face_opencv(image: Image.Image) -> bool:
"""Detect face using OpenCV DNN"""
try:
# Convert PIL Image to OpenCV format
open_cv_image = np.array(image)
open_cv_image = open_cv_image[:, :, ::-1].copy() # RGB to BGR
(h, w) = open_cv_image.shape[:2]
blob = cv2.dnn.blobFromImage(open_cv_image, 1.0, (300, 300),
(104.0, 177.0, 123.0))
face_net.setInput(blob)
detections = face_net.forward()
# Check if any detection has confidence > 0.5
for i in range(detections.shape[2]):
confidence = detections[0, 0, i, 2]
if confidence > 0.5:
return True # Face detected
return False # No face detected
except Exception as e:
print(f"Face detection error: {e}")
return False # Fail safe: assume no face
def predict_deepfake(image):
try:
# Step 1: Face Detection
has_face = detect_face_opencv(image)
if not has_face:
return {
"id": str(uuid.uuid4()),
"isDeepfake": None,
"confidence": 0.0,
"details": "No face detected in the image. Cannot proceed with deepfake analysis."
}
# Step 2: Deepfake Prediction (your original logic)
img_tensor = preprocess_image(image)
with torch.inference_mode():
logits = model(img_tensor)
probabilities = F.softmax(logits, dim=-1)
confidence, predicted_index = torch.max(probabilities, dim=-1)
predicted_label = id2label[predicted_index.item()]
details = "Deepfake detected." if predicted_label == "fake" else "Image appears to be real."
return {
"id": str(uuid.uuid4()),
"isDeepfake": predicted_label == "fake",
"confidence": round(confidence.item() * 100, 2),
"details": details
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
# @app.post("/api/analyze", response_model=AnalysisResult)
# async def analyze_image(file: UploadFile = File(...)):
# if not file.content_type.startswith("image/"):
# raise HTTPException(status_code=400, detail="File must be an image")
# try:
# contents = await file.read()
# image = Image.open(io.BytesIO(contents)).convert("RGB")
# result = predict_deepfake(image)
# return JSONResponse(content=result)
# except Exception as e:
# raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.post("/api/analyze-base64", response_model=AnalysisResult)
async def analyze_base64_image(data: ImageData):
try:
image_data = data.image.split("base64,")[-1]
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
result = predict_deepfake(image)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.get("/")
async def root():
return {"message": "DeepFake Detector API is running"}
if __name__ == "__main__":
# Remove uvicorn.run for Hugging Face Spaces
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)