Spaces:
Paused
Paused
import gradio as gr | |
import os | |
import time | |
from PIL import Image | |
import torch | |
import whisperx | |
from transformers import CLIPVisionModel, CLIPImageProcessor, AutoModelForCausalLM, AutoTokenizer | |
from models.vision_projector_model import VisionProjector | |
from config import VisionProjectorConfig, app_config as cfg | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") | |
clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
vision_projector = VisionProjector(VisionProjectorConfig()) | |
ckpt = torch.load(cfg['vision_projector_file'], map_location=torch.device(device)) | |
vision_projector.load_state_dict(ckpt['model_state_dict']) | |
phi_base_model = AutoModelForCausalLM.from_pretrained( | |
'microsoft/phi-2', | |
low_cpu_mem_usage=True, | |
return_dict=True, | |
torch_dtype=torch.float32, | |
trust_remote_code=True | |
# device_map=device_map, | |
) | |
from peft import PeftModel | |
phi_new_model = "models/phi_adapter" | |
phi_model = PeftModel.from_pretrained(phi_base_model, phi_new_model) | |
phi_model = phi_model.merge_and_unload().to(device) | |
'''compute_type = 'float32' | |
if device != 'cpu': | |
compute_type = 'float16''' | |
audi_model = whisperx.load_model("small", device, compute_type='float16') | |
tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-2') | |
tokenizer.pad_token = tokenizer.unk_token | |
### app functions ## | |
context_added = False | |
query_added = False | |
context = None | |
context_type = '' | |
query = '' | |
bot_active = False | |
def print_like_dislike(x: gr.LikeData): | |
print(x.index, x.value, x.liked) | |
def add_text(history, text): | |
global context, context_type, context_added, query, query_added | |
context_added = False | |
if not context_type and '</context>' not in text: | |
context = "**Please add context (upload image/audio or enter text followed by \</context\>" | |
context_type = 'error' | |
context_added = True | |
query_added = False | |
elif '</context>' in text: | |
context_type = 'text' | |
context_added = True | |
text = text.replace('</context>', ' ') | |
context = text | |
query_added = False | |
elif context_type in ['[text]', '[image]', '[audio]']: | |
query = 'Human### ' + text + '\n' + 'AI### ' | |
query_added = True | |
context_added = False | |
else: | |
query_added = False | |
context_added = True | |
context = 'error' | |
context = "**Please provide a valid context**" | |
history = history + [(text, None)] | |
return history, gr.Textbox(value="", interactive=False) | |
def add_file(history, file): | |
global context_added, context, context_type, query_added | |
context = file | |
context_type = 'image' | |
context_added = True | |
query_added = False | |
history = history + [((file.name,), None)] | |
return history | |
def audio_upload(history, audio_file): | |
global context, context_type, context_added, query, query_added | |
if audio_file: | |
context_added = True | |
context_type = 'audio' | |
context = audio_file | |
query_added = False | |
history = history + [((audio_file,), None)] | |
else: | |
pass | |
return history | |
def preprocess_fn(history): | |
global context, context_added, query, context_type, query_added | |
if context_added: | |
if context_type == 'image': | |
image = Image.open(context) | |
inputs = clip_processor(images=image, return_tensors="pt") | |
x = clip_model(**inputs, output_hidden_states=True) | |
image_features = x.hidden_states[-2] | |
context = vision_projector(image_features) | |
elif context_type == 'audio': | |
audio_file = context | |
audio = whisperx.load_audio(audio_file) | |
result = audi_model.transcribe(audio, batch_size=1) | |
error = False | |
if result.get('language', None) and result.get('segments', None): | |
try: | |
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) | |
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) | |
except Exception as e: | |
error = True | |
print(result.get('language', None)) | |
if not error and result.get('segments', []) and len(result["segments"]) > 0 and result["segments"][0].get('text', None): | |
text = result["segments"][0].get('text', '') | |
print(text) | |
context_type = 'audio' | |
context_added = True | |
context = text | |
query_added = False | |
print(context) | |
else: | |
error = True | |
else: | |
error = True | |
if error: | |
context_type = 'error' | |
context_added = True | |
context = "**Please provide a valid audio file / context**" | |
query_added = False | |
print("Here") | |
return history | |
def bot(history): | |
global context, context_added, query, context_type, query_added, bot_active | |
response = '' | |
if context_added: | |
context_added = False | |
if context_type == 'error': | |
response = context | |
query = '' | |
elif context_type in ['image', 'audio', 'text']: | |
response = '' | |
if context_type == 'audio': | |
response = 'Context: \n🗣 ' + '"_' + context.strip() + '_"\n\n' | |
response += "**Please proceed with your queries**" | |
query = '' | |
context_type = '[' + context_type + ']' | |
elif query_added: | |
query_added = False | |
if context_type == '[image]': | |
query_ids = tokenizer.encode(query) | |
query_ids = torch.tensor(query_ids, dtype=torch.int32).unsqueeze(0).to(device) | |
query_embeds = phi_model.get_input_embeddings()(query_ids) | |
inputs_embeds = torch.cat([context.to(device), query_embeds], dim=1) | |
out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50, | |
bos_token_id=tokenizer.bos_token_id) | |
response = tokenizer.decode(out[0], skip_special_tokens=True) | |
elif context_type in ['[text]', '[audio]']: | |
input_text = context + query | |
input_tokens = tokenizer.encode(input_text) | |
input_ids = torch.tensor(input_tokens, dtype=torch.int32).unsqueeze(0).to(device) | |
inputs_embeds = phi_model.get_input_embeddings()(input_ids) | |
out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50, | |
bos_token_id=tokenizer.bos_token_id) | |
response = tokenizer.decode(out[0], skip_special_tokens=True) | |
else: | |
query = '' | |
response = "**Please provide a valid context**" | |
if response: | |
bot_active = True | |
if history and len(history[-1]) > 1: | |
history[-1][1] = "" | |
for character in response: | |
history[-1][1] += character | |
time.sleep(0.05) | |
yield history | |
time.sleep(0.5) | |
bot_active = False | |
def clear_fn(): | |
global context_added, context_type, context, query, query_added | |
context_added = False | |
context_type = '' | |
context = None | |
query = '' | |
query_added = False | |
return { | |
chatbot: None | |
} | |
with gr.Blocks() as app: | |
gr.Markdown( | |
""" | |
# ContextGPT - A Multimodal chatbot | |
### Upload image or audio to add a context. And then ask questions. | |
### You can also enter text followed by \</context\> to set the context. | |
""" | |
) | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
bubble_full_width=False | |
) | |
with gr.Row(): | |
txt = gr.Textbox( | |
scale=4, | |
show_label=False, | |
placeholder="Press enter to send ", | |
container=False, | |
) | |
with gr.Row(): | |
aud = gr.Audio(sources=['microphone', 'upload'], type='filepath', max_length=100, show_download_button=True, | |
show_share_button=True) | |
btn = gr.UploadButton("📷", file_types=["image"]) | |
with gr.Row(): | |
clear = gr.Button("Clear") | |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( | |
preprocess_fn, chatbot, chatbot | |
).then( | |
bot, chatbot, chatbot, api_name="bot_response" | |
) | |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False) | |
file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then( | |
preprocess_fn, chatbot, chatbot | |
).then( | |
bot, chatbot, chatbot, api_name="bot_response" | |
) | |
chatbot.like(print_like_dislike, None, None) | |
clear.click(clear_fn, None, chatbot, queue=False) | |
aud.stop_recording(audio_upload, [chatbot, aud], [chatbot], queue=False).then( | |
preprocess_fn, chatbot, chatbot | |
).then( | |
bot, chatbot, chatbot, api_name="bot_response" | |
) | |
aud.upload(audio_upload, [chatbot, aud], [chatbot], queue=False).then( | |
preprocess_fn, chatbot, chatbot | |
).then( | |
bot, chatbot, chatbot, api_name="bot_response" | |
) | |
app.queue() | |
app.launch() |