ElectricAlexis commited on
Commit
d900b7e
Β·
verified Β·
1 Parent(s): 272ffb6

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +38 -43
inference.py CHANGED
@@ -42,15 +42,16 @@ byte_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
42
 
43
  model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config).to(device)
44
 
 
45
  def download_model_weights():
46
  weights_path = "weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth"
47
  local_weights_path = os.path.join(os.getcwd(), weights_path)
48
-
49
  # Check if weights already exist locally
50
  if os.path.exists(local_weights_path):
51
  logger.info(f"Model weights already exist at {local_weights_path}")
52
  return local_weights_path
53
-
54
  logger.info("Downloading model weights from HuggingFace Hub...")
55
  try:
56
  # Download from HuggingFace
@@ -92,7 +93,7 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
92
 
93
  model = prepare_model_for_kbit_training(
94
  model,
95
- use_gradient_checkpointing=False
96
  )
97
 
98
  print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
@@ -107,13 +108,12 @@ model.eval()
107
 
108
 
109
  def postprocess_inst_names(abc_text):
110
-
111
  with open('standard_inst_names.txt', 'r', encoding='utf-8') as f:
112
  standard_instruments_list = [line.strip() for line in f if line.strip()]
113
 
114
  with open('instrument_mapping.json', 'r', encoding='utf-8') as f:
115
  instrument_mapping = json.load(f)
116
-
117
  abc_lines = abc_text.split('\n')
118
  abc_lines = list(filter(None, abc_lines))
119
  abc_lines = [line + '\n' for line in abc_lines]
@@ -123,20 +123,20 @@ def postprocess_inst_names(abc_text):
123
  match = re.search(r'nm="([^"]*)"', line)
124
  if match:
125
  inst_name = match.group(1)
126
-
127
  # Check if the instrument name is already standard
128
  if inst_name in standard_instruments_list:
129
  continue
130
-
131
  # Find the most similar key in instrument_mapping
132
  matching_key = difflib.get_close_matches(inst_name, list(instrument_mapping.keys()), n=1, cutoff=0.6)
133
-
134
  if matching_key:
135
  # Replace the instrument name with the standardized version
136
  replacement = instrument_mapping[matching_key[0]]
137
  new_line = line.replace(f'nm="{inst_name}"', f'nm="{replacement}"')
138
  abc_lines[i] = new_line
139
-
140
  # Combine the lines back into a single string
141
  processed_abc_text = ''.join(abc_lines)
142
  return processed_abc_text
@@ -145,7 +145,7 @@ def postprocess_inst_names(abc_text):
145
  def complete_brackets(s):
146
  stack = []
147
  bracket_map = {'{': '}', '[': ']', '(': ')'}
148
-
149
  # Iterate through each character, handle bracket matching
150
  for char in s:
151
  if char in bracket_map:
@@ -157,15 +157,13 @@ def complete_brackets(s):
157
  if stack and stack[-1] == key:
158
  stack.pop()
159
  break # Found matching right bracket, process next character
160
-
161
  # Complete missing right brackets (in reverse order of remaining left brackets in stack)
162
  completion = ''.join(bracket_map[c] for c in reversed(stack))
163
  return s + completion
164
 
165
 
166
-
167
  def rest_unreduce(abc_lines):
168
-
169
  tunebody_index = None
170
  for i in range(len(abc_lines)):
171
  if abc_lines[i].startswith('%%score'):
@@ -215,7 +213,7 @@ def rest_unreduce(abc_lines):
215
  line_bar_dict[key] = value
216
 
217
  # calculate duration and collect barline
218
- dur_dict = {}
219
  for symbol, bartext in line_bar_dict.items():
220
  right_barline = ''.join(re.split(Barline_regexPattern, bartext)[-2:])
221
  bartext = bartext[:-len(right_barline)]
@@ -232,7 +230,7 @@ def rest_unreduce(abc_lines):
232
  try:
233
  ref_dur = max(dur_dict, key=dur_dict.get)
234
  except:
235
- pass # use last ref_dur
236
 
237
  if i == 0:
238
  prefix_left_barline = line.split('[V:')[0]
@@ -256,16 +254,11 @@ def rest_unreduce(abc_lines):
256
  return unreduced_lines
257
 
258
 
259
-
260
-
261
-
262
-
263
  def inference_patch(period, composer, instrumentation):
264
-
265
- prompt_lines=[
266
- '%' + period + '\n',
267
- '%' + composer + '\n',
268
- '%' + instrumentation + '\n']
269
 
270
  while True:
271
 
