Spaces:
Running
Running
File size: 4,060 Bytes
6db1e39 40c76e9 6db1e39 239f8ba 40c76e9 239f8ba 5c0bae3 239f8ba 6db1e39 239f8ba 5c0bae3 6db1e39 239f8ba 6db1e39 239f8ba 6db1e39 239f8ba 6db1e39 40c76e9 239f8ba 6db1e39 1efc6e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
import torch
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import torch
import os
# Check CUDA availability
if not torch.cuda.is_available():
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["LIBRARY_PATH"] = "/usr/local/cuda/lib64/stubs:$LIBRARY_PATH"
# Initialize environment
os.makedirs("./offload", exist_ok=True)
HF_TOKEN = os.environ.get("HF_TOKEN")
# Memory optimization
torch.cuda.empty_cache()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Load BLIP-2
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
blip_model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
torch_dtype=torch.float16,
device_map="auto"
).eval()
# Load Phi-3
phi_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.float16,
load_in_4bit=torch.cuda.is_available(), # Only use 4bit if CUDA available
token=HF_TOKEN
).eval()
phi_tokenizer = AutoTokenizer.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
token=HF_TOKEN
)
def analyze_image(image):
inputs = blip_processor(image, return_tensors="pt").to(blip_model.device)
generated_ids = blip_model.generate(**inputs, max_length=50)
return blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
def generate_meme_caption(image_desc, user_prompt):
messages = [
{"role": "system", "content": "You are a meme expert. Create funny captions in format: TOP TEXT | BOTTOM TEXT"},
{"role": "user", "content": f"Image context: {image_desc}\nUser input: {user_prompt}\nGenerate 3 meme captions (max 10 words each):"}
]
inputs = phi_tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True
).to(phi_model.device)
outputs = phi_model.generate(
inputs,
max_new_tokens=200,
temperature=0.7,
do_sample=True
)
return phi_tokenizer.decode(outputs[0], skip_special_tokens=True)
def create_meme(image, top_text, bottom_text):
img = image.copy()
draw = ImageDraw.Draw(img)
# Use available font (works in Colab/Spaces)
try:
font = ImageFont.truetype("arial.ttf", size=min(img.size)//12)
except:
font = ImageFont.load_default()
# Top text
draw.text(
(img.width/2, 10),
top_text,
font=font,
fill="white",
anchor="mt",
stroke_width=2,
stroke_fill="black"
)
# Bottom text
draw.text(
(img.width/2, img.height-10),
bottom_text,
font=font,
fill="white",
anchor="mb",
stroke_width=2,
stroke_fill="black"
)
return img
def process_meme(image, user_prompt):
image_desc = analyze_image(image)
raw_output = generate_meme_caption(image_desc, user_prompt)
captions = []
for line in raw_output.split("\n"):
if "|" in line:
parts = line.split("|", 1)
if len(parts) == 2:
captions.append((parts[0].strip(), parts[1].strip()))
memes = [create_meme(image, top, bottom) for top, bottom in captions[:3]]
return memes
with gr.Blocks(title="AI Meme Generator") as demo:
gr.Markdown("# 🚀 AI Meme Generator")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image")
text_input = gr.Textbox(label="Meme Theme/Prompt")
submit_btn = gr.Button("Generate Memes!")
gallery = gr.Gallery(label="Generated Memes", columns=3)
submit_btn.click(
fn=process_meme,
inputs=[image_input, text_input],
outputs=gallery
)
if __name__ == "__main__":
demo.launch() |