|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import JSONResponse |
|
from pydantic import BaseModel |
|
import uvicorn |
|
import numpy as np |
|
import io |
|
from PIL import Image |
|
import base64 |
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import ViTImageProcessor, SwinForImageClassification |
|
import lightning as L |
|
import uuid |
|
import cv2 |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
label2id = {'fake': 0, 'real': 1} |
|
id2label = {0: 'fake', 1: 'real'} |
|
|
|
|
|
hyper_params = { |
|
"MODEL_CKPT": "microsoft/swin-small-patch4-window7-224", |
|
"num_labels": 2, |
|
"id2label": id2label, |
|
"label2id": label2id, |
|
} |
|
|
|
vit_img_processor = ViTImageProcessor.from_pretrained(hyper_params['MODEL_CKPT']) |
|
|
|
class DeepFakeModel(L.LightningModule): |
|
def __init__(self, hyperparams: dict): |
|
super().__init__() |
|
self.model = SwinForImageClassification.from_pretrained( |
|
hyperparams["MODEL_CKPT"], |
|
num_labels=hyperparams["num_labels"], |
|
id2label=hyperparams["id2label"], |
|
label2id=hyperparams["label2id"], |
|
ignore_mismatched_sizes=True |
|
) |
|
self.loss_fn = torch.nn.CrossEntropyLoss() |
|
|
|
def forward(self, pixel_values): |
|
output = self.model(pixel_values=pixel_values) |
|
return output.logits |
|
|
|
|
|
model = DeepFakeModel(hyper_params) |
|
state_dict = torch.load("deepfake_new_trained.pth", map_location=torch.device(device)) |
|
model.load_state_dict(state_dict) |
|
model.to(device) |
|
model.eval() |
|
print("Model loaded successfully") |
|
|
|
|
|
app = FastAPI(title="DeepFake Detector API", description="API for detecting deepfake images", version="1.0.0") |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
class ImageData(BaseModel): |
|
image: str |
|
|
|
class AnalysisResult(BaseModel): |
|
id: str |
|
isDeepfake: bool |
|
confidence: float |
|
details: str |
|
|
|
|
|
def preprocess_image(img): |
|
img = vit_img_processor(img, return_tensors='pt')['pixel_values'].to(device) |
|
return img |
|
|
|
|
|
|
|
|
|
face_net = cv2.dnn.readNetFromCaffe( |
|
"deploy.prototxt", |
|
"res10_300x300_ssd_iter_140000.caffemodel" |
|
) |
|
|
|
def detect_face_opencv(image: Image.Image) -> bool: |
|
"""Detect face using OpenCV DNN""" |
|
try: |
|
|
|
open_cv_image = np.array(image) |
|
open_cv_image = open_cv_image[:, :, ::-1].copy() |
|
|
|
(h, w) = open_cv_image.shape[:2] |
|
blob = cv2.dnn.blobFromImage(open_cv_image, 1.0, (300, 300), |
|
(104.0, 177.0, 123.0)) |
|
|
|
face_net.setInput(blob) |
|
detections = face_net.forward() |
|
|
|
|
|
for i in range(detections.shape[2]): |
|
confidence = detections[0, 0, i, 2] |
|
if confidence > 0.5: |
|
return True |
|
|
|
return False |
|
except Exception as e: |
|
print(f"Face detection error: {e}") |
|
return False |
|
|
|
def predict_deepfake(image): |
|
try: |
|
|
|
has_face = detect_face_opencv(image) |
|
if not has_face: |
|
return { |
|
"id": str(uuid.uuid4()), |
|
"isDeepfake": None, |
|
"confidence": 0.0, |
|
"details": "No face detected in the image. Cannot proceed with deepfake analysis." |
|
} |
|
|
|
|
|
img_tensor = preprocess_image(image) |
|
with torch.inference_mode(): |
|
logits = model(img_tensor) |
|
probabilities = F.softmax(logits, dim=-1) |
|
confidence, predicted_index = torch.max(probabilities, dim=-1) |
|
predicted_label = id2label[predicted_index.item()] |
|
|
|
details = "Deepfake detected." if predicted_label == "fake" else "Image appears to be real." |
|
return { |
|
"id": str(uuid.uuid4()), |
|
"isDeepfake": predicted_label == "fake", |
|
"confidence": round(confidence.item() * 100, 2), |
|
"details": details |
|
} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/analyze-base64", response_model=AnalysisResult) |
|
async def analyze_base64_image(data: ImageData): |
|
try: |
|
image_data = data.image.split("base64,")[-1] |
|
image_bytes = base64.b64decode(image_data) |
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
result = predict_deepfake(image) |
|
return JSONResponse(content=result) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") |
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "DeepFake Detector API is running"} |
|
|
|
if __name__ == "__main__": |
|
|
|
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) |
|
|
|
|
|
|