Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -1,74 +1,104 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
|
|
|
|
3 |
import requests
|
4 |
-
from
|
5 |
-
|
6 |
-
# Load models using diffusers
|
7 |
-
general_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
8 |
-
anime_model = StableDiffusionPipeline.from_pretrained("hakurei/waifu-diffusion")
|
9 |
-
|
10 |
-
# Placeholder functions for the actual implementations
|
11 |
-
def check_anime_image(image):
|
12 |
-
# Use SauceNAO or similar service to check if the image is anime
|
13 |
-
# and fetch similar images and tags
|
14 |
-
return False, [], []
|
15 |
-
|
16 |
-
def describe_image_general(image):
|
17 |
-
# Use the general model to describe the image
|
18 |
-
description = general_model(image)
|
19 |
-
return description
|
20 |
-
|
21 |
-
def describe_image_anime(image):
|
22 |
-
# Use the anime model to describe the image
|
23 |
-
description = anime_model(image)
|
24 |
-
return description
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
tags =
|
39 |
-
return tags, original_tags
|
40 |
-
else:
|
41 |
-
return ["Not an anime image"], []
|
42 |
else:
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
def
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
def
|
51 |
-
|
52 |
-
|
|
|
|
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
|
61 |
-
|
|
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
with gr.TabItem("Original Tags"):
|
67 |
-
original_tags = gr.TextArea(label="Original Tags")
|
68 |
|
69 |
-
|
|
|
|
|
70 |
|
71 |
-
|
72 |
-
|
|
|
73 |
|
74 |
-
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
|
4 |
+
from diffusers import DiffusionPipeline
|
5 |
import requests
|
6 |
+
from PIL import Image
|
7 |
+
from io import BytesIO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
# Initialize models
|
10 |
+
anime_model = DiffusionPipeline.from_pretrained("SmilingWolf/wd-v1-4-vit-tagger")
|
11 |
+
photo_model = AutoModelForZeroShotImageClassification.from_pretrained("facebook/florence-base-in21k-retrieval")
|
12 |
+
processor = AutoProcessor.from_pretrained("facebook/florence-base-in21k-retrieval")
|
13 |
|
14 |
+
def get_booru_image(booru, image_id):
|
15 |
+
# This is a placeholder function. You'd need to implement the actual API calls for each booru.
|
16 |
+
url = f"https://api.{booru}.org/images/{image_id}"
|
17 |
+
response = requests.get(url)
|
18 |
+
img = Image.open(BytesIO(response.content))
|
19 |
+
tags = ["tag1", "tag2", "tag3"] # Placeholder
|
20 |
+
return img, tags
|
21 |
|
22 |
+
def transcribe_image(image, image_type, transcriber, booru_tags=None):
|
23 |
+
if image_type == "Anime":
|
24 |
+
with torch.no_grad():
|
25 |
+
tags = anime_model(image)
|
|
|
|
|
|
|
26 |
else:
|
27 |
+
inputs = processor(images=image, return_tensors="pt")
|
28 |
+
outputs = photo_model(**inputs)
|
29 |
+
tags = outputs.logits.topk(50).indices.squeeze().tolist()
|
30 |
+
tags = [processor.config.id2label[t] for t in tags]
|
31 |
+
|
32 |
+
if booru_tags:
|
33 |
+
tags = list(set(tags + booru_tags))
|
34 |
+
|
35 |
+
return ", ".join(tags)
|
36 |
|
37 |
+
def update_image(image_type, booru, image_id, uploaded_image):
|
38 |
+
if image_type == "Anime" and booru != "Upload":
|
39 |
+
image, booru_tags = get_booru_image(booru, image_id)
|
40 |
+
return image, gr.update(visible=True), booru_tags
|
41 |
+
elif uploaded_image is not None:
|
42 |
+
return uploaded_image, gr.update(visible=True), None
|
43 |
+
else:
|
44 |
+
return None, gr.update(visible=False), None
|
45 |
|
46 |
+
def on_image_type_change(image_type):
|
47 |
+
if image_type == "Anime":
|
48 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(choices=["Anime", "Photo/Other"])
|
49 |
+
else:
|
50 |
+
return gr.update(visible=False), gr.update(visible=True), gr.update(choices=["Photo/Other", "Anime"])
|
51 |
|
52 |
+
with gr.Blocks() as app:
|
53 |
+
gr.Markdown("# Image Transcription App")
|
54 |
+
|
55 |
+
with gr.Tab("Step 1: Image"):
|
56 |
+
image_type = gr.Dropdown(["Anime", "Photo/Other"], label="Image type")
|
57 |
+
|
58 |
+
with gr.Column(visible=False) as anime_options:
|
59 |
+
booru = gr.Dropdown(["Gelbooru", "Danbooru", "Upload"], label="Boorus")
|
60 |
+
image_id = gr.Textbox(label="Image ID")
|
61 |
+
get_image_btn = gr.Button("Get image")
|
62 |
+
|
63 |
+
upload_btn = gr.UploadButton("Upload Image", visible=False)
|
64 |
+
|
65 |
+
image_display = gr.Image(label="Image to transcribe", visible=False)
|
66 |
+
booru_tags = gr.State(None)
|
67 |
+
|
68 |
+
transcribe_btn = gr.Button("Transcribe", visible=False)
|
69 |
+
transcribe_with_tags_btn = gr.Button("Transcribe with booru tags", visible=False)
|
70 |
+
|
71 |
+
with gr.Tab("Step 2: Transcribe"):
|
72 |
+
transcriber = gr.Dropdown(["Anime", "Photo/Other"], label="Transcriber")
|
73 |
+
transcribe_image_display = gr.Image(label="Image to transcribe")
|
74 |
+
transcribe_btn_final = gr.Button("Transcribe")
|
75 |
+
tags_output = gr.Textbox(label="Transcribed tags")
|
76 |
+
|
77 |
+
image_type.change(on_image_type_change, inputs=[image_type],
|
78 |
+
outputs=[anime_options, upload_btn, transcriber])
|
79 |
+
|
80 |
+
get_image_btn.click(update_image,
|
81 |
+
inputs=[image_type, booru, image_id, upload_btn],
|
82 |
+
outputs=[image_display, transcribe_btn, booru_tags])
|
83 |
+
|
84 |
+
upload_btn.upload(update_image,
|
85 |
+
inputs=[image_type, booru, image_id, upload_btn],
|
86 |
+
outputs=[image_display, transcribe_btn, booru_tags])
|
87 |
|
88 |
+
def transcribe_and_update(image, image_type, transcriber, booru_tags):
|
89 |
+
tags = transcribe_image(image, image_type, transcriber, booru_tags)
|
90 |
+
return image, tags
|
91 |
|
92 |
+
transcribe_btn.click(transcribe_and_update,
|
93 |
+
inputs=[image_display, image_type, transcriber, booru_tags],
|
94 |
+
outputs=[transcribe_image_display, tags_output])
|
|
|
|
|
95 |
|
96 |
+
transcribe_with_tags_btn.click(transcribe_and_update,
|
97 |
+
inputs=[image_display, image_type, transcriber, booru_tags],
|
98 |
+
outputs=[transcribe_image_display, tags_output])
|
99 |
|
100 |
+
transcribe_btn_final.click(transcribe_image,
|
101 |
+
inputs=[transcribe_image_display, image_type, transcriber],
|
102 |
+
outputs=[tags_output])
|
103 |
|
104 |
+
app.launch()
|