mrfakename commited on
Commit
81c68d9
·
verified ·
1 Parent(s): d27404a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -227
app.py CHANGED
@@ -1,234 +1,112 @@
1
- import spaces
2
  import gradio as gr
3
- import sys
4
- import threading
5
- import queue
6
- from io import TextIOBase
7
- from inference import inference_patch
8
- import datetime
9
- import subprocess
10
- import os
11
-
12
- # Predefined valid combinations set
13
- with open('prompts.txt', 'r') as f:
14
- prompts = f.readlines()
15
- valid_combinations = set()
16
- for prompt in prompts:
17
- prompt = prompt.strip()
18
- parts = prompt.split('_')
19
- valid_combinations.add((parts[0], parts[1], parts[2]))
20
-
21
- # Generate available options
22
- periods = sorted({p for p, _, _ in valid_combinations})
23
- composers = sorted({c for _, c, _ in valid_combinations})
24
- instruments = sorted({i for _, _, i in valid_combinations})
25
-
26
- # Dynamic component updates
27
- def update_components(period, composer):
28
- if not period:
29
- return [
30
- gr.Dropdown(choices=[], value=None, interactive=False),
31
- gr.Dropdown(choices=[], value=None, interactive=False)
32
- ]
33
-
34
- valid_composers = sorted({c for p, c, _ in valid_combinations if p == period})
35
- valid_instruments = sorted({i for p, c, i in valid_combinations if p == period and c == composer}) if composer else []
36
-
37
- return [
38
- gr.Dropdown(
39
- choices=valid_composers,
40
- value=composer if composer in valid_composers else None,
41
- interactive=True
42
- ),
43
- gr.Dropdown(
44
- choices=valid_instruments,
45
- value=None,
46
- interactive=bool(valid_instruments)
47
- )
48
- ]
49
-
50
-
51
- class RealtimeStream(TextIOBase):
52
- def __init__(self, queue):
53
- self.queue = queue
54
-
55
- def write(self, text):
56
- self.queue.put(text)
57
- return len(text)
58
-
59
-
60
- def save_and_convert(abc_content, period, composer, instrumentation):
61
- if not all([period, composer, instrumentation]):
62
- raise gr.Error("Please complete a valid generation first before saving")
63
-
64
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
65
- prompt_str = f"{period}_{composer}_{instrumentation}"
66
- filename_base = f"{timestamp}_{prompt_str}"
67
-
68
- abc_filename = f"{filename_base}.abc"
69
- with open(abc_filename, "w", encoding="utf-8") as f:
70
- f.write(abc_content)
71
-
72
- xml_filename = f"{filename_base}.xml"
73
  try:
74
- subprocess.run(
75
- ["python", "abc2xml.py", '-o', '.', abc_filename, ],
76
- check=True,
77
- capture_output=True,
78
- text=True
79
- )
80
- except subprocess.CalledProcessError as e:
81
- error_msg = f"Conversion failed: {e.stderr}" if e.stderr else "Unknown error"
82
- raise gr.Error(f"ABC to XML conversion failed: {error_msg}. Please try to generate another composition.")
83
-
84
- return f"Saved successfully: {abc_filename} -> {xml_filename}"
85
-
86
-
87
- @spaces.GPU
88
- def generate_music(period, composer, instrumentation):
89
- if (period, composer, instrumentation) not in valid_combinations:
90
- raise gr.Error("Invalid prompt combination! Please re-select from the period options")
91
-
92
- output_queue = queue.Queue()
93
- original_stdout = sys.stdout
94
- sys.stdout = RealtimeStream(output_queue)
95
-
96
- result_container = []
97
- def run_inference():
98
- try:
99
- result_container.append(inference_patch(period, composer, instrumentation))
100
- finally:
101
- sys.stdout = original_stdout
102
-
103
- thread = threading.Thread(target=run_inference)
104
- thread.start()
105
-
106
- process_output = ""
107
- while thread.is_alive():
108
- try:
109
- text = output_queue.get(timeout=0.1)
110
- process_output += text
111
- yield process_output, None
112
- except queue.Empty:
113
- continue
114
-
115
- while not output_queue.empty():
116
- text = output_queue.get()
117
- process_output += text
118
- yield process_output, None
119
-
120
- final_result = result_container[0] if result_container else ""
121
- yield process_output, final_result
122
-
123
- with gr.Blocks() as demo:
124
- gr.Markdown("## NotaGen")
125
 
