File size: 4,965 Bytes
09daea9
 
7795444
 
 
 
 
 
 
 
 
 
b1cdcda
 
 
7795444
 
 
 
 
 
64bbe5e
7795444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a07a29c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load models
def load_models():
    model_name = "X-ART/LeX-Enhancer-full"
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
        # device_map="auto"
    ).to("cuda")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    return model, tokenizer

model, tokenizer = load_models()

# @spaces.GPU()
def generate_enhanced_caption(image_caption, text_caption):
    """Generate enhanced caption using the LeX-Enhancer model"""
    combined_caption = f"{image_caption}, with the text on it: {text_caption}."
    
    instruction = """
Below is the simple caption of an image with text. Please deduce the detailed description of the image based on this simple caption. Note: 
1. The description should only include visual elements and should not contain any extended meanings. 
2. The visual elements should be as rich as possible, such as the main objects in the image, their respective attributes, the spatial relationships between the objects, lighting and shadows, color style, any text in the image and its style, etc. 
3. The output description should be a single paragraph and should not be structured. 
4. The description should avoid certain situations, such as pure white or black backgrounds, blurry text, excessive rendering of text, or harsh visual styles. 
5. The detailed caption should be human readable and fluent. 
6. Avoid using vague expressions such as "may be" or "might be"; the generated caption must be in a definitive, narrative tone. 
7. Do not use negative sentence structures, such as "there is nothing in the image," etc. The entire caption should directly describe the content of the image. 
8. The entire output should be limited to 200 words.
"""
    
    messages = [
        {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
        {"role": "user", "content": instruction + "\nSimple Caption:\n" + combined_caption}
    ]
    
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=1024
    )
    
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    enhanced_caption = response.split("</think>", -1)[-1].strip(" ").strip("\n")
    
    # Clear memory
    del model_inputs, generated_ids
    torch.cuda.empty_cache()
    
    return combined_caption, enhanced_caption

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# LeX-Enhancer Demo")
    gr.Markdown("## Enhance your image captions with detailed visual descriptions")
    gr.Markdown("Project Page: https://zhaoshitian.github.io/lexart/")
    
    with gr.Row():
        with gr.Column():
            image_caption = gr.Textbox(
                lines=2,
                label="Image Caption",
                placeholder="Describe the visual content of the image",
                value="A picture of a group of people gathered in front of a world map"
            )
            text_caption = gr.Textbox(
                lines=2,
                label="Text Caption",
                placeholder="Describe any text that should appear in the image",
                value="\"Communicate\" in purple, \"Execute\" in yellow"
            )
            
            submit_btn = gr.Button("Enhance Caption", variant="primary")
        
        with gr.Column():
            combined_caption_box = gr.Textbox(
                label="Combined Caption",
                interactive=False
            )
            enhanced_caption_box = gr.Textbox(
                label="Enhanced Caption",
                interactive=False,
                lines=8
            )
    
    # Example prompts
    examples = [
        ["A modern office workspace", "\"Innovation\" in bold blue letters at the center"],
        ["A beach sunset scene", "\"Relax\" in cursive white text in the corner"],
        ["A futuristic city skyline", "\"The Future is Now\" in neon pink glowing letters"]
    ]
    
    gr.Examples(
        examples=examples,
        inputs=[image_caption, text_caption],
        outputs=[combined_caption_box, enhanced_caption_box],
        fn=generate_enhanced_caption,
        label="Example Inputs - Click on any to see the enhancement"
    )
    
    submit_btn.click(
        fn=generate_enhanced_caption,
        inputs=[image_caption, text_caption],
        outputs=[combined_caption_box, enhanced_caption_box]
    )

if __name__ == "__main__":
    demo.launch()