from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel import uvicorn import numpy as np import io from PIL import Image import base64 import torch import torch.nn.functional as F from transformers import ViTImageProcessor, SwinForImageClassification import lightning as L import uuid import cv2 # Set device device = "cuda" if torch.cuda.is_available() else "cpu" # Label mappings label2id = {'fake': 0, 'real': 1} id2label = {0: 'fake', 1: 'real'} # Load model hyper_params = { "MODEL_CKPT": "microsoft/swin-small-patch4-window7-224", "num_labels": 2, "id2label": id2label, "label2id": label2id, } vit_img_processor = ViTImageProcessor.from_pretrained(hyper_params['MODEL_CKPT']) class DeepFakeModel(L.LightningModule): def __init__(self, hyperparams: dict): super().__init__() self.model = SwinForImageClassification.from_pretrained( hyperparams["MODEL_CKPT"], num_labels=hyperparams["num_labels"], id2label=hyperparams["id2label"], label2id=hyperparams["label2id"], ignore_mismatched_sizes=True ) self.loss_fn = torch.nn.CrossEntropyLoss() def forward(self, pixel_values): output = self.model(pixel_values=pixel_values) return output.logits # Load trained model model = DeepFakeModel(hyper_params) state_dict = torch.load("deepfake_new_trained.pth", map_location=torch.device(device)) model.load_state_dict(state_dict) model.to(device) model.eval() print("Model loaded successfully") # Initialize FastAPI app app = FastAPI(title="DeepFake Detector API", description="API for detecting deepfake images", version="1.0.0") # Configure CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], # Update with frontend server address in production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class ImageData(BaseModel): image: str # Base64 encoded image class AnalysisResult(BaseModel): id: str isDeepfake: bool confidence: float details: str def preprocess_image(img): img = vit_img_processor(img, return_tensors='pt')['pixel_values'].to(device) return img # Load the face detector once face_net = cv2.dnn.readNetFromCaffe( "deploy.prototxt", "res10_300x300_ssd_iter_140000.caffemodel" ) def detect_face_opencv(image: Image.Image) -> bool: """Detect face using OpenCV DNN""" try: # Convert PIL Image to OpenCV format open_cv_image = np.array(image) open_cv_image = open_cv_image[:, :, ::-1].copy() # RGB to BGR (h, w) = open_cv_image.shape[:2] blob = cv2.dnn.blobFromImage(open_cv_image, 1.0, (300, 300), (104.0, 177.0, 123.0)) face_net.setInput(blob) detections = face_net.forward() # Check if any detection has confidence > 0.5 for i in range(detections.shape[2]): confidence = detections[0, 0, i, 2] if confidence > 0.5: return True # Face detected return False # No face detected except Exception as e: print(f"Face detection error: {e}") return False # Fail safe: assume no face def predict_deepfake(image): try: # Step 1: Face Detection has_face = detect_face_opencv(image) if not has_face: return { "id": str(uuid.uuid4()), "isDeepfake": None, "confidence": 0.0, "details": "No face detected in the image. Cannot proceed with deepfake analysis." } # Step 2: Deepfake Prediction (your original logic) img_tensor = preprocess_image(image) with torch.inference_mode(): logits = model(img_tensor) probabilities = F.softmax(logits, dim=-1) confidence, predicted_index = torch.max(probabilities, dim=-1) predicted_label = id2label[predicted_index.item()] details = "Deepfake detected." if predicted_label == "fake" else "Image appears to be real." return { "id": str(uuid.uuid4()), "isDeepfake": predicted_label == "fake", "confidence": round(confidence.item() * 100, 2), "details": details } except Exception as e: raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}") # @app.post("/api/analyze", response_model=AnalysisResult) # async def analyze_image(file: UploadFile = File(...)): # if not file.content_type.startswith("image/"): # raise HTTPException(status_code=400, detail="File must be an image") # try: # contents = await file.read() # image = Image.open(io.BytesIO(contents)).convert("RGB") # result = predict_deepfake(image) # return JSONResponse(content=result) # except Exception as e: # raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") @app.post("/api/analyze-base64", response_model=AnalysisResult) async def analyze_base64_image(data: ImageData): try: image_data = data.image.split("base64,")[-1] image_bytes = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") result = predict_deepfake(image) return JSONResponse(content=result) except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") @app.get("/") async def root(): return {"message": "DeepFake Detector API is running"} if __name__ == "__main__": # Remove uvicorn.run for Hugging Face Spaces uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)