implax commited on
Commit
ba12ca2
·
1 Parent(s): d03fc1d
Files changed (2) hide show
  1. app.py +6 -682
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,690 +1,11 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
-
5
-
6
- from datasets import load_dataset
7
- import pandas as pd
8
- import re
9
- import json
10
-
11
- # Load dataset
12
- ds = load_dataset("AGBonnet/augmented-clinical-notes")
13
- df = ds["train"].to_pandas() # Convert to pandas DataFrame for easier manipulation
14
-
15
-
16
-
17
-
18
- from snorkel.labeling import labeling_function
19
-
20
- # Define pulmonary keywords (ICD-10 inspired)
21
- PULMONARY_TERMS = {
22
- "asthma", "copd", "pneumonia", "pulmonary fibrosis", "bronchitis",
23
- "tuberculosis", "lung cancer", "emphysema", "pneumothorax",
24
- "cystic fibrosis", "ARDS", "pulmonary embolism", "chronic bronchitis"
25
- }
26
-
27
- # Define regex patterns for variations (e.g., "COPD exacerbation")
28
- PULMONARY_REGEX = re.compile(
29
- r'('
30
- r'\b(asthma|asthmatic|bronchial asthma)\b|'
31
- r'\b(COPD|chronic obstructive pulmonary disease|chronic obstructive lung disease)\b|'
32
- r'\b(pneumonia|CAP|HAP|VAP|community-acquired pneumonia|hospital-acquired pneumonia|ventilator-associated pneumonia)\b|'
33
- r'\b(pulmonary embolism|PE|pulmonary thromboembolism)\b|'
34
- r'\b(tuberculosis|TB|mycobacterium tuberculosis|pulmonary TB)\b|'
35
- r'\b(lung cancer|lung carcinoma|bronchogenic carcinoma|NSCLC|SCLC|non-small cell lung cancer|small cell lung cancer)\b|'
36
- r'\b(bronchitis|acute bronchitis|chronic bronchitis)\b|'
37
- r'\b(pulmonary fibrosis|idiopathic pulmonary fibrosis|IPF)\b|'
38
- r'\b(cystic fibrosis|CF)\b|'
39
- r'\b(pneumothorax|collapsed lung)\b|'
40
- r'\b(ARDS|acute respiratory distress syndrome)\b|'
41
- r'\b(emphysema|pulmonary emphysema)\b|'
42
- r'\b(interstitial lung disease|ILD)\b|'
43
- r'\b(pulmonary hypertension|PH)\b|'
44
- r'\b(pleural effusion|hydrothorax)\b|'
45
- r'\b(silicosis|occupational lung disease)\b|'
46
- r'\b(COVID-19|SARS-CoV-2|coronavirus)\b'
47
- r')',
48
- flags=re.IGNORECASE # Match case-insensitively
49
- )
50
-
51
- # Labeling Function 1: Check structured JSON summary for diagnoses
52
- @labeling_function()
53
- def lf_summary_diagnosis(row):
54
- try:
55
- summary = json.loads(row["summary"])
56
- diagnoses = summary.get("diagnosis", [])
57
- # Ensure diagnoses is a list
58
- if not isinstance(diagnoses, list):
59
- diagnoses = [diagnoses]
60
- for d in diagnoses:
61
- if any(term in d.lower() for term in PULMONARY_TERMS):
62
- return 1
63
- except Exception as e:
64
- pass
65
- return 0 # non-pulmonary
66
-
67
- # Labeling Function 2: Keyword search in notes
68
- @labeling_function()
69
- def lf_note_keywords(row):
70
- note_text = ((row.get("note") or "") + " " + (row.get("full_note") or "")).lower()
71
- if any(term in note_text for term in PULMONARY_TERMS):
72
- return 1
73
- return 0
74
-
75
- # Improved negation-aware regex (checks for negation near pulmonary terms)
76
- NEGATION_REGEX = re.compile(
77
- r'\b(no history of|ruled out|denies|negative for|no|without)\b\s*' # Negation trigger
78
- r'(?:\w+\s+){0,5}' # Allow up to 5 words between negation and pulmonary term
79
- r'(' + PULMONARY_REGEX.pattern + r')', # Pulmonary terms from your regex
80
- flags=re.IGNORECASE
81
- )
82
-
83
- @labeling_function()
84
- def lf_note_regex(row):
85
- note_text = row["note"] + " " + row["full_note"]
86
- # Check for pulmonary terms
87
- pulmonary_match = PULMONARY_REGEX.search(note_text)
88
- if not pulmonary_match:
89
- return 0 # No pulmonary term found
90
-
91
- # Check if the pulmonary term is negated
92
- if NEGATION_REGEX.search(note_text):
93
- return 0 # Pulmonary term is negated
94
- return 1 # Pulmonary term is affirmed
95
-
96
-
97
-
98
-
99
- from snorkel.labeling import PandasLFApplier, LFAnalysis
100
- from snorkel.labeling.model import LabelModel
101
-
102
- # Combine labeling functions
103
- lfs = [lf_summary_diagnosis, lf_note_keywords, lf_note_regex]
104
-
105
- # Apply labeling functions to the DataFrame using Snorkel's PandasLFApplier
106
- applier = PandasLFApplier(lfs)
107
- L_train = applier.apply(df)
108
-
109
- # Analyze LF performance (coverage, conflicts)
110
- analysis = LFAnalysis(L_train, lfs)
111
- analysis.lf_summary() # This prints a summary of your labeling functions
112
-
113
- # Train a LabelModel to combine LF outputs
114
- label_model = LabelModel(cardinality=2, verbose=True)
115
- label_model.fit(L_train, n_epochs=500, log_freq=100)
116
-
117
- # Predict probabilistic labels; here, tie_break_policy="abstain" will mark ties as abstentions (-1)
118
- df["label_pulmonary"] = label_model.predict(L_train, tie_break_policy="abstain")
119
-
120
- # Filter for pulmonary cases (label == 1)
121
- pulmonary_df = df[df["label_pulmonary"] == 1].reset_index(drop=True)
122
-
123
- # Optionally, inspect the results
124
- print("Pulmonary cases:", len(pulmonary_df))
125
-
126
-
127
-
128
-
129
- # Display a random sample of rows
130
- print(df[['note', 'summary', 'label_pulmonary']].sample(10))
131
-
132
-
133
-
134
-
135
- # Define regex patterns for target conditions
136
- CONDITION_REGEX = {
137
- "Asthma": re.compile(
138
- r'\b(asthma|asthmatic|bronchial asthma)\b',
139
- flags=re.IGNORECASE
140
- ),
141
- "COPD": re.compile(
142
- r'\b(COPD|chronic obstructive pulmonary disease|chronic obstructive lung disease|emphysema|chronic bronchitis)\b',
143
- flags=re.IGNORECASE
144
- ),
145
- "Pneumonia": re.compile(
146
- r'\b(pneumonia|CAP|HAP|VAP|community-acquired pneumonia|hospital-acquired pneumonia|ventilator-associated pneumonia)\b',
147
- flags=re.IGNORECASE
148
- ),
149
- "Lung Cancer": re.compile(
150
- r'\b(lung cancer|lung carcinoma|bronchogenic carcinoma|NSCLC|SCLC|non-small cell lung cancer|small cell lung cancer)\b',
151
- flags=re.IGNORECASE
152
- ),
153
- "Tuberculosis": re.compile(
154
- r'\b(tuberculosis|TB|mycobacterium tuberculosis|pulmonary TB)\b',
155
- flags=re.IGNORECASE
156
- ),
157
- "Pleural Effusion": re.compile(
158
- r'\b(pleural effusion|hydrothorax)\b',
159
- flags=re.IGNORECASE
160
- )
161
- }
162
-
163
- # Negation regex (improved to check proximity to condition terms)
164
- NEGATION_REGEX = re.compile(
165
- r'\b(no history of|ruled out|denies|negative for|no|without)\b\s*' # Negation trigger
166
- r'(?:\w+\s+){0,5}' # Allow up to 5 words between negation and condition
167
- r'(' + '|'.join([pattern.pattern for pattern in CONDITION_REGEX.values()]) + r')', # Combined condition terms
168
- flags=re.IGNORECASE
169
- )
170
-
171
- def get_condition_labels(row):
172
- note_text = row["note"] + " " + row["full_note"]
173
- labels = []
174
-
175
- # Check for negations first
176
- negation_match = NEGATION_REGEX.search(note_text)
177
-
178
- for condition, pattern in CONDITION_REGEX.items():
179
- # Skip if the condition term is negated
180
- if negation_match and pattern.search(negation_match.group(0)):
181
- continue
182
- # Check if condition is mentioned
183
- if pattern.search(note_text):
184
- labels.append(condition)
185
-
186
- return labels
187
-
188
- # Apply labeling to pulmonary cases
189
- pulmonary_df["conditions"] = pulmonary_df.apply(get_condition_labels, axis=1)
190
-
191
- # Classify remaining cases as "Other Pulmonary"
192
- pulmonary_df["conditions"] = pulmonary_df["conditions"].apply(
193
- lambda x: x if x else ["Other Pulmonary"]
194
- )
195
-
196
-
197
-
198
-
199
- from collections import defaultdict
200
-
201
- label_counts = defaultdict(int)
202
- for labels in pulmonary_df["conditions"]:
203
- for label in labels:
204
- label_counts[label] += 1
205
-
206
- print("Label distribution:")
207
- for k, v in label_counts.items():
208
- print(f"{k}: {v}")
209
-
210
-
211
-
212
-
213
- import pandas as pd
214
-
215
- # Label distribution data
216
- label_counts = {
217
- "Asthma": 509,
218
- "Pneumonia": 1294,
219
- "Other Pulmonary": 1907,
220
- "Tuberculosis": 851,
221
- "Pleural Effusion": 743,
222
- "COPD": 697,
223
- "Lung Cancer": 415
224
- }
225
-
226
- # Convert to DataFrame for easier plotting
227
- df_counts = pd.DataFrame(list(label_counts.items()), columns=["Condition", "Count"])
228
-
229
- import matplotlib.pyplot as plt
230
- import seaborn as sns
231
-
232
- # Set style
233
- sns.set(style="whitegrid")
234
-
235
- # Create bar plot
236
- plt.figure(figsize=(10, 6))
237
- sns.barplot(x="Condition", y="Count", data=df_counts, palette="viridis")
238
-
239
- # Add labels and title
240
- plt.title("Distribution of Pulmonary Conditions", fontsize=16)
241
- plt.xlabel("Condition", fontsize=14)
242
- plt.ylabel("Count", fontsize=14)
243
- plt.xticks(rotation=45, ha="right") # Rotate x-axis labels for readability
244
-
245
- # Show plot
246
- plt.tight_layout()
247
- plt.show()
248
-
249
-
250
-
251
-
252
- import nltk
253
- from nltk.corpus import stopwords
254
- from wordcloud import WordCloud
255
- import matplotlib.pyplot as plt
256
-
257
- # Download stop words from nltk (do this once)
258
- nltk.download('stopwords')
259
-
260
- # Get the list of stop words
261
- stop_words = set(stopwords.words('english'))
262
-
263
- # Combine all notes into one large string
264
- text = " ".join(pulmonary_df['note'].dropna()) # Combine all notes into a single string
265
-
266
- # Tokenize the text and remove stop words
267
- filtered_words = [word for word in text.split() if word.lower() not in stop_words]
268
-
269
- # Join the filtered words back into a single string
270
- cleaned_text = " ".join(filtered_words)
271
-
272
- # Create a WordCloud object
273
- wordcloud = WordCloud(width=800, height=400, background_color='white').generate(cleaned_text)
274
-
275
- # Plot the WordCloud image
276
- plt.figure(figsize=(10, 5))
277
- plt.imshow(wordcloud, interpolation='bilinear')
278
- plt.axis('off')
279
- plt.show()
280
-
281
-
282
-
283
-
284
- from sklearn.feature_extraction.text import TfidfVectorizer
285
-
286
- vectorizer = TfidfVectorizer(max_features=5000, stop_words="english",ngram_range=(1, 2))
287
- X = vectorizer.fit_transform(pulmonary_df['note']) # Note column
288
-
289
-
290
-
291
-
292
- #Transform Object data type to string
293
- pulmonary_df["conditions"] = pulmonary_df["conditions"].apply(lambda x: x[0])
294
-
295
- from sklearn.model_selection import train_test_split
296
- from sklearn.linear_model import LogisticRegression
297
- from sklearn.metrics import classification_report
298
-
299
- X_train, X_test, y_train, y_test = train_test_split(X, pulmonary_df['conditions'], test_size=0.1, random_state=42)
300
-
301
- #Logistic Regression Classification Report
302
- model = LogisticRegression(max_iter=1000)
303
- model.fit(X_train, y_train)
304
- y_pred_before_smote = model.predict(X_test)
305
- print(classification_report(y_test, y_pred_before_smote))
306
- report_before_smote = classification_report(y_test, y_pred_before_smote, output_dict=True)
307
-
308
-
309
-
310
-
311
- from imblearn.over_sampling import SMOTE
312
-
313
- # Logistic Regression after SMOTE
314
- smote = SMOTE(random_state=42)
315
- X_train_res, y_train_res = smote.fit_resample(X_train, y_train)
316
- model.fit(X_train_res, y_train_res)
317
- y_pred_after_smote = model.predict(X_test)
318
- print(classification_report(y_test, y_pred_after_smote))
319
- report_after_smote = classification_report(y_test, y_pred_after_smote, output_dict=True)
320
-
321
-
322
-
323
-
324
- # Convert y_resampled to a pandas Series to get the distribution
325
- y_resampled = pd.Series(y_train_res)
326
-
327
- # Get class distribution after SMOTE
328
- label_counts_smote = y_resampled.value_counts()
329
-
330
- print("Label distribution after SMOTE:")
331
- print(label_counts_smote)
332
-
333
-
334
-
335
-
336
- from sklearn.ensemble import RandomForestClassifier
337
- from sklearn.metrics import classification_report
338
- from sklearn.model_selection import train_test_split
339
- from prettytable import PrettyTable
340
- import pandas as pd
341
-
342
- # Split the data before applying SMOTE
343
- X_train, X_test, y_train, y_test = train_test_split(X, pulmonary_df['conditions'], test_size=0.3, random_state=42)
344
-
345
- # Train Random Forest without SMOTE
346
- rf_model_before_smote = RandomForestClassifier(n_estimators=100, random_state=42)
347
- rf_model_before_smote.fit(X_train, y_train)
348
-
349
- # Make Predictions
350
- y_pred_before_smote = rf_model_before_smote.predict(X_test)
351
-
352
- # Generate classification report
353
- report_before_smote = classification_report(y_test, y_pred_before_smote, output_dict=True)
354
-
355
- # Convert to DataFrame
356
- df_report_before_smote = pd.DataFrame(report_before_smote).transpose()
357
-
358
- # Use PrettyTable for a more structured look
359
- table_before_smote = PrettyTable()
360
- table_before_smote.field_names = ["Class", "Precision", "Recall", "F1-Score", "Support"]
361
-
362
- for index, row in df_report_before_smote.iterrows():
363
- table_before_smote.add_row([index, round(row['precision'], 2), round(row['recall'], 2), round(row['f1-score'], 2), int(row['support'])])
364
-
365
- print(table_before_smote)
366
-
367
-
368
-
369
-
370
- # %pip install prettytable
371
- from prettytable import PrettyTable
372
-
373
- #Random Forest Now Model with SMOTE
374
-
375
- from sklearn.ensemble import RandomForestClassifier
376
- from sklearn.metrics import classification_report
377
- from imblearn.over_sampling import SMOTE
378
- from sklearn.model_selection import train_test_split
379
-
380
- # Apply SMOTE
381
- smote = SMOTE(random_state=42)
382
- X_resampled, y_resampled = smote.fit_resample(X, pulmonary_df['conditions'])
383
-
384
- # Split the data
385
- X_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled, test_size=0.3, random_state=42)
386
-
387
- # Train Random Forest
388
- rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
389
- rf_model.fit(X_train, y_train)
390
-
391
- # Make Predictions
392
- y_pred = rf_model.predict(X_test)
393
-
394
- # Generate classification report
395
- report = classification_report(y_test, y_pred, output_dict=True)
396
-
397
- # Convert to DataFrame
398
- df_report = pd.DataFrame(report).transpose()
399
-
400
- # Use PrettyTable for a more structured look
401
- table = PrettyTable()
402
- table.field_names = ["Class", "Precision", "Recall", "F1-Score", "Support"]
403
-
404
- for index, row in df_report.iterrows():
405
- table.add_row([index, round(row['precision'], 2), round(row['recall'], 2), round(row['f1-score'], 2), int(row['support'])])
406
-
407
- print(table)
408
-
409
-
410
-
411
-
412
- import matplotlib.pyplot as plt
413
- import pandas as pd
414
-
415
- # Assuming report_after_smote and df_report are already generated as DataFrames
416
-
417
- # Convert the necessary columns to DataFrame for easy plotting
418
- df_lr = pd.DataFrame(report_after_smote).transpose() # Logistic Regression after SMOTE
419
- df_rf = pd.DataFrame(report).transpose() # Random Forest
420
-
421
- # Extract relevant columns (precision, recall, and f1-score)
422
- metrics = ['precision', 'recall', 'f1-score']
423
-
424
- # Set up the figure for the plot
425
- plt.figure(figsize=(10, 6))
426
-
427
- # Plot for each metric
428
- for metric in metrics:
429
- plt.plot(df_lr.index, df_lr[metric], label=f'LR After SMOTE - {metric.capitalize()}', marker='o')
430
- plt.plot(df_rf.index, df_rf[metric], label=f'RF - {metric.capitalize()}', marker='x')
431
-
432
- # Add labels and title
433
- plt.title('Comparison of Logistic Regression and Random Forest Performance')
434
- plt.xlabel('Class Labels')
435
- plt.ylabel('Score')
436
- plt.legend(title="Model and Metric")
437
-
438
- # Show the plot
439
- plt.xticks(rotation=45)
440
- plt.tight_layout()
441
- plt.show()
442
-
443
-
444
-
445
-
446
- from sklearn.utils import resample
447
-
448
- #Separate majority and minority classes
449
- grouped = pulmonary_df.groupby("Conditions")
450
- max_size = grouped.size().max()
451
-
452
- #Oversample each class to the same count as the majority class
453
- oversampled_df = grouped.apply(
454
- lambda x: resample(x, replace=True, n_samples=max_size, random_state=42)
455
-
456
- ).reset_index(drop=True)
457
-
458
- print(oversampled_df["conditions"].value_counts())
459
-
460
-
461
-
462
-
463
-
464
- from datasets import Dataset
465
- from transformers import AutoTokenizer
466
-
467
- dataset = Dataset.from_pandas(oversampled_df)
468
-
469
- # Tokenize using ClinicalBERT
470
- model_checkpoint = "emilyalsentzer/Bio_ClinicalBERT"
471
- tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
472
-
473
- def tokenize_function(example):
474
- return tokenizer(example["note"], truncation=True, padding="max_length", max_length=512)
475
-
476
- tokenized_dataset = dataset.map(tokenize_function, batched=True)
477
-
478
- # Label encoding
479
- from sklearn.preprocessing import LabelEncoder
480
- label_encoder = LabelEncoder()
481
- tokenized_dataset = tokenized_dataset.add_column("label", label_encoder.fit_transform(tokenized_dataset["conditions"]))
482
-
483
- # Final formatting
484
- tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
485
-
486
-
487
-
488
 
