import spaces import gradio as gr from PIL import Image from transformers import AutoModelForCausalLM, AutoProcessor from starvector.data.util import process_and_rasterize_svg import torch import io USE_BOTH_MODELS = True # Set this to True to load both models # Load models at startup models = {} if USE_BOTH_MODELS: # Load 8b model model_name_8b = "starvector/starvector-8b-im2svg" models['8b'] = { 'model': AutoModelForCausalLM.from_pretrained(model_name_8b, torch_dtype=torch.float16, trust_remote_code=True), 'processor': None # Will be set below } models['8b']['model'].cuda() models['8b']['model'].eval() models['8b']['processor'] = models['8b']['model'].model.processor # Load 1b model model_name_1b = "starvector/starvector-1b-im2svg" models['1b'] = { 'model': AutoModelForCausalLM.from_pretrained(model_name_1b, torch_dtype=torch.float16, trust_remote_code=True), 'processor': None } models['1b']['model'].cuda() models['1b']['model'].eval() models['1b']['processor'] = models['1b']['model'].model.processor else: # Load only 8b model model_name = "starvector/starvector-8b-im2svg" models['8b'] = { 'model': AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True), 'processor': None } models['8b']['model'].cuda() models['8b']['model'].eval() models['8b']['processor'] = models['8b']['model'].model.processor @spaces.GPU def convert_to_svg(image, model_choice): try: if image is None: return None, None, "Please upload an image first" # Select the model based on user choice selected_model = models[model_choice]['model'] selected_processor = models[model_choice]['processor'] # Process the uploaded image image_pil = Image.open(image) image_tensor = selected_processor(image_pil, return_tensors="pt")['pixel_values'].cuda() if not image_tensor.shape[0] == 1: image_tensor = image_tensor.squeeze(0) batch = {"image": image_tensor} # Generate SVG raw_svg = selected_model.generate_im2svg(batch, max_length=4000)[0] svg, raster_image = process_and_rasterize_svg(raw_svg) # Convert SVG string to bytes for download svg_bytes = io.BytesIO(svg.encode('utf-8')) return raster_image, svg_bytes, f"Conversion successful using {model_choice} model!" except Exception as e: return None, None, f"Error: {str(e)}" # Create Blocks interface with gr.Blocks(title="StarVector") as demo: gr.Markdown("# StarVector") gr.Markdown("Upload an image to convert it to SVG format using StarVector model") with gr.Row(): with gr.Column(scale=1): # Input section input_image = gr.Image(type="filepath", label="Upload Image") if USE_BOTH_MODELS: model_choice = gr.Radio( choices=["8b", "1b"], value="8b", label="Select Model", info="Choose between 8b (larger) and 1b (smaller) models" ) convert_btn = gr.Button("Convert to SVG") example = gr.Examples( examples=[["assets/examples/sample-18.png"]], inputs=input_image ) with gr.Column(scale=1): # Output section output_preview = gr.Image(type="pil", label="Rasterized SVG Preview") output_file = gr.File(label="Download SVG") status = gr.Textbox(label="Status") # Connect button click to conversion function inputs = [input_image] if USE_BOTH_MODELS: inputs.append(model_choice) convert_btn.click( fn=convert_to_svg, inputs=inputs, outputs=[output_preview, output_file, status] ) # Launch the app demo.launch()