fossbk commited on
Commit
0b6fb66
·
verified ·
1 Parent(s): 14ca073

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -67
app.py CHANGED
@@ -1,68 +1,173 @@
1
- import gradio as gr
2
- from transformers import pipeline
3
- from PIL import Image
4
- import torch
5
- import tempfile
6
  import os
7
- from moviepy.editor import VideoFileClip
8
-
9
- # Kiểm tra thiết bị sử dụng GPU hay CPU
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
-
12
- # Tải hình phân loại ảnh từ Hugging Face (sử dụng mô hình ảnh cho video)
13
- image_classifier = pipeline("image-classification", model="google/vit-base-patch16-224-in21k", device=0 if device == "cuda" else -1)
14
-
15
- # Hàm phân loại ảnh
16
- def classify_image(image, model_name):
17
- if model_name == "ViT":
18
- classifier = image_classifier
19
- else:
20
- classifier = image_classifier # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác
21
-
22
- # Phân loại ảnh
23
- result = classifier(image)
24
- return result[0]['label'], result[0]['score']
25
-
26
- # Hàm phân loại video (trích xuất frame đầu tiên của video)
27
- def classify_video(video, model_name):
28
- # Trích xuất frame đầu tiên của video
29
- video_clip = VideoFileClip(video.name)
30
- frame = video_clip.get_frame(0) # Lấy frame đầu tiên
31
- image = Image.fromarray(frame)
32
-
33
- # Phân loại frame đầu tiên của video
34
- if model_name == "ViT":
35
- classifier = image_classifier
36
- else:
37
- classifier = image_classifier # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác
38
-
39
- result = classifier(image)
40
- return result[0]['label'], result[0]['score']
41
-
42
- # Giao diện Gradio với các tab
43
- with gr.Blocks() as demo:
44
- with gr.Tab("Image Classification"):
45
- gr.Markdown("### Upload an image for classification")
46
- with gr.Row():
47
- model_choice_image = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT")
48
- image_input = gr.Image(type="pil", label="Upload Image")
49
- image_output_label = gr.Textbox(label="Prediction")
50
- image_output_score = gr.Textbox(label="Confidence Score")
51
-
52
- classify_image_button = gr.Button("Classify Image")
53
-
54
- classify_image_button.click(classify_image, inputs=[image_input, model_choice_image], outputs=[image_output_label, image_output_score])
55
-
56
- with gr.Tab("Video Classification"):
57
- gr.Markdown("### Upload a video for classification")
58
- with gr.Row():
59
- model_choice_video = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT")
60
- video_input = gr.Video(label="Upload Video")
61
- video_output_label = gr.Textbox(label="Prediction")
62
- video_output_score = gr.Textbox(label="Confidence Score")
63
-
64
- classify_video_button = gr.Button("Classify Video")
65
-
66
- classify_video_button.click(classify_video, inputs=[video_input, model_choice_video], outputs=[video_output_label, video_output_score])
67
-
68
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms as T
8
+ from PIL import Image
9
+ from decord import VideoReader
10
+ from decord import cpu
11
+ from uniformer_light_video import uniformer_xxs_video
12
+ from uniformer_light_image import uniformer_xxs_image
13
+ from kinetics_class_index import kinetics_classnames
14
+ from imagenet_class_index import imagenet_classnames
15
+ from transforms import (
16
+ GroupNormalize, GroupScale, GroupCenterCrop,
17
+ Stack, ToTorchFormatTensor
18
+ )
19
+
20
+ import gradio as gr
21
+ from huggingface_hub import hf_hub_download
22
+
23
+
24
+ # Device on which to run the model
25
+ # Set to cuda to load on GPU
26
+ device = "cpu"
27
+ model_video_path = hf_hub_download(repo_id="Andy1621/uniformer_light", filename="uniformer_xxs16_160_k400.pth")
28
+ model_image_path = hf_hub_download(repo_id="Andy1621/uniformer_light", filename="uniformer_xxs_160_in1k.pth")
29
+ # Pick a pretrained model
30
+ model_video = uniformer_xxs_video()
31
+ model_video.load_state_dict(torch.load(model_video_path, map_location='cpu'))
32
+ model_image = uniformer_xxs_image()
33
+ model_image.load_state_dict(torch.load(model_image_path, map_location='cpu'))
34
+ # Set to eval mode and move to desired device
35
+ model_video = model_video.to(device).eval()
36
+ model_image = model_image.to(device).eval()
37
+
38
+ # Create an id to label name mapping
39
+ kinetics_id_to_classname = {}
40
+ for k, v in kinetics_classnames.items():
41
+ kinetics_id_to_classname[k] = v
42
+ imagenet_id_to_classname = {}
43
+ for k, v in imagenet_classnames.items():
44
+ imagenet_id_to_classname[k] = v[1]
45
+
46
+
47
+ def get_index(num_frames, num_segments=8):
48
+ seg_size = float(num_frames - 1) / num_segments
49
+ start = int(seg_size / 2)
50
+ offsets = np.array([
51
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
52
+ ])
53
+ return offsets
54
+
55
+
56
+ def load_video(video_path):
57
+ vr = VideoReader(video_path, ctx=cpu(0))
58
+ num_frames = len(vr)
59
+ frame_indices = get_index(num_frames, 16)
60
+
61
+ # transform
62
+ crop_size = 160
63
+ scale_size = 160
64
+ input_mean = [0.485, 0.456, 0.406]
65
+ input_std = [0.229, 0.224, 0.225]
66
+
67
+ transform = T.Compose([
68
+ GroupScale(int(scale_size)),
69
+ GroupCenterCrop(crop_size),
70
+ Stack(),
71
+ ToTorchFormatTensor(),
72
+ GroupNormalize(input_mean, input_std)
73
+ ])
74
+
75
+ images_group = list()
76
+ for frame_index in frame_indices:
77
+ img = Image.fromarray(vr[frame_index].asnumpy())
78
+ images_group.append(img)
79
+ torch_imgs = transform(images_group)
80
+ return torch_imgs
81
+
82
+
83
+ def inference_video(video):
84
+ vid = load_video(video)
85
+
86
+ # The model expects inputs of shape: B x C x H x W
87
+ TC, H, W = vid.shape
88
+ inputs = vid.reshape(1, TC//3, 3, H, W).permute(0, 2, 1, 3, 4)
89
+
90
+ with torch.no_grad():
91
+ prediction = model_video(inputs)
92
+ prediction = F.softmax(prediction, dim=1).flatten()
93
+
94
+ return {kinetics_id_to_classname[str(i)]: float(prediction[i]) for i in range(400)}
95
+
96
+
97
+ def set_example_video(example: list) -> dict:
98
+ return gr.Video.update(value=example[0])
99
+
100
+
101
+ def inference_image(img):
102
+ image = img
103
+ image_transform = T.Compose(
104
+ [
105
+ T.Resize(224),
106
+ T.CenterCrop(224),
107
+ T.ToTensor(),
108
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
109
+ ]
110
+ )
111
+ image = image_transform(image)
112
+
113
+ # The model expects inputs of shape: B x C x H x W
114
+ image = image.unsqueeze(0)
115
+
116
+ with torch.no_grad():
117
+ prediction = model_image(image)
118
+ prediction = F.softmax(prediction, dim=1).flatten()
119
+
120
+ return {imagenet_id_to_classname[str(i)]: float(prediction[i]) for i in range(1000)}
121
+
122
+
123
+ def set_example_image(example: list) -> dict:
124
+ return gr.Image.update(value=example[0])
125
+
126
+
127
+ demo = gr.Blocks()
128
+ with demo:
129
+ gr.Markdown(
130
+ """
131
+ # UniFormer Light
132
+ Gradio demo for <a href='https://github.com/Sense-X/UniFormer' target='_blank'>UniFormer</a>: To use it, simply upload your video, or click one of the examples to load them. Read more at the links below.
133
+ """
134
+ )
135
+
136
+ with gr.Tab("Video"):
137
+ with gr.Box():
138
+ with gr.Row():
139
+ with gr.Column():
140
+ with gr.Row():
141
+ input_video = gr.Video(label='Input Video').style(height=360)
142
+ with gr.Row():
143
+ submit_video_button = gr.Button('Submit')
144
+ with gr.Column():
145
+ label_video = gr.Label(num_top_classes=5)
146
+ with gr.Row():
147
+ example_videos = gr.Dataset(components=[input_video], samples=[['./videos/hitting_baseball.mp4'], ['./videos/hoverboarding.mp4'], ['./videos/yoga.mp4']])
148
+
149
+ with gr.Tab("Image"):
150
+ with gr.Box():
151
+ with gr.Row():
152
+ with gr.Column():
153
+ with gr.Row():
154
+ input_image = gr.Image(label='Input Image', type='pil').style(height=360)
155
+ with gr.Row():
156
+ submit_image_button = gr.Button('Submit')
157
+ with gr.Column():
158
+ label_image = gr.Label(num_top_classes=5)
159
+ with gr.Row():
160
+ example_images = gr.Dataset(components=[input_image], samples=[['./images/cat.png'], ['./images/dog.png'], ['./images/panda.png']])
161
+
162
+ gr.Markdown(
163
+ """
164
+ <p style='text-align: center'><a href='https://arxiv.org/abs/2201.09450' target='_blank'>[TPAMI] UniFormer: Unifying Convolution and Self-attention for Visual Recognition</a> | <a href='https://github.com/Sense-X/UniFormer' target='_blank'>Github Repo</a></p>
165
+ """
166
+ )
167
+
168
+ submit_video_button.click(fn=inference_video, inputs=input_video, outputs=label_video)
169
+ example_videos.click(fn=set_example_video, inputs=example_videos, outputs=example_videos.components)
170
+ submit_image_button.click(fn=inference_image, inputs=input_image, outputs=label_image)
171
+ example_images.click(fn=set_example_image, inputs=example_images, outputs=example_images.components)
172
+
173
+ demo.launch(enable_queue=True)