mrfakename commited on
Commit
c0acc92
·
verified ·
1 Parent(s): cf6316c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +330 -110
app.py CHANGED
@@ -1,114 +1,334 @@
1
- import spaces
2
- import gradio as gr
3
- from PIL import Image
4
- from transformers import AutoModelForCausalLM, AutoProcessor
5
- from starvector.data.util import process_and_rasterize_svg
 
 
 
 
 
 
 
 
6
  import torch
7
- import io
8
-
9
- USE_BOTH_MODELS = True # Set this to True to load both models
10
-
11
- # Load models at startup
12
- models = {}
13
- if USE_BOTH_MODELS:
14
- # Load 8b model
15
- model_name_8b = "starvector/starvector-8b-im2svg"
16
- models['8b'] = {
17
- 'model': AutoModelForCausalLM.from_pretrained(model_name_8b, torch_dtype=torch.float16, trust_remote_code=True),
18
- 'processor': None # Will be set below
19
- }
20
- models['8b']['model'].cuda()
21
- models['8b']['model'].eval()
22
- models['8b']['processor'] = models['8b']['model'].model.processor
23
-
24
- # Load 1b model
25
- model_name_1b = "starvector/starvector-1b-im2svg"
26
- models['1b'] = {
27
- 'model': AutoModelForCausalLM.from_pretrained(model_name_1b, torch_dtype=torch.float16, trust_remote_code=True),
28
- 'processor': None
29
- }
30
- models['1b']['model'].cuda()
31
- models['1b']['model'].eval()
32
- models['1b']['processor'] = models['1b']['model'].model.processor
33
- else:
34
- # Load only 8b model
35
- model_name = "starvector/starvector-8b-im2svg"
36
- models['8b'] = {
37
- 'model': AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True),
38
- 'processor': None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  }
40
- models['8b']['model'].cuda()
41
- models['8b']['model'].eval()
42
- models['8b']['processor'] = models['8b']['model'].model.processor
43
-
44
- @spaces.GPU
45
- def convert_to_svg(image, model_choice):
46
- try:
47
- if image is None:
48
- return None, None, "Please upload an image first"
49
-
50
- # Select the model based on user choice
51
- selected_model = models[model_choice]['model']
52
- selected_processor = models[model_choice]['processor']
53
-
54
- # Process the uploaded image
55
- image_pil = Image.open(image)
56
- image_tensor = selected_processor(image_pil, return_tensors="pt")['pixel_values'].cuda()
57
-
58
- if not image_tensor.shape[0] == 1:
59
- image_tensor = image_tensor.squeeze(0)
60
-
61
- batch = {"image": image_tensor}
62
-
63
- # Generate SVG
64
- raw_svg = selected_model.generate_im2svg(batch, max_length=4000)[0]
65
- svg, raster_image = process_and_rasterize_svg(raw_svg)
66
-
67
- # Convert SVG string to bytes for download
68
- svg_bytes = io.BytesIO(svg.encode('utf-8'))
69
-
70
- return raster_image, svg_bytes, f"Conversion successful using {model_choice} model!"
71
- except Exception as e:
72
- return None, None, f"Error: {str(e)}"
73
-
74
- # Create Blocks interface
75
- with gr.Blocks(title="StarVector") as demo:
76
- gr.Markdown("# StarVector")
77
- gr.Markdown("Upload an image to convert it to SVG format using StarVector model")
78
-
79
- with gr.Row():
80
- with gr.Column(scale=1):
81
- # Input section
82
- input_image = gr.Image(type="filepath", label="Upload Image")
83
- if USE_BOTH_MODELS:
84
- model_choice = gr.Radio(
85
- choices=["8b", "1b"],
86
- value="8b",
87
- label="Select Model",
88
- info="Choose between 8b (larger) and 1b (smaller) models"
89
- )
90
- convert_btn = gr.Button("Convert to SVG")
91
- example = gr.Examples(
92
- examples=[["assets/examples/sample-18.png"]],
93
- inputs=input_image
94
- )
95
-
96
- with gr.Column(scale=1):
97
- # Output section
98
- output_preview = gr.Image(type="pil", label="Rasterized SVG Preview")
99
- output_file = gr.File(label="Download SVG")
100
- status = gr.Textbox(label="Status")
101
-
102
- # Connect button click to conversion function
103
- inputs = [input_image]
104
- if USE_BOTH_MODELS:
105
- inputs.append(model_choice)
106
-
107
- convert_btn.click(
108
- fn=convert_to_svg,
109
- inputs=inputs,
110
- outputs=[output_preview, output_file, status]
111
  )
