raygiles3 commited on
Commit
cb301f9
·
verified ·
1 Parent(s): bb7420f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer
4
  from huggingface_hub import login
5
  import os
6
 
@@ -18,7 +18,7 @@ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
18
  whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
19
 
20
  # Initialize the summarization model and tokenizer
21
- summarization_model = AutoModelForSeq2SeqLM.from_pretrained("meta-llama/Llama-2-7b-hf")
22
  summarization_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
23
 
24
  # Function to transcribe audio
@@ -32,8 +32,8 @@ def transcribe_audio(audio_file):
32
 
33
  # Function to summarize text
34
  def summarize_text(text):
35
- inputs = summarization_tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=512, truncation=True)
36
- summary_ids = summarization_model.generate(inputs, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
37
  summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
38
  return summary
39
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import login
5
  import os
6
 
 
18
  whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
19
 
20
  # Initialize the summarization model and tokenizer
21
+ summarization_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
22
  summarization_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
23
 
24
  # Function to transcribe audio
 
32
 
33
  # Function to summarize text
34
  def summarize_text(text):
35
+ inputs = summarization_tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
36
+ summary_ids = summarization_model.generate(inputs.input_ids, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
37
  summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
38
  return summary
39