import gradio as gr import os import time from PIL import Image import torch import whisperx from transformers import CLIPVisionModel, CLIPImageProcessor, AutoModelForCausalLM, AutoTokenizer from models.vision_projector_model import VisionProjector from config import VisionProjectorConfig, app_config as cfg device = 'cuda' if torch.cuda.is_available() else 'cpu' clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32") vision_projector = VisionProjector(VisionProjectorConfig()) ckpt = torch.load(cfg['vision_projector_file'], map_location=torch.device(device)) vision_projector.load_state_dict(ckpt['model_state_dict']) phi_base_model = AutoModelForCausalLM.from_pretrained( 'microsoft/phi-2', low_cpu_mem_usage=True, return_dict=True, torch_dtype=torch.float32, trust_remote_code=True # device_map=device_map, ) from peft import PeftModel phi_new_model = "models/phi_adapter" phi_model = PeftModel.from_pretrained(phi_base_model, phi_new_model) phi_model = phi_model.merge_and_unload().to(device) '''compute_type = 'float32' if device != 'cpu': compute_type = 'float16''' audi_model = whisperx.load_model("small", device, compute_type='float16') tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-2') tokenizer.pad_token = tokenizer.unk_token ### app functions ## context_added = False query_added = False context = None context_type = '' query = '' bot_active = False def print_like_dislike(x: gr.LikeData): print(x.index, x.value, x.liked) def add_text(history, text): global context, context_type, context_added, query, query_added context_added = False if not context_type and '' not in text: context = "**Please add context (upload image/audio or enter text followed by \" context_type = 'error' context_added = True query_added = False elif '' in text: context_type = 'text' context_added = True text = text.replace('', ' ') context = text query_added = False elif context_type in ['[text]', '[image]', '[audio]']: query = 'Human### ' + text + '\n' + 'AI### ' query_added = True context_added = False else: query_added = False context_added = True context = 'error' context = "**Please provide a valid context**" history = history + [(text, None)] return history, gr.Textbox(value="", interactive=False) def add_file(history, file): global context_added, context, context_type, query_added context = file context_type = 'image' context_added = True query_added = False history = history + [((file.name,), None)] return history def audio_upload(history, audio_file): global context, context_type, context_added, query, query_added if audio_file: context_added = True context_type = 'audio' context = audio_file query_added = False history = history + [((audio_file,), None)] else: pass return history def preprocess_fn(history): global context, context_added, query, context_type, query_added if context_added: if context_type == 'image': image = Image.open(context) inputs = clip_processor(images=image, return_tensors="pt") x = clip_model(**inputs, output_hidden_states=True) image_features = x.hidden_states[-2] context = vision_projector(image_features) elif context_type == 'audio': audio_file = context audio = whisperx.load_audio(audio_file) result = audi_model.transcribe(audio, batch_size=1) error = False if result.get('language', None) and result.get('segments', None): try: model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) except Exception as e: error = True print(result.get('language', None)) if not error and result.get('segments', []) and len(result["segments"]) > 0 and result["segments"][0].get('text', None): text = result["segments"][0].get('text', '') print(text) context_type = 'audio' context_added = True context = text query_added = False print(context) else: error = True else: error = True if error: context_type = 'error' context_added = True context = "**Please provide a valid audio file / context**" query_added = False print("Here") return history def bot(history): global context, context_added, query, context_type, query_added, bot_active response = '' if context_added: context_added = False if context_type == 'error': response = context query = '' elif context_type in ['image', 'audio', 'text']: response = '' if context_type == 'audio': response = 'Context: \nšŸ—£ ' + '"_' + context.strip() + '_"\n\n' response += "**Please proceed with your queries**" query = '' context_type = '[' + context_type + ']' elif query_added: query_added = False if context_type == '[image]': query_ids = tokenizer.encode(query) query_ids = torch.tensor(query_ids, dtype=torch.int32).unsqueeze(0).to(device) query_embeds = phi_model.get_input_embeddings()(query_ids) inputs_embeds = torch.cat([context.to(device), query_embeds], dim=1) out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50, bos_token_id=tokenizer.bos_token_id) response = tokenizer.decode(out[0], skip_special_tokens=True) elif context_type in ['[text]', '[audio]']: input_text = context + query input_tokens = tokenizer.encode(input_text) input_ids = torch.tensor(input_tokens, dtype=torch.int32).unsqueeze(0).to(device) inputs_embeds = phi_model.get_input_embeddings()(input_ids) out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50, bos_token_id=tokenizer.bos_token_id) response = tokenizer.decode(out[0], skip_special_tokens=True) else: query = '' response = "**Please provide a valid context**" if response: bot_active = True if history and len(history[-1]) > 1: history[-1][1] = "" for character in response: history[-1][1] += character time.sleep(0.05) yield history time.sleep(0.5) bot_active = False def clear_fn(): global context_added, context_type, context, query, query_added context_added = False context_type = '' context = None query = '' query_added = False return { chatbot: None } with gr.Blocks() as app: gr.Markdown( """ # ContextGPT - A Multimodal chatbot ### Upload image or audio to add a context. And then ask questions. ### You can also enter text followed by \ to set the context. """ ) chatbot = gr.Chatbot( [], elem_id="chatbot", bubble_full_width=False ) with gr.Row(): txt = gr.Textbox( scale=4, show_label=False, placeholder="Press enter to send ", container=False, ) with gr.Row(): aud = gr.Audio(sources=['microphone', 'upload'], type='filepath', max_length=100, show_download_button=True, show_share_button=True) btn = gr.UploadButton("šŸ“·", file_types=["image"]) with gr.Row(): clear = gr.Button("Clear") txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( preprocess_fn, chatbot, chatbot ).then( bot, chatbot, chatbot, api_name="bot_response" ) txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False) file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then( preprocess_fn, chatbot, chatbot ).then( bot, chatbot, chatbot, api_name="bot_response" ) chatbot.like(print_like_dislike, None, None) clear.click(clear_fn, None, chatbot, queue=False) aud.stop_recording(audio_upload, [chatbot, aud], [chatbot], queue=False).then( preprocess_fn, chatbot, chatbot ).then( bot, chatbot, chatbot, api_name="bot_response" ) aud.upload(audio_upload, [chatbot, aud], [chatbot], queue=False).then( preprocess_fn, chatbot, chatbot ).then( bot, chatbot, chatbot, api_name="bot_response" ) app.queue() app.launch()