File size: 5,880 Bytes
d60cc72
 
 
 
 
 
 
 
 
 
 
8ce99dc
d60cc72
 
8ce99dc
 
d60cc72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ce99dc
d60cc72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b4ae9b
d60cc72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ce99dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d60cc72
 
8ce99dc
 
 
 
 
 
 
 
 
 
 
d60cc72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ce99dc
 
 
 
 
 
 
 
 
 
 
 
 
d60cc72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ce99dc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import uvicorn
import numpy as np
import io
from PIL import Image
import base64
import torch
import torch.nn.functional as F
from transformers import ViTImageProcessor, SwinForImageClassification
import lightning as L
import uuid
import cv2

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Label mappings
label2id = {'fake': 0, 'real': 1}
id2label = {0: 'fake', 1: 'real'}

# Load model
hyper_params = {
    "MODEL_CKPT": "microsoft/swin-small-patch4-window7-224",
    "num_labels": 2,
    "id2label": id2label,
    "label2id": label2id,
}

vit_img_processor = ViTImageProcessor.from_pretrained(hyper_params['MODEL_CKPT'])

class DeepFakeModel(L.LightningModule):
    def __init__(self, hyperparams: dict):
        super().__init__()
        self.model = SwinForImageClassification.from_pretrained(
            hyperparams["MODEL_CKPT"],
            num_labels=hyperparams["num_labels"],
            id2label=hyperparams["id2label"],
            label2id=hyperparams["label2id"],
            ignore_mismatched_sizes=True
        )
        self.loss_fn = torch.nn.CrossEntropyLoss()
    
    def forward(self, pixel_values):
        output = self.model(pixel_values=pixel_values)
        return output.logits

# Load trained model
model = DeepFakeModel(hyper_params)
state_dict = torch.load("deepfake_new_trained.pth", map_location=torch.device(device))
model.load_state_dict(state_dict)
model.to(device)
model.eval()
print("Model loaded successfully")

# Initialize FastAPI app
app = FastAPI(title="DeepFake Detector API", description="API for detecting deepfake images", version="1.0.0")

# Configure CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Update with frontend server address in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class ImageData(BaseModel):
    image: str  # Base64 encoded image

class AnalysisResult(BaseModel):
    id: str
    isDeepfake: bool
    confidence: float
    details: str


def preprocess_image(img):
    img = vit_img_processor(img, return_tensors='pt')['pixel_values'].to(device)
    return img



# Load the face detector once
face_net = cv2.dnn.readNetFromCaffe(
    "deploy.prototxt",
    "res10_300x300_ssd_iter_140000.caffemodel"
)

def detect_face_opencv(image: Image.Image) -> bool:
    """Detect face using OpenCV DNN"""
    try:
        # Convert PIL Image to OpenCV format
        open_cv_image = np.array(image)
        open_cv_image = open_cv_image[:, :, ::-1].copy()  # RGB to BGR
        
        (h, w) = open_cv_image.shape[:2]
        blob = cv2.dnn.blobFromImage(open_cv_image, 1.0, (300, 300),
                                     (104.0, 177.0, 123.0))

        face_net.setInput(blob)
        detections = face_net.forward()

        # Check if any detection has confidence > 0.5
        for i in range(detections.shape[2]):
            confidence = detections[0, 0, i, 2]
            if confidence > 0.5:
                return True  # Face detected
        
        return False  # No face detected
    except Exception as e:
        print(f"Face detection error: {e}")
        return False  # Fail safe: assume no face

def predict_deepfake(image):
    try:
        # Step 1: Face Detection
        has_face = detect_face_opencv(image)
        if not has_face:
            return {
                "id": str(uuid.uuid4()),
                "isDeepfake": None,
                "confidence": 0.0,
                "details": "No face detected in the image. Cannot proceed with deepfake analysis."
            }
        
        # Step 2: Deepfake Prediction (your original logic)
        img_tensor = preprocess_image(image)
        with torch.inference_mode():
            logits = model(img_tensor)
            probabilities = F.softmax(logits, dim=-1)
            confidence, predicted_index = torch.max(probabilities, dim=-1)
            predicted_label = id2label[predicted_index.item()]
        
        details = "Deepfake detected." if predicted_label == "fake" else "Image appears to be real."
        return {
            "id": str(uuid.uuid4()),
            "isDeepfake": predicted_label == "fake",
            "confidence": round(confidence.item() * 100, 2),
            "details": details
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")




# @app.post("/api/analyze", response_model=AnalysisResult)
# async def analyze_image(file: UploadFile = File(...)):
#     if not file.content_type.startswith("image/"):
#         raise HTTPException(status_code=400, detail="File must be an image")
#     try:
#         contents = await file.read()
#         image = Image.open(io.BytesIO(contents)).convert("RGB")
#         result = predict_deepfake(image)
#         return JSONResponse(content=result)
#     except Exception as e:
#         raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")


@app.post("/api/analyze-base64", response_model=AnalysisResult)
async def analyze_base64_image(data: ImageData):
    try:
        image_data = data.image.split("base64,")[-1]
        image_bytes = base64.b64decode(image_data)
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        result = predict_deepfake(image)
        return JSONResponse(content=result)
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")


@app.get("/")
async def root():
    return {"message": "DeepFake Detector API is running"}

if __name__ == "__main__":
    # Remove uvicorn.run for Hugging Face Spaces
    uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)