joey1101 commited on
Commit
105a0a4
·
verified ·
1 Parent(s): 152d61c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -145,8 +145,8 @@ def response_gen(user_review):
145
  inputs = tokenizer(prompt, return_tensors="pt") # Tokenize the input prompt
146
  outputs = model.generate(
147
  **inputs,
148
- max_new_tokens=300,
149
- min_length=75, # Ensure concise and complete responses
150
  no_repeat_ngram_size=2, # Avoid repetitive phrases
151
  temperature=0.7 # Add randomness for more natural responses
152
  )
@@ -168,9 +168,19 @@ def sound_gen(response):
168
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") # Load speaker embeddings
169
  speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) # Use a default embedding
170
 
171
- inputs = processor(text=response, return_tensors="pt") # Process text for spectrogram generation
172
- inputs["input_ids"] = inputs["input_ids"].to(torch.int32) # Match tensor format (fix runtime error)
173
- spectrogram = model.generate_speech(inputs["input_ids"], speaker_embeddings) # Generate the spectrogram
 
 
 
 
 
 
 
 
 
 
174
 
175
  with torch.no_grad():
176
  speech = vocoder(spectrogram) # Convert spectrogram to waveform
 
145
  inputs = tokenizer(prompt, return_tensors="pt") # Tokenize the input prompt
146
  outputs = model.generate(
147
  **inputs,
148
+ max_new_tokens=200,
149
+ min_length=50, # Ensure concise and complete responses
150
  no_repeat_ngram_size=2, # Avoid repetitive phrases
151
  temperature=0.7 # Add randomness for more natural responses
152
  )
 
168
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") # Load speaker embeddings
169
  speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) # Use a default embedding
170
 
171
+ # Ensure the response is not too long for the model's capacity
172
+ max_tokens = 300 # Limit the input text tokens to a maximum of 300
173
+ truncated_response = response[:max_tokens] # Truncate the response to fit within the limit
174
+
175
+ # Process the truncated text for spectrogram generation
176
+ inputs = processor(text=truncated_response, return_tensors="pt")# Process text for spectrogram generation
177
+ inputs_embeds_size = inputs["input_ids"].size(1)
178
+
179
+ # Ensure tensor dimensions align between input IDs and speaker embeddings
180
+ speaker_embeddings = speaker_embeddings[:, :inputs_embeds_size] # Match dimensions with input IDs
181
+
182
+ # Generate the spectrogram using the SpeechT5 model
183
+ spectrogram = model.generate_speech(inputs["input_ids"], speaker_embeddings) # Generate the spectrogram
184
 
185
  with torch.no_grad():
186
  speech = vocoder(spectrogram) # Convert spectrogram to waveform