aolko commited on
Commit
1376e14
·
verified ·
1 Parent(s): 35b7dc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -60
app.py CHANGED
@@ -1,74 +1,104 @@
1
  import gradio as gr
2
- from PIL import Image
 
 
3
  import requests
4
- from diffusers import StableDiffusionPipeline
5
-
6
- # Load models using diffusers
7
- general_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
8
- anime_model = StableDiffusionPipeline.from_pretrained("hakurei/waifu-diffusion")
9
-
10
- # Placeholder functions for the actual implementations
11
- def check_anime_image(image):
12
- # Use SauceNAO or similar service to check if the image is anime
13
- # and fetch similar images and tags
14
- return False, [], []
15
-
16
- def describe_image_general(image):
17
- # Use the general model to describe the image
18
- description = general_model(image)
19
- return description
20
-
21
- def describe_image_anime(image):
22
- # Use the anime model to describe the image
23
- description = anime_model(image)
24
- return description
25
 
26
- def merge_tags(tags1, tags2):
27
- # Merge tags, removing duplicates
28
- return list(set(tags1 + tags2))
 
29
 
30
- # Gradio app functions
31
- def process_image(image, mode):
32
- # Convert the image to a format suitable for the models
33
- image = image.resize((256, 256))
 
 
 
34
 
35
- if mode == "Anime":
36
- is_anime, similar_images, original_tags = check_anime_image(image)
37
- if is_anime:
38
- tags = describe_image_anime(image)
39
- return tags, original_tags
40
- else:
41
- return ["Not an anime image"], []
42
  else:
43
- tags = describe_image_general(image)
44
- return tags, []
 
 
 
 
 
 
 
45
 
46
- def describe(image, mode):
47
- tags, original_tags = process_image(image, mode)
48
- return gr.update(value="\n".join(tags)), gr.update(value="\n".join(original_tags))
 
 
 
 
 
49
 
50
- def merge(tags, original_tags):
51
- merged_tags = merge_tags(tags.split("\n"), original_tags.split("\n"))
52
- return "\n".join(merged_tags)
 
 
53
 
54
- # Gradio interface
55
- with gr.Blocks() as demo:
56
- with gr.Row():
57
- image_input = gr.Image(type="pil", tool="editor", label="Upload/Paste Image")
58
- mode = gr.Dropdown(choices=["Anime", "General"], label="Mode")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- describe_button = gr.Button("Describe")
61
- merge_button = gr.Button("Merge Tags")
 
62
 
63
- with gr.TabGroup() as tab_group:
64
- with gr.TabItem("Described Tags"):
65
- described_tags = gr.TextArea(label="Described Tags")
66
- with gr.TabItem("Original Tags"):
67
- original_tags = gr.TextArea(label="Original Tags")
68
 
69
- merged_tags = gr.TextArea(label="Merged Tags")
 
 
70
 
71
- describe_button.click(describe, inputs=[image_input, mode], outputs=[described_tags, original_tags])
72
- merge_button.click(merge, inputs=[described_tags, original_tags], outputs=merged_tags)
 
73
 
74
- demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
4
+ from diffusers import DiffusionPipeline
5
  import requests
6
+ from PIL import Image
7
+ from io import BytesIO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Initialize models
10
+ anime_model = DiffusionPipeline.from_pretrained("SmilingWolf/wd-v1-4-vit-tagger")
11
+ photo_model = AutoModelForZeroShotImageClassification.from_pretrained("facebook/florence-base-in21k-retrieval")
12
+ processor = AutoProcessor.from_pretrained("facebook/florence-base-in21k-retrieval")
13
 
14
+ def get_booru_image(booru, image_id):
15
+ # This is a placeholder function. You'd need to implement the actual API calls for each booru.
16
+ url = f"https://api.{booru}.org/images/{image_id}"
17
+ response = requests.get(url)
18
+ img = Image.open(BytesIO(response.content))
19
+ tags = ["tag1", "tag2", "tag3"] # Placeholder
20
+ return img, tags
21
 
