Spaces:
Runtime error
Runtime error
File size: 3,145 Bytes
c5d1577 388f879 c5d1577 f561c25 c5d1577 ebf9292 c5d1577 f561c25 ebf9292 f561c25 c5d1577 388f879 c5d1577 d02110f c5d1577 f561c25 388f879 c5d1577 d02110f c5d1577 f561c25 388f879 c5d1577 388f879 c5d1577 f561c25 c5d1577 388f879 c5d1577 f561c25 c5d1577 bfaed2e c5d1577 388f879 c5d1577 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import gradio as gr
from setup import setup
import torch
import gc
from PIL import Image
from transformers import AutoModel, AutoImageProcessor
from anime2sketch.model import Anime2Sketch
import spaces
setup()
print("Setup finished")
MLE_MODEL_REPO = "p1atdev/MangaLineExtraction-hf"
class MangaLineExtractor:
model = AutoModel.from_pretrained(MLE_MODEL_REPO, trust_remote_code=True)
processor = AutoImageProcessor.from_pretrained(MLE_MODEL_REPO, trust_remote_code=True)
@spaces.GPU
@torch.no_grad()
def __call__(self, image: Image.Image) -> Image.Image:
inputs = self.processor(image, return_tensors="pt")
outputs = self.model(inputs.pixel_values)
line_image = Image.fromarray(outputs.pixel_values[0].numpy().astype("uint8"), mode="L")
return line_image
mle_model = MangaLineExtractor()
a2s_model = Anime2Sketch("./models/netG.pth", "cpu")
def flush():
gc.collect()
torch.cuda.empty_cache()
@torch.no_grad()
def extract(image):
result = mle_model(image)
return result
@torch.no_grad()
def convert_to_sketch(image):
result = a2s_model.predict(image)
return result
def start(image):
return [extract(image), convert_to_sketch(Image.fromarray(image).convert("RGB"))]
def clear():
return [None, None]
def ui():
with gr.Blocks() as blocks:
gr.Markdown(
"""
# Anime to Sketch
Unofficial demo for converting illustrations into sketches.
Original repos:
- [MangaLineExtraction_PyTorch](https://github.com/ljsabc/MangaLineExtraction_PyTorch)
- [Anime2Sketch](https://github.com/Mukosame/Anime2Sketch)
Using with 🤗 transformers:
- [MangaLineExtraction-hf](https://huggingface.co/p1atdev/MangaLineExtraction-hf)
"""
)
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input", interactive=True)
extract_btn = gr.Button("Start", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Column():
# with gr.Row():
extract_output_img = gr.Image(
label="MangaLineExtraction", interactive=False
)
to_sketch_output_img = gr.Image(label="Anime2Sketch", interactive=False)
gr.Examples(
fn=start,
examples=[
["./examples/0.jpg"],
["./examples/1.jpg"],
["./examples/2.jpg"],
],
inputs=[input_img],
outputs=[extract_output_img, to_sketch_output_img],
label="Examples",
# cache_examples=True,
)
gr.Markdown("Images are from nijijourney.")
extract_btn.click(
fn=start,
inputs=[input_img],
outputs=[extract_output_img, to_sketch_output_img],
)
clear_btn.click(
fn=clear,
inputs=[],
outputs=[extract_output_img, to_sketch_output_img],
)
return blocks
if __name__ == "__main__":
ui().launch()
|