Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from PIL import Image | |
from transformers import AutoModelForCausalLM, AutoProcessor | |
from starvector.data.util import process_and_rasterize_svg | |
import torch | |
import io | |
USE_BOTH_MODELS = True # Set this to True to load both models | |
# Load models at startup | |
models = {} | |
if USE_BOTH_MODELS: | |
# Load 8b model | |
model_name_8b = "starvector/starvector-8b-im2svg" | |
models['8b'] = { | |
'model': AutoModelForCausalLM.from_pretrained(model_name_8b, torch_dtype=torch.float16, trust_remote_code=True), | |
'processor': None # Will be set below | |
} | |
models['8b']['model'].cuda() | |
models['8b']['model'].eval() | |
models['8b']['processor'] = models['8b']['model'].model.processor | |
# Load 1b model | |
model_name_1b = "starvector/starvector-1b-im2svg" | |
models['1b'] = { | |
'model': AutoModelForCausalLM.from_pretrained(model_name_1b, torch_dtype=torch.float16, trust_remote_code=True), | |
'processor': None | |
} | |
models['1b']['model'].cuda() | |
models['1b']['model'].eval() | |
models['1b']['processor'] = models['1b']['model'].model.processor | |
else: | |
# Load only 8b model | |
model_name = "starvector/starvector-8b-im2svg" | |
models['8b'] = { | |
'model': AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True), | |
'processor': None | |
} | |
models['8b']['model'].cuda() | |
models['8b']['model'].eval() | |
models['8b']['processor'] = models['8b']['model'].model.processor | |
def convert_to_svg(image, model_choice): | |
try: | |
if image is None: | |
return None, None, "Please upload an image first" | |
# Select the model based on user choice | |
selected_model = models[model_choice]['model'] | |
selected_processor = models[model_choice]['processor'] | |
# Process the uploaded image | |
image_pil = Image.open(image) | |
image_tensor = selected_processor(image_pil, return_tensors="pt")['pixel_values'].cuda() | |
if not image_tensor.shape[0] == 1: | |
image_tensor = image_tensor.squeeze(0) | |
batch = {"image": image_tensor} | |
# Generate SVG | |
raw_svg = selected_model.generate_im2svg(batch, max_length=4000)[0] | |
svg, raster_image = process_and_rasterize_svg(raw_svg) | |
# Convert SVG string to bytes for download | |
svg_bytes = io.BytesIO(svg.encode('utf-8')) | |
return raster_image, svg_bytes, f"Conversion successful using {model_choice} model!" | |
except Exception as e: | |
return None, None, f"Error: {str(e)}" | |
# Create Blocks interface | |
with gr.Blocks(title="StarVector") as demo: | |
gr.Markdown("# StarVector") | |
gr.Markdown("Upload an image to convert it to SVG format using StarVector model") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Input section | |
input_image = gr.Image(type="filepath", label="Upload Image") | |
if USE_BOTH_MODELS: | |
model_choice = gr.Radio( | |
choices=["8b", "1b"], | |
value="8b", | |
label="Select Model", | |
info="Choose between 8b (larger) and 1b (smaller) models" | |
) | |
convert_btn = gr.Button("Convert to SVG") | |
example = gr.Examples( | |
examples=[["assets/examples/sample-18.png"]], | |
inputs=input_image | |
) | |
with gr.Column(scale=1): | |
# Output section | |
output_preview = gr.Image(type="pil", label="Rasterized SVG Preview") | |
output_file = gr.File(label="Download SVG") | |
status = gr.Textbox(label="Status") | |
# Connect button click to conversion function | |
inputs = [input_image] | |
if USE_BOTH_MODELS: | |
inputs.append(model_choice) | |
convert_btn.click( | |
fn=convert_to_svg, | |
inputs=inputs, | |
outputs=[output_preview, output_file, status] | |
) | |
# Launch the app | |
demo.launch() |