Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- app.py +57 -69
- inference.py +18 -20
app.py
CHANGED
@@ -15,7 +15,7 @@ from inference import inference_patch
|
|
15 |
from convert import abc2xml, xml2, pdf2img
|
16 |
|
17 |
|
18 |
-
#
|
19 |
with open('prompts.txt', 'r') as f:
|
20 |
prompts = f.readlines()
|
21 |
|
@@ -25,12 +25,12 @@ for prompt in prompts:
|
|
25 |
parts = prompt.split('_')
|
26 |
valid_combinations.add((parts[0], parts[1], parts[2]))
|
27 |
|
28 |
-
#
|
29 |
periods = sorted({p for p, _, _ in valid_combinations})
|
30 |
composers = sorted({c for _, c, _ in valid_combinations})
|
31 |
instruments = sorted({i for _, _, i in valid_combinations})
|
32 |
|
33 |
-
#
|
34 |
def update_components(period, composer):
|
35 |
if not period:
|
36 |
return [
|
@@ -54,7 +54,7 @@ def update_components(period, composer):
|
|
54 |
)
|
55 |
]
|
56 |
|
57 |
-
#
|
58 |
class RealtimeStream(TextIOBase):
|
59 |
def __init__(self, queue):
|
60 |
self.queue = queue
|
@@ -81,7 +81,7 @@ def convert_files(abc_content, period, composer, instrumentation):
|
|
81 |
with open(filename_base_postinst + ".abc", "w", encoding="utf-8") as f:
|
82 |
f.write(postprocessed_inst_abc)
|
83 |
|
84 |
-
#
|
85 |
file_paths = {'abc': abc_filename}
|
86 |
try:
|
87 |
# abc2xml
|
@@ -115,15 +115,15 @@ def convert_files(abc_content, period, composer, instrumentation):
|
|
115 |
})
|
116 |
|
117 |
except Exception as e:
|
118 |
-
raise gr.Error(f"
|
119 |
|
120 |
return file_paths
|
121 |
|
122 |
|
123 |
-
#
|
124 |
def update_page(direction, data):
|
125 |
"""
|
126 |
-
data
|
127 |
"""
|
128 |
if not data:
|
129 |
return None, gr.update(interactive=False), gr.update(interactive=False), data
|
@@ -134,9 +134,9 @@ def update_page(direction, data):
|
|
134 |
data['current_page'] += 1
|
135 |
|
136 |
current_page_index = data['current_page']
|
137 |
-
#
|
138 |
new_image = f"{data['base']}_page_{current_page_index+1}.png"
|
139 |
-
#
|
140 |
prev_btn_state = gr.update(interactive=(current_page_index > 0))
|
141 |
next_btn_state = gr.update(interactive=(current_page_index < data['pages'] - 1))
|
142 |
|
@@ -146,13 +146,13 @@ def update_page(direction, data):
|
|
146 |
@spaces.GPU(duration=600)
|
147 |
def generate_music(period, composer, instrumentation):
|
148 |
"""
|
149 |
-
|
150 |
-
|
151 |
-
1) process_output (
|
152 |
-
2) final_output (
|
153 |
-
3) pdf_image (
|
154 |
-
4) audio_player (mp3
|
155 |
-
5) pdf_state (
|
156 |
"""
|
157 |
# Set a different random seed each time based on current timestamp
|
158 |
random_seed = int(time.time()) % 10000
|
@@ -175,7 +175,7 @@ def generate_music(period, composer, instrumentation):
|
|
175 |
pass
|
176 |
|
177 |
if (period, composer, instrumentation) not in valid_combinations:
|
178 |
-
#
|
179 |
raise gr.Error("Invalid prompt combination! Please re-select from the period options")
|
180 |
|
181 |
output_queue = queue.Queue()
|
@@ -186,7 +186,7 @@ def generate_music(period, composer, instrumentation):
|
|
186 |
|
187 |
def run_inference():
|
188 |
try:
|
189 |
-
#
|
190 |
result = inference_patch(period, composer, instrumentation)
|
191 |
result_container.append(result)
|
192 |
finally:
|
@@ -201,40 +201,40 @@ def generate_music(period, composer, instrumentation):
|
|
201 |
audio_file = None
|
202 |
pdf_state = None
|
203 |
|
204 |
-
#
|
205 |
while thread.is_alive():
|
206 |
try:
|
207 |
text = output_queue.get(timeout=0.1)
|
208 |
process_output += text
|
209 |
-
#
|
210 |
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False)
|
211 |
except queue.Empty:
|
212 |
continue
|
213 |
|
214 |
-
#
|
215 |
while not output_queue.empty():
|
216 |
text = output_queue.get()
|
217 |
process_output += text
|
218 |
|
219 |
-
#
|
220 |
final_result = result_container[0] if result_container else ""
|
221 |
|
222 |
-
#
|
223 |
final_output_abc = "Converting files..."
|
224 |
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False)
|
225 |
|
226 |
|
227 |
-
#
|
228 |
try:
|
229 |
file_paths = convert_files(final_result, period, composer, instrumentation)
|
230 |
final_output_abc = final_result
|
231 |
-
#
|
232 |
if file_paths['pages'] > 0:
|
233 |
pdf_image = f"{file_paths['base']}_page_1.png"
|
234 |
audio_file = file_paths['mp3']
|
235 |
-
pdf_state = file_paths #
|
236 |
|
237 |
-
#
|
238 |
download_list = []
|
239 |
if 'abc' in file_paths and os.path.exists(file_paths['abc']):
|
240 |
download_list.append(file_paths['abc'])
|
@@ -247,60 +247,60 @@ def generate_music(period, composer, instrumentation):
|
|
247 |
if 'mp3' in file_paths and os.path.exists(file_paths['mp3']):
|
248 |
download_list.append(file_paths['mp3'])
|
249 |
except Exception as e:
|
250 |
-
#
|
251 |
yield process_output, f"Error converting files: {str(e)}", None, None, None, gr.update(value=None, visible=False)
|
252 |
return
|
253 |
|
254 |
-
|
255 |
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=download_list, visible=True)
|
256 |
|
257 |
|
258 |
def get_file(file_type, period, composer, instrumentation):
|
259 |
"""
|
260 |
-
|
261 |
"""
|
262 |
-
#
|
263 |
-
#
|
264 |
-
#
|
265 |
possible_files = [f for f in os.listdir('.') if f.endswith(f'.{file_type}')]
|
266 |
if not possible_files:
|
267 |
return None
|
268 |
-
#
|
269 |
possible_files.sort(key=os.path.getmtime)
|
270 |
return possible_files[-1]
|
271 |
|
272 |
|
273 |
css = """
|
274 |
-
/*
|
275 |
button[size="sm"] {
|
276 |
padding: 4px 8px !important;
|
277 |
margin: 2px !important;
|
278 |
min-width: 60px;
|
279 |
}
|
280 |
|
281 |
-
/* PDF
|
282 |
#pdf-preview {
|
283 |
-
border-radius: 8px; /*
|
284 |
-
box-shadow: 0 2px 8px rgba(0,0,0,0.1); /*
|
285 |
}
|
286 |
|
287 |
.page-btn {
|
288 |
-
padding: 12px !important; /*
|
289 |
-
margin: auto !important; /*
|
290 |
}
|
291 |
|
292 |
-
/*
|
293 |
.page-btn:hover {
|
294 |
background: #f0f0f0 !important;
|
295 |
transform: scale(1.05);
|
296 |
}
|
297 |
|
298 |
-
/*
|
299 |
.gr-row {
|
300 |
-
gap: 10px !important; /*
|
301 |
}
|
302 |
|
303 |
-
/*
|
304 |
.audio-panel {
|
305 |
margin-top: 15px !important;
|
306 |
max-width: 400px;
|
@@ -310,23 +310,13 @@ button[size="sm"] {
|
|
310 |
height: 200px !important;
|
311 |
}
|
312 |
|
313 |
-
/*
|
314 |
.save-as-row {
|
315 |
margin-top: 15px;
|
316 |
padding: 10px;
|
317 |
border-top: 1px solid #eee;
|
318 |
}
|
319 |
|
320 |
-
.save-as-label {
|
321 |
-
font-weight: bold;
|
322 |
-
margin-right: 10px;
|
323 |
-
align-self: center;
|
324 |
-
}
|
325 |
-
|
326 |
-
.save-buttons {
|
327 |
-
gap: 5px; /* 按钮间距 */
|
328 |
-
}
|
329 |
-
|
330 |
/* Download files styling */
|
331 |
.download-files {
|
332 |
margin-top: 15px;
|
@@ -339,12 +329,12 @@ button[size="sm"] {
|
|
339 |
with gr.Blocks(css=css) as demo:
|
340 |
gr.Markdown("## NotaGen")
|
341 |
|
342 |
-
#
|
343 |
pdf_state = gr.State()
|
344 |
|
345 |
with gr.Column():
|
346 |
with gr.Row():
|
347 |
-
#
|
348 |
with gr.Column():
|
349 |
with gr.Row():
|
350 |
period_dd = gr.Dropdown(
|
@@ -384,18 +374,16 @@ with gr.Blocks(css=css) as demo:
|
|
384 |
placeholder="Post-processed ABC scores will be shown here..."
|
385 |
)
|
386 |
|
387 |
-
#
|
388 |
audio_player = gr.Audio(
|
389 |
label="Audio Preview",
|
390 |
format="mp3",
|
391 |
interactive=False,
|
392 |
-
# container=False,
|
393 |
-
# elem_id="audio-preview"
|
394 |
)
|
395 |
|
396 |
-
#
|
397 |
with gr.Column():
|
398 |
-
#
|
399 |
pdf_image = gr.Image(
|
400 |
label="Sheet Music Preview",
|
401 |
show_label=False,
|
@@ -406,7 +394,7 @@ with gr.Blocks(css=css) as demo:
|
|
406 |
show_download_button=False
|
407 |
)
|
408 |
|
409 |
-
#
|
410 |
with gr.Row():
|
411 |
prev_btn = gr.Button(
|
412 |
"⬅️ Last Page",
|
@@ -430,7 +418,7 @@ with gr.Blocks(css=css) as demo:
|
|
430 |
type="filepath" # Make sure this is set to filepath
|
431 |
)
|
432 |
|
433 |
-
#
|
434 |
period_dd.change(
|
435 |
update_components,
|
436 |
inputs=[period_dd, composer_dd],
|
@@ -442,26 +430,26 @@ with gr.Blocks(css=css) as demo:
|
|
442 |
outputs=[composer_dd, instrument_dd]
|
443 |
)
|
444 |
|
445 |
-
#
|
446 |
generate_btn.click(
|
447 |
generate_music,
|
448 |
inputs=[period_dd, composer_dd, instrument_dd],
|
449 |
outputs=[process_output, final_output, pdf_image, audio_player, pdf_state, download_files]
|
450 |
)
|
451 |
|
452 |
-
#
|
453 |
prev_signal = gr.Textbox(value="prev", visible=False)
|
454 |
next_signal = gr.Textbox(value="next", visible=False)
|
455 |
|
456 |
prev_btn.click(
|
457 |
update_page,
|
458 |
-
inputs=[prev_signal, pdf_state], # ✅
|
459 |
outputs=[pdf_image, prev_btn, next_btn, pdf_state]
|
460 |
)
|
461 |
|
462 |
next_btn.click(
|
463 |
update_page,
|
464 |
-
inputs=[next_signal, pdf_state], # ✅
|
465 |
outputs=[pdf_image, prev_btn, next_btn, pdf_state]
|
466 |
)
|
467 |
|
|
|
15 |
from convert import abc2xml, xml2, pdf2img
|
16 |
|
17 |
|
18 |
+
# Read prompt combinations
|
19 |
with open('prompts.txt', 'r') as f:
|
20 |
prompts = f.readlines()
|
21 |
|
|
|
25 |
parts = prompt.split('_')
|
26 |
valid_combinations.add((parts[0], parts[1], parts[2]))
|
27 |
|
28 |
+
# Prepare dropdown options
|
29 |
periods = sorted({p for p, _, _ in valid_combinations})
|
30 |
composers = sorted({c for _, c, _ in valid_combinations})
|
31 |
instruments = sorted({i for _, _, i in valid_combinations})
|
32 |
|
33 |
+
# Dynamically update composer and instrument dropdown options
|
34 |
def update_components(period, composer):
|
35 |
if not period:
|
36 |
return [
|
|
|
54 |
)
|
55 |
]
|
56 |
|
57 |
+
# Custom realtime stream for outputting model inference process to frontend
|
58 |
class RealtimeStream(TextIOBase):
|
59 |
def __init__(self, queue):
|
60 |
self.queue = queue
|
|
|
81 |
with open(filename_base_postinst + ".abc", "w", encoding="utf-8") as f:
|
82 |
f.write(postprocessed_inst_abc)
|
83 |
|
84 |
+
# Convert files
|
85 |
file_paths = {'abc': abc_filename}
|
86 |
try:
|
87 |
# abc2xml
|
|
|
115 |
})
|
116 |
|
117 |
except Exception as e:
|
118 |
+
raise gr.Error(f"File processing failed: {str(e)}")
|
119 |
|
120 |
return file_paths
|
121 |
|
122 |
|
123 |
+
# Page navigation control function
|
124 |
def update_page(direction, data):
|
125 |
"""
|
126 |
+
data contains three key pieces of information: 'pages', 'current_page', and 'base'
|
127 |
"""
|
128 |
if not data:
|
129 |
return None, gr.update(interactive=False), gr.update(interactive=False), data
|
|
|
134 |
data['current_page'] += 1
|
135 |
|
136 |
current_page_index = data['current_page']
|
137 |
+
# Update image path
|
138 |
new_image = f"{data['base']}_page_{current_page_index+1}.png"
|
139 |
+
# When current_page==0, prev_btn is disabled; when current_page==pages-1, next_btn is disabled
|
140 |
prev_btn_state = gr.update(interactive=(current_page_index > 0))
|
141 |
next_btn_state = gr.update(interactive=(current_page_index < data['pages'] - 1))
|
142 |
|
|
|
146 |
@spaces.GPU(duration=600)
|
147 |
def generate_music(period, composer, instrumentation):
|
148 |
"""
|
149 |
+
Must ensure each yield returns the same number of values.
|
150 |
+
We're preparing to return 5 values, corresponding to:
|
151 |
+
1) process_output (intermediate inference information)
|
152 |
+
2) final_output (final ABC)
|
153 |
+
3) pdf_image (path to the PNG of the first page of the PDF)
|
154 |
+
4) audio_player (mp3 path)
|
155 |
+
5) pdf_state (state for page navigation)
|
156 |
"""
|
157 |
# Set a different random seed each time based on current timestamp
|
158 |
random_seed = int(time.time()) % 10000
|
|
|
175 |
pass
|
176 |
|
177 |
if (period, composer, instrumentation) not in valid_combinations:
|
178 |
+
# If the combination is invalid, raise an error
|
179 |
raise gr.Error("Invalid prompt combination! Please re-select from the period options")
|
180 |
|
181 |
output_queue = queue.Queue()
|
|
|
186 |
|
187 |
def run_inference():
|
188 |
try:
|
189 |
+
# Use downloaded model weights path for inference
|
190 |
result = inference_patch(period, composer, instrumentation)
|
191 |
result_container.append(result)
|
192 |
finally:
|
|
|
201 |
audio_file = None
|
202 |
pdf_state = None
|
203 |
|
204 |
+
# First continuously read intermediate output
|
205 |
while thread.is_alive():
|
206 |
try:
|
207 |
text = output_queue.get(timeout=0.1)
|
208 |
process_output += text
|
209 |
+
# No final ABC yet, files not yet converted
|
210 |
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False)
|
211 |
except queue.Empty:
|
212 |
continue
|
213 |
|
214 |
+
# After thread ends, get all remaining items from the queue
|
215 |
while not output_queue.empty():
|
216 |
text = output_queue.get()
|
217 |
process_output += text
|
218 |
|
219 |
+
# Final inference result
|
220 |
final_result = result_container[0] if result_container else ""
|
221 |
|
222 |
+
# Display file conversion prompt
|
223 |
final_output_abc = "Converting files..."
|
224 |
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False)
|
225 |
|
226 |
|
227 |
+
# Convert files
|
228 |
try:
|
229 |
file_paths = convert_files(final_result, period, composer, instrumentation)
|
230 |
final_output_abc = final_result
|
231 |
+
# Get the first image and mp3 file
|
232 |
if file_paths['pages'] > 0:
|
233 |
pdf_image = f"{file_paths['base']}_page_1.png"
|
234 |
audio_file = file_paths['mp3']
|
235 |
+
pdf_state = file_paths # Directly use the converted information dictionary as state
|
236 |
|
237 |
+
# Prepare download file list
|
238 |
download_list = []
|
239 |
if 'abc' in file_paths and os.path.exists(file_paths['abc']):
|
240 |
download_list.append(file_paths['abc'])
|
|
|
247 |
if 'mp3' in file_paths and os.path.exists(file_paths['mp3']):
|
248 |
download_list.append(file_paths['mp3'])
|
249 |
except Exception as e:
|
250 |
+
# If conversion fails, return error message to output box
|
251 |
yield process_output, f"Error converting files: {str(e)}", None, None, None, gr.update(value=None, visible=False)
|
252 |
return
|
253 |
|
254 |
+
# Final yield with all information - modify here to make component visible
|
255 |
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=download_list, visible=True)
|
256 |
|
257 |
|
258 |
def get_file(file_type, period, composer, instrumentation):
|
259 |
"""
|
260 |
+
Returns the local file of specified type for Gradio download
|
261 |
"""
|
262 |
+
# Here you actually need to return based on specific file paths saved earlier, simplified for demo
|
263 |
+
# If matching by timestamp, you can store all converted files in a directory and get the latest
|
264 |
+
# This is just an example:
|
265 |
possible_files = [f for f in os.listdir('.') if f.endswith(f'.{file_type}')]
|
266 |
if not possible_files:
|
267 |
return None
|
268 |
+
# Simply return the latest
|
269 |
possible_files.sort(key=os.path.getmtime)
|
270 |
return possible_files[-1]
|
271 |
|
272 |
|
273 |
css = """
|
274 |
+
/* Compact button style */
|
275 |
button[size="sm"] {
|
276 |
padding: 4px 8px !important;
|
277 |
margin: 2px !important;
|
278 |
min-width: 60px;
|
279 |
}
|
280 |
|
281 |
+
/* PDF preview area */
|
282 |
#pdf-preview {
|
283 |
+
border-radius: 8px; /* Rounded corners */
|
284 |
+
box-shadow: 0 2px 8px rgba(0,0,0,0.1); /* Shadow */
|
285 |
}
|
286 |
|
287 |
.page-btn {
|
288 |
+
padding: 12px !important; /* Increase clickable area */
|
289 |
+
margin: auto !important; /* Vertical center */
|
290 |
}
|
291 |
|
292 |
+
/* Button hover effect */
|
293 |
.page-btn:hover {
|
294 |
background: #f0f0f0 !important;
|
295 |
transform: scale(1.05);
|
296 |
}
|
297 |
|
298 |
+
/* Layout adjustment */
|
299 |
.gr-row {
|
300 |
+
gap: 10px !important; /* Element spacing */
|
301 |
}
|
302 |
|
303 |
+
/* Audio player */
|
304 |
.audio-panel {
|
305 |
margin-top: 15px !important;
|
306 |
max-width: 400px;
|
|
|
310 |
height: 200px !important;
|
311 |
}
|
312 |
|
313 |
+
/* Save functionality area */
|
314 |
.save-as-row {
|
315 |
margin-top: 15px;
|
316 |
padding: 10px;
|
317 |
border-top: 1px solid #eee;
|
318 |
}
|
319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
/* Download files styling */
|
321 |
.download-files {
|
322 |
margin-top: 15px;
|
|
|
329 |
with gr.Blocks(css=css) as demo:
|
330 |
gr.Markdown("## NotaGen")
|
331 |
|
332 |
+
# For storing PDF page count, current page and other information
|
333 |
pdf_state = gr.State()
|
334 |
|
335 |
with gr.Column():
|
336 |
with gr.Row():
|
337 |
+
# Left sidebar
|
338 |
with gr.Column():
|
339 |
with gr.Row():
|
340 |
period_dd = gr.Dropdown(
|
|
|
374 |
placeholder="Post-processed ABC scores will be shown here..."
|
375 |
)
|
376 |
|
377 |
+
# Audio playback
|
378 |
audio_player = gr.Audio(
|
379 |
label="Audio Preview",
|
380 |
format="mp3",
|
381 |
interactive=False,
|
|
|
|
|
382 |
)
|
383 |
|
384 |
+
# Right sidebar
|
385 |
with gr.Column():
|
386 |
+
# Image container
|
387 |
pdf_image = gr.Image(
|
388 |
label="Sheet Music Preview",
|
389 |
show_label=False,
|
|
|
394 |
show_download_button=False
|
395 |
)
|
396 |
|
397 |
+
# Page navigation buttons
|
398 |
with gr.Row():
|
399 |
prev_btn = gr.Button(
|
400 |
"⬅️ Last Page",
|
|
|
418 |
type="filepath" # Make sure this is set to filepath
|
419 |
)
|
420 |
|
421 |
+
# Dropdown linking
|
422 |
period_dd.change(
|
423 |
update_components,
|
424 |
inputs=[period_dd, composer_dd],
|
|
|
430 |
outputs=[composer_dd, instrument_dd]
|
431 |
)
|
432 |
|
433 |
+
# Click generate button, note outputs must match each yield in generate_music
|
434 |
generate_btn.click(
|
435 |
generate_music,
|
436 |
inputs=[period_dd, composer_dd, instrument_dd],
|
437 |
outputs=[process_output, final_output, pdf_image, audio_player, pdf_state, download_files]
|
438 |
)
|
439 |
|
440 |
+
# Page navigation
|
441 |
prev_signal = gr.Textbox(value="prev", visible=False)
|
442 |
next_signal = gr.Textbox(value="next", visible=False)
|
443 |
|
444 |
prev_btn.click(
|
445 |
update_page,
|
446 |
+
inputs=[prev_signal, pdf_state], # ✅ Use component
|
447 |
outputs=[pdf_image, prev_btn, next_btn, pdf_state]
|
448 |
)
|
449 |
|
450 |
next_btn.click(
|
451 |
update_page,
|
452 |
+
inputs=[next_signal, pdf_state], # ✅ Use component
|
453 |
outputs=[pdf_image, prev_btn, next_btn, pdf_state]
|
454 |
)
|
455 |
|
inference.py
CHANGED
@@ -69,30 +69,30 @@ def download_model_weights():
|
|
69 |
|
70 |
def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
|
71 |
"""
|
72 |
-
|
73 |
-
|
74 |
-
1.
|
75 |
-
2.
|
76 |
-
3.
|
77 |
"""
|
78 |
-
#
|
79 |
model = model.to(dtype=torch.float16)
|
80 |
|
81 |
-
#
|
82 |
for param in model.parameters():
|
83 |
if param.dtype == torch.float32:
|
84 |
param.requires_grad = False
|
85 |
|
86 |
-
#
|
87 |
if use_gradient_checkpointing:
|
88 |
model.gradient_checkpointing_enable()
|
89 |
|
90 |
return model
|
91 |
|
92 |
-
|
93 |
model = prepare_model_for_kbit_training(
|
94 |
model,
|
95 |
-
use_gradient_checkpointing=False
|
96 |
)
|
97 |
|
98 |
print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
@@ -146,19 +146,19 @@ def complete_brackets(s):
|
|
146 |
stack = []
|
147 |
bracket_map = {'{': '}', '[': ']', '(': ')'}
|
148 |
|
149 |
-
#
|
150 |
for char in s:
|
151 |
if char in bracket_map:
|
152 |
stack.append(char)
|
153 |
elif char in bracket_map.values():
|
154 |
-
#
|
155 |
for key, value in bracket_map.items():
|
156 |
if value == char:
|
157 |
if stack and stack[-1] == key:
|
158 |
stack.pop()
|
159 |
-
break #
|
160 |
|
161 |
-
#
|
162 |
completion = ''.join(bracket_map[c] for c in reversed(stack))
|
163 |
return s + completion
|
164 |
|
@@ -333,26 +333,24 @@ def inference_patch(period, composer, instrumentation):
|
|
333 |
predicted_patch = torch.tensor([predicted_patch], device=device) # (1, 16)
|
334 |
input_patches = torch.cat([input_patches, predicted_patch], dim=1) # (1, 16 * patch_len)
|
335 |
|
336 |
-
if len(byte_list) > 102400:
|
337 |
failure_flag = True
|
338 |
break
|
339 |
-
if time.time() - start_time >
|
340 |
failure_flag = True
|
341 |
break
|
342 |
|
343 |
if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:
|
344 |
-
# 做流式切片
|
345 |
print('Stream generating...')
|
346 |
|
347 |
metadata = ''.join(metadata_byte_list)
|
348 |
context_tunebody = ''.join(context_tunebody_byte_list)
|
349 |
|
350 |
if '\n' not in context_tunebody:
|
351 |
-
#
|
352 |
-
break
|
353 |
|
354 |
context_tunebody_liness = context_tunebody.split('\n')
|
355 |
-
if not context_tunebody.endswith('\n'):
|
356 |
context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness) - 1)] + [context_tunebody_liness[-1]]
|
357 |
else:
|
358 |
context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness))]
|
|
|
69 |
|
70 |
def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
|
71 |
"""
|
72 |
+
Prepare model for k-bit training.
|
73 |
+
Features include:
|
74 |
+
1. Convert model to mixed precision (FP16).
|
75 |
+
2. Disable unnecessary gradient computations.
|
76 |
+
3. Enable gradient checkpointing (optional).
|
77 |
"""
|
78 |
+
# Convert model to mixed precision
|
79 |
model = model.to(dtype=torch.float16)
|
80 |
|
81 |
+
# Disable gradients for embedding layers
|
82 |
for param in model.parameters():
|
83 |
if param.dtype == torch.float32:
|
84 |
param.requires_grad = False
|
85 |
|
86 |
+
# Enable gradient checkpointing
|
87 |
if use_gradient_checkpointing:
|
88 |
model.gradient_checkpointing_enable()
|
89 |
|
90 |
return model
|
91 |
|
92 |
+
|
93 |
model = prepare_model_for_kbit_training(
|
94 |
model,
|
95 |
+
use_gradient_checkpointing=False
|
96 |
)
|
97 |
|
98 |
print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
|
|
146 |
stack = []
|
147 |
bracket_map = {'{': '}', '[': ']', '(': ')'}
|
148 |
|
149 |
+
# Iterate through each character, handle bracket matching
|
150 |
for char in s:
|
151 |
if char in bracket_map:
|
152 |
stack.append(char)
|
153 |
elif char in bracket_map.values():
|
154 |
+
# Find the corresponding left bracket
|
155 |
for key, value in bracket_map.items():
|
156 |
if value == char:
|
157 |
if stack and stack[-1] == key:
|
158 |
stack.pop()
|
159 |
+
break # Found matching right bracket, process next character
|
160 |
|
161 |
+
# Complete missing right brackets (in reverse order of remaining left brackets in stack)
|
162 |
completion = ''.join(bracket_map[c] for c in reversed(stack))
|
163 |
return s + completion
|
164 |
|
|
|
333 |
predicted_patch = torch.tensor([predicted_patch], device=device) # (1, 16)
|
334 |
input_patches = torch.cat([input_patches, predicted_patch], dim=1) # (1, 16 * patch_len)
|
335 |
|
336 |
+
if len(byte_list) > 102400:
|
337 |
failure_flag = True
|
338 |
break
|
339 |
+
if time.time() - start_time > 10 * 60:
|
340 |
failure_flag = True
|
341 |
break
|
342 |
|
343 |
if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:
|
|
|
344 |
print('Stream generating...')
|
345 |
|
346 |
metadata = ''.join(metadata_byte_list)
|
347 |
context_tunebody = ''.join(context_tunebody_byte_list)
|
348 |
|
349 |
if '\n' not in context_tunebody:
|
350 |
+
break # Generated content is all metadata, abandon
|
|
|
351 |
|
352 |
context_tunebody_liness = context_tunebody.split('\n')
|
353 |
+
if not context_tunebody.endswith('\n'):
|
354 |
context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness) - 1)] + [context_tunebody_liness[-1]]
|
355 |
else:
|
356 |
context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness))]
|