import streamlit as st import gradio as gr import shap import numpy as np import scipy as sp import torch import transformers from transformers import pipeline from transformers import bert_based_case from transformers import AutoModelForSequenceClassification from transformers import TFAutoModelForSequenceClassification from transformers import AutoTokenizer, AutoModelForTokenClassification import matplotlib.pyplot as plt import sys import csv csv.field_size_limit(sys.maxsize) device = "cuda:0" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained("MiyunKim/Mod4Team5") model = AutoModelForSequenceClassification.from_pretrained("MiyunKim/Mod4Team5").to(device) # build a pipeline object to do predictions pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True) explainer = shap.Explainer(pred) ## # classifier = transformers.pipeline("text-classification", model = "cross-encoder/qnli-electra-base") # def med_score(x): # label = x['label'] # score_1 = x['score'] # return round(score_1,3) # def sym_score(x): # label2sym= x['label'] # score_1sym = x['score'] # return round(score_1sym,3) ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu # def adr_predict(x): encoded_input = tokenizer(x, return_tensors='pt') output = model(**encoded_input) scores = output[0][0].detach() scores = toech.nn.functional.softmax(scores) shap_values = explainer([str(x).lower()]) # # Find the index of the class you want as the default reference (e.g., 'label_1') # label_1_index = np.where(np.array(explainer.output_names) == 'label_1')[0][0] # # Plot the SHAP values for a specific instance in your dataset (e.g., instance 0) # shap.plots.text(shap_values[label_1_index][0]) # Plot the SHAP values using custom color mapping (red for negative, blue for positive) local_plot = shap.plots.text(shap_values[0], display=False, color=lambda v: 'blue' if v >= 0 else 'red') # med = med_score(classifier(x+str(", There is a medication."))[0]) # sym = sym_score(classifier(x+str(", There is a symptom."))[0]) res = ner_pipe(x) entity_colors = { 'Severity': 'red', 'Sign_symptom': 'green', 'Medication': 'lightblue', 'Age': 'yellow', 'Sex':'yellow', 'Diagnostic_procedure':'gray', 'Biological_structure':'silver'} htext = "" prev_end = 0 for entity in res: start = entity['start'] end = entity['end'] word = entity['word'].replace("##", "") color = entity_colors[entity['entity_group']] htext += f"{x[prev_end:start]}{word}" prev_end = end htext += x[prev_end:] return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot,htext # ,{"Contains Medication": float(med), "No Medications": float(1-med)} , {"Contains Symptoms": float(sym), "No Symptoms": float(1-sym)} def main(prob1): text = str(prob1).lower() obj = adr_predict(text) return obj[0],obj[1],obj[2] title = "Welcome to **communicADR** 📈" description1 = """This app takes user inputs and predicts adverse reactions to medications. Please do NOT use for medical diagnosis.""" with gr.Blocks(title=title) as demo: gr.Markdown(f"## {title}") gr.Markdown(description1) gr.Markdown("""---""") prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...") submit_btn = gr.Button("Analyze") with gr.Row(): with gr.Column(visible=True) as output_col: label = gr.Label(label = "Predicted Label") with gr.Column(visible=True) as output_col: local_plot = gr.HTML(label = 'Shap:') htext = gr.HTML(label="NER") # med = gr.Label(label = "Contains Medication") # sym = gr.Label(label = "Contains Symptoms") submit_btn.click( main, [prob1], [label ,local_plot, htext # , med, sym ], api_name="adr" ) with gr.Row(): gr.Markdown("### Click on any of the examples below to see how it works:") gr.Examples([["A 65 year-old male had severe headache after taking Aspirin. The lab results were normal."], ["A 65 year-old female had minor pain in upper abdomen after taking Warfarin."]], [prob1], [label,local_plot, htext # , med, sym ], main, cache_examples=True) demo.launch()