ElectricAlexis commited on
Commit
8e170be
·
verified ·
1 Parent(s): 94c7b25

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +57 -69
  2. inference.py +18 -20
app.py CHANGED
@@ -15,7 +15,7 @@ from inference import inference_patch
15
  from convert import abc2xml, xml2, pdf2img
16
 
17
 
18
- # 读取 prompt 组合
19
  with open('prompts.txt', 'r') as f:
20
  prompts = f.readlines()
21
 
@@ -25,12 +25,12 @@ for prompt in prompts:
25
  parts = prompt.split('_')
26
  valid_combinations.add((parts[0], parts[1], parts[2]))
27
 
28
- # 准备下拉框选项
29
  periods = sorted({p for p, _, _ in valid_combinations})
30
  composers = sorted({c for _, c, _ in valid_combinations})
31
  instruments = sorted({i for _, _, i in valid_combinations})
32
 
33
- # 动态更新作曲家、乐器下拉选项
34
  def update_components(period, composer):
35
  if not period:
36
  return [
@@ -54,7 +54,7 @@ def update_components(period, composer):
54
  )
55
  ]
56
 
57
- # 自定义实时流,用于把模型推理过程输出到前端
58
  class RealtimeStream(TextIOBase):
59
  def __init__(self, queue):
60
  self.queue = queue
@@ -81,7 +81,7 @@ def convert_files(abc_content, period, composer, instrumentation):
81
  with open(filename_base_postinst + ".abc", "w", encoding="utf-8") as f:
82
  f.write(postprocessed_inst_abc)
83
 
84
- # 转换文件
85
  file_paths = {'abc': abc_filename}
86
  try:
87
  # abc2xml
@@ -115,15 +115,15 @@ def convert_files(abc_content, period, composer, instrumentation):
115
  })
116
 
117
  except Exception as e:
118
- raise gr.Error(f"文件处理失败: {str(e)}")
119
 
120
  return file_paths
121
 
122
 
123
- # 翻页控制函数
124
  def update_page(direction, data):
125
  """
126
- data 里面包含了 'pages','current_page','base' 三个关键信息
127
  """
128
  if not data:
129
  return None, gr.update(interactive=False), gr.update(interactive=False), data
@@ -134,9 +134,9 @@ def update_page(direction, data):
134
  data['current_page'] += 1
135
 
136
  current_page_index = data['current_page']
137
- # 更新图片路径
138
  new_image = f"{data['base']}_page_{current_page_index+1}.png"
139
- # current_page==0 时,prev_btn 不可用;当 current_page==pages-1 时,next_btn 不可用
140
  prev_btn_state = gr.update(interactive=(current_page_index > 0))
141
  next_btn_state = gr.update(interactive=(current_page_index < data['pages'] - 1))
142
 
@@ -146,13 +146,13 @@ def update_page(direction, data):
146
  @spaces.GPU(duration=600)
147
  def generate_music(period, composer, instrumentation):
148
  """
149
- 需要保证每次 yield 的返回值数量一致。
150
- 我们这里准备返回 5 个值,对应:
151
- 1) process_output (中间推理信息)
152
- 2) final_output (最终 ABC)
153
- 3) pdf_image (PDF 第一页对应的 png 路径)
154
- 4) audio_player (mp3 路径)
155
- 5) pdf_state (翻页用的 state)
156
  """
157
  # Set a different random seed each time based on current timestamp
158
  random_seed = int(time.time()) % 10000
@@ -175,7 +175,7 @@ def generate_music(period, composer, instrumentation):
175
  pass
176
 
177
  if (period, composer, instrumentation) not in valid_combinations:
178
- # 如果组合非法,直接抛出错误
179
  raise gr.Error("Invalid prompt combination! Please re-select from the period options")
180
 
181
  output_queue = queue.Queue()
@@ -186,7 +186,7 @@ def generate_music(period, composer, instrumentation):
186
 
187
  def run_inference():
188
  try:
189
- # 使用下载的模型权重路径进行推理
190
  result = inference_patch(period, composer, instrumentation)
191
  result_container.append(result)
192
  finally:
@@ -201,40 +201,40 @@ def generate_music(period, composer, instrumentation):
201
  audio_file = None
202
  pdf_state = None
203
 
204
- # 先持续读中间输出
205
  while thread.is_alive():
206
  try:
207
  text = output_queue.get(timeout=0.1)