489
  import torch
490
-
491
- # print("Number of GPU: ", torch.cuda.device_count())
492
- # print("GPU Name: ", torch.cuda.get_device_name())
493
-
494
-
495
-
496
-
497
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
498
- # print('Using device:', device)
499
-
500
-
501
-
502
-
503
- from sklearn.metrics import accuracy_score, precision_recall_fscore_support
504
- from transformers import Trainer, AutoModelForSequenceClassification, AutoTokenizer
505
-
506
-
507
- split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
508
- train_dataset = split_dataset["train"]
509
- eval_dataset = split_dataset["test"]
510
-
511
-
512
- # Load saved model + tokenizer
513
- model = AutoModelForSequenceClassification.from_pretrained("./trained_clinicalbert")
514
- tokenizer = AutoTokenizer.from_pretrained("./trained_clinicalbert")
515
-
516
- # Load model with the correct number of classes
517
- # num_classes = len(label_encoder.classes_)
518
- # model = AutoModelForSequenceClassification.from_pretrained(
519
- # model_checkpoint,
520
- # num_labels=num_classes
521
- # )
522
-
523
- from transformers import TrainingArguments
524
-
525
- training_args = TrainingArguments(
526
- output_dir="./results",
527
- eval_strategy="epoch",
528
- save_strategy="epoch",
529
- logging_strategy="epoch",
530
- per_device_train_batch_size=8,
531
- per_device_eval_batch_size=8,
532
- num_train_epochs=4,
533
- learning_rate=2e-5,
534
- weight_decay=0.01,
535
- load_best_model_at_end=True,
536
- metric_for_best_model="f1",
537
- )
538
-
539
- def compute_metrics(p):
540
- preds = p.predictions.argmax(axis=1)
541
- labels = p.label_ids
542
- precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="weighted")
543
- acc = accuracy_score(labels, preds)
544
- return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
545
-
546
-
547
- trainer = Trainer(
548
- model=model,
549
- args=training_args,
550
- train_dataset=train_dataset,
551
- eval_dataset=eval_dataset,
552
- tokenizer=tokenizer,
553
- compute_metrics=compute_metrics
554
- )
555
-
556
- # trainer.train()
557
-
558
- # trainer.save_model("./trained_clinicalbert")
559
- # tokenizer.save_pretrained("./trained_clinicalbert")
560
-
561
- trainer.evaluate()
562
-
563
-
564
-
565
-
566
- predictions_output = trainer.predict(eval_dataset)
567
- y_pred = predictions_output.predictions.argmax(axis=1)
568
- y_prob = predictions_output.predictions # softmax scores (for ROC/AUC)
569
- y_true = predictions_output.label_ids
570
-
571
- from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, roc_curve, precision_recall_curve, average_precision_score
572
- import matplotlib.pyplot as plt
573
- import seaborn as sns
574
-
575
- # Confusion Matrix
576
- cm = confusion_matrix(y_true, y_pred)
577
- plt.figure(figsize=(8,6))
578
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
579
- xticklabels=label_encoder.classes_,
580
- yticklabels=label_encoder.classes_)
581
- plt.xlabel("Predicted")
582
- plt.ylabel("Actual")
583
- plt.title("Confusion Matrix")
584
- plt.show()
585
-
586
- # Classification Report (includes F1, Precision, Recall per class)
587
- print(classification_report(y_true, y_pred, target_names=label_encoder.classes_))
588
-
589
-
590
-
591
-
592
- from sklearn.preprocessing import label_binarize
593
- from scipy.special import softmax
594
-
595
-
596
- # Apply softmax to get probabilities
597
- y_probs = softmax(y_prob, axis=1)
598
-
599
- # Binarize the true labels for multi-class ROC (One-vs-Rest)
600
- y_true_bin = label_binarize(y_true, classes=list(range(len(label_encoder.classes_))))
601
-
602
- # Plot ROC curve per class
603
- plt.figure(figsize=(10, 6))
604
-
605
- for i, class_name in enumerate(label_encoder.classes_):
606
- fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_probs[:, i])
607
- auc_score = roc_auc_score(y_true_bin[:, i], y_probs[:, i])
608
- plt.plot(fpr, tpr, label=f"{class_name} (AUC = {auc_score:.2f})")
609
-
610
- # Plot random classifier line
611
- plt.plot([0, 1], [0, 1], 'k--', label="Random (AUC = 0.50)")
612
-
613
- # Plot formatting
614
- plt.xlabel("False Positive Rate")
615
- plt.ylabel("True Positive Rate")
616
- plt.title("ROC Curves (One-vs-Rest for 7 Classes)")
617
- plt.legend(loc="lower right")
618
- plt.grid(True)
619
- plt.tight_layout()
620
- plt.show()
621
-
622
-
623
-
624
-
625
- #PR Curves
626
-
627
- #Plot PR Curve per class
628
- plt.figure(figsize=(8, 6))
629
-
630
- for i, class_name in enumerate(label_encoder.classes_):
631
- precision, recall, _ = precision_recall_curve(y_true_bin[:, i], y_probs[:, i])
632
- pr_auc = average_precision_score(y_true_bin[:, i], y_probs[:, i])
633
- plt.plot(recall, precision, label=f"{class_name} (AP = {pr_auc:.2f})")
634
-
635
- # Plot formatting
636
- plt.xlabel("Recall")
637
- plt.ylabel("Precision")
638
- plt.title("Precision-Recall Curves (One-vs-Rest for 7 Classes)")
639
- plt.legend(loc="lower left")
640
- plt.grid(True)
641
- plt.tight_layout()
642
- plt.show()
643
-
644
-
645
-
646
-
647
  import pandas as pd
 
