File size: 4,060 Bytes
6db1e39
 
 
 
40c76e9
6db1e39
239f8ba
40c76e9
 
 
 
 
 
239f8ba
5c0bae3
239f8ba
6db1e39
239f8ba
 
 
5c0bae3
 
6db1e39
239f8ba
 
6db1e39
 
 
239f8ba
 
6db1e39
239f8ba
6db1e39
 
 
 
 
40c76e9
239f8ba
 
6db1e39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1efc6e5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
import torch
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import torch
import os

# Check CUDA availability
if not torch.cuda.is_available():
    os.environ["BITSANDBYTES_NOWELCOME"] = "1"
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
    os.environ["LIBRARY_PATH"] = "/usr/local/cuda/lib64/stubs:$LIBRARY_PATH"

# Initialize environment
os.makedirs("./offload", exist_ok=True)
HF_TOKEN = os.environ.get("HF_TOKEN")

# Memory optimization
torch.cuda.empty_cache()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Load BLIP-2
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
blip_model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b",
    torch_dtype=torch.float16,
    device_map="auto"
).eval()

# Load Phi-3
phi_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct",
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.float16,
    load_in_4bit=torch.cuda.is_available(),  # Only use 4bit if CUDA available
    token=HF_TOKEN
).eval()
phi_tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct",
    token=HF_TOKEN
)

def analyze_image(image):
    inputs = blip_processor(image, return_tensors="pt").to(blip_model.device)
    generated_ids = blip_model.generate(**inputs, max_length=50)
    return blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

def generate_meme_caption(image_desc, user_prompt):
    messages = [
        {"role": "system", "content": "You are a meme expert. Create funny captions in format: TOP TEXT | BOTTOM TEXT"},
        {"role": "user", "content": f"Image context: {image_desc}\nUser input: {user_prompt}\nGenerate 3 meme captions (max 10 words each):"}
    ]
    
    inputs = phi_tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
        add_generation_prompt=True
    ).to(phi_model.device)
    
    outputs = phi_model.generate(
        inputs,
        max_new_tokens=200,
        temperature=0.7,
        do_sample=True
    )
    return phi_tokenizer.decode(outputs[0], skip_special_tokens=True)

def create_meme(image, top_text, bottom_text):
    img = image.copy()
    draw = ImageDraw.Draw(img)
    
    # Use available font (works in Colab/Spaces)
    try:
        font = ImageFont.truetype("arial.ttf", size=min(img.size)//12)
    except:
        font = ImageFont.load_default()
    
    # Top text
    draw.text(
        (img.width/2, 10), 
        top_text, 
        font=font, 
        fill="white", 
        anchor="mt",
        stroke_width=2, 
        stroke_fill="black"
    )
    
    # Bottom text
    draw.text(
        (img.width/2, img.height-10), 
        bottom_text, 
        font=font, 
        fill="white", 
        anchor="mb",
        stroke_width=2, 
        stroke_fill="black"
    )
    return img

def process_meme(image, user_prompt):
    image_desc = analyze_image(image)
    raw_output = generate_meme_caption(image_desc, user_prompt)
    
    captions = []
    for line in raw_output.split("\n"):
        if "|" in line:
            parts = line.split("|", 1)
            if len(parts) == 2:
                captions.append((parts[0].strip(), parts[1].strip()))
    
    memes = [create_meme(image, top, bottom) for top, bottom in captions[:3]]
    return memes

with gr.Blocks(title="AI Meme Generator") as demo:
    gr.Markdown("# 🚀 AI Meme Generator")
    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Image")
        text_input = gr.Textbox(label="Meme Theme/Prompt")
    submit_btn = gr.Button("Generate Memes!")
    gallery = gr.Gallery(label="Generated Memes", columns=3)
    
    submit_btn.click(
        fn=process_meme,
        inputs=[image_input, text_input],
        outputs=gallery
    )

if __name__ == "__main__":
    demo.launch()