208
  process_output += text
209
- # 暂时没有最终 ABC,还没有转文件
210
  yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False)
211
  except queue.Empty:
212
  continue
213
 
214
- # 线程结束后,把剩余的队列都拿出来
215
  while not output_queue.empty():
216
  text = output_queue.get()
217
  process_output += text
218
 
219
- # 最终推理结果
220
  final_result = result_container[0] if result_container else ""
221
 
222
- # 显示转换文件的提示
223
  final_output_abc = "Converting files..."
224
  yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False)
225
 
226
 
227
- # 做文件转换
228
  try:
229
  file_paths = convert_files(final_result, period, composer, instrumentation)
230
  final_output_abc = final_result
231
- # 拿到第一张图片和 mp3 文件
232
  if file_paths['pages'] > 0:
233
  pdf_image = f"{file_paths['base']}_page_1.png"
234
  audio_file = file_paths['mp3']
235
- pdf_state = file_paths # 直接把转换后的信息字典拿来存到 state
236
 
237
- # 准备下载文件列表
238
  download_list = []
239
  if 'abc' in file_paths and os.path.exists(file_paths['abc']):
240
  download_list.append(file_paths['abc'])
@@ -247,60 +247,60 @@ def generate_music(period, composer, instrumentation):
247
  if 'mp3' in file_paths and os.path.exists(file_paths['mp3']):
248
  download_list.append(file_paths['mp3'])
249
  except Exception as e:
250
- # 如果失败了,把错误信息返回到输出框
251
  yield process_output, f"Error converting files: {str(e)}", None, None, None, gr.update(value=None, visible=False)
252
  return
253
 
254
- # 最后一次 yield,带上所有信息 - 修改此处让组件可见
255
  yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=download_list, visible=True)
256
 
257
 
258
  def get_file(file_type, period, composer, instrumentation):
259
  """
260
- 返回本地的指定类型文件,用于 Gradio 下载
261
  """
262
- # 这里其实需要你根据先前保存下来的具体文件路径来返回,演示时可以简化
263
- # 如果是按 timestamp 去匹配,可以把转换的文件都存在某个目录下再拿最新的
264
- # 这里仅做示例:
265
  possible_files = [f for f in os.listdir('.') if f.endswith(f'.{file_type}')]
266
  if not possible_files:
267
  return None
268
- # 简单返回最新的
269
  possible_files.sort(key=os.path.getmtime)
270
  return possible_files[-1]
271
 
272
 
