Julien Simon commited on
Commit
e0b3a35
·
1 Parent(s): 464f630
Files changed (2) hide show
  1. app.py +14 -2
  2. requirements.txt +6 -3
app.py CHANGED
@@ -1,5 +1,9 @@
1
  import gradio as gr
2
  from transformers import pipeline
 
 
 
 
3
 
4
  model_names = [
5
  "apple/mobilevit-small",
@@ -15,11 +19,15 @@ model_names = [
15
  "shi-labs/dinat-base-in1k-224",
16
  ]
17
 
 
 
18
 
19
  def process(image_file, top_k):
20
  labels = []
21
  for m in model_names:
22
- p = pipeline("image-classification", model=m)
 
 
23
  pred = p(image_file)
24
  labels.append({x["label"]: x["score"] for x in pred[:top_k]})
25
  return labels
@@ -32,7 +40,11 @@ top_k = gr.Slider(minimum=1, maximum=5, step=1, value=5, label="Top k classes")
32
  # Output
33
  labels = [gr.Label(label=m) for m in model_names]
34
 
35
- description = "This Space lets you quickly compare the most popular image classifiers available on the hub, including the recent NAT and DINAT models. All of them have been fine-tuned on the ImageNet-1k dataset. Anecdotally, the three sample images have been generated with a Stable Diffusion model :)"
 
 
 
 
36
 
37
  iface = gr.Interface(
38
  theme="huggingface",
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ import torch
4
+
5
+ # Check if CUDA is available
6
+ device = 0 if torch.cuda.is_available() else -1
7
 
8
  model_names = [
9
  "apple/mobilevit-small",
 
19
  "shi-labs/dinat-base-in1k-224",
20
  ]
21
 
22
+ # Cache for pipelines to avoid reloading models
23
+ pipelines = {}
24
 
25
  def process(image_file, top_k):
26
  labels = []
27
  for m in model_names:
28
+ if m not in pipelines:
29
+ pipelines[m] = pipeline("image-classification", model=m, device=device)
30
+ p = pipelines[m]
31
  pred = p(image_file)
32
  labels.append({x["label"]: x["score"] for x in pred[:top_k]})
33
  return labels
 
40
  # Output
41
  labels = [gr.Label(label=m) for m in model_names]
42
 
43
+ description = (
44
+ "This Space compares popular image classifiers available on the Hugging Face hub, "
45
+ "including NAT and DINAT models. All models have been fine-tuned on ImageNet-1k. "
46
+ "The sample images were generated with Stable Diffusion."
47
+ )
48
 
49
  iface = gr.Interface(
50
  theme="huggingface",
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
- torch
2
- transformers
3
- natten
 
 
 
 
1
+ torch==2.7.0
2
+ transformers==4.51.3
3
+ gradio>=3.50.0
4
+ --no-build-isolation
5
+ natten==0.17.5
6
+ Pillow>=9.0.0