import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextIteratorStreamer, BitsAndBytesConfig import gradio as gr from threading import Thread import numpy as np from PIL import Image import subprocess import spaces # Add this import # Install flash-attention subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) # Constants TITLE = "

Phi 3.5 Multimodal (Text + Vision)

" DESCRIPTION = "# Phi-3.5 Multimodal Demo (Text + Vision)" # Model configurations TEXT_MODEL_ID = "microsoft/Phi-3.5-mini-instruct" VISION_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" device = "cuda" if torch.cuda.is_available() else "cpu" # Quantization config for text model quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) # Load models and tokenizers text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID) text_model = AutoModelForCausalLM.from_pretrained( TEXT_MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config ) vision_model = AutoModelForCausalLM.from_pretrained( VISION_MODEL_ID, trust_remote_code=True, torch_dtype="auto", attn_implementation="flash_attention_2" ).to(device).eval() vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True) # Helper functions @spaces.GPU def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20): conversation = [{"role": "system", "content": system_prompt}] for prompt, answer in history: conversation.extend([ {"role": "user", "content": prompt}, {"role": "assistant", "content": answer}, ]) conversation.append({"role": "user", "content": message}) input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(text_model.device) streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=temperature > 0, top_p=top_p, top_k=top_k, temperature=temperature, eos_token_id=[128001, 128008, 128009], streamer=streamer, ) with torch.no_grad(): thread = Thread(target=text_model.generate, kwargs=generate_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield history + [[message, buffer]] @spaces.GPU def process_vision_query(image, text_input): prompt = f"<|user|>\n<|image_1|>\n{text_input}<|end|>\n<|assistant|>\n" # Convert the image to bytes if it's not already if isinstance(image, Image.Image): # If it's a PIL Image, convert to bytes import io img_byte_arr = io.BytesIO() image.save(img_byte_arr, format='PNG') image = img_byte_arr.getvalue() elif isinstance(image, np.ndarray): # If it's a numpy array, convert to PIL Image first, then to bytes image = Image.fromarray(image).convert("RGB") img_byte_arr = io.BytesIO() image.save(img_byte_arr, format='PNG') image = img_byte_arr.getvalue() # Now process the image bytes inputs = vision_processor(prompt, images=image, return_tensors="pt").to(device) with torch.no_grad(): generate_ids = vision_model.generate( **inputs, max_new_tokens=1000, eos_token_id=vision_processor.tokenizer.eos_token_id ) generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] response = vision_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] return response # Custom CSS custom_css = """ body { background-color: #0b0f19; color: #e2e8f0; font-family: 'Arial', sans-serif;} #custom-header { text-align: center; padding: 20px 0; background-color: #1a202c; margin-bottom: 20px; border-radius: 10px;} #custom-header h1 { font-size: 2.5rem; margin-bottom: 0.5rem;} #custom-header h1 .blue { color: #60a5fa;} #custom-header h1 .pink { color: #f472b6;} #custom-header h2 { font-size: 1.5rem; color: #94a3b8;} .suggestions { display: flex; justify-content: center; flex-wrap: wrap; gap: 1rem; margin: 20px 0;} .suggestion { background-color: #1e293b; border-radius: 0.5rem; padding: 1rem; display: flex; align-items: center; transition: transform 0.3s ease; width: 200px;} .suggestion:hover { transform: translateY(-5px);} .suggestion-icon { font-size: 1.5rem; margin-right: 1rem; background-color: #2d3748; padding: 0.5rem; border-radius: 50%;} .gradio-container { max-width: 100% !important;} #component-0, #component-1, #component-2 { max-width: 100% !important;} footer { text-align: center; margin-top: 2rem; color: #64748b;} """ # Custom HTML for the header custom_header = """

Phi 3.5 Multimodal Assistant

Text and Vision AI at Your Service

""" # Custom HTML for suggestions custom_suggestions = """
💬

Chat with the Text Model

🖼️

Analyze Images with Vision Model

🤖

Get AI-generated responses

🔍

Explore advanced options

""" # Gradio interface with gr.Blocks(css=custom_css, theme=gr.themes.Base().set( body_background_fill="#0b0f19", body_text_color="#e2e8f0", button_primary_background_fill="#3b82f6", button_primary_background_fill_hover="#2563eb", button_primary_text_color="white", block_title_text_color="#94a3b8", block_label_text_color="#94a3b8", )) as demo: gr.HTML(custom_header) gr.HTML(custom_suggestions) with gr.Tab("Text Model (Phi-3.5-mini)"): chatbot = gr.Chatbot(height=400) msg = gr.Textbox(label="Message", placeholder="Type your message here...") with gr.Accordion("Advanced Options", open=False): system_prompt = gr.Textbox(value="You are a helpful assistant", label="System Prompt") temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature") max_new_tokens = gr.Slider(minimum=128, maximum=8192, step=1, value=1024, label="Max new tokens") top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p") top_k = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k") submit_btn = gr.Button("Submit", variant="primary") clear_btn = gr.Button("Clear Chat", variant="secondary") submit_btn.click(stream_text_chat, [msg, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k], [chatbot]) clear_btn.click(lambda: None, None, chatbot, queue=False) with gr.Tab("Vision Model (Phi-3.5-vision)"): with gr.Row(): with gr.Column(scale=1): vision_input_img = gr.Image(label="Upload an Image", type="pil") vision_text_input = gr.Textbox(label="Ask a question about the image", placeholder="What do you see in this image?") vision_submit_btn = gr.Button("Analyze Image", variant="primary") with gr.Column(scale=1): vision_output_text = gr.Textbox(label="AI Analysis", lines=10) vision_submit_btn.click(process_vision_query, [vision_input_img, vision_text_input], [vision_output_text]) gr.HTML("") if __name__ == "__main__": demo.launch()