273
  css = """
274
- /* 紧凑按钮样式 */
275
  button[size="sm"] {
276
  padding: 4px 8px !important;
277
  margin: 2px !important;
278
  min-width: 60px;
279
  }
280
 
281
- /* PDF预览区 */
282
  #pdf-preview {
283
- border-radius: 8px; /* 圆角 */
284
- box-shadow: 0 2px 8px rgba(0,0,0,0.1); /* 阴影 */
285
  }
286
 
287
  .page-btn {
288
- padding: 12px !important; /* 增大点击区域 */
289
- margin: auto !important; /* 垂直居中 */
290
  }
291
 
292
- /* 按钮悬停效果 */
293
  .page-btn:hover {
294
  background: #f0f0f0 !important;
295
  transform: scale(1.05);
296
  }
297
 
298
- /* 布局调整 */
299
  .gr-row {
300
- gap: 10px !important; /* 元素间距 */
301
  }
302
 
303
- /* 音频播放器 */
304
  .audio-panel {
305
  margin-top: 15px !important;
306
  max-width: 400px;
@@ -310,23 +310,13 @@ button[size="sm"] {
310
  height: 200px !important;
311
  }
312
 
313
- /* 保存功能区 */
314
  .save-as-row {
315
  margin-top: 15px;
316
  padding: 10px;
317
  border-top: 1px solid #eee;
318
  }
319
 
320
- .save-as-label {
321
- font-weight: bold;
322
- margin-right: 10px;
323
- align-self: center;
324
- }
325
-
326
- .save-buttons {
327
- gap: 5px; /* 按钮间距 */
328
- }
329
-
330
  /* Download files styling */
331
  .download-files {
332
  margin-top: 15px;
@@ -339,12 +329,12 @@ button[size="sm"] {
339
  with gr.Blocks(css=css) as demo:
340
  gr.Markdown("## NotaGen")
341
 
342
- # 用于保存 PDF 页数、当前页等信息
343
  pdf_state = gr.State()
344
 
345
  with gr.Column():
346
  with gr.Row():
347
- # 左侧栏
348
  with gr.Column():
349
  with gr.Row():
350
  period_dd = gr.Dropdown(
@@ -384,18 +374,16 @@ with gr.Blocks(css=css) as demo:
384
  placeholder="Post-processed ABC scores will be shown here..."
385
  )
386
 
387
- # 音频播放
388
  audio_player = gr.Audio(
389
  label="Audio Preview",
390
  format="mp3",
391
  interactive=False,
392
- # container=False,
393
- # elem_id="audio-preview"
394
  )
395
 
396
- # 右侧栏
397
  with gr.Column():
398
- # 图片容器
399
  pdf_image = gr.Image(
400
  label="Sheet Music Preview",
401
  show_label=False,
@@ -406,7 +394,7 @@ with gr.Blocks(css=css) as demo:
406
  show_download_button=False
407
  )
408
 
409
- # 翻页按钮
410
  with gr.Row():
411
  prev_btn = gr.Button(
412
  "⬅️ Last Page",
@@ -430,7 +418,7 @@ with gr.Blocks(css=css) as demo:
430
  type="filepath" # Make sure this is set to filepath
431
  )
432
 
433
- # 下拉框联动
434
  period_dd.change(
435
  update_components,
436
  inputs=[period_dd, composer_dd],
@@ -442,26 +430,26 @@ with gr.Blocks(css=css) as demo:
442
  outputs=[composer_dd, instrument_dd]
443
  )
444
 
445
- # 点击生成按钮,注意 outputs 要和 generate_music 里每次 yield 保持一致
446
  generate_btn.click(
447
  generate_music,
448
  inputs=[period_dd, composer_dd, instrument_dd],
449
  outputs=[process_output, final_output, pdf_image, audio_player, pdf_state, download_files]
450
  )
451
 
452
- # 翻页
453
  prev_signal = gr.Textbox(value="prev", visible=False)
454
  next_signal = gr.Textbox(value="next", visible=False)
455
 
456
  prev_btn.click(
457
  update_page,
458
- inputs=[prev_signal, pdf_state], # ✅ 使用组件
459
  outputs=[pdf_image, prev_btn, next_btn, pdf_state]
460
  )
461
 
462
  next_btn.click(
463
  update_page,
464
- inputs=[next_signal, pdf_state], # ✅ 使用组件
465
  outputs=[pdf_image, prev_btn, next_btn, pdf_state]
466
  )
467
 
 
15
  from convert import abc2xml, xml2, pdf2img
16
 
17
 
18
+ # Read prompt combinations
19
  with open('prompts.txt', 'r') as f:
20
  prompts = f.readlines()
21
 
 
25
  parts = prompt.split('_')
26
  valid_combinations.add((parts[0], parts[1], parts[2]))
27
 
28
+ # Prepare dropdown options
29
  periods = sorted({p for p, _, _ in valid_combinations})
30
  composers = sorted({c for _, c, _ in valid_combinations})
31
  instruments = sorted({i for _, _, i in valid_combinations})
32
 
33
+ # Dynamically update composer and instrument dropdown options
34
  def update_components(period, composer):
35
  if not period:
36
  return [
 
54
  )
55
  ]
56
 
57
+ # Custom realtime stream for outputting model inference process to frontend
58
  class RealtimeStream(TextIOBase):
59
  def __init__(self, queue):
60
  self.queue = queue
 
81
  with open(filename_base_postinst + ".abc", "w", encoding="utf-8") as f:
82
  f.write(postprocessed_inst_abc)
83
 
84
+ # Convert files
85
  file_paths = {'abc': abc_filename}
86
  try:
87
  # abc2xml
 
115
  })
116
 
117
  except Exception as e:
118
+ raise gr.Error(f"File processing failed: {str(e)}")
119
 
120
  return file_paths
121
 
122
 
123
+ # Page navigation control function
124
  def update_page(direction, data):
125
  """
