preston-cell commited on
Commit
f984625
·
verified ·
1 Parent(s): a0edfdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -59
app.py CHANGED
@@ -1,80 +1,60 @@
1
  import gradio as gr
2
- import torch
3
- import numpy as np
4
- from PIL import Image
5
  from transformers import (
6
- pipeline,
7
- AutoModelForCausalLM,
8
- AutoProcessor,
9
- AutoTokenizer,
10
- GenerationConfig,
11
- TextStreamer,
12
  )
13
  from datasets import load_dataset
 
 
14
 
15
- # Use CPU if no GPU is available
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
- print(f"Device set to use {device}")
19
 
20
- # Load image captioning model (BLIP)
21
- caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base", device=device)
22
 
23
- # Load OCR model (Florence-2-base)
24
- ocr_model = AutoModelForCausalLM.from_pretrained(
25
- "microsoft/Florence-2-base", trust_remote_code=True, torch_dtype=dtype
26
- ).to(device)
27
- ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
28
 
29
- # Load SmallDoge model for context generation
30
- doge_tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M-Instruct")
31
- doge_model = AutoModelForCausalLM.from_pretrained(
32
- "SmallDoge/Doge-320M-Instruct", trust_remote_code=True
33
- ).to(device)
34
- doge_config = GenerationConfig(
35
- max_new_tokens=100,
36
- use_cache=True,
37
- do_sample=True,
38
- temperature=0.8,
39
- top_p=0.9,
40
- repetition_penalty=1.0
41
- )
42
 
43
- # Load SpeechT5 for TTS
44
- synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts", device=device)
 
 
 
45
 
46
- # Use known compatible 600-dim speaker embedding
47
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
48
- speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) # Shape: [1, 600]
49
 
50
  def process_image(image):
51
  try:
52
- # Caption generation
53
  caption = caption_model(image)[0]['generated_text']
54
 
55
- # OCR extraction
56
- ocr_inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(device)
57
- ocr_outputs = ocr_model.generate(
58
- input_ids=ocr_inputs["input_ids"],
59
- pixel_values=ocr_inputs["pixel_values"],
60
  max_new_tokens=1024,
61
  num_beams=3,
62
- do_sample=False,
63
  )
64
- extracted_text = ocr_processor.batch_decode(ocr_outputs, skip_special_tokens=True)[0]
65
 
66
- # Context generation using Doge
67
- prompt = f"Determine the context of this image based on the caption and extracted text.\nCaption: {caption}\nExtracted text: {extracted_text}\nContext:"
68
- conversation = [{"role": "user", "content": prompt}]
69
- doge_inputs = doge_tokenizer.apply_chat_template(conversation, tokenize=True, return_tensors="pt").to(device)
70
- doge_output = doge_model.generate(doge_inputs, generation_config=doge_config)
71
- context = doge_tokenizer.decode(doge_output[0], skip_special_tokens=True)
72
 
73
- # Convert context to speech
74
- speech = synthesiser(
75
- context,
76
- forward_params={"speaker_embeddings": speaker_embedding}
77
- )
78
  audio = np.array(speech["audio"])
79
  rate = speech["sampling_rate"]
80
 
@@ -93,8 +73,8 @@ iface = gr.Interface(
93
  gr.Textbox(label="Extracted Text (OCR)"),
94
  gr.Textbox(label="Generated Context")
95
  ],
96
- title="SeeSay Contextualizer with Doge + SpeechT5",
97
- description="Upload an image to generate a caption, extract OCR text, determine context with Doge-320M, and hear it with SpeechT5."
98
  )
99
 
100
- iface.launch(share=True)
 
1
  import gradio as gr
 
 
 
2
  from transformers import (
3
+ pipeline,
4
+ AutoProcessor,
5
+ AutoModelForCausalLM,
6
+ AutoTokenizer,
7
+ set_seed
 
8
  )
9
  from datasets import load_dataset
10
+ import torch
11
+ import numpy as np
12
 
13
+ # Set seed
14
+ set_seed(42)
 
 
15
 
16
+ # Captioning model
17
+ caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
18
 
19
+ # GPT-2 model for context generation
20
+ gpt2_generator = pipeline("text-generation", model="gpt2")
 
 
 
21
 
22
+ # SpeechT5 for text-to-speech
23
+ synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts")
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # Load Florence-2-base for OCR
26
+ ocr_device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
28
+ ocr_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", torch_dtype=ocr_dtype, trust_remote_code=True).to(ocr_device)
29
+ ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
30
 
31
+ # Load speaker embedding
32
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
33
+ speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
34
 
35
  def process_image(image):
36
  try:
37
+ # Generate caption
38
  caption = caption_model(image)[0]['generated_text']
39
 
40
+ # Extract OCR text
41
+ inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(ocr_device, ocr_dtype)
42
+ generated_ids = ocr_model.generate(
43
+ input_ids=inputs["input_ids"],
44
+ pixel_values=inputs["pixel_values"],
45
  max_new_tokens=1024,
46
  num_beams=3,
47
+ do_sample=False
48
  )
49
+ extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
50
 
51
+ # Generate context with GPT-2
52
+ prompt = f"Determine the context of this image based on the caption and extracted text. Caption: {caption}. Extracted text: {extracted_text}. Context:"
53
+ context_output = gpt2_generator(prompt, max_length=100, num_return_sequences=1)
54
+ context = context_output[0]['generated_text']
 
 
55
 
56
+ # Text-to-speech
57
+ speech = synthesiser(context, forward_params={"speaker_embeddings": speaker_embedding})
 
 
 
58
  audio = np.array(speech["audio"])
59
  rate = speech["sampling_rate"]
60
 
 
73
  gr.Textbox(label="Extracted Text (OCR)"),
74
  gr.Textbox(label="Generated Context")
75
  ],
76
+ title="SeeSay Contextualizer",
77
+ description="Upload an image to generate a caption, extract text, create audio from context, and determine the context using GPT-2 and Florence-2-base."
78
  )
79
 
80
+ iface.launch()