@@ -294,21 +287,22 @@ def inference_patch(period, composer, instrumentation):
294
  tunebody_flag = False
295
 
296
  with torch.inference_mode():
297
-
298
  while True:
299
  with torch.autocast(device_type='cuda', dtype=torch.float16):
300
  predicted_patch = model.generate(input_patches.unsqueeze(0),
301
- top_k=TOP_K,
302
- top_p=TOP_P,
303
- temperature=TEMPERATURE)
304
- if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith('[r:'): # εˆζ¬‘θΏ›ε…₯tunebodyοΌŒεΏ…ι‘»δ»₯[r:0/开倴
 
305
  tunebody_flag = True
306
  r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device)
307
  temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1)
308
  predicted_patch = model.generate(temp_input_patches.unsqueeze(0),
309
- top_k=TOP_K,
310
- top_p=TOP_P,
311
- temperature=TEMPERATURE)
312
  predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch
313
  if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id:
314
  end_flag = True
@@ -336,7 +330,7 @@ def inference_patch(period, composer, instrumentation):
336
  if len(byte_list) > 102400:
337
  failure_flag = True
338
  break
339
- if time.time() - start_time > 10 * 60:
340
  failure_flag = True
341
  break
342
 
@@ -347,16 +341,19 @@ def inference_patch(period, composer, instrumentation):
347
  context_tunebody = ''.join(context_tunebody_byte_list)
348
 
349
  if '\n' not in context_tunebody:
350
- break # Generated content is all metadata, abandon
 
 
351
 
352
- context_tunebody_liness = context_tunebody.split('\n')
353
  if not context_tunebody.endswith('\n'):
354
- context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness) - 1)] + [context_tunebody_liness[-1]]
 
355
  else:
356
- context_tunebody_liness = [context_tunebody_liness[i] + '\n' for i in range(len(context_tunebody_liness))]
 
357
 
358
- cut_index = len(context_tunebody_liness) // 2
359
- abc_code_slice = metadata + ''.join(context_tunebody_liness[-cut_index:])
360
 
361
  input_patches = patchilizer.encode_generate(abc_code_slice)
362
 
@@ -379,15 +376,13 @@ def inference_patch(period, composer, instrumentation):
379
  failure_flag = True
380
  pass
381
  else:
382
- unreduced_abc_lines = [line for line in unreduced_abc_lines if not(line.startswith('%') and not line.startswith('%%'))]
 
383
  unreduced_abc_lines = ['X:1\n'] + unreduced_abc_lines
384
  unreduced_abc_text = ''.join(unreduced_abc_lines)
385
  return unreduced_abc_text
386
 
387
 
388
-
389
-
390
  if __name__ == '__main__':
391
-
392
  inference_patch('Classical', 'Beethoven, Ludwig van', 'Orchestral')
393
 
 
42
 
43
  model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config).to(device)
44
 
45
+
46
  def download_model_weights():
47
  weights_path = "weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth"
48
  local_weights_path = os.path.join(os.getcwd(), weights_path)
49
+
50
  # Check if weights already exist locally
51
  if os.path.exists(local_weights_path):
52
  logger.info(f"Model weights already exist at {local_weights_path}")
53
  return local_weights_path
54
+
55
  logger.info("Downloading model weights from HuggingFace Hub...")
56
  try:
57
  # Download from HuggingFace
 
93
 
94
  model = prepare_model_for_kbit_training(
95
  model,
96
+ use_gradient_checkpointing=False
97
  )
98
 
99
  print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
 
108
 
109
 
110
  def postprocess_inst_names(abc_text):
 
111
  with open('standard_inst_names.txt', 'r', encoding='utf-8') as f:
112
  standard_instruments_list = [line.strip() for line in f if line.strip()]
113
 
114
  with open('instrument_mapping.json', 'r', encoding='utf-8') as f:
115
  instrument_mapping = json.load(f)
116
+
117
  abc_lines = abc_text.split('\n')
118
  abc_lines = list(filter(None, abc_lines))
119
  abc_lines = [line + '\n' for line in abc_lines]
 
123
  match = re.search(r'nm="([^"]*)"', line)
124
  if match:
125
  inst_name = match.group(1)
126
+
127
  # Check if the instrument name is already standard
128
  if inst_name in standard_instruments_list:
129
  continue
130
+
131
  # Find the most similar key in instrument_mapping
132
  matching_key = difflib.get_close_matches(inst_name, list(instrument_mapping.keys()), n=1, cutoff=0.6)
133
+
134
  if matching_key:
135
  # Replace the instrument name with the standardized version
136
  replacement = instrument_mapping[matching_key[0]]
