GavinHuang commited on
Commit
18b21ee
Β·
1 Parent(s): e888ead

add actual code

Browse files
Files changed (2) hide show
  1. app.py +111 -8
  2. requirements.txt +6 -0
app.py CHANGED
@@ -1,14 +1,117 @@
 
1
  import gradio as gr
2
- import spaces
3
  import torch
 
 
 
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' πŸ€”
 
 
7
 
 
8
  @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0' πŸ€—
11
- return f"Hello {zero + n} Tensor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
 
 
1
+ import os
2
  import gradio as gr
 
3
  import torch
4
+ import spaces
5
+ import nemo.collections.asr as nemo_asr
6
+ from omegaconf import OmegaConf
7
+ import time
8
 
9
+ # Check if CUDA is available
10
+ print(f"CUDA available: {torch.cuda.is_available()}")
11
+ if torch.cuda.is_available():
12
+ print(f"CUDA device: {torch.cuda.get_device_name(0)}")
13
 
14
+ # Initialize the ASR model
15
  @spaces.GPU
16
+ def load_model():
17
+ print("Loading ASR model...")
18
+ # Load the NVIDIA NeMo ASR model
19
+ model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
20
+ # Move model to GPU if available
21
+ if torch.cuda.is_available():
22
+ model = model.cuda()
23
+ print(f"Model loaded on device: {model.device}")
24
+ return model
25
+
26
+ # Global variable to store the model
27
+ model = load_model()
28
+
29
+ def transcribe(audio, state=""):
30
+ """
31
+ Transcribe audio in real-time
32
+ """
33
+ # Skip processing if no audio is provided
34
+ if audio is None:
35
+ return state, state
36
+
37
+ # Get the sample rate from the audio
38
+ sample_rate = 16000 # Default to 16kHz if not specified
39
+
40
+ # Process the audio with the ASR model
41
+ with torch.no_grad():
42
+ transcription = model.transcribe([audio])[0]
43
+
44
+ # Append new transcription to the state
45
+ if state == "":
46
+ new_state = transcription
47
+ else:
48
+ new_state = state + " " + transcription
49
+
50
+ return new_state, new_state
51
+
52
+ # Define the Gradio interface
53
+ with gr.Blocks(title="Real-time Speech-to-Text with NeMo") as demo:
54
+ gr.Markdown("# πŸŽ™οΈ Real-time Speech-to-Text Transcription")
55
+ gr.Markdown("Powered by NVIDIA NeMo and the parakeet-tdt-0.6b-v2 model")
56
+
57
+ with gr.Row():
58
+ with gr.Column(scale=2):
59
+ audio_input = gr.Audio(
60
+ source="microphone",
61
+ type="numpy",
62
+ streaming=True,
63
+ label="Speak into your microphone"
64
+ )
65
+
66
+ clear_btn = gr.Button("Clear Transcript")
67
+
68
+ with gr.Column(scale=3):
69
+ text_output = gr.Textbox(
70
+ label="Transcription",
71
+ placeholder="Your speech will appear here...",
72
+ lines=10
73
+ )
74
+ streaming_text = gr.Textbox(
75
+ label="Real-time Transcription",
76
+ placeholder="Real-time results will appear here...",
77
+ lines=2
78
+ )
79
+
80
+ # State to store the ongoing transcription
81
+ state = gr.State("")
82
+
83
+ # Handle the audio stream
84
+ audio_input.stream(
85
+ fn=transcribe,
86
+ inputs=[audio_input, state],
87
+ outputs=[state, streaming_text],
88
+ )
89
+
90
+ # Clear the transcription
91
+ def clear_transcription():
92
+ return "", "", ""
93
+
94
+ clear_btn.click(
95
+ fn=clear_transcription,
96
+ inputs=[],
97
+ outputs=[text_output, streaming_text, state]
98
+ )
99
+
100
+ # Update the main text output when the state changes
101
+ state.change(
102
+ fn=lambda s: s,
103
+ inputs=[state],
104
+ outputs=[text_output]
105
+ )
106
+
107
+ gr.Markdown("## πŸ“ Instructions")
108
+ gr.Markdown("""
109
+ 1. Click the microphone button to start recording
110
+ 2. Speak clearly into your microphone
111
+ 3. The transcription will appear in real-time
112
+ 4. Click 'Clear Transcript' to start a new transcription
113
+ """)
114
 
115
+ # Launch the app
116
+ if __name__ == "__main__":
117
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=1.13.0
2
+ gradio>=3.32.0
3
+ nemo_toolkit[asr]>=1.18.0
4
+ omegaconf>=2.2.0
5
+ spaces>=0.15.0
6
+ numpy>=1.22.0