Update app.py
Browse files
app.py
CHANGED
@@ -1,18 +1,17 @@
|
|
1 |
-
import json
|
2 |
-
|
3 |
-
import gradio as gr
|
4 |
from PIL import Image
|
5 |
-
import
|
6 |
-
import
|
7 |
-
import
|
8 |
-
from timm.models import VisionTransformer
|
9 |
import torch
|
10 |
from torchvision.transforms import transforms
|
11 |
from torchvision.transforms import InterpolationMode
|
12 |
import torchvision.transforms.functional as TF
|
|
|
|
|
|
|
|
|
|
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
-
import numpy as np
|
15 |
-
import matplotlib.cm as cm
|
16 |
|
17 |
class Fit(torch.nn.Module):
|
18 |
def __init__(
|
@@ -147,12 +146,13 @@ cached_model = hf_hub_download(
|
|
147 |
safetensors.torch.load_model(model, cached_model)
|
148 |
model.eval()
|
149 |
|
150 |
-
with open("tagger_tags.json", "
|
151 |
-
tags = json.
|
152 |
-
allowed_tags = list(tags.keys())
|
153 |
|
154 |
-
for
|
155 |
-
|
|
|
|
|
156 |
|
157 |
@spaces.GPU(duration=5)
|
158 |
def run_classifier(image: Image.Image, threshold):
|
@@ -161,11 +161,10 @@ def run_classifier(image: Image.Image, threshold):
|
|
161 |
|
162 |
with torch.no_grad():
|
163 |
probits = model(tensor)[0] # type: torch.Tensor
|
164 |
-
values, indices = probits.topk(250)
|
|
|
|
|
165 |
|
166 |
-
tag_score = dict()
|
167 |
-
for i in range(indices.size(0)):
|
168 |
-
tag_score[allowed_tags[indices[i]]] = values[i].item()
|
169 |
sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
|
170 |
|
171 |
return *create_tags(threshold, sorted_tag_score), img, sorted_tag_score
|
@@ -178,8 +177,9 @@ def create_tags(threshold, sorted_tag_score: dict):
|
|
178 |
def clear_image():
|
179 |
return "", {}, None, {}, None
|
180 |
|
|
|
181 |
def cam_inference(img, threshold, alpha, evt: gr.SelectData):
|
182 |
-
|
183 |
tensor = transform(img).unsqueeze(0)
|
184 |
|
185 |
gradients = {}
|
@@ -191,7 +191,6 @@ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
|
|
191 |
def hook_backward(module, grad_in, grad_out):
|
192 |
gradients['value'] = grad_out[0]
|
193 |
|
194 |
-
target_tag_index = allowed_tags.index(target_tag)
|
195 |
handle_forward = model.norm.register_forward_hook(hook_forward)
|
196 |
handle_backward = model.norm.register_full_backward_hook(hook_backward)
|
197 |
|
@@ -287,11 +286,11 @@ with gr.Blocks(css=custom_css) as demo:
|
|
287 |
with gr.Row():
|
288 |
with gr.Column():
|
289 |
image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
|
290 |
-
threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
|
291 |
cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
|
292 |
alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
|
293 |
with gr.Column():
|
294 |
tag_string = gr.Textbox(label="Tag String")
|
|
|
295 |
label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
|
296 |
|
297 |
gr.Markdown("""
|
|
|
|
|
|
|
|
|
1 |
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.cm as cm
|
4 |
+
import msgspec
|
|
|
5 |
import torch
|
6 |
from torchvision.transforms import transforms
|
7 |
from torchvision.transforms import InterpolationMode
|
8 |
import torchvision.transforms.functional as TF
|
9 |
+
import timm
|
10 |
+
from timm.models import VisionTransformer
|
11 |
+
import safetensors.torch
|
12 |
+
import gradio as gr
|
13 |
+
import spaces
|
14 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
15 |
|
16 |
class Fit(torch.nn.Module):
|
17 |
def __init__(
|
|
|
146 |
safetensors.torch.load_model(model, cached_model)
|
147 |
model.eval()
|
148 |
|
149 |
+
with open("tagger_tags.json", "rb") as file:
|
150 |
+
tags = msgspec.json.decode(file.read(), type=dict[str, int])
|
|
|
151 |
|
152 |
+
for tag in tags.keys():
|
153 |
+
tags[tag.replace("_", " ")] = tags.pop(tag)
|
154 |
+
|
155 |
+
allowed_tags = list(tags.keys())
|
156 |
|
157 |
@spaces.GPU(duration=5)
|
158 |
def run_classifier(image: Image.Image, threshold):
|
|
|
161 |
|
162 |
with torch.no_grad():
|
163 |
probits = model(tensor)[0] # type: torch.Tensor
|
164 |
+
values, indices = probits.cpu().topk(250)
|
165 |
+
|
166 |
+
tag_score = {allowed_tags[idx.item()]: val.item() for idx, val in zip(indices, values)}
|
167 |
|
|
|
|
|
|
|
168 |
sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
|
169 |
|
170 |
return *create_tags(threshold, sorted_tag_score), img, sorted_tag_score
|
|
|
177 |
def clear_image():
|
178 |
return "", {}, None, {}, None
|
179 |
|
180 |
+
@spaces.GPU(duration=5)
|
181 |
def cam_inference(img, threshold, alpha, evt: gr.SelectData):
|
182 |
+
target_tag_index = tags[evt.value]
|
183 |
tensor = transform(img).unsqueeze(0)
|
184 |
|
185 |
gradients = {}
|
|
|
191 |
def hook_backward(module, grad_in, grad_out):
|
192 |
gradients['value'] = grad_out[0]
|
193 |
|
|
|
194 |
handle_forward = model.norm.register_forward_hook(hook_forward)
|
195 |
handle_backward = model.norm.register_full_backward_hook(hook_backward)
|
196 |
|
|
|
286 |
with gr.Row():
|
287 |
with gr.Column():
|
288 |
image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
|
|
|
289 |
cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
|
290 |
alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
|
291 |
with gr.Column():
|
292 |
tag_string = gr.Textbox(label="Tag String")
|
293 |
+
threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
|
294 |
label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
|
295 |
|
296 |
gr.Markdown("""
|