Spaces:
Build error
Build error
File size: 6,524 Bytes
231fb5f a5e055b 231fb5f 7e2d83a 231fb5f 45cb2fc 7e2d83a 2796a5e d6b2ea5 2796a5e 231fb5f 2796a5e a5e055b 231fb5f fc8a89a 231fb5f fc8a89a 231fb5f fc8a89a 231fb5f fc8a89a 6f27c43 231fb5f 6f27c43 7877f6d 6f27c43 7877f6d 6f27c43 7877f6d 6f27c43 fc8a89a 231fb5f fc8a89a 231fb5f 2b390ac 80e2071 d6b2ea5 80e2071 d3fde93 80e2071 d3fde93 231fb5f 80e2071 d6b2ea5 80e2071 231fb5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextIteratorStreamer, BitsAndBytesConfig
import gradio as gr
from threading import Thread
import numpy as np
from PIL import Image
import subprocess
import spaces
# Install flash-attention
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# Constants
TITLE = "<h1><center>Phi 3.5 Multimodal (Text + Vision)</center></h1>"
DESCRIPTION = "# Phi-3.5 Multimodal Demo (Text + Vision)"
# Model configurations
TEXT_MODEL_ID = "microsoft/Phi-3.5-mini-instruct"
VISION_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Quantization config for text model
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
# Load models and tokenizers
text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID)
text_model = AutoModelForCausalLM.from_pretrained(
TEXT_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=quantization_config
)
vision_model = AutoModelForCausalLM.from_pretrained(
VISION_MODEL_ID,
trust_remote_code=True,
torch_dtype="auto",
attn_implementation="flash_attention_2"
).to(device).eval()
vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True)
# Helper functions
@spaces.GPU
def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20):
conversation = [{"role": "system", "content": system_prompt}]
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer},
])
conversation.append({"role": "user", "content": message})
input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(text_model.device)
streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
top_p=top_p,
top_k=top_k,
temperature=temperature,
eos_token_id=[128001, 128008, 128009],
streamer=streamer,
)
with torch.no_grad():
thread = Thread(target=text_model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield history + [[message, buffer]]
@spaces.GPU
def process_vision_query(image, text_input):
prompt = f"<|user|>\n<|image_1|>\n{text_input}<|end|>\n<|assistant|>\n"
# Ensure the image is in the correct format
if isinstance(image, np.ndarray):
# Convert numpy array to PIL Image
image = Image.fromarray(image).convert("RGB")
elif not isinstance(image, Image.Image):
raise ValueError("Invalid image type. Expected PIL.Image.Image or numpy.ndarray")
# Now process the image
inputs = vision_processor(prompt, images=image, return_tensors="pt").to(device)
with torch.no_grad():
generate_ids = vision_model.generate(
**inputs,
max_new_tokens=1000,
eos_token_id=vision_processor.tokenizer.eos_token_id
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = vision_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return response
# Modified combined chat function
def combined_chat(message, image, history, system_prompt, temperature, max_new_tokens, top_p, top_k):
if image is not None:
# Process image query
response = process_vision_query(image, message)
history.append((message, response))
return history, None
else:
# Process text query
return stream_text_chat(message, history, system_prompt, temperature, max_new_tokens, top_p, top_k), None
# Function to toggle between text and image input
def toggle_input(choice):
if choice == "Text":
return gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(visible=False), gr.update(visible=True)
# Custom CSS
custom_css = """
body { background-color: #343541; color: #ececf1; font-family: 'Arial', sans-serif; }
.gradio-container { max-width: 800px !important; margin: auto; }
#chatbot { height: 400px; overflow-y: auto; }
#input-container { display: flex; align-items: center; }
#msg, #image-input { flex-grow: 1; margin-right: 10px; }
#submit-btn { min-width: 60px; }
footer { text-align: center; margin-top: 2rem; color: #acacbe; }
"""
# Gradio interface
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
chatbot = gr.Chatbot(elem_id="chatbot")
with gr.Row(elem_id="input-container"):
input_type = gr.Radio(["Text", "Image"], value="Text", label="Input Type")
with gr.Column(visible=True) as text_input:
msg = gr.Textbox(
show_label=False,
placeholder="Send a message...",
elem_id="msg"
)
with gr.Column(visible=False) as image_input:
image = gr.Image(type="pil", elem_id="image-input")
submit_btn = gr.Button("Send", elem_id="submit-btn")
clear_btn = gr.Button("Clear Chat", variant="secondary")
with gr.Accordion("Advanced Options", open=False):
system_prompt = gr.Textbox(value="You are a helpful assistant", label="System Prompt")
temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature")
max_new_tokens = gr.Slider(minimum=128, maximum=8192, step=1, value=1024, label="Max new tokens")
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p")
top_k = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k")
input_type.change(toggle_input, input_type, [text_input, image_input])
submit_btn.click(combined_chat, [msg, image, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k], [chatbot, image])
clear_btn.click(lambda: ([], None), None, [chatbot, image], queue=False)
gr.HTML("<footer>Powered by Phi 3.5 Multimodal AI</footer>")
if __name__ == "__main__":
demo.launch() |