126
+ data contains three key pieces of information: 'pages', 'current_page', and 'base'
127
  """
128
  if not data:
129
  return None, gr.update(interactive=False), gr.update(interactive=False), data
 
134
  data['current_page'] += 1
135
 
136
  current_page_index = data['current_page']
137
+ # Update image path
138
  new_image = f"{data['base']}_page_{current_page_index+1}.png"
139
+ # When current_page==0, prev_btn is disabled; when current_page==pages-1, next_btn is disabled
140
  prev_btn_state = gr.update(interactive=(current_page_index > 0))
141
  next_btn_state = gr.update(interactive=(current_page_index < data['pages'] - 1))
142
 
 
146
  @spaces.GPU(duration=600)
147
  def generate_music(period, composer, instrumentation):
148
  """
149
+ Must ensure each yield returns the same number of values.
150
+ We're preparing to return 5 values, corresponding to:
151
+ 1) process_output (intermediate inference information)
152
+ 2) final_output (final ABC)
153
+ 3) pdf_image (path to the PNG of the first page of the PDF)
154
+ 4) audio_player (mp3 path)
155
+ 5) pdf_state (state for page navigation)
156
  """
157
  # Set a different random seed each time based on current timestamp
158
  random_seed = int(time.time()) % 10000
 
175
  pass
176
 
177
  if (period, composer, instrumentation) not in valid_combinations:
178
+ # If the combination is invalid, raise an error
179
  raise gr.Error("Invalid prompt combination! Please re-select from the period options")
180
 
181
  output_queue = queue.Queue()
 
186
 
187
  def run_inference():
188
  try:
189
+ # Use downloaded model weights path for inference
190
  result = inference_patch(period, composer, instrumentation)
191
  result_container.append(result)
192
  finally:
 
201
  audio_file = None
202
  pdf_state = None
203
 
204
+ # First continuously read intermediate output
205
  while thread.is_alive():
206
  try:
207
  text = output_queue.get(timeout=0.1)
208
  process_output += text
209
+ # No final ABC yet, files not yet converted
210
  yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False)
211
  except queue.Empty:
212
  continue
213
 
214
+ # After thread ends, get all remaining items from the queue
215
  while not output_queue.empty():
216
  text = output_queue.get()
217
  process_output += text
218
 
219
+ # Final inference result
220
  final_result = result_container[0] if result_container else ""
221
 
222
+ # Display file conversion prompt
223
  final_output_abc = "Converting files..."
224
  yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False)
225
 
226
 
227
+ # Convert files
228
  try:
229
  file_paths = convert_files(final_result, period, composer, instrumentation)
230
  final_output_abc = final_result
231
+ # Get the first image and mp3 file
232
  if file_paths['pages'] > 0:
233
  pdf_image = f"{file_paths['base']}_page_1.png"
234
  audio_file = file_paths['mp3']
235
+ pdf_state = file_paths # Directly use the converted information dictionary as state
236
 
237
+ # Prepare download file list
238
  download_list = []
239
  if 'abc' in file_paths and os.path.exists(file_paths['abc']):
240
  download_list.append(file_paths['abc'])
 
247
  if 'mp3' in file_paths and os.path.exists(file_paths['mp3']):
248
  download_list.append(file_paths['mp3'])
249
  except Exception as e:
250
+ # If conversion fails, return error message to output box
251
  yield process_output, f"Error converting files: {str(e)}", None, None, None, gr.update(value=None, visible=False)
252
  return
253
 
254
+ # Final yield with all information - modify here to make component visible
255
  yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=download_list, visible=True)
256
 
257
 
258
  def get_file(file_type, period, composer, instrumentation):
259
  """
260
+ Returns the local file of specified type for Gradio download
261
  """
262
+ # Here you actually need to return based on specific file paths saved earlier, simplified for demo
263
+ # If matching by timestamp, you can store all converted files in a directory and get the latest
264
+ # This is just an example:
265
  possible_files = [f for f in os.listdir('.') if f.endswith(f'.{file_type}')]
266
  if not possible_files:
267
  return None
268
+ # Simply return the latest
269
  possible_files.sort(key=os.path.getmtime)
270
  return possible_files[-1]
271
 
272
 
273
  css = """
274
+ /* Compact button style */
275
  button[size="sm"] {
276
  padding: 4px 8px !important;
277
  margin: 2px !important;
278
  min-width: 60px;
279
  }
280
 
281
+ /* PDF preview area */
282
  #pdf-preview {
283
+ border-radius: 8px; /* Rounded corners */
284
+ box-shadow: 0 2px 8px rgba(0,0,0,0.1); /* Shadow */
285
  }
286
 
