Gapeleon commited on
Commit
2c4caee
·
verified ·
1 Parent(s): 346fcfd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -0
app.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, BitsAndBytesConfig
4
+ import gradio as gr
5
+ import os
6
+ import time
7
+
8
+ # --- Configuration ---
9
+ model_name = "ibm-granite/granite-speech-3.2-8b"
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ print(f"Using device: {device}")
12
+
13
+ # --- Load Model and Processor (runs only once on startup) ---
14
+ print("Loading processor...")
15
+ speech_granite_processor = AutoProcessor.from_pretrained(
16
+ model_name, trust_remote_code=True)
17
+ tokenizer = speech_granite_processor.tokenizer
18
+ print("Processor loaded.")
19
+
20
+ print("Configuring quantization...")
21
+ quantization_config = BitsAndBytesConfig(
22
+ load_in_4bit=True,
23
+ bnb_4bit_compute_dtype=torch.float16,
24
+ bnb_4bit_quant_type="nf4", # TODO: Try fp4 as an alternative.
25
+ bnb_4bit_use_double_quant=True
26
+ )
27
+ print("Quantization configured.")
28
+
29
+ print("Loading model...")
30
+ speech_granite = AutoModelForSpeechSeq2Seq.from_pretrained(
31
+ model_name,
32
+ quantization_config=quantization_config,
33
+ device_map="auto",
34
+ trust_remote_code=True
35
+ )
36
+ speech_granite.eval()
37
+ print("Model loaded.")
38
+
39
+ # --- Core Transcription Function ---
40
+ def transcribe_audio(audio_input):
41
+ """
42
+ Transcribes audio using the loaded Granite model.
43
+
44
+ Args:
45
+ audio_input (tuple or str): Audio data from Gradio.
46
+ If from microphone: A tuple (sample_rate, numpy_array).
47
+ If from file upload: A string filepath.
48
+
49
+ Returns:
50
+ str: The transcribed text.
51
+ float: Processing time in seconds.
52
+ """
53
+ start_time = time.time()
54
+
55
+ if audio_input is None:
56
+ return "Error: No audio provided.", 0.0
57
+
58
+ print(f"Received audio input type: {type(audio_input)}")
59
+
60
+ # --- Load and Preprocess Audio ---
61
+ try:
62
+ if isinstance(audio_input, str): # File upload
63
+ audio_path = audio_input
64
+ wav, sr = torchaudio.load(audio_path, normalize=True)
65
+ print(f"Loaded from file: {audio_path}")
66
+ elif isinstance(audio_input, tuple): # Microphone input
67
+ sr, wav_np = audio_input
68
+ wav = torch.from_numpy(wav_np).float().unsqueeze(0) # Convert numpy to tensor [1, N]
69
+ # Normalize microphone input (assuming it's not normalized)
70
+ wav = wav / torch.max(torch.abs(wav))
71
+ print(f"Loaded from microphone input. Sample rate: {sr}, Shape: {wav.shape}")
72
+ else:
73
+ return f"Error: Unsupported audio input type: {type(audio_input)}.", 0.0
74
+
75
+ print(f"Original sample rate: {sr}, Channels: {wav.shape[0] if wav.dim() > 1 else 1}")
76
+
77
+ # Convert to mono if stereo
78
+ if wav.dim() > 1 and wav.shape[0] > 1:
79
+ wav = torch.mean(wav, dim=0, keepdim=True)
80
+ print("Converted stereo to mono")
81
+
82
+ # Ensure it's 2D [1, N]
83
+ if wav.dim() == 1:
84
+ wav = wav.unsqueeze(0)
85
+
86
+ # Resample to 16kHz if necessary
87
+ if sr != 16000:
88
+ print(f"Resampling from {sr}Hz to 16000Hz...")
89
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
90
+ wav = resampler(wav)
91
+ sr = 16000
92
+ print(f"Resampled to {sr}Hz")
93
+
94
+ print(f"Final audio: sample rate {sr}Hz, shape {wav.shape}")
95
+ assert wav.shape[0] == 1 and sr == 16000, "Audio preprocessing failed"
96
+
97
+ except Exception as e:
98
+ print(f"Error during audio loading/processing: {e}")
99
+ return f"Error processing audio: {e}", 0.0
100
+
101
+ # --- Prepare Prompt ---
102
+ chat = [
103
+ {
104
+ "role": "system",
105
+ "content": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant",
106
+ },
107
+ {
108
+ "role": "user",
109
+ "content": "<|audio|>can you transcribe the speech into a written format?",
110
+ }
111
+ ]
112
+ text = tokenizer.apply_chat_template(
113
+ chat, tokenize=False, add_generation_prompt=True
114
+ )
115
+
116
+ # --- Process and Generate ---
117
+ try:
118
+ print("Processing inputs...")
119
+ # Send audio tensor (wav) directly, not the filepath
120
+ model_inputs = speech_granite_processor(
121
+ text,
122
+ audios=wav.squeeze(0).numpy(),
123
+ sampling_rate=sr,
124
+ device=device, # Compute embeddings on target device
125
+ return_tensors="pt",
126
+ ).to(device) # Move tensors to target device (GPU/CPU)
127
+ print("Inputs processed.")
128
+
129
+
130
+ print("Generating transcription...")
131
+ # Generate on the same device as the model
132
+ model_outputs = speech_granite.generate(
133
+ **model_inputs,
134
+ max_new_tokens=1000,
135
+ num_beams=4,
136
+ do_sample=False,
137
+ min_length=1,
138
+ top_p=1.0,
139
+ repetition_penalty=3.0,
140
+ length_penalty=1.0,
141
+ temperature=1.0,
142
+ bos_token_id=tokenizer.bos_token_id,
143
+ eos_token_id=tokenizer.eos_token_id,
144
+ pad_token_id=tokenizer.pad_token_id,
145
+ )
146
+ print("Generation complete.")
147
+
148
+ # --- Decode Output ---
149
+ num_input_tokens = model_inputs["input_ids"].shape[-1]
150
+ # Ensure output tensor is on CPU for decoding if necessary
151
+ new_tokens = model_outputs[0, num_input_tokens:].cpu() # Move to CPU before decoding
152
+
153
+ output_text = tokenizer.batch_decode(
154
+ [new_tokens], # Wrap in a list for batch_decode
155
+ add_special_tokens=False,
156
+ skip_special_tokens=True
157
+ )
158
+ transcription = output_text[0].strip().upper() # Get first item, strip whitespace, uppercase
159
+ print(f"Raw output: {output_text[0]}")
160
+ print(f"Final Transcription: {transcription}")
161
+
162
+ except Exception as e:
163
+ print(f"Error during generation/decoding: {e}")
164
+ import traceback
165
+ traceback.print_exc() # Print full traceback for debugging
166
+ return f"Error during transcription: {e}", 0.0
167
+
168
+ end_time = time.time()
169
+ processing_time = round(end_time - start_time, 2)
170
+ print(f"Processing time: {processing_time} seconds")
171
+
172
+ # Clean up temporary file if it was created by Gradio upload
173
+ # NOTE: Gradio typically handles cleanup, but belt-and-suspenders approach
174
+ if isinstance(audio_input, str) and os.path.exists(audio_input):
175
+ try:
176
+ # Check if it looks like a temp file before deleting
177
+ if "gradio" in audio_input or "tmp" in audio_input:
178
+ # os.remove(audio_input) # Be cautious enabling this
179
+ print(f"Skipping deletion of temp file: {audio_input}")
180
+ pass
181
+ except OSError as e:
182
+ print(f"Warning: Could not delete temp file {audio_input}: {e}")
183
+
184
+
185
+ return transcription, processing_time
186
+
187
+ # --- Gradio Interface Definition ---
188
+ # Download example files (replace with actual URLs if needed, or use local paths if packaged)
189
+ # Example using librispeech sample from HF datasets
190
+ try:
191
+ from datasets import load_dataset
192
+ ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
193
+ example_audio_path = ds[0]["file"] # Use the path directly if possible
194
+ example_list = [[example_audio_path]] # Gradio expects list of lists for examples
195
+ except Exception as e:
196
+ print(f"Could not load example dataset: {e}. Examples will be empty.")
197
+ example_list = []
198
+
199
+
200
+ title = "IBM Granite Speech-to-Text (8B Quantized)"
201
+ description = """
202
+ Transcribe speech audio using the `ibm-granite/granite-speech-3.2-8b` model (4-bit quantized).
203
+ Upload an audio file or use your microphone. The model expects **English** speech.
204
+ Processing might take some time depending on the audio length and hardware (especially on CPU or less powerful GPUs).
205
+ """
206
+
207
+ # Define inputs and outputs
208
+ audio_in = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Input Audio") # Use filepath for torchaudio
209
+ text_out = gr.Textbox(label="Transcription", lines=5)
210
+ time_out = gr.Number(label="Processing Time (s)")
211
+
212
+ # Create and launch the interface
213
+ iface = gr.Interface(
214
+ fn=transcribe_audio,
215
+ inputs=audio_in,
216
+ outputs=[text_out, time_out],
217
+ title=title,
218
+ description=description,
219
+ examples=example_list,
220
+ cache_examples=False # Disable caching if examples change or have issues
221
+ )
222
+
223
+ if __name__ == "__main__":
224
+ iface.launch(debug=True) # Set debug=True for more detailed logs locally