camanalo1 commited on
Commit
0f4f655
·
verified ·
1 Parent(s): da3b26a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -11
app.py CHANGED
@@ -1,12 +1,50 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
-
4
- pipe = pipeline(task="image-classification",
5
- # model that can do 22k-category classification
6
- model="microsoft/beit-base-patch16-224-pt22k-ft22k")
7
- gr.Interface.from_pipeline(pipe,
8
- title="22k Image Classification",
9
- description="Object Recognition using Microsoft BEIT",
10
- examples = ['wonder_cat.jpg', 'aki_dog.jpg',],
11
- article = "Author: <a href=\"https://huggingface.co/rowel\">Rowel Atienza</a>",
12
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline, VitsTokenizer, VitsModel, set_seed
3
+ import numpy as np
4
+ import torch
5
+ import io
6
+ import soundfile as sf
7
+
8
+ # Initialize ASR pipeline
9
+ transcriber = pipeline("automatic-speech-recognition", model="facebook/s2t-small-librispeech-asr")
10
+
11
+ # Initialize LLM pipeline
12
+ generator = pipeline("text-generation", model="gpt2")
13
+
14
+ # Initialize TTS tokenizer and model
15
+ tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
16
+ model = VitsModel.from_pretrained("facebook/mms-tts-eng")
17
+
18
+ def transcribe_and_generate_audio(audio):
19
+ sr, y = audio
20
+ y = y.astype(np.float32)
21
+ y /= np.max(np.abs(y))
22
+
23
+ # Transcribe audio
24
+ asr_output = transcriber({"sampling_rate": sr, "raw": y})["text"]
25
+
26
+ # Generate text based on ASR output
27
+ generated_text = generator(asr_output, max_length=100, num_return_sequences=1)[0]['generated_text']
28
+
29
+ # Generate audio from text
30
+ inputs = tokenizer(text=generated_text, return_tensors="pt")
31
+ set_seed(555)
32
+ with torch.no_grad():
33
+ outputs = model(**inputs)
34
+ waveform = outputs.waveform[0]
35
+ waveform_path = "output.wav"
36
+ sf.write(waveform_path, waveform.numpy(), 22050, format='wav')
37
+
38
+ return waveform_path
39
+
40
+ # Define Gradio interface
41
+ audio_input = gr.Interface(
42
+ transcribe_and_generate_audio,
43
+ gr.Audio(sources=["microphone"], label="Speak Here"),
44
+ "audio",
45
+ title="ASR -> LLM -> TTS",
46
+ description="Speak into the microphone and hear the generated audio."
47
+ )
48
+
49
+ # Launch the interface
50
+ audio_input.launch()