from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer import torch from PIL import Image, ImageDraw, ImageFont import gradio as gr import torch import os # Check CUDA availability if not torch.cuda.is_available(): os.environ["BITSANDBYTES_NOWELCOME"] = "1" os.environ["CUDA_VISIBLE_DEVICES"] = "" os.environ["LIBRARY_PATH"] = "/usr/local/cuda/lib64/stubs:$LIBRARY_PATH" # Initialize environment os.makedirs("./offload", exist_ok=True) HF_TOKEN = os.environ.get("HF_TOKEN") # Memory optimization torch.cuda.empty_cache() os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Load BLIP-2 blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") blip_model = Blip2ForConditionalGeneration.from_pretrained( "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto" ).eval() # Load Phi-3 phi_model = AutoModelForCausalLM.from_pretrained( "microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True, device_map="auto", torch_dtype=torch.float16, load_in_4bit=torch.cuda.is_available(), # Only use 4bit if CUDA available token=HF_TOKEN ).eval() phi_tokenizer = AutoTokenizer.from_pretrained( "microsoft/Phi-3-mini-4k-instruct", token=HF_TOKEN ) def analyze_image(image): inputs = blip_processor(image, return_tensors="pt").to(blip_model.device) generated_ids = blip_model.generate(**inputs, max_length=50) return blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] def generate_meme_caption(image_desc, user_prompt): messages = [ {"role": "system", "content": "You are a meme expert. Create funny captions in format: TOP TEXT | BOTTOM TEXT"}, {"role": "user", "content": f"Image context: {image_desc}\nUser input: {user_prompt}\nGenerate 3 meme captions (max 10 words each):"} ] inputs = phi_tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True ).to(phi_model.device) outputs = phi_model.generate( inputs, max_new_tokens=200, temperature=0.7, do_sample=True ) return phi_tokenizer.decode(outputs[0], skip_special_tokens=True) def create_meme(image, top_text, bottom_text): img = image.copy() draw = ImageDraw.Draw(img) # Use available font (works in Colab/Spaces) try: font = ImageFont.truetype("arial.ttf", size=min(img.size)//12) except: font = ImageFont.load_default() # Top text draw.text( (img.width/2, 10), top_text, font=font, fill="white", anchor="mt", stroke_width=2, stroke_fill="black" ) # Bottom text draw.text( (img.width/2, img.height-10), bottom_text, font=font, fill="white", anchor="mb", stroke_width=2, stroke_fill="black" ) return img def process_meme(image, user_prompt): image_desc = analyze_image(image) raw_output = generate_meme_caption(image_desc, user_prompt) captions = [] for line in raw_output.split("\n"): if "|" in line: parts = line.split("|", 1) if len(parts) == 2: captions.append((parts[0].strip(), parts[1].strip())) memes = [create_meme(image, top, bottom) for top, bottom in captions[:3]] return memes with gr.Blocks(title="AI Meme Generator") as demo: gr.Markdown("# 🚀 AI Meme Generator") with gr.Row(): image_input = gr.Image(type="pil", label="Upload Image") text_input = gr.Textbox(label="Meme Theme/Prompt") submit_btn = gr.Button("Generate Memes!") gallery = gr.Gallery(label="Generated Memes", columns=3) submit_btn.click( fn=process_meme, inputs=[image_input, text_input], outputs=gallery ) if __name__ == "__main__": demo.launch()