import gradio as gr import torch from PIL import Image from transformers import AutoModelForCausalLM, AutoProcessor from starvector.data.util import process_and_rasterize_svg # Load model and processor model = AutoModelForCausalLM.from_pretrained( "starvector/starvector-8b-im2svg", trust_remote_code=True, torch_dtype=torch.float16 ).cuda() processor = AutoProcessor.from_pretrained("starvector/starvector-8b-im2svg") def generate_svg(input_data, input_type): if input_type == "image": # Process image input image = processor(input_data, return_tensors="pt")['pixel_values'].cuda() raw_svg = model.generate_im2svg({"image": image}, max_length=4000)[0] else: # Process text input raw_svg = model.generate_text2svg(input_data, max_length=4000)[0] svg_code, raster_image = process_and_rasterize_svg(raw_svg) return svg_code, raster_image with gr.Blocks() as demo: gr.Markdown("# 💫 StarVector SVG Generator") with gr.Tab("Image to SVG"): gr.Markdown("Upload an image to convert to SVG") with gr.Row(): image_input = gr.Image(type="pil", label="Input Image") image_output = gr.Image(label="SVG Preview") svg_code = gr.Code(label="Generated SVG Code") image_button = gr.Button("Convert to SVG") with gr.Tab("Text to SVG"): gr.Markdown("Enter text to generate SVG") with gr.Row(): text_input = gr.Textbox(label="Text Prompt") text_output = gr.Image(label="SVG Preview") text_svg_code = gr.Code(label="Generated SVG Code") text_button = gr.Button("Generate SVG") image_button.click( fn=lambda x: generate_svg(x, "image"), inputs=image_input, outputs=[svg_code, image_output] ) text_button.click( fn=lambda x: generate_svg(x, "text"), inputs=text_input, outputs=[text_svg_code, text_output] ) demo.launch()