22
+ def transcribe_image(image, image_type, transcriber, booru_tags=None):
23
+ if image_type == "Anime":
24
+ with torch.no_grad():
25
+ tags = anime_model(image)
 
 
 
26
  else:
27
+ inputs = processor(images=image, return_tensors="pt")
28
+ outputs = photo_model(**inputs)
29
+ tags = outputs.logits.topk(50).indices.squeeze().tolist()
30
+ tags = [processor.config.id2label[t] for t in tags]
31
+
32
+ if booru_tags:
33
+ tags = list(set(tags + booru_tags))
34
+
35
+ return ", ".join(tags)
36
 
37
+ def update_image(image_type, booru, image_id, uploaded_image):
38
+ if image_type == "Anime" and booru != "Upload":
39
+ image, booru_tags = get_booru_image(booru, image_id)
40
+ return image, gr.update(visible=True), booru_tags
41
+ elif uploaded_image is not None:
42
+ return uploaded_image, gr.update(visible=True), None
43
+ else:
44
+ return None, gr.update(visible=False), None
45
 
46
+ def on_image_type_change(image_type):
47
+ if image_type == "Anime":
48
+ return gr.update(visible=True), gr.update(visible=True), gr.update(choices=["Anime", "Photo/Other"])
49
+ else:
50
+ return gr.update(visible=False), gr.update(visible=True), gr.update(choices=["Photo/Other", "Anime"])
51
 
52
+ with gr.Blocks() as app:
53
+ gr.Markdown("# Image Transcription App")
54
+
55
+ with gr.Tab("Step 1: Image"):
56
+ image_type = gr.Dropdown(["Anime", "Photo/Other"], label="Image type")
57
+
58
+ with gr.Column(visible=False) as anime_options:
59
+ booru = gr.Dropdown(["Gelbooru", "Danbooru", "Upload"], label="Boorus")
60
+ image_id = gr.Textbox(label="Image ID")
61
+ get_image_btn = gr.Button("Get image")
62
+
63
+ upload_btn = gr.UploadButton("Upload Image", visible=False)
64
+
65
+ image_display = gr.Image(label="Image to transcribe", visible=False)
66
+ booru_tags = gr.State(None)
67
+
68
+ transcribe_btn = gr.Button("Transcribe", visible=False)
69
+ transcribe_with_tags_btn = gr.Button("Transcribe with booru tags", visible=False)
70
+
71
+ with gr.Tab("Step 2: Transcribe"):
72
+ transcriber = gr.Dropdown(["Anime", "Photo/Other"], label="Transcriber")
73
+ transcribe_image_display = gr.Image(label="Image to transcribe")
74
+ transcribe_btn_final = gr.Button("Transcribe")
75
+ tags_output = gr.Textbox(label="Transcribed tags")
76
+
77
+ image_type.change(on_image_type_change, inputs=[image_type],
78
+ outputs=[anime_options, upload_btn, transcriber])
79
+
80
+ get_image_btn.click(update_image,
81
+ inputs=[image_type, booru, image_id, upload_btn],
82
+ outputs=[image_display, transcribe_btn, booru_tags])
83
+
84
+ upload_btn.upload(update_image,
85
+ inputs=[image_type, booru, image_id, upload_btn],
86
+ outputs=[image_display, transcribe_btn, booru_tags])
87
 
88
+ def transcribe_and_update(image, image_type, transcriber, booru_tags):
89
+ tags = transcribe_image(image, image_type, transcriber, booru_tags)
90
+ return image, tags
91
 
92
+ transcribe_btn.click(transcribe_and_update,
93
+ inputs=[image_display, image_type, transcriber, booru_tags],
94
+ outputs=[transcribe_image_display, tags_output])
 
 
95
 
96
+ transcribe_with_tags_btn.click(transcribe_and_update,
97
+ inputs=[image_display, image_type, transcriber, booru_tags],
98
+ outputs=[transcribe_image_display, tags_output])
99
 
100
+ transcribe_btn_final.click(transcribe_image,
101
+ inputs=[transcribe_image_display, image_type, transcriber],
102
+ outputs=[tags_output])
103
 
104
+ app.launch()