preston-cell commited on
Commit
a0edfdb
·
verified ·
1 Parent(s): 1b524ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -51
app.py CHANGED
@@ -1,32 +1,37 @@
1
  import gradio as gr
2
- from transformers import (
3
- pipeline, AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
4
- GenerationConfig, TextStreamer
5
- )
6
- from datasets import load_dataset
7
  import torch
8
  import numpy as np
9
  from PIL import Image
 
 
 
 
 
 
 
 
 
10
 
11
- # Device and dtype setup
12
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
14
 
15
- # Caption model (BLIP)
16
- caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
17
 
18
- # Florence-2-base model for OCR
19
  ocr_model = AutoModelForCausalLM.from_pretrained(
20
- "microsoft/Florence-2-base",
21
- torch_dtype=torch_dtype,
22
- trust_remote_code=True
23
  ).to(device)
24
  ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
25
 
26
- # Doge model for context generation
27
- context_tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M-Instruct")
28
- context_model = AutoModelForCausalLM.from_pretrained("SmallDoge/Doge-320M-Instruct", trust_remote_code=True).to(device)
29
- generation_config = GenerationConfig(
 
 
30
  max_new_tokens=100,
31
  use_cache=True,
32
  do_sample=True,
@@ -34,50 +39,42 @@ generation_config = GenerationConfig(
34
  top_p=0.9,
35
  repetition_penalty=1.0
36
  )
37
- streamer = TextStreamer(tokenizer=context_tokenizer, skip_prompt=True)
38
 
39
- # SpeechT5 for TTS
40
- tts = pipeline("text-to-speech", model="microsoft/speecht5_tts")
41
 
42
- # Load valid 600-dim speaker embedding
43
- speaker_embedding = None
44
- embedding_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
45
- for item in embedding_dataset:
46
- vec = torch.tensor(item["xvector"])
47
- if vec.shape[0] == 600:
48
- speaker_embedding = vec.unsqueeze(0)
49
- break
50
- if speaker_embedding is None:
51
- raise ValueError("No suitable speaker embedding of 600 dimensions found.")
52
 
53
  def process_image(image):
54
  try:
55
- # Generate caption
56
  caption = caption_model(image)[0]['generated_text']
57
 
58
- # Extract text using Florence-2
59
- inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(device, torch_dtype)
60
- generated_ids = ocr_model.generate(
61
- input_ids=inputs["input_ids"],
62
- pixel_values=inputs["pixel_values"],
63
  max_new_tokens=1024,
64
- do_sample=False,
65
  num_beams=3,
 
66
  )
67
- extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
68
 
69
- # Generate context using Doge
70
- prompt = f"Determine the context of this image based on the caption and extracted text. Caption: {caption}. Extracted text: {extracted_text}. Context:"
71
  conversation = [{"role": "user", "content": prompt}]
72
- context_inputs = context_tokenizer.apply_chat_template(conversation=conversation, tokenize=True, return_tensors="pt").to(device)
73
- output = context_model.generate(
74
- context_inputs,
75
- generation_config=generation_config,
76
- )
77
- context = context_tokenizer.decode(output[0], skip_special_tokens=True)
78
 
79
  # Convert context to speech
80
- speech = tts(context, forward_params={"speaker_embeddings": speaker_embedding})
 
 
 
81
  audio = np.array(speech["audio"])
82
  rate = speech["sampling_rate"]
83
 
@@ -86,7 +83,7 @@ def process_image(image):
86
  except Exception as e:
87
  return None, f"Error: {str(e)}", "", ""
88
 
89
- # Gradio Interface
90
  iface = gr.Interface(
91
  fn=process_image,
92
  inputs=gr.Image(type='pil', label="Upload an Image"),
@@ -96,8 +93,8 @@ iface = gr.Interface(
96
  gr.Textbox(label="Extracted Text (OCR)"),
97
  gr.Textbox(label="Generated Context")
98
  ],
99
- title="SeeSay Contextualizer",
100
- description="Upload an image to generate a caption, extract text, generate context with Doge, and convert to speech."
101
  )
102
 
103
  iface.launch(share=True)
 
1
  import gradio as gr
 
 
 
 
 
2
  import torch
3
  import numpy as np
4
  from PIL import Image
5
+ from transformers import (
6
+ pipeline,
7
+ AutoModelForCausalLM,
8
+ AutoProcessor,
9
+ AutoTokenizer,
10
+ GenerationConfig,
11
+ TextStreamer,
12
+ )
13
+ from datasets import load_dataset
14
 
15
+ # Use CPU if no GPU is available
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
+ print(f"Device set to use {device}")
19
 
20
+ # Load image captioning model (BLIP)
21
+ caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base", device=device)
22
 
23
+ # Load OCR model (Florence-2-base)
24
  ocr_model = AutoModelForCausalLM.from_pretrained(
25
+ "microsoft/Florence-2-base", trust_remote_code=True, torch_dtype=dtype
 
 
26
  ).to(device)
27
  ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
28
 
29
+ # Load SmallDoge model for context generation
30
+ doge_tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M-Instruct")
31
+ doge_model = AutoModelForCausalLM.from_pretrained(
32
+ "SmallDoge/Doge-320M-Instruct", trust_remote_code=True
33
+ ).to(device)
34
+ doge_config = GenerationConfig(
35
  max_new_tokens=100,
36
  use_cache=True,
37
  do_sample=True,
 
39
  top_p=0.9,
40
  repetition_penalty=1.0
41
  )
 
42
 
43
+ # Load SpeechT5 for TTS
44
+ synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts", device=device)
45
 
46
+ # Use known compatible 600-dim speaker embedding
47
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
48
+ speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) # Shape: [1, 600]
 
 
 
 
 
 
 
