Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import AutoModelForCausalLM, AutoProcessor | |
from starvector.data.util import process_and_rasterize_svg | |
# Load model and processor | |
model = AutoModelForCausalLM.from_pretrained( | |
"starvector/starvector-8b-im2svg", | |
trust_remote_code=True, | |
torch_dtype=torch.float16 | |
).cuda() | |
processor = AutoProcessor.from_pretrained("starvector/starvector-8b-im2svg") | |
def generate_svg(input_data, input_type): | |
if input_type == "image": | |
# Process image input | |
image = processor(input_data, return_tensors="pt")['pixel_values'].cuda() | |
raw_svg = model.generate_im2svg({"image": image}, max_length=4000)[0] | |
else: | |
# Process text input | |
raw_svg = model.generate_text2svg(input_data, max_length=4000)[0] | |
svg_code, raster_image = process_and_rasterize_svg(raw_svg) | |
return svg_code, raster_image | |
with gr.Blocks() as demo: | |
gr.Markdown("# π« StarVector SVG Generator") | |
with gr.Tab("Image to SVG"): | |
gr.Markdown("Upload an image to convert to SVG") | |
with gr.Row(): | |
image_input = gr.Image(type="pil", label="Input Image") | |
image_output = gr.Image(label="SVG Preview") | |
svg_code = gr.Code(label="Generated SVG Code") | |
image_button = gr.Button("Convert to SVG") | |
with gr.Tab("Text to SVG"): | |
gr.Markdown("Enter text to generate SVG") | |
with gr.Row(): | |
text_input = gr.Textbox(label="Text Prompt") | |
text_output = gr.Image(label="SVG Preview") | |
text_svg_code = gr.Code(label="Generated SVG Code") | |
text_button = gr.Button("Generate SVG") | |
image_button.click( | |
fn=lambda x: generate_svg(x, "image"), | |
inputs=image_input, | |
outputs=[svg_code, image_output] | |
) | |
text_button.click( | |
fn=lambda x: generate_svg(x, "text"), | |
inputs=text_input, | |
outputs=[text_svg_code, text_output] | |
) | |
demo.launch() | |