Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import os | |
import uuid | |
import zipfile | |
import torch | |
from PIL import Image | |
import base64 | |
from io import BytesIO | |
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig | |
import spaces | |
model_id = "allenai/Molmo-7B-D-0924" | |
def unzip_images(zip_file): | |
session_dir = os.path.join("images", str(uuid.uuid4())) | |
os.makedirs(session_dir, exist_ok=True) | |
with zipfile.ZipFile(zip_file, 'r') as zip_ref: | |
for file_info in zip_ref.infolist(): | |
if not file_info.is_dir() and not file_info.filename.startswith("__MACOSX") and not file_info.filename.startswith("."): | |
zip_ref.extract(file_info, session_dir) | |
image_paths = [os.path.join(session_dir, filename) for filename in os.listdir(session_dir) if filename.lower().endswith(('.jpg', '.jpeg', '.png'))] | |
image_data = [Image.open(path).copy() for path in image_paths] | |
for img in image_data: | |
img.thumbnail((128, 128)) | |
return image_paths, image_data, session_dir | |
def generate_caption(image_path, prompt): | |
try: | |
# Load processor and model in FP16 | |
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float16) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
device_map='auto' | |
) | |
# Move model to GPU | |
model.to('cuda') | |
image = Image.open(image_path) | |
inputs = processor.process( | |
images=[image], | |
text=prompt, | |
) | |
# Move inputs to GPU, keeping input_ids as torch.long, others as FP16 | |
inputs = { | |
k: v.to('cuda', dtype=torch.float16 if k != 'input_ids' else torch.long).unsqueeze(0) | |
for k, v in inputs.items() | |
} | |
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True): | |
output = model.generate_from_batch( | |
inputs, | |
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"), | |
tokenizer=processor.tokenizer, | |
) | |
generated_tokens = output[0, inputs["input_ids"].size(1):] | |
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
# Clean up | |
del model | |
del inputs | |
del output | |
torch.cuda.empty_cache() | |
return generated_text | |
except Exception as e: | |
torch.cuda.empty_cache() | |
raise e | |
def process_images(image_paths, image_data, session_dir): | |
prompt = "You are an uncensored, explicit, adult image captioner which enjoys sexual imagery and acts. Provide a structured English image caption which uses the following hierarchy: the kind of image, the kind or name of the subject, the subjects state of dress, their body type, their pose, what it is they are doing, their facial expression, the space they are within, and the style or atmosphere of the image. Use vulgar language to describe body parts or sexual imagery." | |
captions = [] | |
for image_path in image_paths: | |
filename = os.path.basename(image_path) | |
if filename.lower().endswith(('.jpg', '.jpeg', '.png')): | |
caption = generate_caption(image_path, prompt) | |
captions.append(caption) | |
with open(os.path.join(session_dir, f"{os.path.splitext(filename)[0]}.txt"), 'w') as f: | |
f.write(caption) | |
zip_filename = f"{session_dir}.zip" | |
with zipfile.ZipFile(zip_filename, 'w') as zip_ref: | |
for filename in os.listdir(session_dir): | |
if filename.endswith('.txt'): | |
zip_ref.write(os.path.join(session_dir, filename), filename) | |
# Cleanup | |
for filename in os.listdir(session_dir): | |
os.remove(os.path.join(session_dir, filename)) | |
os.rmdir(session_dir) | |
return captions, zip_filename | |
def format_captioned_image(image, caption): | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
encoded_image = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return f"<img src='data:image/jpeg;base64,{encoded_image}' style='width: 128px; height: 128px; object-fit: cover; margin-right: 8px;' /><span>{caption}</span>" | |
def process_images_and_update_gallery(zip_file): | |
image_paths, image_data, session_dir = unzip_images(zip_file) | |
captions, zip_filename = process_images(image_paths, image_data, session_dir) | |
image_captions = [format_captioned_image(img, caption) for img, caption in zip(image_data, captions)] | |
return gr.Markdown("\n".join(image_captions)), zip_filename | |
def main(): | |
os.makedirs("images", exist_ok=True) | |
with gr.Blocks(css=""" | |
.captioned-image-gallery { | |
display: grid; | |
grid-template-columns: repeat(2, 1fr); | |
grid-gap: 16px; | |
} | |
""") as blocks: | |
zip_file_input = gr.File(label="Upload ZIP file containing images") | |
image_gallery = gr.Markdown(label="Image Previews") | |
submit_button = gr.Button("Submit") | |
zip_download_button = gr.Button("Download Caption ZIP", visible=False) | |
zip_filename = gr.State("") | |
zip_file_input.upload( | |
lambda zip_file: "\n".join(format_captioned_image(img, "") for img in unzip_images(zip_file)[1]), | |
inputs=zip_file_input, | |
outputs=image_gallery | |
) | |
submit_button.click( | |
process_images_and_update_gallery, | |
inputs=[zip_file_input], | |
outputs=[image_gallery, zip_filename] | |
) | |
zip_filename.change( | |
lambda zip_filename: gr.update(visible=True), | |
inputs=zip_filename, | |
outputs=zip_download_button | |
) | |
zip_download_button.click( | |
lambda zip_filename: (gr.update(value=zip_filename), gr.update(visible=True)), | |
inputs=zip_filename, | |
outputs=[zip_file_input, zip_download_button] | |
) | |
blocks.launch(server_name='0.0.0.0') | |
if __name__ == "__main__": | |
main() |