jytole commited on
Commit
e8233e7
·
1 Parent(s): ca5e437

Added multi-waveform generation option

Browse files
Files changed (1) hide show
  1. app.py +20 -2
app.py CHANGED
@@ -9,13 +9,16 @@ pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float32)
9
  #pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
10
  pipe = pipe.to("cpu")
11
 
 
 
 
12
  generator = torch.Generator("cpu")
13
 
14
  def texttoaudio(prompt, neg_prompt, seed, inf_steps, guidance_scale):
15
  if prompt is None:
16
  raise gr.Error("Please provide a text input.")
17
 
18
- audio = pipe(
19
  prompt,
20
  negative_prompt=neg_prompt,
21
  num_inference_steps=int(inf_steps),
@@ -26,7 +29,22 @@ def texttoaudio(prompt, neg_prompt, seed, inf_steps, guidance_scale):
26
 
27
  # save the audio sample as a .wav file
28
  # scipy.io.wavfile.write("output.wav", rate=16000, data=audio)
29
- return (16000, audio[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  iface = gr.Interface(fn=texttoaudio, title="Prompt, Neg Prompt, Seed, Inf Steps, Guidance Scale", inputs=["text", "text", "number", "number", "number"], outputs="audio")
32
 
 
9
  #pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
10
  pipe = pipe.to("cpu")
11
 
12
+ clap_model = ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full").to(device)
13
+ processor = AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full")
14
+
15
  generator = torch.Generator("cpu")
16
 
17
  def texttoaudio(prompt, neg_prompt, seed, inf_steps, guidance_scale):
18
  if prompt is None:
19
  raise gr.Error("Please provide a text input.")
20
 
21
+ waveforms = pipe(
22
  prompt,
23
  negative_prompt=neg_prompt,
24
  num_inference_steps=int(inf_steps),
 
29
 
30
  # save the audio sample as a .wav file
31
  # scipy.io.wavfile.write("output.wav", rate=16000, data=audio)
32
+ if waveforms.shape[0] > 1:
33
+ waveform = score_waveforms(text, waveforms)
34
+ else:
35
+ waveform = waveforms[0]
36
+
37
+ return (16000, waveform)
38
+
39
+ def score_waveforms(text, waveforms):
40
+ inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True)
41
+ inputs = {key: inputs[key].to(device) for key in inputs}
42
+ with torch.no_grad():
43
+ logits_per_text = clap_model(**inputs).logits_per_text # this is the audio-text similarity score
44
+ probs = logits_per_text.softmax(dim=-1) # we can take the softmax to get the label probabilities
45
+ most_probable = torch.argmax(probs) # and now select the most likely audio waveform
46
+ waveform = waveforms[most_probable]
47
+ return waveform
48
 
49
  iface = gr.Interface(fn=texttoaudio, title="Prompt, Neg Prompt, Seed, Inf Steps, Guidance Scale", inputs=["text", "text", "number", "number", "number"], outputs="audio")
50