648
  from scipy.special import softmax
649
 
650
- # Sample clinical notes
651
- demo_notes = [
652
- "Patient presents with high fever, chills, shortness of breath, and crackles heard on auscultation. Chest X-ray shows consolidation in the right lower lobe.",
653
- "Patient complains of chest tightness and wheezing that worsens at night and after physical activity. Symptoms relieved by use of albuterol inhaler.",
654
- "The patient is a 68-year-old male with a 40-pack-year smoking history who presents with worsening shortness of breath over the past 6 months. He reports a chronic productive cough that is worse in the mornings, occasional wheezing, and fatigue with mild exertion. On physical examination, breath sounds are diminished bilaterally with prolonged expiratory phase. Pulmonary function tests show reduced FEV1/FVC ratio consistent with obstructive lung disease. There are no signs of active infection. He denies fever or chills. Chest X-ray reveals hyperinflated lungs and flattened diaphragms.",
655
- "Patient has persistent cough, night sweats, weight loss, and hemoptysis. Sputum test positive for acid-fast bacilli.",
656
- "Patient presents with shortness of breath and pleuritic chest pain. Physical exam shows decreased breath sounds and dullness to percussion on the left side. Ultrasound confirms fluid accumulation."
657
- ]
658
-
659
- # Predict function
660
- def batch_predict(notes, model, tokenizer, label_encoder):
661
- predictions = []
662
- for note in notes:
663
- inputs = tokenizer(note, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
664
- inputs = {key: val.to(model.device) for key, val in inputs.items()}
665
- outputs = model(**inputs)
666
- probs = softmax(outputs.logits.detach().cpu().numpy(), axis=1)
667
- pred_idx = probs.argmax(axis=1)[0]
668
- pred_class = label_encoder.inverse_transform([pred_idx])[0]
669
- confidence = probs[0][pred_idx]
670
- predictions.append((pred_class, round(float(confidence), 4)))
671
- return predictions
672
-
673
- # Create DataFrame
674
- demo_df = pd.DataFrame({"Clinical Note": demo_notes})
675
- demo_df[["Predicted Label", "Confidence"]] = batch_predict(demo_notes, model, tokenizer, label_encoder)
676
-
677
- # View table
678
- demo_df
679
-
680
-
681
-
682
-
683
- import gradio as gr
684
-
685
- # Extract class names dynamically from the DataFrame
686
- classes = sorted(oversampled_df["conditions"].unique().tolist())
687
-
688
  # Load model and tokenizer
689
  model_path = "trained_clinicalbert"
690
  model = AutoModelForSequenceClassification.from_pretrained(model_path)
@@ -694,6 +15,9 @@ model.eval()
694
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
695
  model.to(device)
696
 
 
 
 
697
  # Prediction function
698
  def predict_clinical_note(note):
699
  inputs = tokenizer(note, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
 
1
+ #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import torch
4
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import pandas as pd
6
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
7
  from scipy.special import softmax
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Load model and tokenizer
10
  model_path = "trained_clinicalbert"
11
  model = AutoModelForSequenceClassification.from_pretrained(model_path)
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  model.to(device)
17
 
18
+ # Define class labels manually (ensure it matches the trained model)
19
+ classes = ["Asthma", "COPD", "Lung Cancer", "Other Pulmonary", "Pleural Effusion", "Pneumonia", "Tuberculosis"]
20
+
21
  # Prediction function
22
  def predict_clinical_note(note):
23
  inputs = tokenizer(note, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
requirements.txt CHANGED
@@ -29,6 +29,7 @@ httpcore==1.0.8
29
  httpx==0.28.1
30
  huggingface-hub==0.30.2
31
  idna==3.10
 
32
  Jinja2==3.1.6
33
  joblib==1.4.2
34
  kiwisolver==1.4.8
 
29
  httpx==0.28.1
30
  huggingface-hub==0.30.2
31
  idna==3.10
32
+ imbalanced-learn==0.11.0
33
  Jinja2==3.1.6
34
  joblib==1.4.2
35
  kiwisolver==1.4.8