File size: 5,880 Bytes
d60cc72 8ce99dc d60cc72 8ce99dc d60cc72 8ce99dc d60cc72 2b4ae9b d60cc72 8ce99dc d60cc72 8ce99dc d60cc72 8ce99dc d60cc72 8ce99dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import uvicorn
import numpy as np
import io
from PIL import Image
import base64
import torch
import torch.nn.functional as F
from transformers import ViTImageProcessor, SwinForImageClassification
import lightning as L
import uuid
import cv2
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Label mappings
label2id = {'fake': 0, 'real': 1}
id2label = {0: 'fake', 1: 'real'}
# Load model
hyper_params = {
"MODEL_CKPT": "microsoft/swin-small-patch4-window7-224",
"num_labels": 2,
"id2label": id2label,
"label2id": label2id,
}
vit_img_processor = ViTImageProcessor.from_pretrained(hyper_params['MODEL_CKPT'])
class DeepFakeModel(L.LightningModule):
def __init__(self, hyperparams: dict):
super().__init__()
self.model = SwinForImageClassification.from_pretrained(
hyperparams["MODEL_CKPT"],
num_labels=hyperparams["num_labels"],
id2label=hyperparams["id2label"],
label2id=hyperparams["label2id"],
ignore_mismatched_sizes=True
)
self.loss_fn = torch.nn.CrossEntropyLoss()
def forward(self, pixel_values):
output = self.model(pixel_values=pixel_values)
return output.logits
# Load trained model
model = DeepFakeModel(hyper_params)
state_dict = torch.load("deepfake_new_trained.pth", map_location=torch.device(device))
model.load_state_dict(state_dict)
model.to(device)
model.eval()
print("Model loaded successfully")
# Initialize FastAPI app
app = FastAPI(title="DeepFake Detector API", description="API for detecting deepfake images", version="1.0.0")
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Update with frontend server address in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ImageData(BaseModel):
image: str # Base64 encoded image
class AnalysisResult(BaseModel):
id: str
isDeepfake: bool
confidence: float
details: str
def preprocess_image(img):
img = vit_img_processor(img, return_tensors='pt')['pixel_values'].to(device)
return img
# Load the face detector once
face_net = cv2.dnn.readNetFromCaffe(
"deploy.prototxt",
"res10_300x300_ssd_iter_140000.caffemodel"
)
def detect_face_opencv(image: Image.Image) -> bool:
"""Detect face using OpenCV DNN"""
try:
# Convert PIL Image to OpenCV format
open_cv_image = np.array(image)
open_cv_image = open_cv_image[:, :, ::-1].copy() # RGB to BGR
(h, w) = open_cv_image.shape[:2]
blob = cv2.dnn.blobFromImage(open_cv_image, 1.0, (300, 300),
(104.0, 177.0, 123.0))
face_net.setInput(blob)
detections = face_net.forward()
# Check if any detection has confidence > 0.5
for i in range(detections.shape[2]):
confidence = detections[0, 0, i, 2]
if confidence > 0.5:
return True # Face detected
return False # No face detected
except Exception as e:
print(f"Face detection error: {e}")
return False # Fail safe: assume no face
def predict_deepfake(image):
try:
# Step 1: Face Detection
has_face = detect_face_opencv(image)
if not has_face:
return {
"id": str(uuid.uuid4()),
"isDeepfake": None,
"confidence": 0.0,
"details": "No face detected in the image. Cannot proceed with deepfake analysis."
}
# Step 2: Deepfake Prediction (your original logic)
img_tensor = preprocess_image(image)
with torch.inference_mode():
logits = model(img_tensor)
probabilities = F.softmax(logits, dim=-1)
confidence, predicted_index = torch.max(probabilities, dim=-1)
predicted_label = id2label[predicted_index.item()]
details = "Deepfake detected." if predicted_label == "fake" else "Image appears to be real."
return {
"id": str(uuid.uuid4()),
"isDeepfake": predicted_label == "fake",
"confidence": round(confidence.item() * 100, 2),
"details": details
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
# @app.post("/api/analyze", response_model=AnalysisResult)
# async def analyze_image(file: UploadFile = File(...)):
# if not file.content_type.startswith("image/"):
# raise HTTPException(status_code=400, detail="File must be an image")
# try:
# contents = await file.read()
# image = Image.open(io.BytesIO(contents)).convert("RGB")
# result = predict_deepfake(image)
# return JSONResponse(content=result)
# except Exception as e:
# raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.post("/api/analyze-base64", response_model=AnalysisResult)
async def analyze_base64_image(data: ImageData):
try:
image_data = data.image.split("base64,")[-1]
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
result = predict_deepfake(image)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.get("/")
async def root():
return {"message": "DeepFake Detector API is running"}
if __name__ == "__main__":
# Remove uvicorn.run for Hugging Face Spaces
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
|