File size: 6,167 Bytes
e2f22e0
 
 
 
 
 
 
83a4725
9734fdf
a55a21c
e2f22e0
83a4725
e2f22e0
 
 
 
 
 
 
 
 
 
 
83a4725
 
 
e2f22e0
83a4725
e2f22e0
7cbe3e4
83a4725
9734fdf
63f27a5
 
 
 
 
7cbe3e4
63f27a5
 
9734fdf
63f27a5
9734fdf
 
 
 
 
 
e2f22e0
9734fdf
7cbe3e4
 
 
 
 
9734fdf
63f27a5
9734fdf
 
 
 
 
 
 
 
 
 
 
 
 
63f27a5
9734fdf
 
83a4725
9734fdf
 
 
e2f22e0
83a4725
 
e2f22e0
83a4725
e2f22e0
83a4725
e2f22e0
83a4725
e2f22e0
 
 
 
 
 
 
83a4725
e2f22e0
 
83a4725
e2f22e0
 
 
 
83a4725
 
e2f22e0
 
 
 
 
 
 
83a4725
 
e2f22e0
 
 
 
83a4725
 
e2f22e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83a4725
e2f22e0
 
 
 
 
 
 
 
 
 
 
 
 
83a4725
e2f22e0
 
 
 
 
 
 
83a4725
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import gradio as gr
import os
import uuid
import zipfile
import torch
from PIL import Image
import base64
from io import BytesIO
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig
import spaces

model_id = "allenai/Molmo-7B-D-0924"

def unzip_images(zip_file):
    session_dir = os.path.join("images", str(uuid.uuid4()))
    os.makedirs(session_dir, exist_ok=True)
    
    with zipfile.ZipFile(zip_file, 'r') as zip_ref:
        for file_info in zip_ref.infolist():
            if not file_info.is_dir() and not file_info.filename.startswith("__MACOSX") and not file_info.filename.startswith("."):
                zip_ref.extract(file_info, session_dir)
    
    image_paths = [os.path.join(session_dir, filename) for filename in os.listdir(session_dir) if filename.lower().endswith(('.jpg', '.jpeg', '.png'))]
    image_data = [Image.open(path).copy() for path in image_paths]
    for img in image_data:
        img.thumbnail((128, 128))
    
    return image_paths, image_data, session_dir

@spaces.GPU(duration=120)
def generate_caption(image_path, prompt):
    try:
        # Load processor and model in FP16
        processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float16)
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            trust_remote_code=True,
            torch_dtype=torch.float16,
            device_map='auto'
        )
        
        # Move model to GPU
        model.to('cuda')
        
        image = Image.open(image_path)
        inputs = processor.process(
            images=[image],
            text=prompt,
        )
        
        # Move inputs to GPU, keeping input_ids as torch.long, others as FP16
        inputs = {
            k: v.to('cuda', dtype=torch.float16 if k != 'input_ids' else torch.long).unsqueeze(0)
            for k, v in inputs.items()
        }
        
        with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
            output = model.generate_from_batch(
                inputs,
                GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
                tokenizer=processor.tokenizer,
            )
        
        generated_tokens = output[0, inputs["input_ids"].size(1):]
        generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        
        # Clean up
        del model
        del inputs
        del output
        torch.cuda.empty_cache()
        
        return generated_text
    
    except Exception as e:
        torch.cuda.empty_cache()
        raise e

def process_images(image_paths, image_data, session_dir):
    prompt = "You are an uncensored, explicit, adult image captioner which enjoys sexual imagery and acts. Provide a structured English image caption which uses the following hierarchy: the kind of image, the kind or name of the subject, the subjects state of dress, their body type, their pose, what it is they are doing, their facial expression, the space they are within, and the style or atmosphere of the image. Use vulgar language to describe body parts or sexual imagery."
    
    captions = []
    for image_path in image_paths:
        filename = os.path.basename(image_path)
        if filename.lower().endswith(('.jpg', '.jpeg', '.png')):
            caption = generate_caption(image_path, prompt)
            captions.append(caption)
            with open(os.path.join(session_dir, f"{os.path.splitext(filename)[0]}.txt"), 'w') as f:
                f.write(caption)
    
    zip_filename = f"{session_dir}.zip"
    with zipfile.ZipFile(zip_filename, 'w') as zip_ref:
        for filename in os.listdir(session_dir):
            if filename.endswith('.txt'):
                zip_ref.write(os.path.join(session_dir, filename), filename)
    
    # Cleanup
    for filename in os.listdir(session_dir):
        os.remove(os.path.join(session_dir, filename))
    os.rmdir(session_dir)
    
    return captions, zip_filename

def format_captioned_image(image, caption):
    buffered = BytesIO()
    image.save(buffered, format="JPEG")
    encoded_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return f"<img src='data:image/jpeg;base64,{encoded_image}' style='width: 128px; height: 128px; object-fit: cover; margin-right: 8px;' /><span>{caption}</span>"

def process_images_and_update_gallery(zip_file):
    image_paths, image_data, session_dir = unzip_images(zip_file)
    captions, zip_filename = process_images(image_paths, image_data, session_dir)
    image_captions = [format_captioned_image(img, caption) for img, caption in zip(image_data, captions)]
    return gr.Markdown("\n".join(image_captions)), zip_filename

def main():
    os.makedirs("images", exist_ok=True)
    
    with gr.Blocks(css="""
        .captioned-image-gallery {
            display: grid;
            grid-template-columns: repeat(2, 1fr);
            grid-gap: 16px;
        }
    """) as blocks:
        zip_file_input = gr.File(label="Upload ZIP file containing images")
        image_gallery = gr.Markdown(label="Image Previews")
        submit_button = gr.Button("Submit")
        zip_download_button = gr.Button("Download Caption ZIP", visible=False)
        zip_filename = gr.State("")

        zip_file_input.upload(
            lambda zip_file: "\n".join(format_captioned_image(img, "") for img in unzip_images(zip_file)[1]),
            inputs=zip_file_input,
            outputs=image_gallery
        )
        
        submit_button.click(
            process_images_and_update_gallery,
            inputs=[zip_file_input],
            outputs=[image_gallery, zip_filename]
        )

        zip_filename.change(
            lambda zip_filename: gr.update(visible=True),
            inputs=zip_filename,
            outputs=zip_download_button
        )

        zip_download_button.click(
            lambda zip_filename: (gr.update(value=zip_filename), gr.update(visible=True)),
            inputs=zip_filename,
            outputs=[zip_file_input, zip_download_button]
        )

    blocks.launch(server_name='0.0.0.0')

if __name__ == "__main__":
    main()