Amarthya7 commited on
Commit
06e0c89
·
verified ·
1 Parent(s): eabaeca

Upload 3 files

Browse files
models/image_analyzer.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
3
+
4
+
5
+ class ImageAnalyzer:
6
+ def __init__(self):
7
+ # Load the chest X-ray analysis model
8
+ try:
9
+ model_name = "facebook/deit-base-patch16-224-medical-cxr"
10
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
11
+ self.model = AutoModelForImageClassification.from_pretrained(model_name)
12
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.model.to(self.device)
14
+ print(f"Image model loaded on {self.device}")
15
+ except Exception as e:
16
+ print(f"Error loading image model: {e}")
17
+ # Fallback to a dummy model
18
+ self.model = None
19
+ self.feature_extractor = None
20
+
21
+ def analyze(self, image):
22
+ """Analyze an X-ray image and return predictions with confidence scores"""
23
+ if self.model is None or self.feature_extractor is None:
24
+ return {"No findings": 0.7, "Abnormal": 0.3} # Dummy results
25
+
26
+ try:
27
+ inputs = self.feature_extractor(images=image, return_tensors="pt").to(
28
+ self.device
29
+ )
30
+ with torch.no_grad():
31
+ outputs = self.model(**inputs)
32
+
33
+ # Process outputs to get predicted class and confidence
34
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
35
+ predictions = {}
36
+ for i, prob in enumerate(probabilities):
37
+ label = self.model.config.id2label[i]
38
+ predictions[label] = float(prob)
39
+
40
+ return predictions
41
+ except Exception as e:
42
+ print(f"Error during image analysis: {e}")
43
+ return {"Error": "Could not analyze image"}
models/multimodal_fusion.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class MultimodalFusion:
2
+ """
3
+ Combines insights from image analysis and text analysis
4
+ to provide comprehensive medical assessment
5
+ """
6
+
7
+ def __init__(self):
8
+ pass
9
+
10
+ def fuse_insights(self, image_results, text_results):
11
+ """
12
+ Fuse insights from image and text analysis
13
+
14
+ Args:
15
+ image_results (dict): Results from image analysis
16
+ text_results (dict): Results from text analysis
17
+
18
+ Returns:
19
+ dict: Combined insights with recommendation
20
+ """
21
+ # In a real implementation, this would use more sophisticated fusion techniques
22
+ combined_insights = {
23
+ "Image findings": image_results,
24
+ "Text findings": text_results,
25
+ }
26
+
27
+ # Simple fusion logic
28
+ confidence_scores = [
29
+ value for key, value in image_results.items() if isinstance(value, float)
30
+ ]
31
+ avg_confidence = (
32
+ sum(confidence_scores) / len(confidence_scores) if confidence_scores else 0
33
+ )
34
+
35
+ # Determine if any abnormalities are detected in image
36
+ image_abnormal = any(
37
+ key != "No findings" and value > 0.5
38
+ for key, value in image_results.items()
39
+ if isinstance(value, float)
40
+ )
41
+
42
+ # Check if text analysis found concerning elements
43
+ text_concerning = text_results.get("Sentiment") == "Concerning"
44
+
45
+ # Generate recommendation
46
+ if image_abnormal and text_concerning:
47
+ recommendation = "High priority: Both image and text indicate abnormalities"
48
+ elif image_abnormal:
49
+ recommendation = "Medium priority: Image shows potential abnormalities"
50
+ elif text_concerning:
51
+ recommendation = "Medium priority: Text report indicates concerns"
52
+ else:
53
+ recommendation = "Low priority: No significant findings detected"
54
+
55
+ combined_insights["Recommendation"] = recommendation
56
+ combined_insights["Confidence"] = f"{avg_confidence:.2f}"
57
+
58
+ return combined_insights
models/text_analyzer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
3
+
4
+
5
+ class TextAnalyzer:
6
+ def __init__(self):
7
+ # Load the medical text analysis model
8
+ try:
9
+ model_name = "medicalai/ClinicalBERT"
10
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
12
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.model.to(self.device)
14
+
15
+ # NER pipeline for medical entities
16
+ self.ner_pipeline = pipeline(
17
+ "ner", model="samrawal/bert-base-uncased_medical-ner"
18
+ )
19
+ print(f"Text model loaded on {self.device}")
20
+ except Exception as e:
21
+ print(f"Error loading text model: {e}")
22
+ # Fallback to dummy functionality
23
+ self.model = None
24
+ self.tokenizer = None
25
+ self.ner_pipeline = None
26
+
27
+ def analyze(self, text):
28
+ """Analyze medical report text and extract key insights"""
29
+ if text.strip() == "":
30
+ return {"Insights": "No text provided"}
31
+
32
+ if self.model is None or self.tokenizer is None:
33
+ # Dummy analysis
34
+ return {
35
+ "Entities": ["fever", "cough"],
36
+ "Sentiment": "Concerning",
37
+ "Key findings": "Patient shows symptoms of respiratory illness",
38
+ }
39
+
40
+ try:
41
+ # Extract medical entities
42
+ if self.ner_pipeline:
43
+ entities = self.ner_pipeline(text)
44
+ unique_entities = list(set([entity["word"] for entity in entities]))
45
+ else:
46
+ unique_entities = []
47
+
48
+ # Simple text classification (in real app, would be more sophisticated)
49
+ inputs = self.tokenizer(
50
+ text, return_tensors="pt", padding=True, truncation=True
51
+ ).to(self.device)
52
+ with torch.no_grad():
53
+ outputs = self.model(**inputs)
54
+
55
+ # This is a placeholder - in reality would depend on the actual model output
56
+ sentiment = (
57
+ "Concerning" if torch.sigmoid(outputs.logits).item() > 0.5 else "Normal"
58
+ )
59
+
60
+ # Generate key findings (simplified)
61
+ key_findings = f"Report indicates {'abnormal' if sentiment == 'Concerning' else 'normal'} findings"
62
+ if unique_entities:
63
+ key_findings += f" with mentions of {', '.join(unique_entities[:5])}"
64
+
65
+ return {
66
+ "Entities": unique_entities[:10],
67
+ "Sentiment": sentiment,
68
+ "Key findings": key_findings,
69
+ }
70
+ except Exception as e:
71
+ print(f"Error during text analysis: {e}")
72
+ return {"Error": "Could not analyze text"}