Spaces:
Runtime error
Runtime error
Upload inference.py
Browse files- inference.py +38 -43
inference.py
CHANGED
@@ -42,15 +42,16 @@ byte_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
|
|
42 |
|
43 |
model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config).to(device)
|
44 |
|
|
|
45 |
def download_model_weights():
|
46 |
weights_path = "weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth"
|
47 |
local_weights_path = os.path.join(os.getcwd(), weights_path)
|
48 |
-
|
49 |
# Check if weights already exist locally
|
50 |
if os.path.exists(local_weights_path):
|
51 |
logger.info(f"Model weights already exist at {local_weights_path}")
|
52 |
return local_weights_path
|
53 |
-
|
54 |
logger.info("Downloading model weights from HuggingFace Hub...")
|
55 |
try:
|
56 |
# Download from HuggingFace
|
@@ -92,7 +93,7 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
|
|
92 |
|
93 |
model = prepare_model_for_kbit_training(
|
94 |
model,
|
95 |
-
use_gradient_checkpointing=False
|
96 |
)
|
97 |
|
98 |
print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
@@ -107,13 +108,12 @@ model.eval()
|
|
107 |
|
108 |
|
109 |
def postprocess_inst_names(abc_text):
|
110 |
-
|
111 |
with open('standard_inst_names.txt', 'r', encoding='utf-8') as f:
|
112 |
standard_instruments_list = [line.strip() for line in f if line.strip()]
|
113 |
|
114 |
with open('instrument_mapping.json', 'r', encoding='utf-8') as f:
|
115 |
instrument_mapping = json.load(f)
|
116 |
-
|
117 |
abc_lines = abc_text.split('\n')
|
118 |
abc_lines = list(filter(None, abc_lines))
|
119 |
abc_lines = [line + '\n' for line in abc_lines]
|
@@ -123,20 +123,20 @@ def postprocess_inst_names(abc_text):
|
|
123 |
match = re.search(r'nm="([^"]*)"', line)
|
124 |
if match:
|
125 |
inst_name = match.group(1)
|
126 |
-
|
127 |
# Check if the instrument name is already standard
|
128 |
if inst_name in standard_instruments_list:
|
129 |
continue
|
130 |
-
|
131 |
# Find the most similar key in instrument_mapping
|
132 |
matching_key = difflib.get_close_matches(inst_name, list(instrument_mapping.keys()), n=1, cutoff=0.6)
|
133 |
-
|
134 |
if matching_key:
|
135 |
# Replace the instrument name with the standardized version
|
136 |
replacement = instrument_mapping[matching_key[0]]
|
137 |
new_line = line.replace(f'nm="{inst_name}"', f'nm="{replacement}"')
|
138 |
abc_lines[i] = new_line
|
139 |
-
|
140 |
# Combine the lines back into a single string
|
141 |
processed_abc_text = ''.join(abc_lines)
|
142 |
return processed_abc_text
|
@@ -145,7 +145,7 @@ def postprocess_inst_names(abc_text):
|
|
145 |
def complete_brackets(s):
|
146 |
stack = []
|
147 |
bracket_map = {'{': '}', '[': ']', '(': ')'}
|
148 |
-
|
149 |
# Iterate through each character, handle bracket matching
|
150 |
for char in s:
|
151 |
if char in bracket_map:
|
@@ -157,15 +157,13 @@ def complete_brackets(s):
|
|
157 |
if stack and stack[-1] == key:
|
158 |
stack.pop()
|
159 |
break # Found matching right bracket, process next character
|
160 |
-
|
161 |
# Complete missing right brackets (in reverse order of remaining left brackets in stack)
|
162 |
completion = ''.join(bracket_map[c] for c in reversed(stack))
|
163 |
return s + completion
|
164 |
|
165 |
|
166 |
-
|
167 |
def rest_unreduce(abc_lines):
|
168 |
-
|
169 |
tunebody_index = None
|
170 |
for i in range(len(abc_lines)):
|
171 |
if abc_lines[i].startswith('%%score'):
|
@@ -215,7 +213,7 @@ def rest_unreduce(abc_lines):
|
|
215 |
line_bar_dict[key] = value
|
216 |
|
217 |
# calculate duration and collect barline
|
218 |
-
dur_dict = {}
|
219 |
for symbol, bartext in line_bar_dict.items():
|
220 |
right_barline = ''.join(re.split(Barline_regexPattern, bartext)[-2:])
|
221 |
bartext = bartext[:-len(right_barline)]
|
@@ -232,7 +230,7 @@ def rest_unreduce(abc_lines):
|
|
232 |
try:
|
233 |
ref_dur = max(dur_dict, key=dur_dict.get)
|
234 |
except:
|
235 |
-
pass
|
236 |
|
237 |
if i == 0:
|
238 |
prefix_left_barline = line.split('[V:')[0]
|
@@ -256,16 +254,11 @@ def rest_unreduce(abc_lines):
|
|
256 |
return unreduced_lines
|
257 |
|
258 |
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
def inference_patch(period, composer, instrumentation):
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
'%' + instrumentation + '\n']
|
269 |
|
270 |
while True:
|
271 |
|
@@ -294,21 +287,22 @@ def inference_patch(period, composer, instrumentation):
|
|
294 |
tunebody_flag = False
|
295 |
|
296 |
with torch.inference_mode():
|
297 |
-
|
298 |
while True:
|
299 |
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
300 |
predicted_patch = model.generate(input_patches.unsqueeze(0),
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith(
|
|
|
305 |
tunebody_flag = True
|
306 |
r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device)
|
307 |
temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1)
|
308 |
predicted_patch = model.generate(temp_input_patches.unsqueeze(0),
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch
|
313 |
if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id:
|
314 |
end_flag = True
|
@@ -336,7 +330,7 @@ def inference_patch(period, composer, instrumentation):
|
|
336 |
if len(byte_list) > 102400:
|
337 |
failure_flag = True
|
338 |
break
|
339 |
-
if time.time() - start_time > 10 * 60:
|
340 |
failure_flag = True
|
341 |
break
|
342 |
|
@@ -347,16 +341,19 @@ def inference_patch(period, composer, instrumentation):
|
|
347 |
context_tunebody = ''.join(context_tunebody_byte_list)
|
348 |
|
349 |
if '\n' not in context_tunebody:
|
350 |
-
break
|
|
|
|
|
351 |
|
352 |
-
context_tunebody_liness = context_tunebody.split('\n')
|
353 |
if not context_tunebody.endswith('\n'):
|
354 |
-
|
|
|
355 |
else:
|
356 |
-
|
|
|
357 |
|
358 |
-
cut_index = len(
|
359 |
-
abc_code_slice = metadata + ''.join(
|
360 |
|
361 |
input_patches = patchilizer.encode_generate(abc_code_slice)
|
362 |
|
@@ -379,15 +376,13 @@ def inference_patch(period, composer, instrumentation):
|
|
379 |
failure_flag = True
|
380 |
pass
|
381 |
else:
|
382 |
-
unreduced_abc_lines = [line for line in unreduced_abc_lines if
|
|
|
383 |
unreduced_abc_lines = ['X:1\n'] + unreduced_abc_lines
|
384 |
unreduced_abc_text = ''.join(unreduced_abc_lines)
|
385 |
return unreduced_abc_text
|
386 |
|
387 |
|
388 |
-
|
389 |
-
|
390 |
if __name__ == '__main__':
|
391 |
-
|
392 |
inference_patch('Classical', 'Beethoven, Ludwig van', 'Orchestral')
|
393 |
|
|
|
42 |
|
43 |
model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config).to(device)
|
44 |
|
45 |
+
|
46 |
def download_model_weights():
|
47 |
weights_path = "weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth"
|
48 |
local_weights_path = os.path.join(os.getcwd(), weights_path)
|
49 |
+
|
50 |
# Check if weights already exist locally
|
51 |
if os.path.exists(local_weights_path):
|
52 |
logger.info(f"Model weights already exist at {local_weights_path}")
|
53 |
return local_weights_path
|
54 |
+
|
55 |
logger.info("Downloading model weights from HuggingFace Hub...")
|
56 |
try:
|
57 |
# Download from HuggingFace
|
|
|
93 |
|
94 |
model = prepare_model_for_kbit_training(
|
95 |
model,
|
96 |
+
use_gradient_checkpointing=False
|
97 |
)
|
98 |
|
99 |
print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
|
|
108 |
|
109 |
|
110 |
def postprocess_inst_names(abc_text):
|
|
|
111 |
with open('standard_inst_names.txt', 'r', encoding='utf-8') as f:
|
112 |
standard_instruments_list = [line.strip() for line in f if line.strip()]
|
113 |
|
114 |
with open('instrument_mapping.json', 'r', encoding='utf-8') as f:
|
115 |
instrument_mapping = json.load(f)
|
116 |
+
|
117 |
abc_lines = abc_text.split('\n')
|
118 |
abc_lines = list(filter(None, abc_lines))
|
119 |
abc_lines = [line + '\n' for line in abc_lines]
|
|
|
123 |
match = re.search(r'nm="([^"]*)"', line)
|
124 |
if match:
|
125 |
inst_name = match.group(1)
|
126 |
+
|
127 |
# Check if the instrument name is already standard
|
128 |
if inst_name in standard_instruments_list:
|
129 |
continue
|
130 |
+
|
131 |
# Find the most similar key in instrument_mapping
|
132 |
matching_key = difflib.get_close_matches(inst_name, list(instrument_mapping.keys()), n=1, cutoff=0.6)
|
133 |
+
|
134 |
if matching_key:
|
135 |
# Replace the instrument name with the standardized version
|
136 |
replacement = instrument_mapping[matching_key[0]]
|
137 |
new_line = line.replace(f'nm="{inst_name}"', f'nm="{replacement}"')
|
138 |
abc_lines[i] = new_line
|
139 |
+
|
140 |
# Combine the lines back into a single string
|
141 |
processed_abc_text = ''.join(abc_lines)
|
142 |
return processed_abc_text
|
|
|
145 |
def complete_brackets(s):
|
146 |
stack = []
|
147 |
bracket_map = {'{': '}', '[': ']', '(': ')'}
|
148 |
+
|
149 |
# Iterate through each character, handle bracket matching
|
150 |
for char in s:
|
151 |
if char in bracket_map:
|
|
|
157 |
if stack and stack[-1] == key:
|
158 |
stack.pop()
|
159 |
break # Found matching right bracket, process next character
|
160 |
+
|
161 |
# Complete missing right brackets (in reverse order of remaining left brackets in stack)
|
162 |
completion = ''.join(bracket_map[c] for c in reversed(stack))
|
163 |
return s + completion
|
164 |
|
165 |
|
|
|
166 |
def rest_unreduce(abc_lines):
|
|
|
167 |
tunebody_index = None
|
168 |
for i in range(len(abc_lines)):
|
169 |
if abc_lines[i].startswith('%%score'):
|
|
|
213 |
line_bar_dict[key] = value
|
214 |
|
215 |
# calculate duration and collect barline
|
216 |
+
dur_dict = {}
|
217 |
for symbol, bartext in line_bar_dict.items():
|
218 |
right_barline = ''.join(re.split(Barline_regexPattern, bartext)[-2:])
|
219 |
bartext = bartext[:-len(right_barline)]
|
|
|
230 |
try:
|
231 |
ref_dur = max(dur_dict, key=dur_dict.get)
|
232 |
except:
|
233 |
+
pass # use last ref_dur
|
234 |
|
235 |
if i == 0:
|
236 |
prefix_left_barline = line.split('[V:')[0]
|
|
|
254 |
return unreduced_lines
|
255 |
|
256 |
|
|
|
|
|
|
|
|
|
257 |
def inference_patch(period, composer, instrumentation):
|
258 |
+
prompt_lines = [
|
259 |
+
'%' + period + '\n',
|
260 |
+
'%' + composer + '\n',
|
261 |
+
'%' + instrumentation + '\n']
|
|
|
262 |
|
263 |
while True:
|
264 |
|
|
|
287 |
tunebody_flag = False
|
288 |
|
289 |
with torch.inference_mode():
|
290 |
+
|
291 |
while True:
|
292 |
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
293 |
predicted_patch = model.generate(input_patches.unsqueeze(0),
|
294 |
+
top_k=TOP_K,
|
295 |
+
top_p=TOP_P,
|
296 |
+
temperature=TEMPERATURE)
|
297 |
+
if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith(
|
298 |
+
'[r:'): # ε欑θΏε
₯tunebodyοΌεΏ
ι‘»δ»₯[r:0/εΌε€΄
|
299 |
tunebody_flag = True
|
300 |
r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device)
|
301 |
temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1)
|
302 |
predicted_patch = model.generate(temp_input_patches.unsqueeze(0),
|
303 |
+
top_k=TOP_K,
|
304 |
+
top_p=TOP_P,
|
305 |
+
temperature=TEMPERATURE)
|
306 |
predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch
|
307 |
if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id:
|
308 |
end_flag = True
|
|
|
330 |
if len(byte_list) > 102400:
|
331 |
failure_flag = True
|
332 |
break
|
333 |
+
if time.time() - start_time > 10 * 60:
|
334 |
failure_flag = True
|
335 |
break
|
336 |
|
|
|
341 |
context_tunebody = ''.join(context_tunebody_byte_list)
|
342 |
|
343 |
if '\n' not in context_tunebody:
|
344 |
+
break # Generated content is all metadata, abandon
|
345 |
+
|
346 |
+
context_tunebody_lines = context_tunebody.strip().split('\n')
|
347 |
|
|
|
348 |
if not context_tunebody.endswith('\n'):
|
349 |
+
context_tunebody_lines = [context_tunebody_lines[i] + '\n' for i in
|
350 |
+
range(len(context_tunebody_lines) - 1)] + [context_tunebody_lines[-1]]
|
351 |
else:
|
352 |
+
context_tunebody_lines = [context_tunebody_lines[i] + '\n' for i in
|
353 |
+
range(len(context_tunebody_lines))]
|
354 |
|
355 |
+
cut_index = len(context_tunebody_lines) // 2
|
356 |
+
abc_code_slice = metadata + ''.join(context_tunebody_lines[-cut_index:])
|
357 |
|
358 |
input_patches = patchilizer.encode_generate(abc_code_slice)
|
359 |
|
|
|
376 |
failure_flag = True
|
377 |
pass
|
378 |
else:
|
379 |
+
unreduced_abc_lines = [line for line in unreduced_abc_lines if
|
380 |
+
not (line.startswith('%') and not line.startswith('%%'))]
|
381 |
unreduced_abc_lines = ['X:1\n'] + unreduced_abc_lines
|
382 |
unreduced_abc_text = ''.join(unreduced_abc_lines)
|
383 |
return unreduced_abc_text
|
384 |
|
385 |
|
|
|
|
|
386 |
if __name__ == '__main__':
|
|
|
387 |
inference_patch('Classical', 'Beethoven, Ludwig van', 'Orchestral')
|
388 |
|