File size: 3,337 Bytes
0a911cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
import cv2
from PIL import Image
import torch
import numpy as np

from transformers import AutoImageProcessor, AutoProcessor, AutoModel, CLIPVisionModel
from detection import detect_image, detect_video
from model import LinearClassifier


def load_model(detection_type):

    device = torch.device("cpu")

    processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
    clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14", output_attentions=True)

    model_path = f"pretrained_models/{detection_type}/clip_weights.pth"
    checkpoint = torch.load(model_path, map_location="cpu")
    input_dim = checkpoint["linear.weight"].shape[1]
    
    detection_model = LinearClassifier(input_dim)
    detection_model.load_state_dict(checkpoint)
    detection_model = detection_model.to(device)

    return processor, clip_model, detection_model

def process_image(image, detection_type):
    processor, clip_model, detection_model = load_model(detection_type)
    
    results = detect_image(image, processor, clip_model, detection_model)

    pred_score = results["pred_score"]
    attn_map = results["attn_map"]

    return pred_score, attn_map

def process_video(video, detection_type):
    processor, clip_model, detection_model = load_model(detection_type)

    cap = cv2.VideoCapture(video)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(frame)
        frames.append(pil_image)
    cap.release()

    results = detect_video(frames, processor, clip_model, detection_model)

    pred_score = results["pred_score"]
    attn_map = results["attn_map"]

    return pred_score, attn_map

def change_input(input_type):
    if input_type == "Image":
        return gr.update(visible=True), gr.update(visible=False)
    elif input_type == "Video":
        return gr.update(visible=False), gr.update(visible=True)
    else:
        return None


def process_input(input_type, model_type, image, video):
    detection_type = "facial" if model_type == "Facial" else "general"

    if input_type == "Image" and image is not None:
        return process_image(image, detection_type)
    elif input_type == "Video" and video is not None:
        return process_video(video, detection_type)
    else:
        return None, None


with gr.Blocks() as demo:
  
    gr.Markdown("## Deepfake Detection : Facial / General")
  
    input_type = gr.Radio(["Image", "Video"], label="Choose Input Type", value="Image")

    model_type = gr.Radio(["Facial", "General"], label="Choose Model Type", value="General")
  
    image_input = gr.Image(type="pil", label="Upload Image", visible=True)
    video_input = gr.Video(label="Upload Video", visible=False)
  
    process_button = gr.Button("Run Model")

    pred_score_output = gr.Textbox(label="Prediction Score")
    attn_map_output = gr.Image(type="pil", label="Attention Map")
  
    input_type.change(fn=change_input, inputs=[input_type], outputs=[image_input, video_input])
  
    process_button.click(
        fn=process_input, 
        inputs=[input_type, model_type, image_input, video_input], 
        outputs=[pred_score_output, attn_map_output]
    )

if __name__ == "__main__":
    demo.launch()