preston-cell commited on
Commit
8c3caa4
·
verified ·
1 Parent(s): 33c79cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -18
app.py CHANGED
@@ -1,5 +1,12 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, AutoTokenizer, set_seed
 
 
 
 
 
 
 
3
  from datasets import load_dataset
4
  import torch
5
  import numpy as np
@@ -14,25 +21,43 @@ caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captionin
14
  synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts")
15
 
16
  # Load Florence-2 model for OCR
17
- ocr_device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
  ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
19
- ocr_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=ocr_dtype, trust_remote_code=True).to(ocr_device)
 
 
 
 
20
  ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
21
 
22
- # Load GPT-2 (124M) model for text generation
23
- gpt2_generator = pipeline('text-generation', model='gpt2')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # Load speaker embedding
26
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
27
  speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
28
 
29
 
30
  def process_image(image):
31
  try:
32
- # Generate caption from the image
33
  caption = caption_model(image)[0]['generated_text']
34
 
35
- # Extract text (OCR) using Florence-2
36
  inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(ocr_device, ocr_dtype)
37
  generated_ids = ocr_model.generate(
38
  input_ids=inputs["input_ids"],
@@ -43,22 +68,31 @@ def process_image(image):
43
  )
44
  extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
45
 
46
- # Generate context using GPT-2 (124M)
47
- prompt = f"Determine the context of this image based on the caption and extracted text. Caption: {caption}. Extracted text: {extracted_text}. Context:"
48
- context_output = gpt2_generator(prompt, max_length=100, num_return_sequences=1)
49
- context = context_output[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # Convert context to speech
52
  speech = synthesiser(
53
  context,
54
  forward_params={"speaker_embeddings": speaker_embedding}
55
  )
56
 
57
- # Prepare audio data
58
  audio = np.array(speech["audio"])
59
  rate = speech["sampling_rate"]
60
 
61
- # Return audio, caption, extracted text, and context
62
  return (rate, audio), caption, extracted_text, context
63
 
64
  except Exception as e:
@@ -75,8 +109,8 @@ iface = gr.Interface(
75
  gr.Textbox(label="Extracted Text (OCR)"),
76
  gr.Textbox(label="Generated Context")
77
  ],
78
- title="SeeSay Contextualizer with GPT-2 (124M)",
79
- description="Upload an image to generate a caption, extract text, create audio from context, and determine the context using GPT-2."
80
  )
81
 
82
- iface.launch()
 
1
  import gradio as gr
2
+ from transformers import (
3
+ pipeline,
4
+ AutoProcessor,
5
+ AutoModelForCausalLM,
6
+ AutoTokenizer,
7
+ GenerationConfig,
8
+ set_seed
9
+ )
10
  from datasets import load_dataset
11
  import torch
12
  import numpy as np
 
21
  synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts")
22
 
23
  # Load Florence-2 model for OCR
24
+ ocr_device = "cuda" if torch.cuda.is_available() else "cpu"
25
  ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
26
+ ocr_model = AutoModelForCausalLM.from_pretrained(
27
+ "microsoft/Florence-2-large",
28
+ torch_dtype=ocr_dtype,
29
+ trust_remote_code=True
30
+ ).to(ocr_device)
31
  ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
32
 
33
+ # Load Doge-320M-Instruct model for context generation
34
+ doge_tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M-Instruct")
35
+ doge_model = AutoModelForCausalLM.from_pretrained(
36
+ "SmallDoge/Doge-320M-Instruct",
37
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
38
+ trust_remote_code=True
39
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
40
+
41
+ doge_generation_config = GenerationConfig(
42
+ max_new_tokens=100,
43
+ use_cache=True,
44
+ do_sample=True,
45
+ temperature=0.8,
46
+ top_p=0.9,
47
+ repetition_penalty=1.0
48
+ )
49
 
50
+ # Load speaker embedding for SpeechT5
51
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
52
  speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
53
 
54
 
55
  def process_image(image):
56
  try:
57
+ # Step 1: Generate caption
58
  caption = caption_model(image)[0]['generated_text']
59
 
60
+ # Step 2: OCR to extract text
61
  inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(ocr_device, ocr_dtype)
62
  generated_ids = ocr_model.generate(
63
  input_ids=inputs["input_ids"],
 
68
  )
69
  extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
70
 
71
+ # Step 3: Generate context using Doge model
72
+ prompt = f"Determine the context of this image based on the caption and extracted text.\nCaption: {caption}\nExtracted text: {extracted_text}\nContext:"
73
+ conversation = [{"role": "user", "content": prompt}]
74
+ doge_inputs = doge_tokenizer.apply_chat_template(
75
+ conversation=conversation,
76
+ tokenize=True,
77
+ return_tensors="pt"
78
+ ).to(doge_model.device)
79
+
80
+ doge_outputs = doge_model.generate(
81
+ doge_inputs,
82
+ generation_config=doge_generation_config
83
+ )
84
+
85
+ context = doge_tokenizer.decode(doge_outputs[0], skip_special_tokens=True).strip()
86
 
87
+ # Step 4: Convert context to speech
88
  speech = synthesiser(
89
  context,
90
  forward_params={"speaker_embeddings": speaker_embedding}
91
  )
92
 
 
93
  audio = np.array(speech["audio"])
94
  rate = speech["sampling_rate"]
95
 
 
96
  return (rate, audio), caption, extracted_text, context
97
 
98
  except Exception as e:
 
109
  gr.Textbox(label="Extracted Text (OCR)"),
110
  gr.Textbox(label="Generated Context")
111
  ],
112
+ title="SeeSay Contextualizer with Doge-320M",
113
+ description="Upload an image to generate a caption, extract text (OCR), generate context using Doge, and turn it into speech using SpeechT5."
114
  )
115
 
116
+ iface.launch()