Meme_Generator / app.py
akukkapa's picture
Update app.py
1efc6e5 verified
from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
import torch
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import torch
import os
# Check CUDA availability
if not torch.cuda.is_available():
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["LIBRARY_PATH"] = "/usr/local/cuda/lib64/stubs:$LIBRARY_PATH"
# Initialize environment
os.makedirs("./offload", exist_ok=True)
HF_TOKEN = os.environ.get("HF_TOKEN")
# Memory optimization
torch.cuda.empty_cache()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Load BLIP-2
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
blip_model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
torch_dtype=torch.float16,
device_map="auto"
).eval()
# Load Phi-3
phi_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.float16,
load_in_4bit=torch.cuda.is_available(), # Only use 4bit if CUDA available
token=HF_TOKEN
).eval()
phi_tokenizer = AutoTokenizer.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
token=HF_TOKEN
)
def analyze_image(image):
inputs = blip_processor(image, return_tensors="pt").to(blip_model.device)
generated_ids = blip_model.generate(**inputs, max_length=50)
return blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
def generate_meme_caption(image_desc, user_prompt):
messages = [
{"role": "system", "content": "You are a meme expert. Create funny captions in format: TOP TEXT | BOTTOM TEXT"},
{"role": "user", "content": f"Image context: {image_desc}\nUser input: {user_prompt}\nGenerate 3 meme captions (max 10 words each):"}
]
inputs = phi_tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True
).to(phi_model.device)
outputs = phi_model.generate(
inputs,
max_new_tokens=200,
temperature=0.7,
do_sample=True
)
return phi_tokenizer.decode(outputs[0], skip_special_tokens=True)
def create_meme(image, top_text, bottom_text):
img = image.copy()
draw = ImageDraw.Draw(img)
# Use available font (works in Colab/Spaces)
try:
font = ImageFont.truetype("arial.ttf", size=min(img.size)//12)
except:
font = ImageFont.load_default()
# Top text
draw.text(
(img.width/2, 10),
top_text,
font=font,
fill="white",
anchor="mt",
stroke_width=2,
stroke_fill="black"
)
# Bottom text
draw.text(
(img.width/2, img.height-10),
bottom_text,
font=font,
fill="white",
anchor="mb",
stroke_width=2,
stroke_fill="black"
)
return img
def process_meme(image, user_prompt):
image_desc = analyze_image(image)
raw_output = generate_meme_caption(image_desc, user_prompt)
captions = []
for line in raw_output.split("\n"):
if "|" in line:
parts = line.split("|", 1)
if len(parts) == 2:
captions.append((parts[0].strip(), parts[1].strip()))
memes = [create_meme(image, top, bottom) for top, bottom in captions[:3]]
return memes
with gr.Blocks(title="AI Meme Generator") as demo:
gr.Markdown("# πŸš€ AI Meme Generator")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image")
text_input = gr.Textbox(label="Meme Theme/Prompt")
submit_btn = gr.Button("Generate Memes!")
gallery = gr.Gallery(label="Generated Memes", columns=3)
submit_btn.click(
fn=process_meme,
inputs=[image_input, text_input],
outputs=gallery
)
if __name__ == "__main__":
demo.launch()