drhead commited on
Commit
2c3b7a4
·
verified ·
1 Parent(s): 57e083a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -21
app.py CHANGED
@@ -1,18 +1,17 @@
1
- import json
2
-
3
- import gradio as gr
4
  from PIL import Image
5
- import safetensors.torch
6
- import spaces
7
- import timm
8
- from timm.models import VisionTransformer
9
  import torch
10
  from torchvision.transforms import transforms
11
  from torchvision.transforms import InterpolationMode
12
  import torchvision.transforms.functional as TF
 
 
 
 
 
13
  from huggingface_hub import hf_hub_download
14
- import numpy as np
15
- import matplotlib.cm as cm
16
 
17
  class Fit(torch.nn.Module):
18
  def __init__(
@@ -147,12 +146,13 @@ cached_model = hf_hub_download(
147
  safetensors.torch.load_model(model, cached_model)
148
  model.eval()
149
 
150
- with open("tagger_tags.json", "r") as file:
151
- tags = json.load(file) # type: dict
152
- allowed_tags = list(tags.keys())
153
 
154
- for idx, tag in enumerate(allowed_tags):
155
- allowed_tags[idx] = tag.replace("_", " ")
 
 
156
 
157
  @spaces.GPU(duration=5)
158
  def run_classifier(image: Image.Image, threshold):
@@ -161,11 +161,10 @@ def run_classifier(image: Image.Image, threshold):
161
 
162
  with torch.no_grad():
163
  probits = model(tensor)[0] # type: torch.Tensor
164
- values, indices = probits.topk(250)
 
 
165
 
166
- tag_score = dict()
167
- for i in range(indices.size(0)):
168
- tag_score[allowed_tags[indices[i]]] = values[i].item()
169
  sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
170
 
171
  return *create_tags(threshold, sorted_tag_score), img, sorted_tag_score
@@ -178,8 +177,9 @@ def create_tags(threshold, sorted_tag_score: dict):
178
  def clear_image():
179
  return "", {}, None, {}, None
180
 
 
181
  def cam_inference(img, threshold, alpha, evt: gr.SelectData):
182
- target_tag = evt.value
183
  tensor = transform(img).unsqueeze(0)
184
 
185
  gradients = {}
@@ -191,7 +191,6 @@ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
191
  def hook_backward(module, grad_in, grad_out):
192
  gradients['value'] = grad_out[0]
193
 
194
- target_tag_index = allowed_tags.index(target_tag)
195
  handle_forward = model.norm.register_forward_hook(hook_forward)
196
  handle_backward = model.norm.register_full_backward_hook(hook_backward)
197
 
@@ -287,11 +286,11 @@ with gr.Blocks(css=custom_css) as demo:
287
  with gr.Row():
288
  with gr.Column():
289
  image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
290
- threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
291
  cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
292
  alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
293
  with gr.Column():
294
  tag_string = gr.Textbox(label="Tag String")
 
295
  label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
296
 
297
  gr.Markdown("""
 
 
 
 
1
  from PIL import Image
2
+ import numpy as np
3
+ import matplotlib.cm as cm
4
+ import msgspec
 
5
  import torch
6
  from torchvision.transforms import transforms
7
  from torchvision.transforms import InterpolationMode
8
  import torchvision.transforms.functional as TF
9
+ import timm
10
+ from timm.models import VisionTransformer
11
+ import safetensors.torch
12
+ import gradio as gr
13
+ import spaces
14
  from huggingface_hub import hf_hub_download
 
 
15
 
16
  class Fit(torch.nn.Module):
17
  def __init__(
 
146
  safetensors.torch.load_model(model, cached_model)
147
  model.eval()
148
 
149
+ with open("tagger_tags.json", "rb") as file:
150
+ tags = msgspec.json.decode(file.read(), type=dict[str, int])
 
151
 
152
+ for tag in tags.keys():
153
+ tags[tag.replace("_", " ")] = tags.pop(tag)
154
+
155
+ allowed_tags = list(tags.keys())
156
 
157
  @spaces.GPU(duration=5)
158
  def run_classifier(image: Image.Image, threshold):
 
161
 
162
  with torch.no_grad():
163
  probits = model(tensor)[0] # type: torch.Tensor
164
+ values, indices = probits.cpu().topk(250)
165
+
166
+ tag_score = {allowed_tags[idx.item()]: val.item() for idx, val in zip(indices, values)}
167
 
 
 
 
168
  sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
169
 
170
  return *create_tags(threshold, sorted_tag_score), img, sorted_tag_score
 
177
  def clear_image():
178
  return "", {}, None, {}, None
179
 
180
+ @spaces.GPU(duration=5)
181
  def cam_inference(img, threshold, alpha, evt: gr.SelectData):
182
+ target_tag_index = tags[evt.value]
183
  tensor = transform(img).unsqueeze(0)
184
 
185
  gradients = {}
 
191
  def hook_backward(module, grad_in, grad_out):
192
  gradients['value'] = grad_out[0]
193
 
 
194
  handle_forward = model.norm.register_forward_hook(hook_forward)
195
  handle_backward = model.norm.register_full_backward_hook(hook_backward)
196
 
 
286
  with gr.Row():
287
  with gr.Column():
288
  image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
 
289
  cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
290
  alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
291
  with gr.Column():
292
  tag_string = gr.Textbox(label="Tag String")
293
+ threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
294
  label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
295
 
296
  gr.Markdown("""