49
 
50
  def process_image(image):
51
  try:
52
+ # Caption generation
53
  caption = caption_model(image)[0]['generated_text']
54
 
55
+ # OCR extraction
56
+ ocr_inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(device)
57
+ ocr_outputs = ocr_model.generate(
58
+ input_ids=ocr_inputs["input_ids"],
59
+ pixel_values=ocr_inputs["pixel_values"],
60
  max_new_tokens=1024,
 
61
  num_beams=3,
62
+ do_sample=False,
63
  )
64
+ extracted_text = ocr_processor.batch_decode(ocr_outputs, skip_special_tokens=True)[0]
65
 
66
+ # Context generation using Doge
67
+ prompt = f"Determine the context of this image based on the caption and extracted text.\nCaption: {caption}\nExtracted text: {extracted_text}\nContext:"
68
  conversation = [{"role": "user", "content": prompt}]
69
+ doge_inputs = doge_tokenizer.apply_chat_template(conversation, tokenize=True, return_tensors="pt").to(device)
70
+ doge_output = doge_model.generate(doge_inputs, generation_config=doge_config)
71
+ context = doge_tokenizer.decode(doge_output[0], skip_special_tokens=True)
 
 
 
72
 
73
  # Convert context to speech
74
+ speech = synthesiser(
75
+ context,
76
+ forward_params={"speaker_embeddings": speaker_embedding}
77
+ )
78
  audio = np.array(speech["audio"])
79
  rate = speech["sampling_rate"]
80
 
 
83
  except Exception as e:
84
  return None, f"Error: {str(e)}", "", ""
85
 
86
+ # Gradio UI
87
  iface = gr.Interface(
88
  fn=process_image,
89
  inputs=gr.Image(type='pil', label="Upload an Image"),
 
93
  gr.Textbox(label="Extracted Text (OCR)"),
94
  gr.Textbox(label="Generated Context")
95
  ],
96
+ title="SeeSay Contextualizer with Doge + SpeechT5",
97
+ description="Upload an image to generate a caption, extract OCR text, determine context with Doge-320M, and hear it with SpeechT5."
98
  )
99
 
100
  iface.launch(share=True)