preston-cell commited on
Commit
e6ea13d
·
verified ·
1 Parent(s): 6734c03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -45
app.py CHANGED
@@ -1,42 +1,38 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoProcessor, AutoModelForCausalLM
 
3
  import torch
4
  import numpy as np
5
- from datasets import load_dataset
6
- from PIL import Image
7
 
8
- # 1) IMAGE CAPTIONING MODEL
 
 
 
9
  caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
10
 
11
- # 2) OCR MODEL (Florence-2)
 
 
 
12
  ocr_device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
  ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
14
- ocr_model = AutoModelForCausalLM.from_pretrained(
15
- "microsoft/Florence-2-large",
16
- torch_dtype=ocr_dtype,
17
- trust_remote_code=True
18
- ).to(ocr_device)
19
  ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
20
 
21
- # 3) QUESTION-ANSWERING MODEL
22
- qa_model = pipeline(
23
- "question-answering",
24
- model="timpal0l/mdeberta-v3-base-squad2"
25
- )
26
-
27
- # 4) TEXT-TO-SPEECH MODEL
28
- tts_pipeline = pipeline("text-to-speech", model="microsoft/speecht5_tts")
29
 
30
  # Load speaker embedding
31
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
32
  speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
33
 
 
34
  def process_image(image):
35
  try:
36
- # 1) Generate caption from the image
37
  caption = caption_model(image)[0]['generated_text']
38
 
39
- # 2) Extract text from the image using Florence-2
40
  inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(ocr_device, ocr_dtype)
41
  generated_ids = ocr_model.generate(
42
  input_ids=inputs["input_ids"],
@@ -47,32 +43,28 @@ def process_image(image):
47
  )
48
  extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
49
 
50
- # 3) Use QA model to derive context from caption + extracted text
51
- # We treat the "context" string as the knowledge base and ask a question about it
52
- question = "What is the context of this image?"
53
- combined_context = f"Caption: {caption}\nExtracted Text: {extracted_text}"
54
- qa_result = qa_model(question=question, context=combined_context)
55
 
56
- # The QA model returns an extracted "answer" from the combined context
57
- # If the model can't find a direct span, it may return an empty string or a short phrase
58
- final_context = qa_result["answer"]
59
-
60
- # 4) Convert the final context to speech
61
- speech_data = tts_pipeline(
62
- final_context,
63
  forward_params={"speaker_embeddings": speaker_embedding}
64
  )
65
 
66
- # Prepare audio data for Gradio
67
- audio = np.array(speech_data["audio"])
68
- rate = speech_data["sampling_rate"]
69
 
70
- # Return audio, caption, extracted text, and final context
71
- return (rate, audio), caption, extracted_text, final_context
72
 
73
  except Exception as e:
74
  return None, f"Error: {str(e)}", "", ""
75
 
 
76
  # Gradio Interface
77
  iface = gr.Interface(
78
  fn=process_image,
@@ -81,15 +73,10 @@ iface = gr.Interface(
81
  gr.Audio(label="Generated Audio"),
82
  gr.Textbox(label="Generated Caption"),
83
  gr.Textbox(label="Extracted Text (OCR)"),
84
- gr.Textbox(label="QA-derived Context")
85
  ],
86
- title="Contextual Image QA with SpeechT5",
87
- description=(
88
- "1) Generate a caption via BLIP. "
89
- "2) Extract text using Florence-2. "
90
- "3) Use QA with mDeBERTa to find a 'context' from caption + text. "
91
- "4) Convert it to audio via SpeechT5."
92
- ),
93
  )
94
 
95
  iface.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, AutoTokenizer, set_seed
3
+ from datasets import load_dataset
4
  import torch
5
  import numpy as np
 
 
6
 
7
+ # Set seed for reproducibility
8
+ set_seed(42)
9
+
10
+ # Load BLIP model for image captioning
11
  caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
12
 
13
+ # Load SpeechT5 model for text-to-speech
14
+ synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts")
15
+
16
+ # Load Florence-2 model for OCR
17
  ocr_device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
  ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
19
+ ocr_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=ocr_dtype, trust_remote_code=True).to(ocr_device)
 
 
 
 
20
  ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
21
 
22
+ # Load GPT-2 (124M) model for text generation
23
+ gpt2_generator = pipeline('text-generation', model='gpt2')
 
 
 
 
 
 
24
 
25
  # Load speaker embedding
26
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
27
  speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
28
 
29
+
30
  def process_image(image):
31
  try:
32
+ # Generate caption from the image
33
  caption = caption_model(image)[0]['generated_text']
34
 
35
+ # Extract text (OCR) using Florence-2
36
  inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(ocr_device, ocr_dtype)
37
  generated_ids = ocr_model.generate(
38
  input_ids=inputs["input_ids"],
 
43
  )
44
  extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
45
 
46
+ # Generate context using GPT-2 (124M)
47
+ prompt = f"Determine the context of this image based on the caption and extracted text. Caption: {caption}. Extracted text: {extracted_text}. Context:"
48
+ context_output = gpt2_generator(prompt, max_length=100, num_return_sequences=1)
49
+ context = context_output[0]['generated_text']
 
50
 
51
+ # Convert context to speech
52
+ speech = synthesiser(
53
+ context,
 
 
 
 
54
  forward_params={"speaker_embeddings": speaker_embedding}
55
  )
56
 
57
+ # Prepare audio data
58
+ audio = np.array(speech["audio"])
59
+ rate = speech["sampling_rate"]
60
 
61
+ # Return audio, caption, extracted text, and context
62
+ return (rate, audio), caption, extracted_text, context
63
 
64
  except Exception as e:
65
  return None, f"Error: {str(e)}", "", ""
66
 
67
+
68
  # Gradio Interface
69
  iface = gr.Interface(
70
  fn=process_image,
 
73
  gr.Audio(label="Generated Audio"),
74
  gr.Textbox(label="Generated Caption"),
75
  gr.Textbox(label="Extracted Text (OCR)"),
76
+ gr.Textbox(label="Generated Context")
77
  ],
78
+ title="SeeSay Contextualizer with GPT-2 (124M)",
79
+ description="Upload an image to generate a caption, extract text, create audio from context, and determine the context using GPT-2."
 
 
 
 
 
80
  )
81
 
82
  iface.launch()