MultiModelGPT / app.py
piyushgrover's picture
Update app.py
d3dc36c verified
import gradio as gr
import os
import time
from PIL import Image
import torch
import whisperx
from transformers import CLIPVisionModel, CLIPImageProcessor, AutoModelForCausalLM, AutoTokenizer
from models.vision_projector_model import VisionProjector
from config import VisionProjectorConfig, app_config as cfg
device = 'cuda' if torch.cuda.is_available() else 'cpu'
clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
vision_projector = VisionProjector(VisionProjectorConfig())
ckpt = torch.load(cfg['vision_projector_file'], map_location=torch.device(device))
vision_projector.load_state_dict(ckpt['model_state_dict'])
phi_base_model = AutoModelForCausalLM.from_pretrained(
'microsoft/phi-2',
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float32,
trust_remote_code=True
# device_map=device_map,
)
from peft import PeftModel
phi_new_model = "models/phi_adapter"
phi_model = PeftModel.from_pretrained(phi_base_model, phi_new_model)
phi_model = phi_model.merge_and_unload().to(device)
'''compute_type = 'float32'
if device != 'cpu':
compute_type = 'float16'''
audi_model = whisperx.load_model("small", device, compute_type='float16')
tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-2')
tokenizer.pad_token = tokenizer.unk_token
### app functions ##
context_added = False
query_added = False
context = None
context_type = ''
query = ''
bot_active = False
def print_like_dislike(x: gr.LikeData):
print(x.index, x.value, x.liked)
def add_text(history, text):
global context, context_type, context_added, query, query_added
context_added = False
if not context_type and '</context>' not in text:
context = "**Please add context (upload image/audio or enter text followed by \</context\>"
context_type = 'error'
context_added = True
query_added = False
elif '</context>' in text:
context_type = 'text'
context_added = True
text = text.replace('</context>', ' ')
context = text
query_added = False
elif context_type in ['[text]', '[image]', '[audio]']:
query = 'Human### ' + text + '\n' + 'AI### '
query_added = True
context_added = False
else:
query_added = False
context_added = True
context = 'error'
context = "**Please provide a valid context**"
history = history + [(text, None)]
return history, gr.Textbox(value="", interactive=False)
def add_file(history, file):
global context_added, context, context_type, query_added
context = file
context_type = 'image'
context_added = True
query_added = False
history = history + [((file.name,), None)]
return history
def audio_upload(history, audio_file):
global context, context_type, context_added, query, query_added
if audio_file:
context_added = True
context_type = 'audio'
context = audio_file
query_added = False
history = history + [((audio_file,), None)]
else:
pass
return history
def preprocess_fn(history):
global context, context_added, query, context_type, query_added
if context_added:
if context_type == 'image':
image = Image.open(context)
inputs = clip_processor(images=image, return_tensors="pt")
x = clip_model(**inputs, output_hidden_states=True)
image_features = x.hidden_states[-2]
context = vision_projector(image_features)
elif context_type == 'audio':
audio_file = context
audio = whisperx.load_audio(audio_file)
result = audi_model.transcribe(audio, batch_size=1)
error = False
if result.get('language', None) and result.get('segments', None):
try:
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
except Exception as e:
error = True
print(result.get('language', None))
if not error and result.get('segments', []) and len(result["segments"]) > 0 and result["segments"][0].get('text', None):
text = result["segments"][0].get('text', '')
print(text)
context_type = 'audio'
context_added = True
context = text
query_added = False
print(context)
else:
error = True
else:
error = True
if error:
context_type = 'error'
context_added = True
context = "**Please provide a valid audio file / context**"
query_added = False
print("Here")
return history
def bot(history):
global context, context_added, query, context_type, query_added, bot_active
response = ''
if context_added:
context_added = False
if context_type == 'error':
response = context
query = ''
elif context_type in ['image', 'audio', 'text']:
response = ''
if context_type == 'audio':
response = 'Context: \n🗣 ' + '"_' + context.strip() + '_"\n\n'
response += "**Please proceed with your queries**"
query = ''
context_type = '[' + context_type + ']'
elif query_added:
query_added = False
if context_type == '[image]':
query_ids = tokenizer.encode(query)
query_ids = torch.tensor(query_ids, dtype=torch.int32).unsqueeze(0).to(device)
query_embeds = phi_model.get_input_embeddings()(query_ids)
inputs_embeds = torch.cat([context.to(device), query_embeds], dim=1)
out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50,
bos_token_id=tokenizer.bos_token_id)
response = tokenizer.decode(out[0], skip_special_tokens=True)
elif context_type in ['[text]', '[audio]']:
input_text = context + query
input_tokens = tokenizer.encode(input_text)
input_ids = torch.tensor(input_tokens, dtype=torch.int32).unsqueeze(0).to(device)
inputs_embeds = phi_model.get_input_embeddings()(input_ids)
out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50,
bos_token_id=tokenizer.bos_token_id)
response = tokenizer.decode(out[0], skip_special_tokens=True)
else:
query = ''
response = "**Please provide a valid context**"
if response:
bot_active = True
if history and len(history[-1]) > 1:
history[-1][1] = ""
for character in response:
history[-1][1] += character
time.sleep(0.05)
yield history
time.sleep(0.5)
bot_active = False
def clear_fn():
global context_added, context_type, context, query, query_added
context_added = False
context_type = ''
context = None
query = ''
query_added = False
return {
chatbot: None
}
with gr.Blocks() as app:
gr.Markdown(
"""
# ContextGPT - A Multimodal chatbot
### Upload image or audio to add a context. And then ask questions.
### You can also enter text followed by \</context\> to set the context.
"""
)
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
bubble_full_width=False
)
with gr.Row():
txt = gr.Textbox(
scale=4,
show_label=False,
placeholder="Press enter to send ",
container=False,
)
with gr.Row():
aud = gr.Audio(sources=['microphone', 'upload'], type='filepath', max_length=100, show_download_button=True,
show_share_button=True)
btn = gr.UploadButton("📷", file_types=["image"])
with gr.Row():
clear = gr.Button("Clear")
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
preprocess_fn, chatbot, chatbot
).then(
bot, chatbot, chatbot, api_name="bot_response"
)
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(
preprocess_fn, chatbot, chatbot
).then(
bot, chatbot, chatbot, api_name="bot_response"
)
chatbot.like(print_like_dislike, None, None)
clear.click(clear_fn, None, chatbot, queue=False)
aud.stop_recording(audio_upload, [chatbot, aud], [chatbot], queue=False).then(
preprocess_fn, chatbot, chatbot
).then(
bot, chatbot, chatbot, api_name="bot_response"
)
aud.upload(audio_upload, [chatbot, aud], [chatbot], queue=False).then(
preprocess_fn, chatbot, chatbot
).then(
bot, chatbot, chatbot, api_name="bot_response"
)
app.queue()
app.launch()