287
  .page-btn {
288
+ padding: 12px !important; /* Increase clickable area */
289
+ margin: auto !important; /* Vertical center */
290
  }
291
 
292
+ /* Button hover effect */
293
  .page-btn:hover {
294
  background: #f0f0f0 !important;
295
  transform: scale(1.05);
296
  }
297
 
298
+ /* Layout adjustment */
299
  .gr-row {
300
+ gap: 10px !important; /* Element spacing */
301
  }
302
 
303
+ /* Audio player */
304
  .audio-panel {
305
  margin-top: 15px !important;
306
  max-width: 400px;
 
310
  height: 200px !important;
311
  }
312
 
313
+ /* Save functionality area */
314
  .save-as-row {
315
  margin-top: 15px;
316
  padding: 10px;
317
  border-top: 1px solid #eee;
318
  }
319
 
 
 
 
 
 
 
 
 
 
 
320
  /* Download files styling */
321
  .download-files {
322
  margin-top: 15px;
 
329
  with gr.Blocks(css=css) as demo:
330
  gr.Markdown("## NotaGen")
331
 
332
+ # For storing PDF page count, current page and other information
333
  pdf_state = gr.State()
334
 
335
  with gr.Column():
336
  with gr.Row():
337
+ # Left sidebar
338
  with gr.Column():
339
  with gr.Row():
340
  period_dd = gr.Dropdown(
 
374
  placeholder="Post-processed ABC scores will be shown here..."
375
  )
376
 
377
+ # Audio playback
378
  audio_player = gr.Audio(
379
  label="Audio Preview",
380
  format="mp3",
381
  interactive=False,
 
 
382
  )
383
 
384
+ # Right sidebar
385
  with gr.Column():
386
+ # Image container
387
  pdf_image = gr.Image(
388
  label="Sheet Music Preview",
389
  show_label=False,
 
394
  show_download_button=False
395
  )
396
 
397
+ # Page navigation buttons
398
  with gr.Row():
399
  prev_btn = gr.Button(
400
  "⬅️ Last Page",
 
418
  type="filepath" # Make sure this is set to filepath
419
  )
420
 
421
+ # Dropdown linking
422
  period_dd.change(
423
  update_components,
424
  inputs=[period_dd, composer_dd],
 
430
  outputs=[composer_dd, instrument_dd]
431
  )
432
 
433
+ # Click generate button, note outputs must match each yield in generate_music
434
  generate_btn.click(
435
  generate_music,
436
  inputs=[period_dd, composer_dd, instrument_dd],
437
  outputs=[process_output, final_output, pdf_image, audio_player, pdf_state, download_files]
438
  )
439
 
440
+ # Page navigation
441
  prev_signal = gr.Textbox(value="prev", visible=False)
442
  next_signal = gr.Textbox(value="next", visible=False)
443
 
444
  prev_btn.click(
445
  update_page,
446
+ inputs=[prev_signal, pdf_state], # ✅ Use component
447
  outputs=[pdf_image, prev_btn, next_btn, pdf_state]
448
  )
449
 
450
  next_btn.click(
451
  update_page,
452
+ inputs=[next_signal, pdf_state], # ✅ Use component
453
  outputs=[pdf_image, prev_btn, next_btn, pdf_state]
454
  )
455
 
inference.py CHANGED
@@ -69,30 +69,30 @@ def download_model_weights():
69
 
70
  def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
71
  """
72
- k-bit 训练准备模型。
73
- 功能包括:
74
- 1. 将模型转换为混合精度(FP16)。
75
- 2. 禁用不需要的梯度计算。
76
- 3. 启用梯度检查点(可选)。
77
  """
78
- # 将模型转换为混合精度
79
  model = model.to(dtype=torch.float16)
80
 
81
- # 禁用嵌入层的梯度
82
  for param in model.parameters():
83
  if param.dtype == torch.float32:
84
  param.requires_grad = False
85
 
86
- # 启用梯度检查点
87
  if use_gradient_checkpointing:
88
  model.gradient_checkpointing_enable()
89
 
90
  return model
91
 
92
- # 应用量化配置
93
  model = prepare_model_for_kbit_training(
94
  model,
95
- use_gradient_checkpointing=False # 推理时不需要梯度检查
96
  )
97
 
98
  print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
@@ -146,19 +146,19 @@ def complete_brackets(s):
146
  stack = []
147
  bracket_map = {'{': '}', '[': ']', '(': ')'}
148
 
149
- # 遍历每个字符,处理括号匹配
150
  for char in s:
151
  if char in bracket_map:
152
  stack.append(char)
153
  elif char in bracket_map.values():
154
- # 查找对应的左括号
155
  for key, value in bracket_map.items():
156
  if value == char:
157
  if stack and stack[-1] == key:
158
  stack.pop()
159
- break # 找到对应的右括号,处理下一个字符
160
 
161
- # 补全缺失的右括号(按栈中剩余左括号的逆序)
162
  completion = ''.join(bracket_map[c] for c in reversed(stack))
163
  return s + completion
164
 
@@ -333,26 +333,24 @@ def inference_patch(period, composer, instrumentation):
333
  predicted_patch = torch.tensor([predicted_patch], device=device) # (1, 16)
334
  input_patches = torch.cat([input_patches, predicted_patch], dim=1) # (1, 16 * patch_len)
335
 
336
- if len(byte_list) > 102400: # 过长
337
  failure_flag = True
338
  break
339
- if time.time() - start_time > 20 * 60: # 生成时间不得超过20min
340
  failure_flag = True
341
  break
342
 
343
  if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:
344
- # 做流式切片
345
  print('Stream generating...')
346
 
347
  metadata = ''.join(metadata_byte_list)
348
  context_tunebody = ''.join(context_tunebody_byte_list)
349
 
350
  if '\n' not in context_tunebody:
351
- # 生成的全是metadata,放弃
352
- break
353
 
354
  context_tunebody_liness = context_tunebody.split('\n')
355
- if not context_tunebody.endswith('\n'): # 如果生成结果最后一行未完结
356
  context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness) - 1)] + [context_tunebody_liness[-1]]
