SimpleFrog commited on
Commit
f45cdfd
·
verified ·
1 Parent(s): e5d0693

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import tempfile
4
+ import os
5
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
6
+ from peft import PeftModel
7
+
8
+ # Configuration de l'interface Streamlit
9
+ st.title("🔊 Transcription Audio avec Whisper Fine-tuné (LoRA)")
10
+ st.write("Upload un fichier audio et laisse ton modèle fine-tuné faire le travail !")
11
+
12
+ # 🔹 Charger le modèle Whisper Large et appliquer l’adaptateur LoRA
13
+ @st.cache_resource # Permet de ne charger qu'une seule fois le modèle
14
+ def load_model():
15
+ base_model_name = "openai/whisper-large" # Modèle de base
16
+ adapter_model_name = "SimpleFrog/whisper_finetuned" # Adaptateur LoRA
17
+
18
+ # Charger le modèle de base
19
+ model = WhisperForConditionalGeneration.from_pretrained(base_model_name)
20
+
21
+ # Charger l'adaptateur LoRA et l'appliquer au modèle
22
+ model = PeftModel.from_pretrained(model, adapter_model_name)
23
+
24
+ # Charger le processeur audio
25
+ processor = WhisperProcessor.from_pretrained(base_model_name)
26
+
27
+ model.eval() # Mode évaluation
28
+ return processor, model
29
+
30
+ processor, model = load_model()
31
+
32
+ # 🔹 Upload d'un fichier audio
33
+ uploaded_file = st.file_uploader("Upload un fichier audio", type=["mp3", "wav", "m4a"])
34
+
35
+ if uploaded_file is not None:
36
+ # Sauvegarder temporairement l'audio
37
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
38
+ temp_audio.write(uploaded_file.read())
39
+ temp_audio_path = temp_audio.name
40
+
41
+ # Charger et traiter l'audio
42
+ st.write("📄 **Transcription en cours...**")
43
+
44
+ # Charger l'audio
45
+ audio_input = processor(temp_audio_path, return_tensors="pt", sampling_rate=16000)
46
+ input_features = audio_input.input_features
47
+
48
+ # Générer la transcription
49
+ with torch.no_grad():
50
+ predicted_ids = model.generate(input_features)
51
+
52
+ # Décoder la sortie
53
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
54
+
55
+ # Afficher la transcription
56
+ st.subheader("📝 Transcription :")
57
+ st.text_area("", transcription, height=200)
58
+
59
+ # Supprimer le fichier temporaire après l'affichage
60
+ os.remove(temp_audio_path)
61
+
62
+ st.write("🔹 Modèle utilisé :", "Whisper Large + Adaptateur LoRA (SimpleFrog/whisper_finetuned)")