preston-cell commited on
Commit
e9b130c
·
verified ·
1 Parent(s): 6e08443

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoProcessor, AutoModelForCausalLM
3
  from datasets import load_dataset
4
  import torch
5
  import numpy as np
@@ -16,6 +16,11 @@ ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
16
  ocr_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=ocr_dtype, trust_remote_code=True).to(ocr_device)
17
  ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
18
 
 
 
 
 
 
19
  # Load speaker embedding
20
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
21
  speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
@@ -43,15 +48,21 @@ def process_image(image):
43
  )
44
  extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
45
 
 
 
 
 
 
 
46
  # Prepare audio data
47
  audio = np.array(speech["audio"])
48
  rate = speech["sampling_rate"]
49
 
50
- # Return audio, caption, and extracted text
51
- return (rate, audio), caption, extracted_text
52
 
53
  except Exception as e:
54
- return None, f"Error: {str(e)}", ""
55
 
56
 
57
  # Gradio Interface
@@ -61,10 +72,12 @@ iface = gr.Interface(
61
  outputs=[
62
  gr.Audio(label="Generated Audio"),
63
  gr.Textbox(label="Generated Caption"),
64
- gr.Textbox(label="Extracted Text (OCR)")
 
65
  ],
66
- title="SeeSay with SpeechT5 and Florence-2 OCR",
67
- description="Upload an image to generate a caption, hear it described with SpeechT5's speech synthesis, and extract text using Florence-2 OCR."
68
  )
69
 
70
  iface.launch()
 
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, AutoTokenizer
3
  from datasets import load_dataset
4
  import torch
5
  import numpy as np
 
16
  ocr_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=ocr_dtype, trust_remote_code=True).to(ocr_device)
17
  ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
18
 
19
+ # Load TxGemma model for text generation
20
+ text_gen_device = "cuda:0" if torch.cuda.is_available() else "cpu"
21
+ text_gen_tokenizer = AutoTokenizer.from_pretrained("google/txgemma-9b-predict")
22
+ text_gen_model = AutoModelForCausalLM.from_pretrained("google/txgemma-9b-predict", device_map="auto")
23
+
24
  # Load speaker embedding
25
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
26
  speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
 
48
  )
49
  extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
50
 
51
+ # Generate context from caption and extracted text using TxGemma
52
+ prompt = f"Instructions: Determine the context of the image based on the caption and extracted text.\nCaption: {caption}\nExtracted Text: {extracted_text}\nContext:"
53
+ input_ids = text_gen_tokenizer(prompt, return_tensors="pt").to(text_gen_device)
54
+ outputs = text_gen_model.generate(**input_ids, max_new_tokens=50)
55
+ context = text_gen_tokenizer.decode(outputs[0], skip_special_tokens=True)
56
+
57
  # Prepare audio data
58
  audio = np.array(speech["audio"])
59
  rate = speech["sampling_rate"]
60
 
61
+ # Return audio, caption, extracted text, and generated context
62
+ return (rate, audio), caption, extracted_text, context
63
 
64
  except Exception as e:
65
+ return None, f"Error: {str(e)}", "", ""
66
 
67
 
68
  # Gradio Interface
 
72
  outputs=[
73
  gr.Audio(label="Generated Audio"),
74
  gr.Textbox(label="Generated Caption"),
75
+ gr.Textbox(label="Extracted Text (OCR)"),
76
+ gr.Textbox(label="Generated Context")
77
  ],
78
+ title="SeeSay Contextualizer",
79
+ description="Upload an image to generate a caption, extract text, create audio, and determine the context using TxGemma."
80
  )
81
 
82
  iface.launch()
83
+