Spaces:
Paused
Paused
import gradio as gr | |
import os | |
import time | |
from PIL import Image | |
import torch | |
import whisperx | |
from transformers import CLIPVisionModel, CLIPImageProcessor, AutoModelForCausalLM, AutoTokenizer | |
from models.vision_projector_model import VisionProjector | |
from config import VisionProjectorConfig, app_config as cfg | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") | |
clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
vision_projector = VisionProjector(VisionProjectorConfig()) | |
ckpt = torch.load(cfg['vision_projector_file'], map_location=torch.device(device)) | |
vision_projector.load_state_dict(ckpt['model_state_dict']) | |
phi_base_model = AutoModelForCausalLM.from_pretrained( | |
'microsoft/phi-2', | |
low_cpu_mem_usage=True, | |
return_dict=True, | |
torch_dtype=torch.float32, | |
trust_remote_code=True | |
# device_map=device_map, | |
) | |
from peft import PeftModel | |
phi_new_model = "models/phi_adapter" | |
phi_model = PeftModel.from_pretrained(phi_base_model, phi_new_model) | |
phi_model = phi_model.merge_and_unload() | |
compute_type = 'float32' | |
if device != 'cpu': | |
compute_type = 'float16' | |
audi_model = whisperx.load_model("large-v2", device, compute_type=compute_type) | |
tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-2', trust_remote_code=True) | |
tokenizer.pad_token = tokenizer.unk_token | |
### app functions ## | |
context_added = False | |
context = None | |
context_type = '' | |
query = '' | |
def print_like_dislike(x: gr.LikeData): | |
print(x.index, x.value, x.liked) | |
def add_text(history, text): | |
global context, context_type, context_added, query | |
context_added = False | |
if not context_type and '</context>' not in text: | |
history += text | |
history += "**Please add context (upload image/audio or enter text followed by </context>" | |
elif not context_type: | |
context_type = 'text' | |
context_added = True | |
text = text.replace('</context>', ' ') | |
context = text | |
else: | |
if '</context>' in text: | |
context_type = 'text' | |
context_added = True | |
text = text.replace('</context>', ' ') | |
context = text | |
elif context_type in ['text', 'image']: | |
query = 'Human### ' + text + '\n' + 'AI### ' | |
history = history + [(text, None)] | |
return history, gr.Textbox(value="", interactive=False) | |
def add_file(history, file): | |
global context_added, context, context_type | |
context_added = False | |
context_type = '' | |
context = None | |
history = history + [((file.name,), None)] | |
history += [("Building context...", None)] | |
image = Image.open(file) | |
inputs = clip_processor(images=image, return_tensors="pt") | |
x = clip_model(**inputs, output_hidden_states=True) | |
image_features = x.hidden_states[-2] | |
context = vision_projector(image_features) | |
context_type = 'image' | |
context_added = True | |
return history | |
def audio_file(history, audio_file): | |
global context, context_type, context_added, query | |
if audio_file: | |
history = history + [((audio_file,), None)] | |
context_added = False | |
audio = whisperx.load_audio(audio_file) | |
result = audi_model.transcribe(audio, batch_size=1) | |
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) | |
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) | |
text = result["segments"][0]["text"] | |
resp = "π£" + "_" + text.strip() + "_" | |
history += [(resp, None)] | |
context_type = 'text' | |
context_added = True | |
context = text | |
return history | |
def bot(history): | |
global context, context_added, query, context_type | |
if context_added: | |
response = "**Please proceed with your queries**" | |
context_added = False | |
query = '' | |
else: | |
if context_type == 'image': | |
query_ids = tokenizer.encode(query) | |
query_ids = torch.tensor(query_ids, dtype=torch.int32).unsqueeze(0) | |
query_embeds = phi_model.get_input_embeddings()(query_ids) | |
inputs_embeds = torch.cat([context, query_embeds], dim=1) | |
out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50, | |
bos_token_id=tokenizer.bos_token_id) | |
response = tokenizer.decode(out[0], skip_special_tokens=True) | |
elif context_type in ['text', 'audio']: | |
input_text = context + query | |
input_tokens = tokenizer.encode(input_text) | |
input_ids = torch.tensor(input_tokens, dtype=torch.int32).unsqueeze(0) | |
inputs_embeds = phi_model.get_input_embeddings()(input_ids) | |
out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50, | |
bos_token_id=tokenizer.bos_token_id) | |
response = tokenizer.decode(out[0], skip_special_tokens=True) | |
else: | |
response = "**Please provide a valid context**" | |
if len(history[-1]) > 1: | |
history[-1][1] = "" | |
for character in response: | |
history[-1][1] += character | |
time.sleep(0.05) | |
yield history | |
def clear_fn(): | |
global context_added, context_type, context, query | |
context_added = False | |
context_type = '' | |
context = None | |
query = '' | |
return { | |
chatbot: None | |
} | |
with gr.Blocks() as app: | |
gr.Markdown( | |
""" | |
# ContextGPT - A Multimodel chatbot | |
### Upload image or audio to add a context. And then ask questions. | |
### You can also enter text followed by \</context\> to set the context in text format. | |
""" | |
) | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
bubble_full_width=False | |
) | |
with gr.Row(): | |
aud = gr.Audio(sources=['microphone', 'upload'], type='filepath', max_length=100, show_download_button=True, | |
show_share_button=True) | |
btn = gr.UploadButton("π·", file_types=["image"]) | |
with gr.Row(): | |
txt = gr.Textbox( | |
scale=4, | |
show_label=False, | |
placeholder="Press enter to send ", | |
container=False, | |
) | |
with gr.Row(): | |
clear = gr.Button("Clear") | |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( | |
bot, chatbot, chatbot, api_name="bot_response" | |
) | |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False) | |
file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then( | |
bot, chatbot, chatbot | |
) | |
chatbot.like(print_like_dislike, None, None) | |
clear.click(clear_fn, None, chatbot, queue=False) | |
aud.stop_recording(audio_file, [chatbot, aud], [chatbot], queue=False).then( | |
bot, chatbot, chatbot, api_name="bot_response" | |
) | |
aud.upload(audio_file, [chatbot, aud], [chatbot], queue=False).then( | |
bot, chatbot, chatbot, api_name="bot_response" | |
) | |
app.queue() | |
app.launch() | |