File size: 4,022 Bytes
bbd6c3f
3f32750
81c68d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbd6c3f
81c68d9
3f32750
81c68d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbd6c3f
 
81c68d9
3f32750
 
81c68d9
 
 
 
 
 
 
 
 
 
 
 
 
 
3f32750
81c68d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f32750
 
81c68d9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import spaces
import gradio as gr
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from starvector.data.util import process_and_rasterize_svg
import torch
import io

USE_BOTH_MODELS = True  # Set this to True to load both models

# Load models at startup
models = {}
if USE_BOTH_MODELS:
    # Load 8b model
    model_name_8b = "starvector/starvector-8b-im2svg"
    models['8b'] = {
        'model': AutoModelForCausalLM.from_pretrained(model_name_8b, torch_dtype=torch.float16, trust_remote_code=True),
        'processor': None  # Will be set below
    }
    models['8b']['model'].cuda()
    models['8b']['model'].eval()
    models['8b']['processor'] = models['8b']['model'].model.processor

    # Load 1b model
    model_name_1b = "starvector/starvector-1b-im2svg"
    models['1b'] = {
        'model': AutoModelForCausalLM.from_pretrained(model_name_1b, torch_dtype=torch.float16, trust_remote_code=True),
        'processor': None
    }
    models['1b']['model'].cuda()
    models['1b']['model'].eval()
    models['1b']['processor'] = models['1b']['model'].model.processor
else:
    # Load only 8b model
    model_name = "starvector/starvector-8b-im2svg"
    models['8b'] = {
        'model': AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True),
        'processor': None
    }
    models['8b']['model'].cuda()
    models['8b']['model'].eval()
    models['8b']['processor'] = models['8b']['model'].model.processor

@spaces.GPU
def convert_to_svg(image, model_choice):
    try:
        if image is None:
            return None, None, "Please upload an image first"
        
        # Select the model based on user choice
        selected_model = models[model_choice]['model']
        selected_processor = models[model_choice]['processor']
        
        # Process the uploaded image
        image_pil = Image.open(image)
        image_tensor = selected_processor(image_pil, return_tensors="pt")['pixel_values'].cuda()
        
        if not image_tensor.shape[0] == 1:
            image_tensor = image_tensor.squeeze(0)
        
        batch = {"image": image_tensor}
        
        # Generate SVG
        raw_svg = selected_model.generate_im2svg(batch, max_length=4000)[0]
        svg, raster_image = process_and_rasterize_svg(raw_svg)
        
        # Convert SVG string to bytes for download
        svg_bytes = io.BytesIO(svg.encode('utf-8'))
        
        return raster_image, svg_bytes, f"Conversion successful using {model_choice} model!"
    except Exception as e:
        return None, None, f"Error: {str(e)}"

# Create Blocks interface
with gr.Blocks(title="StarVector") as demo:
    gr.Markdown("# StarVector")
    gr.Markdown("Upload an image to convert it to SVG format using StarVector model")
    
    with gr.Row():
        with gr.Column(scale=1):
            # Input section
            input_image = gr.Image(type="filepath", label="Upload Image")
            if USE_BOTH_MODELS:
                model_choice = gr.Radio(
                    choices=["8b", "1b"],
                    value="8b",
                    label="Select Model",
                    info="Choose between 8b (larger) and 1b (smaller) models"
                )
            convert_btn = gr.Button("Convert to SVG")
            example = gr.Examples(
                examples=[["assets/examples/sample-18.png"]],
                inputs=input_image
            )
        
        with gr.Column(scale=1):
            # Output section
            output_preview = gr.Image(type="pil", label="Rasterized SVG Preview")
            output_file = gr.File(label="Download SVG")
            status = gr.Textbox(label="Status")
    
    # Connect button click to conversion function
    inputs = [input_image]
    if USE_BOTH_MODELS:
        inputs.append(model_choice)
        
    convert_btn.click(
        fn=convert_to_svg,
        inputs=inputs,
        outputs=[output_preview, output_file, status]
    )

# Launch the app
demo.launch()