126
  with gr.Row():
127
- # 左侧栏
128
- with gr.Column():
129
- period_dd = gr.Dropdown(
130
- choices=periods,
131
- value=None,
132
- label="Period",
133
- interactive=True
134
- )
135
- composer_dd = gr.Dropdown(
136
- choices=[],
137
- value=None,
138
- label="Composer",
139
- interactive=False
140
- )
141
- instrument_dd = gr.Dropdown(
142
- choices=[],
143
- value=None,
144
- label="Instrumentation",
145
- interactive=False
146
- )
147
-
148
- generate_btn = gr.Button("Generate!", variant="primary")
149
-
150
- process_output = gr.Textbox(
151
- label="Generation process",
152
- interactive=False,
153
- lines=15,
154
- max_lines=15,
155
- placeholder="Generation progress will be shown here...",
156
- elem_classes="process-output"
157
  )
158
-
159
- # 右侧栏
160
- with gr.Column():
161
- final_output = gr.Textbox(
162
- label="Post-processed ABC notation scores",
163
- interactive=True,
164
- lines=23,
165
- placeholder="Post-processed ABC scores will be shown here...",
166
- elem_classes="final-output"
167
- )
168
-
169
- with gr.Row():
170
- save_btn = gr.Button("💾 Save as ABC & XML files", variant="secondary")
171
-
172
- save_status = gr.Textbox(
173
- label="Save Status",
174
- interactive=False,
175
- visible=True,
176
- max_lines=2
177
- )
178
-
179
- period_dd.change(
180
- update_components,
181
- inputs=[period_dd, composer_dd],
182
- outputs=[composer_dd, instrument_dd]
183
- )
184
- composer_dd.change(
185
- update_components,
186
- inputs=[period_dd, composer_dd],
187
- outputs=[composer_dd, instrument_dd]
188
  )
189
-
190
- generate_btn.click(
191
- generate_music,
192
- inputs=[period_dd, composer_dd, instrument_dd],
193
- outputs=[process_output, final_output]
194
- )
195
-
196
- save_btn.click(
197
- save_and_convert,
198
- inputs=[final_output, period_dd, composer_dd, instrument_dd],
199
- outputs=[save_status]
200
- )
201
-
202
-
203
- css = """
204
- .process-output {
205
- background-color: #f0f0f0;
206
- font-family: monospace;
207
- padding: 10px;
208
- border-radius: 5px;
209
- }
210
- .final-output {
211
- background-color: #ffffff;
212
- font-family: sans-serif;
213
- padding: 10px;
214
- border-radius: 5px;
215
- }
216
-
217
- .process-output textarea {
218
- max-height: 500px !important;
219
- overflow-y: auto !important;
220
- white-space: pre-wrap;
221
- }
222
-
223
- """
224
- css += """
225
- button#💾-save-convert:hover {
226
- background-color: #ffe6e6;
227
- }
228
- """
229
-
230
- demo.css = css
231
-
232
- if __name__ == "__main__":
233
 
234
- demo.queue().launch()
 
 
 
1
  import gradio as gr
