fossbk commited on
Commit
8dad76f
·
verified ·
1 Parent(s): d3d18f9
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -2,19 +2,18 @@ import gradio as gr
2
  from transformers import pipeline
3
  from PIL import Image
4
  import torch
 
 
 
5
 
6
  # Kiểm tra thiết bị sử dụng GPU hay CPU
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
- # Tải các mô hình phân loại ảnh và video từ Hugging Face
10
  image_classifier = pipeline("image-classification", model="google/vit-base-patch16-224-in21k", device=0 if device == "cuda" else -1)
11
 
12
- # Sử dụng mô hình phân loại video có sẵn trên Hugging Face
13
- video_classifier = pipeline("video-classification", model="google/vit-base-patch16-224-in21k", device=0 if device == "cuda" else -1)
14
-
15
  # Hàm phân loại ảnh
16
  def classify_image(image, model_name):
17
- # Tùy chọn chọn model ảnh khác nếu người dùng yêu cầu
18
  if model_name == "ViT":
19
  classifier = image_classifier
20
  else:
@@ -24,16 +23,20 @@ def classify_image(image, model_name):
24
  result = classifier(image)
25
  return result[0]['label'], result[0]['score']
26
 
27
- # Hàm phân loại video
28
  def classify_video(video, model_name):
29
- # Tùy chọn chọn model video khác nếu người dùng yêu cầu
 
 
 
 
 
30
  if model_name == "ViT":
31
- classifier = video_classifier
32
  else:
33
- classifier = video_classifier # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác
34
 
35
- # Phân loại video trực tiếp mà không cần trích xuất frame
36
- result = classifier(video)
37
  return result[0]['label'], result[0]['score']
38
 
39
  # Giao diện Gradio
 
2
  from transformers import pipeline
3
  from PIL import Image
4
  import torch
5
+ import tempfile
6
+ import os
7
+ from moviepy.editor import VideoFileClip
8
 
9
  # Kiểm tra thiết bị sử dụng GPU hay CPU
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+ # Tải mô hình phân loại ảnh từ Hugging Face (sử dụng mô hình ảnh cho video)
13
  image_classifier = pipeline("image-classification", model="google/vit-base-patch16-224-in21k", device=0 if device == "cuda" else -1)
14
 
 
 
 
15
  # Hàm phân loại ảnh
16
  def classify_image(image, model_name):
 
17
  if model_name == "ViT":
18
  classifier = image_classifier
19
  else:
 
23
  result = classifier(image)
24
  return result[0]['label'], result[0]['score']
25
 
26
+ # Hàm phân loại video (trích xuất frame đầu tiên của video)
27
  def classify_video(video, model_name):
28
+ # Trích xuất frame đầu tiên của video
29
+ video_clip = VideoFileClip(video.name)
30
+ frame = video_clip.get_frame(0) # Lấy frame đầu tiên
31
+ image = Image.fromarray(frame)
32
+
33
+ # Phân loại frame đầu tiên của video
34
  if model_name == "ViT":
35
+ classifier = image_classifier
36
  else:
37
+ classifier = image_classifier # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác
38
 
39
+ result = classifier(image)
 
40
  return result[0]['label'], result[0]['score']
41
 
42
  # Giao diện Gradio