asigalov61 commited on
Commit
9cd9938
·
verified ·
1 Parent(s): 0403a1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -5
app.py CHANGED
@@ -156,7 +156,7 @@ print('Loading karaoke words list and dict...')
156
  kar_words_list_dict_pickle = hf_hub_download(repo_id='asigalov61/Karaoke-Transformer', filename='all_words_list_dict.pickle')
157
 
158
  with open(kar_words_list_dict_pickle, 'rb') as f:
159
- kar_words_list_dict = pickle.load(f)
160
 
161
  print('Done!')
162
  print('=' * 70)
@@ -213,10 +213,6 @@ def Generate_Karaoke(input_lyrics,
213
  start_time = reqtime.time()
214
  print('=' * 70)
215
 
216
- fn = os.path.basename(input_midi)
217
- fn1 = fn.split('.')[0]
218
-
219
- print('=' * 70)
220
  print('Requested settings:')
221
  print('=' * 70)
222
  print('Input lyrics:', input_lyrics)
@@ -229,6 +225,8 @@ def Generate_Karaoke(input_lyrics,
229
  print('=' * 70)
230
  print('Generating...')
231
 
 
 
232
  kar_model.to(device_type)
233
  kar_model.eval()
234
 
@@ -237,6 +235,34 @@ def Generate_Karaoke(input_lyrics,
237
 
238
  #==================================================================
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  start_score_seq = [1792] + score + [1793]
241
 
242
  #==================================================================
 
156
  kar_words_list_dict_pickle = hf_hub_download(repo_id='asigalov61/Karaoke-Transformer', filename='all_words_list_dict.pickle')
157
 
158
  with open(kar_words_list_dict_pickle, 'rb') as f:
159
+ all_words_list, all_words_dict = pickle.load(f)
160
 
161
  print('Done!')
162
  print('=' * 70)
 
213
  start_time = reqtime.time()
214
  print('=' * 70)
215
 
 
 
 
 
216
  print('Requested settings:')
217
  print('=' * 70)
218
  print('Input lyrics:', input_lyrics)
 
225
  print('=' * 70)
226
  print('Generating...')
227
 
228
+ #==================================================================
229
+
230
  kar_model.to(device_type)
231
  kar_model.eval()
232
 
 
235
 
236
  #==================================================================
237
 
238
+ x = torch.LongTensor([20384]).cuda()
239
+
240
+ with ctx:
241
+ out = kar_model.generate(x,
242
+ 1024,
243
+ temperature=0.85,
244
+ filter_logits_fn=top_p,
245
+ filter_kwargs={'thres': 0.96},
246
+ return_prime=True,
247
+ eos_token=20386,
248
+ verbose=True)
249
+
250
+ y = out.tolist()
251
+
252
+ #==================================================================
253
+
254
+ decoded_lyrics = []
255
+
256
+ for tok in y[0]:
257
+ if 383 < tok < 20384:
258
+ decoded_lyrics.append(all_words_list[tok-384])
259
+
260
+ #==================================================================
261
+
262
+ score = [t for t in y[0] if t < 384]
263
+
264
+ #==================================================================
265
+
266
  start_score_seq = [1792] + score + [1793]
267
 
268
  #==================================================================