SrujayReddy31 commited on
Commit
3980c90
·
verified ·
1 Parent(s): b5473c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -69
app.py CHANGED
@@ -1,87 +1,40 @@
1
  import gradio as gr
2
  import torch
3
- from torch import nn
4
- from transformers import BertTokenizer, BertModel
5
 
6
- # Define the BertClassifier class
7
- class BertClassifier(nn.Module):
8
- def __init__(self, bert: BertModel, num_classes: int):
9
- super().__init__()
10
- self.bert = bert
11
- self.classifier = nn.Linear(bert.config.hidden_size, num_classes)
12
- self.criterion = nn.BCELoss()
13
-
14
- def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, labels=None):
15
- outputs = self.bert(
16
- input_ids=input_ids,
17
- attention_mask=attention_mask,
18
- token_type_ids=token_type_ids,
19
- position_ids=position_ids,
20
- head_mask=head_mask
21
- )
22
- cls_output = outputs.pooler_output
23
- cls_output = self.classifier(cls_output)
24
- cls_output = torch.sigmoid(cls_output)
25
-
26
- loss = 0
27
- if labels is not None:
28
- loss = self.criterion(cls_output, labels)
29
- return loss, cls_output
30
-
31
- # Load the tokenizer and model
32
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
33
- bert_model = BertModel.from_pretrained('bert-base-uncased')
34
- model = BertClassifier(bert_model, num_classes=7)
35
-
36
- # Load the model weights from the .pkl file
37
- model.load_state_dict(torch.load('bert_classifier_mltc.pkl', map_location=torch.device('cpu')))
38
- model.eval()
39
 
40
  # Define prediction function
41
- def predict(text):
42
- tokens = tokenizer.encode(text, add_special_tokens=True, max_length=512, truncation=True)
43
- input_ids = torch.tensor([tokens])
44
- mask = (input_ids != tokenizer.pad_token_id).float()
45
 
46
- with torch.no_grad():
47
- _, outputs = model(input_ids, attention_mask=mask)
 
48
 
49
- # Assuming outputs[0] contains the probability scores for each class
50
- confidence_scores = outputs[0].tolist()
51
-
52
- # Convert to a dictionary mapping labels to confidence scores
53
- labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate', 'clean']
54
- prediction = {label: score for label, score in zip(labels, confidence_scores)}
55
-
56
- return prediction
57
 
58
  # Add example texts
59
  examples = [
60
- ["This is a positive and friendly comment."],
61
- ["I'm not happy with your service."],
62
- ["You are a great person!"],
63
- ["How can one defame someone who thinks the Fort Hood Shooting is justified by God's law?"],
64
- ["A Bisexual, like a homosexual or a heterosexual, is not defined by sexual activity. (Much like a 15 year old boy who is attracted to a girl sexually but has never had sex is still straight). A person who is actually sexually attracted/aroused by the same sex as well as the opposite sex is bisexual."],
65
- ["I'LL SAY NIGGER IF I WANT TO BRO"],
66
- ["Do you need a verifiable source for that statement? Check your own user page you fat, miserable balding fuck."],
67
- ["Fuck you bitch! I will go to your house and kill you with my gun as well! Prepare to die!"],
68
- ["you're hot, i will rape you ;)"],
69
- ["WOULDN'T BE THE FIRST TIME BITCH. FUCK YOU I'LL FIND OUT WHERE YOU LIVE, SODOMIZE YOUR WIFE AND THEN BURN YOUR HOUSE DOWN. FUCK YOU YOU FUCKING QUEER."],
70
- ["What a motherfucking piece of crap those fuckheads for blocking us!"],
71
- ["Get a life you animal fucker. Shut up you fucking nigger. Fuck off and shit your pants full of all the shit you can fill them with. 144.131.176.126"],
72
- ["HOPE YOUR HEAD GETS CUT OFF AND SOMEONE WIPS THERE ASS WITH IT AND THEN STABS YOU IN YOUR HEART"],
73
- ["you people are pretty overzealous with this whole free thing. get a fucking life, you fucking niggers !!!23 16!!!"],
74
- ["Stupid peace of shit stop deleting my stuff asshole go die and fall in a hole go to hell!"],
75
- ["Bye! Don't look, come or think of comming back! Tosser."]
76
  ]
 
77
  # Create Gradio interface
78
  iface = gr.Interface(
79
  fn=predict,
80
- inputs=gr.Textbox(lines=10, placeholder="Enter a comment here..."),
81
- outputs=gr.Label(num_top_classes=7),
82
  examples=examples,
83
- title="Toxic Comment Classification",
84
- description="Classify comments into toxic and non-toxic categories using BERT and GNN model.",
85
  )
86
 
87
  iface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
 
4
 
5
+ # Load the tokenizer, retriever, and model
6
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
7
+ retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
8
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Define prediction function
11
+ def predict(input_text):
12
+ # Tokenize input
13
+ input_ids = tokenizer([input_text], return_tensors="pt").input_ids
 
14
 
15
+ # Generate response
16
+ outputs = model.generate(input_ids)
17
+ response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
18
 
19
+ return response
 
 
 
 
 
 
 
20
 
21
  # Add example texts
22
  examples = [
23
+ ["Patient admitted with a history of heart failure and requires detailed follow-up on cardiovascular treatment."],
24
+ ["What are the complications of diabetes mellitus that need to be monitored in this patient?"],
25
+ ["Describe the appropriate treatment for acute respiratory distress syndrome in a critical care setting."],
26
+ ["Explain the signs and symptoms that indicate a neurological emergency in a stroke patient."],
27
+ ["What are the best practices for managing an infectious disease outbreak in a hospital setting?"]
 
 
 
 
 
 
 
 
 
 
 
28
  ]
29
+
30
  # Create Gradio interface
31
  iface = gr.Interface(
32
  fn=predict,
33
+ inputs=gr.Textbox(lines=10, placeholder="Enter your medical question or clinical notes here..."),
34
+ outputs="text",
35
  examples=examples,
36
+ title="MIMIC-IV RAG Implementation",
37
+ description="Use RAG (Retrieval-Augmented Generation) to generate responses or provide additional information based on clinical notes and medical questions. This model helps in generating relevant information based on existing medical literature.",
38
  )
39
 
40
  iface.launch()