112
 
113
- # Launch the app
114
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ gradio_tts_app.py
4
+
5
+ Run:
6
+ python gradio_tts_app.py
7
+
8
+ Then open the printed local or public URL in your browser.
9
+ """
10
+
11
+ import os
12
+ import random
13
+ import numpy as np
14
  import torch
15
+ import torchaudio
16
+ import whisper
17
+ import gradio as gr
18
+ from argparse import Namespace
19
+
20
+ # ---------------------------------------------------------------------
21
+ # The following imports assume your local project structure:
22
+ # data/tokenizer.py
23
+ # models/voice_star.py
24
+ # inference_tts_utils.py
25
+ # Adjust if needed.
26
+ # ---------------------------------------------------------------------
27
+ from data.tokenizer import AudioTokenizer, TextTokenizer
28
+ from models import voice_star
29
+ from inference_tts_utils import inference_one_sample
30
+
31
+
32
+ ############################################################
33
+ # Utility Functions
34
+ ############################################################
35
+
36
+ def seed_everything(seed=1):
37
+ os.environ['PYTHONHASHSEED'] = str(seed)
38
+ random.seed(seed)
39
+ np.random.seed(seed)
40
+ torch.manual_seed(seed)
41
+ torch.cuda.manual_seed(seed)
42
+ torch.backends.cudnn.benchmark = False
43
+ torch.backends.cudnn.deterministic = True
44
+
45
+
46
+ def estimate_duration(ref_audio_path, text):
47
+ """
48
+ Estimate duration based on seconds per character from the reference audio.
49
+ """
50
+ info = torchaudio.info(ref_audio_path)
51
+ audio_duration = info.num_frames / info.sample_rate
52
+ length_text = max(len(text), 1)
53
+ spc = audio_duration / length_text # seconds per character
54
+ return len(text) * spc
55
+
56
+
57
+ ############################################################
58
+ # Main Inference Function
59
+ ############################################################
60
+
61
+ def run_inference(
62
+ # User-adjustable parameters (no "# do not change" in snippet)
63
+ reference_speech="./demo/5895_34622_000026_000002.wav",
64
+ target_text="VoiceStar is a very interesting model, it's duration controllable and can extrapolate",
65
+ model_name="VoiceStar_840M_40s",
66
+ model_root="./pretrained",
67
+ reference_text=None, # optional
68
+ target_duration=None, # optional
69
+ top_k=10, # can try 10, 20, 30, 40
70
+ temperature=1,
71
+ kvcache=1, # if OOM, set to 0
72
+ repeat_prompt=1, # use higher to improve speaker similarity
73
+ stop_repetition=3, # snippet says "will not use it" but not "do not change"
74
+ seed=1,
75
+ output_dir="./generated_tts",
76
+
77
+ # Non-adjustable parameters (based on snippet instructions)
78
+ codec_audio_sr=16000, # do not change
79
+ codec_sr=50, # do not change
80
+ top_p=1, # do not change
81
+ min_p=1, # do not change
82
+ silence_tokens=None, # do not change it
83
+ multi_trial=None, # do not change it
84
+ sample_batch_size=1, # do not change
85
+ cut_off_sec=100, # do not adjust
86
+ ):
87
+ """
88
+ Inference script for VoiceStar TTS.
89
+ """
90
+ # 1. Set seed
91
+ seed_everything(seed)
92
+
93
+ # 2. Load model checkpoint
94
+ torch.serialization.add_safe_globals([Namespace])
95
+ device = "cuda" if torch.cuda.is_available() else "cpu"
96
+ ckpt_fn = os.path.join(model_root, model_name + ".pth")
97
+ if not os.path.exists(ckpt_fn):
98
+ # use wget to download
99
+ print(f"[Info] Downloading {model_name} checkpoint...")
100
+ os.system(f"wget https://huggingface.co/pyp1/VoiceStar/resolve/main/{model_name}.pth?download=true -O {ckpt_fn}")
101
+ bundle = torch.load(ckpt_fn, map_location=device, weights_only=True)
102
+ args = bundle["args"]
103
+ phn2num = bundle["phn2num"]
104
+
105
+ model = voice_star.VoiceStar(args)
106
+ model.load_state_dict(bundle["model"])
107
+ model.to(device)
108
+ model.eval()
109
+
110
+ # 3. If reference_text not provided, transcribe reference speech with Whisper
111
+ if reference_text is None:
112
+ print("[Info] No reference_text provided. Transcribing reference_speech with Whisper (large-v3-turbo).")
113
+ wh_model = whisper.load_model("large-v3-turbo")
114
+ result = wh_model.transcribe(reference_speech)
115
+ prefix_transcript = result["text"]
116
+ print(f"[Info] Whisper transcribed text: {prefix_transcript}")
117
+ else:
118
+ prefix_transcript = reference_text
119
+
120
+ # 4. If target_duration not provided, estimate from reference speech + target_text
121
+ if target_duration is None:
122
+ target_generation_length = estimate_duration(reference_speech, target_text)
123
+ print(f"[Info] target_duration not provided, estimated as {target_generation_length:.2f}s. Provide --target_duration if needed.")
124
+ else:
125
+ target_generation_length = float(target_duration)
126
+
127
+ # 5. Prepare signature from snippet
128
+ if args.n_codebooks == 4:
129
+ signature = "./pretrained/encodec_6f79c6a8.th"
130
+ elif args.n_codebooks == 8:
131
+ signature = "./pretrained/encodec_8cb1024_giga.th"
132
+ else:
133
+ signature = "./pretrained/encodec_6f79c6a8.th"
134
+
135
+ if silence_tokens is None:
136
+ silence_tokens = []
137
+
138
+ if multi_trial is None:
139
+ multi_trial = []
140
+
141
+ delay_pattern_increment = args.n_codebooks + 1 # from snippet
142
+
143
+ info = torchaudio.info(reference_speech)
144
+ prompt_end_frame = int(cut_off_sec * info.sample_rate)
145
+
146
+ # 6. Tokenizers
147
+ audio_tokenizer = AudioTokenizer(signature=signature)
148
+ text_tokenizer = TextTokenizer(backend="espeak")
149
+
150
+ # 7. decode_config
151
+ decode_config = {
152
+ "top_k": top_k,
153
+ "top_p": top_p,
154
+ "min_p": min_p,
155
+ "temperature": temperature,
156
+ "stop_repetition": stop_repetition,
157
+ "kvcache": kvcache,
158
+ "codec_audio_sr": codec_audio_sr,
159
+ "codec_sr": codec_sr,
160
+ "silence_tokens": silence_tokens,
161
+ "sample_batch_size": sample_batch_size,
162
  }
163
+
164
+ # 8. Run inference
165
+ print("[Info] Running TTS inference...")
166
+ concated_audio, gen_audio = inference_one_sample(
167
+ model, args, phn2num, text_tokenizer, audio_tokenizer,
168
+ reference_speech, target_text,
169
+ device, decode_config,
170
+ prompt_end_frame=prompt_end_frame,
171
+ target_generation_length=target_generation_length,
172
+ delay_pattern_increment=delay_pattern_increment,
173
+ prefix_transcript=prefix_transcript,
174
+ multi_trial=multi_trial,
175
+ repeat_prompt=repeat_prompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  )
177
 
178
+ # The model returns a list of waveforms, pick the first
179
+ concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
180
+
181
+ # 9. Save generated audio
182
+ os.makedirs(output_dir, exist_ok=True)
183
+ out_filename = "generated.wav"
184
+ out_path = os.path.join(output_dir, out_filename)
185
+ torchaudio.save(out_path, gen_audio, codec_audio_sr)
186
+
187
+ print(f"[Success] Generated audio saved to {out_path}")
188
+ return out_path # Return the path for Gradio to load
189
+
190
+
191
+ ############################
192
+ # Transcription function
193
+ ############################
194
+
195
+ def transcribe_audio(reference_speech):
196
+ """
197
+ Transcribe uploaded reference audio with Whisper, return text.
198
+ If no file, return empty string.
199
+ """
200
+ if reference_speech is None:
201
+ return ""
202
+ audio_path = reference_speech # Because type="filepath"
203
+
204
+ if not os.path.exists(audio_path):
205
+ return "File not found."
206
+
207
+ print("[Info] Transcribing with Whisper...")
208
+ model = whisper.load_model("medium") # or "large-v2" etc.
209
+ result = model.transcribe(audio_path)
210
+ return result["text"]
211
+
212
+ ############################
213
+ # Gradio UI
214
+ ############################
215
+
216
+ def main():
217
+ with gr.Blocks() as demo:
218
+ gr.Markdown("## VoiceStar TTS with Editable Reference Text")
219
+
220
+ with gr.Row():
221
+ reference_speech_input = gr.Audio(
222
+ label="Reference Speech",
223
+ type="filepath",
224
+ elem_id="ref_speech"
225
+ )
226
+ transcribe_button = gr.Button("Transcribe")
227
+
228
+ # The transcribed text appears here and can be edited
229
+ reference_text_box = gr.Textbox(
230
+ label="Reference Text (Editable)",
231
+ placeholder="Click 'Transcribe' to auto-fill from reference speech...",
232
+ lines=2
233
+ )
234
+
235
+ target_text_box = gr.Textbox(
236
+ label="Target Text",
237
+ value="VoiceStar is a very interesting model, it's duration controllable and can extrapolate to unseen duration.",
238
+ lines=3
239
+ )
240
+
241
+ model_name_box = gr.Textbox(
242
+ label="Model Name",
243
+ value="VoiceStar_840M_40s"
244
+ )
245
+
246
+ model_root_box = gr.Textbox(
247
+ label="Model Root Directory",
248
+ value="/data1/scratch/pyp/BoostedVoiceEditor/runs"
249
+ )
250
+
251
+ reference_duration_box = gr.Textbox(
252
+ label="Target Duration (Optional)",
253
+ placeholder="Leave empty for auto-estimate."
254
+ )
255
+
256
+ top_k_box = gr.Number(label="top_k", value=10)
257
+ temperature_box = gr.Number(label="temperature", value=1.0)
258
+ kvcache_box = gr.Number(label="kvcache (1 or 0)", value=1)
259
+ repeat_prompt_box = gr.Number(label="repeat_prompt", value=1)
260
+ stop_repetition_box = gr.Number(label="stop_repetition", value=3)
261
+ seed_box = gr.Number(label="Random Seed", value=1)
262
+ output_dir_box = gr.Textbox(label="Output Directory", value="./generated_tts")
263
+
264
+ generate_button = gr.Button("Generate TTS")
265
+ output_audio = gr.Audio(label="Generated Audio", type="filepath")
266
+
267
+ # 1) When user clicks "Transcribe", we call `transcribe_audio`
268
+ transcribe_button.click(
269
+ fn=transcribe_audio,
270
+ inputs=[reference_speech_input],
271
+ outputs=[reference_text_box],
272
+ )
273
+
274
+ # 2) The actual TTS generation function.
275
+ def gradio_inference(
276
+ reference_speech,
277
+ reference_text,
278
+ target_text,
279
+ model_name,
280
+ model_root,
281
+ target_duration,
282
+ top_k,
283
+ temperature,
284
+ kvcache,
285
+ repeat_prompt,
286
+ stop_repetition,
287
+ seed,
288
+ output_dir
289
+ ):
290
+ # Convert any empty strings to None for optional fields
291
+ dur = float(target_duration) if target_duration else None
292
+
293
+ out_path = run_inference(
294
+ reference_speech=reference_speech,
295
+ reference_text=reference_text if reference_text else None,
296
+ target_text=target_text,
297
+ model_name=model_name,
298
+ model_root=model_root,
299
+ target_duration=dur,
300
+ top_k=int(top_k),
301
+ temperature=float(temperature),
302
+ kvcache=int(kvcache),
303
+ repeat_prompt=int(repeat_prompt),
304
+ stop_repetition=int(stop_repetition),
305
+ seed=int(seed),
306
+ output_dir=output_dir
307
+ )
308
+ return out_path
309
+
310
+ # 3) Link the "Generate TTS" button
311
+ generate_button.click(
312
+ fn=gradio_inference,
313
+ inputs=[
314
+ reference_speech_input,
315
+ reference_text_box,
316
+ target_text_box,
317
+ model_name_box,
318
+ model_root_box,
319
+ reference_duration_box,
320
+ top_k_box,
321
+ temperature_box,
322
+ kvcache_box,
323
+ repeat_prompt_box,
324
+ stop_repetition_box,
325
+ seed_box,
326
+ output_dir_box
327
+ ],
328
+ outputs=[output_audio],
329
+ )
330
+
331
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
332
+
333
+ if __name__ == "__main__":
334
+ main()