137
  new_line = line.replace(f'nm="{inst_name}"', f'nm="{replacement}"')
138
  abc_lines[i] = new_line
139
+
140
  # Combine the lines back into a single string
141
  processed_abc_text = ''.join(abc_lines)
142
  return processed_abc_text
 
145
  def complete_brackets(s):
146
  stack = []
147
  bracket_map = {'{': '}', '[': ']', '(': ')'}
148
+
149
  # Iterate through each character, handle bracket matching
150
  for char in s:
151
  if char in bracket_map:
 
157
  if stack and stack[-1] == key:
158
  stack.pop()
159
  break # Found matching right bracket, process next character
160
+
161
  # Complete missing right brackets (in reverse order of remaining left brackets in stack)
162
  completion = ''.join(bracket_map[c] for c in reversed(stack))
163
  return s + completion
164
 
165
 
 
166
  def rest_unreduce(abc_lines):
 
167
  tunebody_index = None
168
  for i in range(len(abc_lines)):
169
  if abc_lines[i].startswith('%%score'):
 
213
  line_bar_dict[key] = value
214
 
215
  # calculate duration and collect barline
216
+ dur_dict = {}
217
  for symbol, bartext in line_bar_dict.items():
218
  right_barline = ''.join(re.split(Barline_regexPattern, bartext)[-2:])
219
  bartext = bartext[:-len(right_barline)]
 
230
  try:
231
  ref_dur = max(dur_dict, key=dur_dict.get)
232
  except:
233
+ pass # use last ref_dur
234
 
235
  if i == 0:
236
  prefix_left_barline = line.split('[V:')[0]
 
254
  return unreduced_lines
255
 
256
 
 
 
 
 
257
  def inference_patch(period, composer, instrumentation):
258
+ prompt_lines = [
259
+ '%' + period + '\n',
260
+ '%' + composer + '\n',
261
+ '%' + instrumentation + '\n']
 
262
 
263
  while True:
264
 
 
287
  tunebody_flag = False
288
 
289
  with torch.inference_mode():
290
+
291
  while True:
292
  with torch.autocast(device_type='cuda', dtype=torch.float16):
293
  predicted_patch = model.generate(input_patches.unsqueeze(0),
294
+ top_k=TOP_K,
295
+ top_p=TOP_P,
296
+ temperature=TEMPERATURE)
297
+ if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith(
298
+ '[r:'): # εˆζ¬‘θΏ›ε…₯tunebodyοΌŒεΏ…ι‘»δ»₯[r:0/开倴
299
  tunebody_flag = True
300
  r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device)
301
  temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1)
302
  predicted_patch = model.generate(temp_input_patches.unsqueeze(0),
303
+ top_k=TOP_K,
304
+ top_p=TOP_P,
305
+ temperature=TEMPERATURE)
306
  predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch
307
  if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id:
308
  end_flag = True
 
330
  if len(byte_list) > 102400:
331
  failure_flag = True
332
  break
333
+ if time.time() - start_time > 10 * 60:
334
  failure_flag = True
335
  break
336
 
 
341
  context_tunebody = ''.join(context_tunebody_byte_list)
342
 
343
  if '\n' not in context_tunebody:
344
+ break # Generated content is all metadata, abandon
345
+
346
+ context_tunebody_lines = context_tunebody.strip().split('\n')
347
 
 
348
  if not context_tunebody.endswith('\n'):
349
+ context_tunebody_lines = [context_tunebody_lines[i] + '\n' for i in
350
+ range(len(context_tunebody_lines) - 1)] + [context_tunebody_lines[-1]]
351
  else:
352
+ context_tunebody_lines = [context_tunebody_lines[i] + '\n' for i in
353
+ range(len(context_tunebody_lines))]
354
 
355
+ cut_index = len(context_tunebody_lines) // 2
356
+ abc_code_slice = metadata + ''.join(context_tunebody_lines[-cut_index:])
357
 
358
  input_patches = patchilizer.encode_generate(abc_code_slice)
359
 
 
376
  failure_flag = True
377
  pass
378
  else:
379
+ unreduced_abc_lines = [line for line in unreduced_abc_lines if
380
+ not (line.startswith('%') and not line.startswith('%%'))]
381
  unreduced_abc_lines = ['X:1\n'] + unreduced_abc_lines
382
  unreduced_abc_text = ''.join(unreduced_abc_lines)
383
  return unreduced_abc_text
384
 
385
 
 
 
386
  if __name__ == '__main__':
 
387
  inference_patch('Classical', 'Beethoven, Ludwig van', 'Orchestral')
388