File size: 5,953 Bytes
85a1b63
 
 
 
fba53f9
 
85a1b63
 
 
 
 
 
 
 
 
 
 
 
a440325
 
85a1b63
 
fba53f9
 
6bc78ec
85a1b63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a440325
 
 
85a1b63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a440325
 
 
85a1b63
 
 
a440325
 
 
 
 
 
 
 
 
 
 
fba53f9
 
 
 
6bc78ec
fba53f9
 
 
 
 
 
 
 
9d20b7c
fba53f9
9d20b7c
 
 
fba53f9
 
 
 
 
 
 
 
85a1b63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fba53f9
85a1b63
 
13679ee
 
a440325
 
9d20b7c
 
85a1b63
 
fba53f9
13679ee
a440325
85a1b63
13679ee
 
 
 
 
85a1b63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import openai
from dotenv import load_dotenv
import os
from monai.networks.nets import SegResNet
from monai.inferers import sliding_window_inference
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
    LoadImaged,
    NormalizeIntensityd,
    Orientationd,
    EnsureChannelFirstd,
)
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import librosa
import torch

load_dotenv()

title = 'Dr Brain 🧠'
description = '''
'''

channel_mapping = {
    0: 1,
    1: 0,
    2: 2,
}

preproc_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys="image"),
        Orientationd(keys=["image"], axcodes="RAS"),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

processor_whisper = WhisperProcessor.from_pretrained("whisper-tiny")
model_whisper = WhisperForConditionalGeneration.from_pretrained("whisper-tiny")

model_tumor_seg = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2,
).to('cpu')

model_tumor_seg.load_state_dict(torch.load("models/best_metric_model_epoch_40.pth", map_location='cpu'))


def inference(input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(240, 240, 160),
            sw_batch_size=1,
            predictor=model_tumor_seg,
            overlap=0.5,
        )

    return _compute(input)


examples = [
    ['examples/BRATS_225.nii.gz', 83, 2, 'english', 'examples/sample1_en.mp3'],
    ['examples/BRATS_485.nii.gz', 90, 1, 'japanese', 'examples/sample2_jp.mp3'],
    ['examples/BRATS_485.nii.gz', 110, 0, 'german', 'examples/sample3_gr.mp3'],
]


def process_audio(sampling_rate, waveform):
    waveform = waveform / 32678.0
    if len(waveform.shape) > 1:
        waveform = librosa.to_mono(waveform.T)
    if sampling_rate != 16000:
        waveform = librosa.resample(waveform, orig_sr=sampling_rate, target_sr=16000)
    waveform = waveform[:16000 * 30]
    waveform = torch.tensor(waveform)
    return waveform


openai.api_key = os.environ.get("OPENAI_KEY")


def make_llm_call(prompt,
                  context="You are a text generation model DR-Brain Developed by team brute force a team 4 AI engineers from RMKCET college they are HARSHA VARDHAN V AKA Thunder-007 , SAWIN KUMAR Y , CHARAN TEJA P, KISHORE S. Your specialized in medical stuff, when refering Dr-Brain refer your self also don't mention openai anywhere."):
    messages = [{"role": "user", "content": prompt}]
    if context:
        messages.insert(0, {"role": "system", "content": context})
    response_obj = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=messages)
    response_message = dict(dict(response_obj)['choices'][0])["message"]["content"]
    return response_message


def detector(tumor_file, slice_number, channel, language, audio_question, text_question):
    llm_answer = "Hi I'm Dr brain please enter a question to answer"
    if text_question:
        llm_answer = make_llm_call(text_question)
    elif audio_question:
        sampling_rate, waveform = audio_question
        forced_decoder_ids = processor_whisper.get_decoder_prompt_ids(language=language, task="transcribe")
        waveform = process_audio(sampling_rate, waveform)
        audio_inputs = processor_whisper(audio=waveform, sampling_rate=16000, return_tensors="pt")
        predicted_ids = model_whisper.generate(**audio_inputs, max_length=400, forced_decoder_ids=forced_decoder_ids)
        transcription = processor_whisper.batch_decode(predicted_ids, skip_special_tokens=True)
        llm_quesion = transcription[0]
        llm_answer = make_llm_call(llm_quesion)
    tumor_file_path = tumor_file.name
    processed_data = preproc_transforms({'image': [tumor_file_path]})
    tensor_3d_input = processed_data['image'].unsqueeze(0).to('cpu')
    with torch.no_grad():
        output = inference(tensor_3d_input)
    img_slice = tensor_3d_input[0][channel, :, :, slice_number]
    plt.imshow(img_slice, cmap='gray')
    input_image_path = f"input_img_channel{channel}.png"
    plt.axis('off')
    plt.savefig(input_image_path, bbox_inches='tight', pad_inches=0)
    channel_image = np.asarray(Image.open(input_image_path))
    os.remove(input_image_path)
    output_image_path = f"ouput_img_channel{channel}.png"
    plt.imshow(post_trans(output[0][channel_mapping[channel], :, :, slice_number]))
    plt.axis('off')
    plt.savefig(output_image_path, bbox_inches='tight', pad_inches=0)
    segment_image = np.asarray(Image.open(output_image_path))
    os.remove(output_image_path)
    return (channel_image, segment_image, llm_answer)


interface = gr.Interface(fn=detector, inputs=[gr.File(label="Tumor File"),
                                              gr.Slider(0, 200, 50, step=1, label="Slice Number"),
                                              gr.Radio((0, 1, 2), label="Channel"),
                                              gr.Radio(("english", "japanese", "german", "spanish"), label="Language"),
                                              gr.Audio(source="microphone"),
                                              gr.Textbox(label='Text Question')],
                         outputs=[gr.Image(label='channel', shape=(1, 1)),
                                  gr.Image(label='Segmented Tumor', shape=(1, 1)),
                                  gr.Textbox(label="Dr brain response")], title=title,
                         examples=examples,
                         description=description, theme='dark')

theme = gr.themes.Default().set(
    button_primary_background_fill="#FF0000",
    button_primary_background_fill_dark="#AAAAAA",
)

interface.launch(debug=True)