2
+ from PIL import Image
3
+ from transformers import AutoModelForCausalLM, AutoProcessor
4
+ from starvector.data.util import process_and_rasterize_svg
5
+ import torch
6
+ import io
7
+
8
+ USE_BOTH_MODELS = True # Set this to True to load both models
9
+
10
+ # Load models at startup
11
+ models = {}
12
+ if USE_BOTH_MODELS:
13
+ # Load 8b model
14
+ model_name_8b = "starvector/starvector-8b-im2svg"
15
+ models['8b'] = {
16
+ 'model': AutoModelForCausalLM.from_pretrained(model_name_8b, torch_dtype=torch.float16, trust_remote_code=True),
17
+ 'processor': None # Will be set below
18
+ }
19
+ models['8b']['model'].cuda()
20
+ models['8b']['model'].eval()
21
+ models['8b']['processor'] = models['8b']['model'].model.processor
22
+
23
+ # Load 1b model
24
+ model_name_1b = "starvector/starvector-1b-im2svg"
25
+ models['1b'] = {
26
+ 'model': AutoModelForCausalLM.from_pretrained(model_name_1b, torch_dtype=torch.float16, trust_remote_code=True),
27
+ 'processor': None
28
+ }
29
+ models['1b']['model'].cuda()
30
+ models['1b']['model'].eval()
31
+ models['1b']['processor'] = models['1b']['model'].model.processor
32
+ else:
33
+ # Load only 8b model
34
+ model_name = "starvector/starvector-8b-im2svg"
35
+ models['8b'] = {
36
+ 'model': AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True),
37
+ 'processor': None
38
+ }
39
+ models['8b']['model'].cuda()
40
+ models['8b']['model'].eval()
41
+ models['8b']['processor'] = models['8b']['model'].model.processor
42
+
43
+ def convert_to_svg(image, model_choice):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  try:
45
+ if image is None:
46
+ return None, None, "Please upload an image first"
47
+
48
+ # Select the model based on user choice
49
+ selected_model = models[model_choice]['model']
50
+ selected_processor = models[model_choice]['processor']
51
+
52
+ # Process the uploaded image
53
+ image_pil = Image.open(image)
54
+ image_tensor = selected_processor(image_pil, return_tensors="pt")['pixel_values'].cuda()
55
+
56
+ if not image_tensor.shape[0] == 1:
57
+ image_tensor = image_tensor.squeeze(0)
58
+
59
+ batch = {"image": image_tensor}
60
+
61
+ # Generate SVG
62
+ raw_svg = selected_model.generate_im2svg(batch, max_length=4000)[0]
63
+ svg, raster_image = process_and_rasterize_svg(raw_svg)
64
+
65
+ # Convert SVG string to bytes for download
66
+ svg_bytes = io.BytesIO(svg.encode('utf-8'))
67
+
68
+ return raster_image, svg_bytes, f"Conversion successful using {model_choice} model!"
69
+ except Exception as e:
70
+ return None, None, f"Error: {str(e)}"
71
+
72
+ # Create Blocks interface
73
+ with gr.Blocks(title="Image to SVG Converter") as demo:
74
+ gr.Markdown("# Image to SVG Converter")
75
+ gr.Markdown("Upload an image to convert it to SVG format using StarVector model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  with gr.Row():
78
+ with gr.Column(scale=1):
79
+ # Input section
80
+ input_image = gr.Image(type="filepath", label="Upload Image")
81
+ if USE_BOTH_MODELS:
82
+ model_choice = gr.Radio(
83
+ choices=["8b", "1b"],
84
+ value="8b",
85
+ label="Select Model",
86
+ info="Choose between 8b (larger) and 1b (smaller) models"
87
+ )
88
+ convert_btn = gr.Button("Convert to SVG")
89
+ example = gr.Examples(
90
+ examples=[["assets/examples/sample-18.png"]],
91
+ inputs=input_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  )
93
+
94
+ with gr.Column(scale=1):
95
+ # Output section
96
+ output_preview = gr.Image(type="pil", label="Rasterized SVG Preview")
97
+ output_file = gr.File(label="Download SVG")
98
+ status = gr.Textbox(label="Status")
99
+
100
+ # Connect button click to conversion function
101
+ inputs = [input_image]
102
+ if USE_BOTH_MODELS:
103
+ inputs.append(model_choice)
104
+
105
+ convert_btn.click(
106
+ fn=convert_to_svg,
107
+ inputs=inputs,
108
+ outputs=[output_preview, output_file, status]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ # Launch the app
112
+ demo.launch()