import gradio as gr import os import time from PIL import Image import torch import whisperx from transformers import CLIPVisionModel, CLIPImageProcessor, AutoModelForCausalLM, AutoTokenizer from models.vision_projector_model import VisionProjector from config import VisionProjectorConfig, app_config as cfg device = 'cuda' if torch.cuda.is_available() else 'cpu' clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32") vision_projector = VisionProjector(VisionProjectorConfig()) ckpt = torch.load(cfg['vision_projector_file'], map_location=torch.device(device)) vision_projector.load_state_dict(ckpt['model_state_dict']) phi_base_model = AutoModelForCausalLM.from_pretrained( 'microsoft/phi-2', low_cpu_mem_usage=True, return_dict=True, torch_dtype=torch.float32, trust_remote_code=True # device_map=device_map, ) from peft import PeftModel phi_new_model = "models/phi_adapter" phi_model = PeftModel.from_pretrained(phi_base_model, phi_new_model) phi_model = phi_model.merge_and_unload() compute_type = 'float32' if device != 'cpu': compute_type = 'float16' audi_model = whisperx.load_model("large-v2", device, compute_type=compute_type) tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-2', trust_remote_code=True) tokenizer.pad_token = tokenizer.unk_token ### app functions ## context_added = False context = None context_type = '' query = '' def print_like_dislike(x: gr.LikeData): print(x.index, x.value, x.liked) def add_text(history, text): global context, context_type, context_added, query context_added = False if not context_type and '' not in text: history += text history += "**Please add context (upload image/audio or enter text followed by " elif not context_type: context_type = 'text' context_added = True text = text.replace('', ' ') context = text else: if '' in text: context_type = 'text' context_added = True text = text.replace('', ' ') context = text elif context_type in ['text', 'image']: query = 'Human### ' + text + '\n' + 'AI### ' history = history + [(text, None)] return history, gr.Textbox(value="", interactive=False) def add_file(history, file): global context_added, context, context_type context_added = False context_type = '' context = None history = history + [((file.name,), None)] history += [("Building context...", None)] image = Image.open(file) inputs = clip_processor(images=image, return_tensors="pt") x = clip_model(**inputs, output_hidden_states=True) image_features = x.hidden_states[-2] context = vision_projector(image_features) context_type = 'image' context_added = True return history def audio_file(history, audio_file): global context, context_type, context_added, query if audio_file: history = history + [((audio_file,), None)] context_added = False audio = whisperx.load_audio(audio_file) result = audi_model.transcribe(audio, batch_size=1) model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) text = result["segments"][0]["text"] resp = "🗣" + "_" + text.strip() + "_" history += [(resp, None)] context_type = 'text' context_added = True context = text return history def bot(history): global context, context_added, query, context_type if context_added: response = "**Please proceed with your queries**" context_added = False query = '' else: if context_type == 'image': query_ids = tokenizer.encode(query) query_ids = torch.tensor(query_ids, dtype=torch.int32).unsqueeze(0) query_embeds = phi_model.get_input_embeddings()(query_ids) inputs_embeds = torch.cat([context, query_embeds], dim=1) out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50, bos_token_id=tokenizer.bos_token_id) response = tokenizer.decode(out[0], skip_special_tokens=True) elif context_type in ['text', 'audio']: input_text = context + query input_tokens = tokenizer.encode(input_text) input_ids = torch.tensor(input_tokens, dtype=torch.int32).unsqueeze(0) inputs_embeds = phi_model.get_input_embeddings()(input_ids) out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50, bos_token_id=tokenizer.bos_token_id) response = tokenizer.decode(out[0], skip_special_tokens=True) else: response = "**Please provide a valid context**" if len(history[-1]) > 1: history[-1][1] = "" for character in response: history[-1][1] += character time.sleep(0.05) yield history def clear_fn(): global context_added, context_type, context, query context_added = False context_type = '' context = None query = '' return { chatbot: None } with gr.Blocks() as app: gr.Markdown( """ # ContextGPT - A Multimodel chatbot ### Upload image or audio to add a context. And then ask questions. ### You can also enter text followed by \ to set the context in text format. """ ) chatbot = gr.Chatbot( [], elem_id="chatbot", bubble_full_width=False ) with gr.Row(): aud = gr.Audio(sources=['microphone', 'upload'], type='filepath', max_length=100, show_download_button=True, show_share_button=True) btn = gr.UploadButton("📷", file_types=["image"]) with gr.Row(): txt = gr.Textbox( scale=4, show_label=False, placeholder="Press enter to send ", container=False, ) with gr.Row(): clear = gr.Button("Clear") txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( bot, chatbot, chatbot, api_name="bot_response" ) txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False) file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then( bot, chatbot, chatbot ) chatbot.like(print_like_dislike, None, None) clear.click(clear_fn, None, chatbot, queue=False) aud.stop_recording(audio_file, [chatbot, aud], [chatbot], queue=False).then( bot, chatbot, chatbot, api_name="bot_response" ) aud.upload(audio_file, [chatbot, aud], [chatbot], queue=False).then( bot, chatbot, chatbot, api_name="bot_response" ) app.queue() app.launch()