Dr-Brain / app.py
thunder-007's picture
Rename app and update context.
6bc78ec
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import openai
from dotenv import load_dotenv
import os
from monai.networks.nets import SegResNet
from monai.inferers import sliding_window_inference
from monai.transforms import (
Activations,
AsDiscrete,
Compose,
LoadImaged,
NormalizeIntensityd,
Orientationd,
EnsureChannelFirstd,
)
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import librosa
import torch
load_dotenv()
title = 'Dr Brain 🧠'
description = '''
'''
channel_mapping = {
0: 1,
1: 0,
2: 2,
}
preproc_transforms = Compose(
[
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys="image"),
Orientationd(keys=["image"], axcodes="RAS"),
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
]
)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
processor_whisper = WhisperProcessor.from_pretrained("whisper-tiny")
model_whisper = WhisperForConditionalGeneration.from_pretrained("whisper-tiny")
model_tumor_seg = SegResNet(
blocks_down=[1, 2, 2, 4],
blocks_up=[1, 1, 1],
init_filters=16,
in_channels=4,
out_channels=3,
dropout_prob=0.2,
).to('cpu')
model_tumor_seg.load_state_dict(torch.load("models/best_metric_model_epoch_40.pth", map_location='cpu'))
def inference(input):
def _compute(input):
return sliding_window_inference(
inputs=input,
roi_size=(240, 240, 160),
sw_batch_size=1,
predictor=model_tumor_seg,
overlap=0.5,
)
return _compute(input)
examples = [
['examples/BRATS_225.nii.gz', 83, 2, 'english', 'examples/sample1_en.mp3'],
['examples/BRATS_485.nii.gz', 90, 1, 'japanese', 'examples/sample2_jp.mp3'],
['examples/BRATS_485.nii.gz', 110, 0, 'german', 'examples/sample3_gr.mp3'],
]
def process_audio(sampling_rate, waveform):
waveform = waveform / 32678.0
if len(waveform.shape) > 1:
waveform = librosa.to_mono(waveform.T)
if sampling_rate != 16000:
waveform = librosa.resample(waveform, orig_sr=sampling_rate, target_sr=16000)
waveform = waveform[:16000 * 30]
waveform = torch.tensor(waveform)
return waveform
openai.api_key = os.environ.get("OPENAI_KEY")
def make_llm_call(prompt,
context="You are a text generation model DR-Brain Developed by team brute force a team 4 AI engineers from RMKCET college they are HARSHA VARDHAN V AKA Thunder-007 , SAWIN KUMAR Y , CHARAN TEJA P, KISHORE S. Your specialized in medical stuff, when refering Dr-Brain refer your self also don't mention openai anywhere."):
messages = [{"role": "user", "content": prompt}]
if context:
messages.insert(0, {"role": "system", "content": context})
response_obj = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=messages)
response_message = dict(dict(response_obj)['choices'][0])["message"]["content"]
return response_message
def detector(tumor_file, slice_number, channel, language, audio_question, text_question):
llm_answer = "Hi I'm Dr brain please enter a question to answer"
if text_question:
llm_answer = make_llm_call(text_question)
elif audio_question:
sampling_rate, waveform = audio_question
forced_decoder_ids = processor_whisper.get_decoder_prompt_ids(language=language, task="transcribe")
waveform = process_audio(sampling_rate, waveform)
audio_inputs = processor_whisper(audio=waveform, sampling_rate=16000, return_tensors="pt")
predicted_ids = model_whisper.generate(**audio_inputs, max_length=400, forced_decoder_ids=forced_decoder_ids)
transcription = processor_whisper.batch_decode(predicted_ids, skip_special_tokens=True)
llm_quesion = transcription[0]
llm_answer = make_llm_call(llm_quesion)
tumor_file_path = tumor_file.name
processed_data = preproc_transforms({'image': [tumor_file_path]})
tensor_3d_input = processed_data['image'].unsqueeze(0).to('cpu')
with torch.no_grad():
output = inference(tensor_3d_input)
img_slice = tensor_3d_input[0][channel, :, :, slice_number]
plt.imshow(img_slice, cmap='gray')
input_image_path = f"input_img_channel{channel}.png"
plt.axis('off')
plt.savefig(input_image_path, bbox_inches='tight', pad_inches=0)
channel_image = np.asarray(Image.open(input_image_path))
os.remove(input_image_path)
output_image_path = f"ouput_img_channel{channel}.png"
plt.imshow(post_trans(output[0][channel_mapping[channel], :, :, slice_number]))
plt.axis('off')
plt.savefig(output_image_path, bbox_inches='tight', pad_inches=0)
segment_image = np.asarray(Image.open(output_image_path))
os.remove(output_image_path)
return (channel_image, segment_image, llm_answer)
interface = gr.Interface(fn=detector, inputs=[gr.File(label="Tumor File"),
gr.Slider(0, 200, 50, step=1, label="Slice Number"),
gr.Radio((0, 1, 2), label="Channel"),
gr.Radio(("english", "japanese", "german", "spanish"), label="Language"),
gr.Audio(source="microphone"),
gr.Textbox(label='Text Question')],
outputs=[gr.Image(label='channel', shape=(1, 1)),
gr.Image(label='Segmented Tumor', shape=(1, 1)),
gr.Textbox(label="Dr brain response")], title=title,
examples=examples,
description=description, theme='dark')
theme = gr.themes.Default().set(
button_primary_background_fill="#FF0000",
button_primary_background_fill_dark="#AAAAAA",
)
interface.launch(debug=True)