VoiceStar / app.py
mrfakename's picture
Update app.py
bbd6c3f verified
raw
history blame
4.02 kB
import spaces
import gradio as gr
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from starvector.data.util import process_and_rasterize_svg
import torch
import io
USE_BOTH_MODELS = True # Set this to True to load both models
# Load models at startup
models = {}
if USE_BOTH_MODELS:
# Load 8b model
model_name_8b = "starvector/starvector-8b-im2svg"
models['8b'] = {
'model': AutoModelForCausalLM.from_pretrained(model_name_8b, torch_dtype=torch.float16, trust_remote_code=True),
'processor': None # Will be set below
}
models['8b']['model'].cuda()
models['8b']['model'].eval()
models['8b']['processor'] = models['8b']['model'].model.processor
# Load 1b model
model_name_1b = "starvector/starvector-1b-im2svg"
models['1b'] = {
'model': AutoModelForCausalLM.from_pretrained(model_name_1b, torch_dtype=torch.float16, trust_remote_code=True),
'processor': None
}
models['1b']['model'].cuda()
models['1b']['model'].eval()
models['1b']['processor'] = models['1b']['model'].model.processor
else:
# Load only 8b model
model_name = "starvector/starvector-8b-im2svg"
models['8b'] = {
'model': AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True),
'processor': None
}
models['8b']['model'].cuda()
models['8b']['model'].eval()
models['8b']['processor'] = models['8b']['model'].model.processor
@spaces.GPU
def convert_to_svg(image, model_choice):
try:
if image is None:
return None, None, "Please upload an image first"
# Select the model based on user choice
selected_model = models[model_choice]['model']
selected_processor = models[model_choice]['processor']
# Process the uploaded image
image_pil = Image.open(image)
image_tensor = selected_processor(image_pil, return_tensors="pt")['pixel_values'].cuda()
if not image_tensor.shape[0] == 1:
image_tensor = image_tensor.squeeze(0)
batch = {"image": image_tensor}
# Generate SVG
raw_svg = selected_model.generate_im2svg(batch, max_length=4000)[0]
svg, raster_image = process_and_rasterize_svg(raw_svg)
# Convert SVG string to bytes for download
svg_bytes = io.BytesIO(svg.encode('utf-8'))
return raster_image, svg_bytes, f"Conversion successful using {model_choice} model!"
except Exception as e:
return None, None, f"Error: {str(e)}"
# Create Blocks interface
with gr.Blocks(title="StarVector") as demo:
gr.Markdown("# StarVector")
gr.Markdown("Upload an image to convert it to SVG format using StarVector model")
with gr.Row():
with gr.Column(scale=1):
# Input section
input_image = gr.Image(type="filepath", label="Upload Image")
if USE_BOTH_MODELS:
model_choice = gr.Radio(
choices=["8b", "1b"],
value="8b",
label="Select Model",
info="Choose between 8b (larger) and 1b (smaller) models"
)
convert_btn = gr.Button("Convert to SVG")
example = gr.Examples(
examples=[["assets/examples/sample-18.png"]],
inputs=input_image
)
with gr.Column(scale=1):
# Output section
output_preview = gr.Image(type="pil", label="Rasterized SVG Preview")
output_file = gr.File(label="Download SVG")
status = gr.Textbox(label="Status")
# Connect button click to conversion function
inputs = [input_image]
if USE_BOTH_MODELS:
inputs.append(model_choice)
convert_btn.click(
fn=convert_to_svg,
inputs=inputs,
outputs=[output_preview, output_file, status]
)
# Launch the app
demo.launch()