File size: 4,303 Bytes
f30b843
8c3caa4
 
 
 
 
 
 
 
e6ea13d
602e80d
 
608498c
e6ea13d
 
 
 
ed4af8f
83cd235
e6ea13d
 
 
 
8c3caa4
efa273d
8c3caa4
 
 
 
 
efa273d
 
8c3caa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9b130c
6a2189c
602e80d
dbabbd4
 
 
 
 
 
 
 
 
 
e6ea13d
602e80d
629e04f
8c3caa4
ed4af8f
602e80d
8c3caa4
efa273d
 
 
 
 
 
 
 
 
 
8c3caa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7203e8
8c3caa4
e6ea13d
 
f7203e8
 
e9b130c
e6ea13d
 
602e80d
e6ea13d
602e80d
629e04f
e9b130c
26dbd13
e6ea13d
602e80d
26dbd13
602e80d
ed4af8f
602e80d
 
efa273d
e9b130c
e6ea13d
602e80d
6a2189c
8c3caa4
629e04f
5c86456
84402c4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
from transformers import (
    pipeline,
    AutoProcessor,
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    set_seed
)
from datasets import load_dataset
import torch
import numpy as np

# Set seed for reproducibility
set_seed(42)

# Load BLIP model for image captioning
caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")

# Load SpeechT5 model for text-to-speech
synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts")

# Load Florence-2 model for OCR
ocr_device = "cuda" if torch.cuda.is_available() else "cpu"
ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
ocr_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Florence-2-large",
    torch_dtype=ocr_dtype,
    trust_remote_code=True
).to(ocr_device)
ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)

# Load Doge-320M-Instruct model for context generation
doge_tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M-Instruct")
doge_model = AutoModelForCausalLM.from_pretrained(
    "SmallDoge/Doge-320M-Instruct",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    trust_remote_code=True
).to("cuda" if torch.cuda.is_available() else "cpu")

doge_generation_config = GenerationConfig(
    max_new_tokens=100,
    use_cache=True,
    do_sample=True,
    temperature=0.8,
    top_p=0.9,
    repetition_penalty=1.0
)

# Load and pad/truncate speaker embedding to exactly 600 dimensions
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
raw_vec = embeddings_dataset[0]["xvector"]

# Force embedding to 600 dimensions
if len(raw_vec) > 600:
    raw_vec = raw_vec[:600]
elif len(raw_vec) < 600:
    raw_vec = raw_vec + [0.0] * (600 - len(raw_vec))

speaker_embedding = torch.tensor(raw_vec, dtype=torch.float32).unsqueeze(0)  # shape [1, 600]
assert speaker_embedding.shape == (1, 600), f"Speaker embedding shape is {speaker_embedding.shape}, expected (1, 600)"

def process_image(image):
    try:
        # Step 1: Generate caption
        caption = caption_model(image)[0]['generated_text']

        # Step 2: OCR to extract text
        inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(ocr_device, ocr_dtype)
        generated_ids = ocr_model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=4096,
            num_beams=3,
            do_sample=False
        )
        extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # Step 3: Generate context using Doge model
        prompt = f"Determine the context of this image based on the caption and extracted text.\nCaption: {caption}\nExtracted text: {extracted_text}\nContext:"
        conversation = [{"role": "user", "content": prompt}]
        doge_inputs = doge_tokenizer.apply_chat_template(
            conversation=conversation,
            tokenize=True,
            return_tensors="pt"
        ).to(doge_model.device)

        doge_outputs = doge_model.generate(
            doge_inputs,
            generation_config=doge_generation_config
        )

        context = doge_tokenizer.decode(doge_outputs[0], skip_special_tokens=True).strip()

        # Step 4: Convert context to speech
        speech = synthesiser(
            context,
            forward_params={"speaker_embeddings": speaker_embedding}
        )

        audio = np.array(speech["audio"])
        rate = speech["sampling_rate"]

        return (rate, audio), caption, extracted_text, context

    except Exception as e:
        return None, f"Error: {str(e)}", "", ""


# Gradio Interface
iface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type='pil', label="Upload an Image"),
    outputs=[
        gr.Audio(label="Generated Audio"),
        gr.Textbox(label="Generated Caption"),
        gr.Textbox(label="Extracted Text (OCR)"),
        gr.Textbox(label="Generated Context")
    ],
    title="SeeSay Contextualizer with Doge-320M",
    description="Upload an image to generate a caption, extract text (OCR), generate context using Doge, and turn it into speech using SpeechT5."
)

iface.launch()