MohamedRashad commited on
Commit
ec99653
·
verified ·
1 Parent(s): 375c73a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -25
app.py CHANGED
@@ -7,7 +7,6 @@ from huggingface_hub import snapshot_download
7
  from dotenv import load_dotenv
8
  load_dotenv()
9
 
10
-
11
  # Check if CUDA is available
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
@@ -59,23 +58,6 @@ def process_prompt(prompt, voice, tokenizer, device):
59
 
60
  return modified_input_ids.to(device), attention_mask.to(device)
61
 
62
- # Generate speech tokens
63
- @spaces.GPU()
64
- def generate_speech_tokens(input_ids, attention_mask, model, params):
65
- with torch.no_grad():
66
- generated_ids = model.generate(
67
- input_ids=input_ids,
68
- attention_mask=attention_mask,
69
- max_new_tokens=params["max_new_tokens"],
70
- do_sample=True,
71
- temperature=params["temperature"],
72
- top_p=params["top_p"],
73
- repetition_penalty=params["repetition_penalty"],
74
- num_return_sequences=1,
75
- eos_token_id=128258,
76
- )
77
- return generated_ids
78
-
79
  # Parse output tokens to audio
80
  def parse_output(generated_ids):
81
  token_to_find = 128257
@@ -131,6 +113,7 @@ def redistribute_codes(code_list, snac_model):
131
  return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
132
 
133
  # Main generation function
 
134
  def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
135
  if not text.strip():
136
  return None
@@ -140,13 +123,18 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
140
  input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
141
 
142
  progress(0.3, "Generating speech tokens...")
143
- params = {
144
- "temperature": temperature,
145
- "top_p": top_p,
146
- "repetition_penalty": repetition_penalty,
147
- "max_new_tokens": max_new_tokens
148
- }
149
- generated_ids = generate_speech_tokens(input_ids, attention_mask, model, params)
 
 
 
 
 
150
 
151
  progress(0.6, "Processing speech tokens...")
152
  code_list = parse_output(generated_ids)
 
7
  from dotenv import load_dotenv
8
  load_dotenv()
9
 
 
10
  # Check if CUDA is available
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
 
58
 
59
  return modified_input_ids.to(device), attention_mask.to(device)
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # Parse output tokens to audio
62
  def parse_output(generated_ids):
63
  token_to_find = 128257
 
113
  return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
114
 
115
  # Main generation function
116
+ @spaces.GPU()
117
  def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
118
  if not text.strip():
119
  return None
 
123
  input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
124
 
125
  progress(0.3, "Generating speech tokens...")
126
+ with torch.no_grad():
127
+ generated_ids = model.generate(
128
+ input_ids=input_ids,
129
+ attention_mask=attention_mask,
130
+ max_new_tokens=max_new_tokens,
131
+ do_sample=True,
132
+ temperature=temperature,
133
+ top_p=top_p,
134
+ repetition_penalty=repetition_penalty,
135
+ num_return_sequences=1,
136
+ eos_token_id=128258,
137
+ )
138
 
139
  progress(0.6, "Processing speech tokens...")
140
  code_list = parse_output(generated_ids)