357
  else:
358
  context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness))]
 
69
 
70
  def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
71
  """
72
+ Prepare model for k-bit training.
73
+ Features include:
74
+ 1. Convert model to mixed precision (FP16).
75
+ 2. Disable unnecessary gradient computations.
76
+ 3. Enable gradient checkpointing (optional).
77
  """
78
+ # Convert model to mixed precision
79
  model = model.to(dtype=torch.float16)
80
 
81
+ # Disable gradients for embedding layers
82
  for param in model.parameters():
83
  if param.dtype == torch.float32:
84
  param.requires_grad = False
85
 
86
+ # Enable gradient checkpointing
87
  if use_gradient_checkpointing:
88
  model.gradient_checkpointing_enable()
89
 
90
  return model
91
 
92
+
93
  model = prepare_model_for_kbit_training(
94
  model,
95
+ use_gradient_checkpointing=False
96
  )
97
 
98
  print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
 
146
  stack = []
147
  bracket_map = {'{': '}', '[': ']', '(': ')'}
148
 
149
+ # Iterate through each character, handle bracket matching
150
  for char in s:
151
  if char in bracket_map:
152
  stack.append(char)
153
  elif char in bracket_map.values():
154
+ # Find the corresponding left bracket
155
  for key, value in bracket_map.items():
156
  if value == char:
157
  if stack and stack[-1] == key:
158
  stack.pop()
159
+ break # Found matching right bracket, process next character
160
 
161
+ # Complete missing right brackets (in reverse order of remaining left brackets in stack)
162
  completion = ''.join(bracket_map[c] for c in reversed(stack))
163
  return s + completion
164
 
 
333
  predicted_patch = torch.tensor([predicted_patch], device=device) # (1, 16)
334
  input_patches = torch.cat([input_patches, predicted_patch], dim=1) # (1, 16 * patch_len)
335
 
336
+ if len(byte_list) > 102400:
337
  failure_flag = True
338
  break
339
+ if time.time() - start_time > 10 * 60:
340
  failure_flag = True
341
  break
342
 
343
  if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:
 
344
  print('Stream generating...')
345
 
346
  metadata = ''.join(metadata_byte_list)
347
  context_tunebody = ''.join(context_tunebody_byte_list)
348
 
349
  if '\n' not in context_tunebody:
350
+ break # Generated content is all metadata, abandon
 
351
 
352
  context_tunebody_liness = context_tunebody.split('\n')
353
+ if not context_tunebody.endswith('\n'):
354
  context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness) - 1)] + [context_tunebody_liness[-1]]
355
  else:
356
  context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness))]