Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM
|
3 |
from datasets import load_dataset
|
4 |
import torch
|
5 |
import numpy as np
|
@@ -16,6 +16,11 @@ ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
16 |
ocr_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=ocr_dtype, trust_remote_code=True).to(ocr_device)
|
17 |
ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
|
18 |
|
|
|
|
|
|
|
|
|
|
|
19 |
# Load speaker embedding
|
20 |
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
|
21 |
speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
|
@@ -43,15 +48,21 @@ def process_image(image):
|
|
43 |
)
|
44 |
extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
# Prepare audio data
|
47 |
audio = np.array(speech["audio"])
|
48 |
rate = speech["sampling_rate"]
|
49 |
|
50 |
-
# Return audio, caption, and
|
51 |
-
return (rate, audio), caption, extracted_text
|
52 |
|
53 |
except Exception as e:
|
54 |
-
return None, f"Error: {str(e)}", ""
|
55 |
|
56 |
|
57 |
# Gradio Interface
|
@@ -61,10 +72,12 @@ iface = gr.Interface(
|
|
61 |
outputs=[
|
62 |
gr.Audio(label="Generated Audio"),
|
63 |
gr.Textbox(label="Generated Caption"),
|
64 |
-
gr.Textbox(label="Extracted Text (OCR)")
|
|
|
65 |
],
|
66 |
-
title="SeeSay
|
67 |
-
description="Upload an image to generate a caption,
|
68 |
)
|
69 |
|
70 |
iface.launch()
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, AutoTokenizer
|
3 |
from datasets import load_dataset
|
4 |
import torch
|
5 |
import numpy as np
|
|
|
16 |
ocr_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=ocr_dtype, trust_remote_code=True).to(ocr_device)
|
17 |
ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
|
18 |
|
19 |
+
# Load TxGemma model for text generation
|
20 |
+
text_gen_device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
21 |
+
text_gen_tokenizer = AutoTokenizer.from_pretrained("google/txgemma-9b-predict")
|
22 |
+
text_gen_model = AutoModelForCausalLM.from_pretrained("google/txgemma-9b-predict", device_map="auto")
|
23 |
+
|
24 |
# Load speaker embedding
|
25 |
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
|
26 |
speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
|
|
|
48 |
)
|
49 |
extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
50 |
|
51 |
+
# Generate context from caption and extracted text using TxGemma
|
52 |
+
prompt = f"Instructions: Determine the context of the image based on the caption and extracted text.\nCaption: {caption}\nExtracted Text: {extracted_text}\nContext:"
|
53 |
+
input_ids = text_gen_tokenizer(prompt, return_tensors="pt").to(text_gen_device)
|
54 |
+
outputs = text_gen_model.generate(**input_ids, max_new_tokens=50)
|
55 |
+
context = text_gen_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
56 |
+
|
57 |
# Prepare audio data
|
58 |
audio = np.array(speech["audio"])
|
59 |
rate = speech["sampling_rate"]
|
60 |
|
61 |
+
# Return audio, caption, extracted text, and generated context
|
62 |
+
return (rate, audio), caption, extracted_text, context
|
63 |
|
64 |
except Exception as e:
|
65 |
+
return None, f"Error: {str(e)}", "", ""
|
66 |
|
67 |
|
68 |
# Gradio Interface
|
|
|
72 |
outputs=[
|
73 |
gr.Audio(label="Generated Audio"),
|
74 |
gr.Textbox(label="Generated Caption"),
|
75 |
+
gr.Textbox(label="Extracted Text (OCR)"),
|
76 |
+
gr.Textbox(label="Generated Context")
|
77 |
],
|
78 |
+
title="SeeSay Contextualizer",
|
79 |
+
description="Upload an image to generate a caption, extract text, create audio, and determine the context using TxGemma."
|
80 |
)
|
81 |
|
82 |
iface.launch()
|
83 |
+
|