Spaces:
Sleeping
Sleeping
Issue fix
Browse files- SkinCancerDiagnosis.py +268 -0
- app.py +4 -344
- rag_pipeline.py +59 -0
SkinCancerDiagnosis.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision.transforms as transforms
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from evo_vit import EvoViTModel
|
8 |
+
import io
|
9 |
+
import os
|
10 |
+
from fpdf import FPDF
|
11 |
+
from torchvision.models import resnet50
|
12 |
+
import nest_asyncio
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
+
|
15 |
+
device='cuda' if torch.cuda.is_available() else 'cpu'
|
16 |
+
|
17 |
+
def load_model(repo_id, filename):
|
18 |
+
model_path = hf_hub_download(
|
19 |
+
repo_id=repo_id,
|
20 |
+
filename=filename,
|
21 |
+
)
|
22 |
+
model = EvoViTModel(img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_classes=2, hidden_dim=512)
|
23 |
+
model.classifier = nn.Linear(512, 1)
|
24 |
+
state_dict = torch.load(model_path, map_location=device)
|
25 |
+
new_state_dict = {}
|
26 |
+
for key, value in state_dict.items():
|
27 |
+
if key.startswith("backbone."):
|
28 |
+
new_key = key[len("backbone."):]
|
29 |
+
else:
|
30 |
+
new_key = key
|
31 |
+
new_state_dict[new_key] = value
|
32 |
+
|
33 |
+
if "classifier.weight" in new_state_dict:
|
34 |
+
original_weight = new_state_dict["classifier.weight"]
|
35 |
+
new_state_dict["classifier.weight"] = original_weight[0:1, :]
|
36 |
+
if "classifier.bias" in new_state_dict:
|
37 |
+
original_bias = new_state_dict["classifier.bias"]
|
38 |
+
new_state_dict["classifier.bias"] = original_bias[0:1]
|
39 |
+
model.load_state_dict(new_state_dict, strict=False)
|
40 |
+
model.to(device)
|
41 |
+
model.eval()
|
42 |
+
return model
|
43 |
+
|
44 |
+
def load_binary_models():
|
45 |
+
base_models = []
|
46 |
+
class_models_mapping = {
|
47 |
+
"Acne and Rosacea Photos": 'santhosh/10fold_model_acne.pth',
|
48 |
+
"Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions": 'santhosh/5fold_model_actinic.pth',
|
49 |
+
"Atopic Dermatitis Photos": 'keerthi/Atopic/best_global_model_5fold.pth',
|
50 |
+
"Bullous Disease Photos": 'santhosh/10fold_model_bullous.pth',
|
51 |
+
"Cellulitis Impetigo and other Bacterial Infections": 'santhosh/10fold_model_cellulitis.pth',
|
52 |
+
"Eczema Photos": 'santhosh/5fold_model_eczema.pth',
|
53 |
+
"ExanthemsandDrugEruptions": 'santhosh/10fold_model_exantherms.pth',
|
54 |
+
"Hair Loss Photos Alopecia and other Hair Diseases": 'keerthi/HairLoss/best_global_model_5fold.pth',
|
55 |
+
"Herpes HPV and other STDs Photos": 'keerthi/Herpes/best_global_model_5fold.pth',
|
56 |
+
"Light Diseases and Disorders of Pigmentation": 'santhosh/5fold_model_light.pth',
|
57 |
+
"Lupus and other Connective Tissue diseases": 'keerthi/Lupus/best_global_model_5fold.pth',
|
58 |
+
"Melanoma Skin Cancer Nevi and Moles": 'keerthi/Melanoma/best_global_model_10fold.pth',
|
59 |
+
"Nail Fungus and other Nail Disease": 'santhosh/5fold_model_nail.pth',
|
60 |
+
"Poison Ivy Photos and other Contact Dermatitis": 'santhosh/5fold_model_poison.pth',
|
61 |
+
"Psoriasis pictures Lichen Planus and related diseases": 'santhosh/10fold_model_psoriasis.pth',
|
62 |
+
"Scabies Lyme Disease and other Infestations and Bites": 'santhosh/5fold_model_scabies.pth',
|
63 |
+
"Seborrheic Keratoses and other Benign Tumors": 'santhosh/10fold_model_seboh.pth',
|
64 |
+
"Systemic Disease": 'keerthi/Systemic/best_global_model_5fold.pth',
|
65 |
+
"Tinea Ringworm Candidiasis and other Fungal Infections": 'santhosh/10fold_model_tinea.pth',
|
66 |
+
"Urticaria Hives": 'keerthi/Urticaria/best_global_model_10fold.pth',
|
67 |
+
"Vascular Tumors": 'keerthi/Vascular/best_global_model_5fold.pth',
|
68 |
+
"Vasculitis Photos": 'keerthi/Vasculitis/best_global_model_10fold.pth',
|
69 |
+
"Warts Molluscum and other Viral Infections": 'santhosh/10fold_model_warts.pth'
|
70 |
+
}
|
71 |
+
repo_id = "KeerthiVM/SkinCancerDiagnosis" # Your Hugging Face repo
|
72 |
+
|
73 |
+
for class_name, filename in class_models_mapping.items():
|
74 |
+
# model_path = os.path.join("best_models_overall", rel_path)
|
75 |
+
model = load_model(repo_id, filename)
|
76 |
+
base_models.append(model)
|
77 |
+
return base_models
|
78 |
+
|
79 |
+
|
80 |
+
class DynamicCNN(nn.Module):
|
81 |
+
def __init__(self, input_channels, fc_layers, num_classes, dropout_rate=0.3):
|
82 |
+
super(DynamicCNN, self).__init__()
|
83 |
+
fc_layers_list = []
|
84 |
+
in_dim = input_channels
|
85 |
+
|
86 |
+
for fc_dim in fc_layers:
|
87 |
+
fc_layers_list.append(nn.Linear(in_dim, fc_dim))
|
88 |
+
fc_layers_list.append(nn.BatchNorm1d(fc_dim))
|
89 |
+
fc_layers_list.append(nn.ReLU())
|
90 |
+
fc_layers_list.append(nn.Dropout(dropout_rate))
|
91 |
+
in_dim = fc_dim
|
92 |
+
|
93 |
+
fc_layers_list.append(nn.Linear(in_dim, num_classes))
|
94 |
+
self.fc = nn.Sequential(*fc_layers_list)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
x = self.fc(x)
|
98 |
+
return x
|
99 |
+
|
100 |
+
|
101 |
+
class SkinDiseaseClassifier:
|
102 |
+
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
103 |
+
self.device = torch.device(device)
|
104 |
+
self.class_names = [
|
105 |
+
"Acne and Rosacea Photos",
|
106 |
+
"Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions",
|
107 |
+
"Atopic Dermatitis Photos",
|
108 |
+
"Bullous Disease Photos",
|
109 |
+
"Cellulitis Impetigo and other Bacterial Infections",
|
110 |
+
"Eczema Photos",
|
111 |
+
"ExanthemsandDrugEruptions",
|
112 |
+
"Hair Loss Photos Alopecia and other Hair Diseases",
|
113 |
+
"Herpes HPV and other STDs Photos",
|
114 |
+
"Light Diseases and Disorders of Pigmentation",
|
115 |
+
"Lupus and other Connective Tissue diseases",
|
116 |
+
"Melanoma Skin Cancer Nevi and Moles",
|
117 |
+
"Nail Fungus and other Nail Disease",
|
118 |
+
"Poison Ivy Photos and other Contact Dermatitis",
|
119 |
+
"Psoriasis pictures Lichen Planus and related diseases",
|
120 |
+
"Scabies Lyme Disease and other Infestations and Bites",
|
121 |
+
"Seborrheic Keratoses and other Benign Tumors",
|
122 |
+
"Systemic Disease",
|
123 |
+
"Tinea Ringworm Candidiasis and other Fungal Infections",
|
124 |
+
"Urticaria Hives",
|
125 |
+
"Vascular Tumors",
|
126 |
+
"Vasculitis Photos",
|
127 |
+
"Warts Molluscum and other Viral Infections"
|
128 |
+
]
|
129 |
+
|
130 |
+
# Initialize models (they'll be loaded when needed)
|
131 |
+
self.base_models = None
|
132 |
+
self.meta_model = None
|
133 |
+
self.resnet_feature_extractor = None
|
134 |
+
|
135 |
+
# Image transformations
|
136 |
+
self.transform = transforms.Compose([
|
137 |
+
transforms.Resize((224, 224)),
|
138 |
+
transforms.ToTensor(),
|
139 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
140 |
+
])
|
141 |
+
|
142 |
+
def load_models(self):
|
143 |
+
"""Load all required models"""
|
144 |
+
# Load binary models
|
145 |
+
self.base_models = load_binary_models()
|
146 |
+
for model in self.base_models:
|
147 |
+
model.to(self.device)
|
148 |
+
model.eval()
|
149 |
+
|
150 |
+
# Load ResNet feature extractor
|
151 |
+
model = resnet50(pretrained=True)
|
152 |
+
layers = [model.layer1, model.layer2, model.layer3, model.layer4]
|
153 |
+
self.resnet_feature_extractor = nn.Sequential(
|
154 |
+
model.conv1, model.bn1, model.relu, model.maxpool, *layers
|
155 |
+
)
|
156 |
+
self.resnet_feature_extractor.to(self.device)
|
157 |
+
self.resnet_feature_extractor.eval()
|
158 |
+
|
159 |
+
# Load meta model
|
160 |
+
print("=== Loading model with weights_only=False ===")
|
161 |
+
meta_model_path = hf_hub_download(
|
162 |
+
repo_id="KeerthiVM/SkinCancerDiagnosis",
|
163 |
+
filename="best_meta_model_two_layer_version4.pth"
|
164 |
+
)
|
165 |
+
checkpoint = torch.load(meta_model_path, map_location=self.device, weights_only=False)
|
166 |
+
|
167 |
+
correct_input_size = checkpoint['state_dict']['fc.0.weight'].shape[1]
|
168 |
+
input_size = 23 + 4 + 7168 # Adjust based on your actual feature size
|
169 |
+
fc_layers = [1024, 512, 256] # Use whatever was in your best model
|
170 |
+
|
171 |
+
self.meta_model = DynamicCNN(
|
172 |
+
input_channels=correct_input_size,
|
173 |
+
fc_layers=fc_layers,
|
174 |
+
num_classes=23,
|
175 |
+
dropout_rate=0.5
|
176 |
+
)
|
177 |
+
|
178 |
+
self.meta_model.load_state_dict(checkpoint['state_dict'])
|
179 |
+
self.meta_model.to(self.device)
|
180 |
+
self.meta_model.eval()
|
181 |
+
|
182 |
+
def extract_image_features(self, image_tensor):
|
183 |
+
"""Extract features using ResNet"""
|
184 |
+
with torch.no_grad():
|
185 |
+
features = []
|
186 |
+
x = image_tensor
|
187 |
+
for layer in self.resnet_feature_extractor.children():
|
188 |
+
x = layer(x)
|
189 |
+
if isinstance(layer, nn.Sequential): # For residual blocks
|
190 |
+
# features.append(F.adaptive_avg_pool2d(x, (1, 1)).flatten(1))
|
191 |
+
# features = torch.cat(features, dim=1)
|
192 |
+
pooled = F.adaptive_avg_pool2d(x, (1, 1)).flatten(1)
|
193 |
+
features.append(pooled)
|
194 |
+
features = torch.cat(features, dim=1)
|
195 |
+
return features.cpu().numpy()
|
196 |
+
|
197 |
+
def predict(self, image, top_k=3):
|
198 |
+
"""Make prediction for a single image"""
|
199 |
+
if self.base_models is None or self.meta_model is None:
|
200 |
+
# self.load_models()
|
201 |
+
raise RuntimeError("Models not loaded - call load_models() first")
|
202 |
+
|
203 |
+
# Load and preprocess image
|
204 |
+
try:
|
205 |
+
# image = Image.open(image_path).convert('RGB')
|
206 |
+
image = image.convert('RGB')
|
207 |
+
except:
|
208 |
+
raise ValueError("Could not load image from path")
|
209 |
+
|
210 |
+
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
211 |
+
|
212 |
+
# Extract features
|
213 |
+
with torch.no_grad():
|
214 |
+
# Get probabilities from each binary model
|
215 |
+
binary_probs = []
|
216 |
+
for model in self.base_models:
|
217 |
+
outputs = model(image_tensor)
|
218 |
+
probs = torch.sigmoid(outputs).squeeze(1)
|
219 |
+
binary_probs.append(probs)
|
220 |
+
|
221 |
+
binary_features = torch.stack(binary_probs, dim=1)
|
222 |
+
|
223 |
+
# Get image features
|
224 |
+
image_features = self.extract_image_features(image_tensor)
|
225 |
+
image_features = torch.from_numpy(image_features).float().to(self.device)
|
226 |
+
|
227 |
+
# Calculate probability statistics
|
228 |
+
top3_probs = torch.topk(binary_features, 3, dim=1).values
|
229 |
+
prob_stats = torch.stack([
|
230 |
+
binary_features.mean(dim=1, keepdim=True),
|
231 |
+
binary_features.std(dim=1, keepdim=True),
|
232 |
+
top3_probs.mean(dim=1, keepdim=True),
|
233 |
+
(top3_probs[:, 0] - top3_probs[:, 2]).unsqueeze(1) # Confidence gap
|
234 |
+
], dim=1).squeeze(-1)
|
235 |
+
|
236 |
+
# Combine all features
|
237 |
+
combined_features = torch.cat([
|
238 |
+
binary_features,
|
239 |
+
image_features,
|
240 |
+
prob_stats
|
241 |
+
], dim=1)
|
242 |
+
|
243 |
+
# Make prediction with meta-model
|
244 |
+
with torch.no_grad():
|
245 |
+
outputs = self.meta_model(combined_features)
|
246 |
+
probabilities = torch.softmax(outputs, dim=1).squeeze().cpu().numpy()
|
247 |
+
|
248 |
+
# Get top predictions
|
249 |
+
top_indices = np.argsort(probabilities)[-top_k:][::-1]
|
250 |
+
top_predictions = [
|
251 |
+
(self.class_names[i], float(probabilities[i]))
|
252 |
+
for i in top_indices
|
253 |
+
]
|
254 |
+
|
255 |
+
return {
|
256 |
+
"top_predictions": top_predictions,
|
257 |
+
"all_probabilities": {name: float(prob) for name, prob in zip(self.class_names, probabilities)}
|
258 |
+
}
|
259 |
+
|
260 |
+
def initialize_classifier():
|
261 |
+
print("⚙️ Initializing skin disease classifier...")
|
262 |
+
classifier = SkinDiseaseClassifier()
|
263 |
+
classifier.load_models()
|
264 |
+
dummy_img = Image.new('RGB', (224, 224))
|
265 |
+
classifier.predict(dummy_img)
|
266 |
+
|
267 |
+
print("⚙️ Initialization successful")
|
268 |
+
return classifier
|
app.py
CHANGED
@@ -1,7 +1,3 @@
|
|
1 |
-
from qdrant_client import QdrantClient
|
2 |
-
from langchain_qdrant import Qdrant
|
3 |
-
from langchain_community.embeddings import HuggingFaceEmbeddings
|
4 |
-
from langchain_community.embeddings import SentenceTransformerEmbeddings
|
5 |
import streamlit as st
|
6 |
import torchvision.transforms as transforms
|
7 |
import torch
|
@@ -13,60 +9,21 @@ from evo_vit import EvoViTModel
|
|
13 |
import io
|
14 |
import os
|
15 |
from fpdf import FPDF
|
16 |
-
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
17 |
-
from langchain.chains import RetrievalQA
|
18 |
-
from langchain.prompts import PromptTemplate
|
19 |
from torchvision.models import resnet50
|
20 |
import nest_asyncio
|
21 |
-
from sentence_transformers import SentenceTransformer
|
22 |
from huggingface_hub import hf_hub_download
|
23 |
-
|
24 |
-
|
|
|
25 |
|
26 |
nest_asyncio.apply()
|
27 |
device='cuda' if torch.cuda.is_available() else 'cpu'
|
28 |
-
|
29 |
st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
|
30 |
|
31 |
-
# os.environ["PGVECTOR_CONNECTION_STRING"] = "postgresql+psycopg2://postgres:postgres@localhost:5432/VectorDB"
|
32 |
-
|
33 |
# === Model Selection ===
|
34 |
available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"]
|
35 |
st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models)
|
36 |
|
37 |
-
# === Qdrant DB Setup ===
|
38 |
-
qdrant_client = QdrantClient(
|
39 |
-
url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
|
40 |
-
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
|
41 |
-
)
|
42 |
-
collection_name = "ks_collection_1.5BE"
|
43 |
-
# embedding_model = SentenceTransformer("D:\DR\RAG\gte-Qwen2-1.5B-instruct", trust_remote_code=True)
|
44 |
-
# embedding_model.max_seq_length = 8192
|
45 |
-
# local_embedding = SentenceTransformerEmbeddings(model=embedding_model)
|
46 |
-
|
47 |
-
model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
|
48 |
-
|
49 |
-
|
50 |
-
local_embedding = HuggingFaceEmbeddings(
|
51 |
-
model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
52 |
-
model_kwargs={"trust_remote_code": True, "device": "cuda" if torch.cuda.is_available() else "cpu"}
|
53 |
-
)
|
54 |
-
print(" Qwen2-1.5B local embedding model loaded.")
|
55 |
-
|
56 |
-
vector_store = Qdrant(
|
57 |
-
client=qdrant_client,
|
58 |
-
collection_name=collection_name,
|
59 |
-
embeddings=local_embedding
|
60 |
-
)
|
61 |
-
retriever = vector_store.as_retriever()
|
62 |
-
|
63 |
-
'''
|
64 |
-
# === Init LLM and Vector DB ===
|
65 |
-
|
66 |
-
CONNECTION_STRING = "postgresql+psycopg2://postgres:postgres@localhost:5432/VectorDB"
|
67 |
-
engine = create_engine(CONNECTION_STRING)
|
68 |
-
embedding_model = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
|
69 |
-
'''
|
70 |
# Dynamically initialize LLM based on selection
|
71 |
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
|
72 |
selected_model = st.session_state["selected_model"]
|
@@ -82,304 +39,7 @@ else:
|
|
82 |
st.error("Unsupported model selected.")
|
83 |
st.stop()
|
84 |
|
85 |
-
'''
|
86 |
-
vector_store = PGVector.from_existing_index(
|
87 |
-
embedding=embedding_model,
|
88 |
-
connection=engine,
|
89 |
-
collection_name="documents"
|
90 |
-
)
|
91 |
-
'''
|
92 |
-
# retriever = vector_store.as_retriever()
|
93 |
-
|
94 |
-
AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
|
95 |
-
You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
|
96 |
-
|
97 |
-
Guidelines:
|
98 |
-
1. Symptoms - Explain in simple terms with proper medical definitions.
|
99 |
-
2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
|
100 |
-
3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
|
101 |
-
4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
|
102 |
-
5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
|
103 |
-
|
104 |
-
Query: {question}
|
105 |
-
Relevant Information: {context}
|
106 |
-
Answer:
|
107 |
-
"""
|
108 |
-
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
|
109 |
-
|
110 |
-
rag_chain = RetrievalQA.from_chain_type(
|
111 |
-
llm=llm,
|
112 |
-
retriever=retriever,
|
113 |
-
chain_type="stuff",
|
114 |
-
chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
|
115 |
-
)
|
116 |
-
|
117 |
-
|
118 |
-
''''
|
119 |
-
Load My models
|
120 |
-
'''
|
121 |
-
|
122 |
-
|
123 |
-
def load_model(repo_id, filename):
|
124 |
-
model_path = hf_hub_download(
|
125 |
-
repo_id=repo_id,
|
126 |
-
filename=filename,
|
127 |
-
)
|
128 |
-
model = EvoViTModel(img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_classes=2, hidden_dim=512)
|
129 |
-
model.classifier = nn.Linear(512, 1)
|
130 |
-
state_dict = torch.load(model_path, map_location=device)
|
131 |
-
new_state_dict = {}
|
132 |
-
for key, value in state_dict.items():
|
133 |
-
if key.startswith("backbone."):
|
134 |
-
new_key = key[len("backbone."):]
|
135 |
-
else:
|
136 |
-
new_key = key
|
137 |
-
new_state_dict[new_key] = value
|
138 |
-
|
139 |
-
if "classifier.weight" in new_state_dict:
|
140 |
-
original_weight = new_state_dict["classifier.weight"]
|
141 |
-
new_state_dict["classifier.weight"] = original_weight[0:1, :]
|
142 |
-
if "classifier.bias" in new_state_dict:
|
143 |
-
original_bias = new_state_dict["classifier.bias"]
|
144 |
-
new_state_dict["classifier.bias"] = original_bias[0:1]
|
145 |
-
model.load_state_dict(new_state_dict, strict=False)
|
146 |
-
model.to(device)
|
147 |
-
model.eval()
|
148 |
-
return model
|
149 |
-
|
150 |
-
def load_binary_models():
|
151 |
-
base_models = []
|
152 |
-
class_models_mapping = {
|
153 |
-
"Acne and Rosacea Photos": 'santhosh/10fold_model_acne.pth',
|
154 |
-
"Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions": 'santhosh/5fold_model_actinic.pth',
|
155 |
-
"Atopic Dermatitis Photos": 'keerthi/Atopic/best_global_model_5fold.pth',
|
156 |
-
"Bullous Disease Photos": 'santhosh/10fold_model_bullous.pth',
|
157 |
-
"Cellulitis Impetigo and other Bacterial Infections": 'santhosh/10fold_model_cellulitis.pth',
|
158 |
-
"Eczema Photos": 'santhosh/5fold_model_eczema.pth',
|
159 |
-
"ExanthemsandDrugEruptions": 'santhosh/10fold_model_exantherms.pth',
|
160 |
-
"Hair Loss Photos Alopecia and other Hair Diseases": 'keerthi/HairLoss/best_global_model_5fold.pth',
|
161 |
-
"Herpes HPV and other STDs Photos": 'keerthi/Herpes/best_global_model_5fold.pth',
|
162 |
-
"Light Diseases and Disorders of Pigmentation": 'santhosh/5fold_model_light.pth',
|
163 |
-
"Lupus and other Connective Tissue diseases": 'keerthi/Lupus/best_global_model_5fold.pth',
|
164 |
-
"Melanoma Skin Cancer Nevi and Moles": 'keerthi/Melanoma/best_global_model_10fold.pth',
|
165 |
-
"Nail Fungus and other Nail Disease": 'santhosh/5fold_model_nail.pth',
|
166 |
-
"Poison Ivy Photos and other Contact Dermatitis": 'santhosh/5fold_model_poison.pth',
|
167 |
-
"Psoriasis pictures Lichen Planus and related diseases": 'santhosh/10fold_model_psoriasis.pth',
|
168 |
-
"Scabies Lyme Disease and other Infestations and Bites": 'santhosh/5fold_model_scabies.pth',
|
169 |
-
"Seborrheic Keratoses and other Benign Tumors": 'santhosh/10fold_model_seboh.pth',
|
170 |
-
"Systemic Disease": 'keerthi/Systemic/best_global_model_5fold.pth',
|
171 |
-
"Tinea Ringworm Candidiasis and other Fungal Infections": 'santhosh/10fold_model_tinea.pth',
|
172 |
-
"Urticaria Hives": 'keerthi/Urticaria/best_global_model_10fold.pth',
|
173 |
-
"Vascular Tumors": 'keerthi/Vascular/best_global_model_5fold.pth',
|
174 |
-
"Vasculitis Photos": 'keerthi/Vasculitis/best_global_model_10fold.pth',
|
175 |
-
"Warts Molluscum and other Viral Infections": 'santhosh/10fold_model_warts.pth'
|
176 |
-
}
|
177 |
-
repo_id = "KeerthiVM/SkinCancerDiagnosis" # Your Hugging Face repo
|
178 |
-
|
179 |
-
for class_name, filename in class_models_mapping.items():
|
180 |
-
# model_path = os.path.join("best_models_overall", rel_path)
|
181 |
-
model = load_model(repo_id, filename)
|
182 |
-
base_models.append(model)
|
183 |
-
return base_models
|
184 |
-
|
185 |
-
|
186 |
-
class DynamicCNN(nn.Module):
|
187 |
-
def __init__(self, input_channels, fc_layers, num_classes, dropout_rate=0.3):
|
188 |
-
super(DynamicCNN, self).__init__()
|
189 |
-
fc_layers_list = []
|
190 |
-
in_dim = input_channels
|
191 |
-
|
192 |
-
for fc_dim in fc_layers:
|
193 |
-
fc_layers_list.append(nn.Linear(in_dim, fc_dim))
|
194 |
-
fc_layers_list.append(nn.BatchNorm1d(fc_dim))
|
195 |
-
fc_layers_list.append(nn.ReLU())
|
196 |
-
fc_layers_list.append(nn.Dropout(dropout_rate))
|
197 |
-
in_dim = fc_dim
|
198 |
-
|
199 |
-
fc_layers_list.append(nn.Linear(in_dim, num_classes))
|
200 |
-
self.fc = nn.Sequential(*fc_layers_list)
|
201 |
-
|
202 |
-
def forward(self, x):
|
203 |
-
x = self.fc(x)
|
204 |
-
return x
|
205 |
-
|
206 |
-
|
207 |
-
class SkinDiseaseClassifier:
|
208 |
-
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
209 |
-
self.device = torch.device(device)
|
210 |
-
self.class_names = [
|
211 |
-
"Acne and Rosacea Photos",
|
212 |
-
"Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions",
|
213 |
-
"Atopic Dermatitis Photos",
|
214 |
-
"Bullous Disease Photos",
|
215 |
-
"Cellulitis Impetigo and other Bacterial Infections",
|
216 |
-
"Eczema Photos",
|
217 |
-
"ExanthemsandDrugEruptions",
|
218 |
-
"Hair Loss Photos Alopecia and other Hair Diseases",
|
219 |
-
"Herpes HPV and other STDs Photos",
|
220 |
-
"Light Diseases and Disorders of Pigmentation",
|
221 |
-
"Lupus and other Connective Tissue diseases",
|
222 |
-
"Melanoma Skin Cancer Nevi and Moles",
|
223 |
-
"Nail Fungus and other Nail Disease",
|
224 |
-
"Poison Ivy Photos and other Contact Dermatitis",
|
225 |
-
"Psoriasis pictures Lichen Planus and related diseases",
|
226 |
-
"Scabies Lyme Disease and other Infestations and Bites",
|
227 |
-
"Seborrheic Keratoses and other Benign Tumors",
|
228 |
-
"Systemic Disease",
|
229 |
-
"Tinea Ringworm Candidiasis and other Fungal Infections",
|
230 |
-
"Urticaria Hives",
|
231 |
-
"Vascular Tumors",
|
232 |
-
"Vasculitis Photos",
|
233 |
-
"Warts Molluscum and other Viral Infections"
|
234 |
-
]
|
235 |
-
|
236 |
-
# Initialize models (they'll be loaded when needed)
|
237 |
-
self.base_models = None
|
238 |
-
self.meta_model = None
|
239 |
-
self.resnet_feature_extractor = None
|
240 |
-
|
241 |
-
# Image transformations
|
242 |
-
self.transform = transforms.Compose([
|
243 |
-
transforms.Resize((224, 224)),
|
244 |
-
transforms.ToTensor(),
|
245 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
246 |
-
])
|
247 |
-
|
248 |
-
def load_models(self):
|
249 |
-
"""Load all required models"""
|
250 |
-
# Load binary models
|
251 |
-
self.base_models = load_binary_models()
|
252 |
-
for model in self.base_models:
|
253 |
-
model.to(self.device)
|
254 |
-
model.eval()
|
255 |
-
|
256 |
-
# Load ResNet feature extractor
|
257 |
-
model = resnet50(pretrained=True)
|
258 |
-
layers = [model.layer1, model.layer2, model.layer3, model.layer4]
|
259 |
-
self.resnet_feature_extractor = nn.Sequential(
|
260 |
-
model.conv1, model.bn1, model.relu, model.maxpool, *layers
|
261 |
-
)
|
262 |
-
self.resnet_feature_extractor.to(self.device)
|
263 |
-
self.resnet_feature_extractor.eval()
|
264 |
-
|
265 |
-
# Load meta model
|
266 |
-
# meta_model_path = 'best_meta_model_two_layer_version4.pth'
|
267 |
-
# checkpoint = torch.load(meta_model_path, map_location=self.device)
|
268 |
-
print("=== Loading model with weights_only=False ===")
|
269 |
-
meta_model_path = hf_hub_download(
|
270 |
-
repo_id="KeerthiVM/SkinCancerDiagnosis",
|
271 |
-
filename="best_meta_model_two_layer_version4.pth"
|
272 |
-
)
|
273 |
-
checkpoint = torch.load(meta_model_path, map_location=self.device, weights_only=False)
|
274 |
-
|
275 |
-
correct_input_size = checkpoint['state_dict']['fc.0.weight'].shape[1]
|
276 |
-
input_size = 23 + 4 + 7168 # Adjust based on your actual feature size
|
277 |
-
fc_layers = [1024, 512, 256] # Use whatever was in your best model
|
278 |
-
|
279 |
-
self.meta_model = DynamicCNN(
|
280 |
-
input_channels=correct_input_size,
|
281 |
-
fc_layers=fc_layers,
|
282 |
-
num_classes=23,
|
283 |
-
dropout_rate=0.5
|
284 |
-
)
|
285 |
-
|
286 |
-
self.meta_model.load_state_dict(checkpoint['state_dict'])
|
287 |
-
self.meta_model.to(self.device)
|
288 |
-
self.meta_model.eval()
|
289 |
-
|
290 |
-
def extract_image_features(self, image_tensor):
|
291 |
-
"""Extract features using ResNet"""
|
292 |
-
with torch.no_grad():
|
293 |
-
features = []
|
294 |
-
x = image_tensor
|
295 |
-
for layer in self.resnet_feature_extractor.children():
|
296 |
-
x = layer(x)
|
297 |
-
if isinstance(layer, nn.Sequential): # For residual blocks
|
298 |
-
# features.append(F.adaptive_avg_pool2d(x, (1, 1)).flatten(1))
|
299 |
-
# features = torch.cat(features, dim=1)
|
300 |
-
pooled = F.adaptive_avg_pool2d(x, (1, 1)).flatten(1)
|
301 |
-
features.append(pooled)
|
302 |
-
features = torch.cat(features, dim=1)
|
303 |
-
return features.cpu().numpy()
|
304 |
-
|
305 |
-
def predict(self, image, top_k=3):
|
306 |
-
"""Make prediction for a single image"""
|
307 |
-
if self.base_models is None or self.meta_model is None:
|
308 |
-
# self.load_models()
|
309 |
-
raise RuntimeError("Models not loaded - call load_models() first")
|
310 |
-
|
311 |
-
# Load and preprocess image
|
312 |
-
try:
|
313 |
-
# image = Image.open(image_path).convert('RGB')
|
314 |
-
image = image.convert('RGB')
|
315 |
-
except:
|
316 |
-
raise ValueError("Could not load image from path")
|
317 |
-
|
318 |
-
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
319 |
-
|
320 |
-
# Extract features
|
321 |
-
with torch.no_grad():
|
322 |
-
# Get probabilities from each binary model
|
323 |
-
binary_probs = []
|
324 |
-
for model in self.base_models:
|
325 |
-
outputs = model(image_tensor)
|
326 |
-
probs = torch.sigmoid(outputs).squeeze(1)
|
327 |
-
binary_probs.append(probs)
|
328 |
-
|
329 |
-
binary_features = torch.stack(binary_probs, dim=1)
|
330 |
-
|
331 |
-
# Get image features
|
332 |
-
image_features = self.extract_image_features(image_tensor)
|
333 |
-
image_features = torch.from_numpy(image_features).float().to(self.device)
|
334 |
-
|
335 |
-
# Calculate probability statistics
|
336 |
-
top3_probs = torch.topk(binary_features, 3, dim=1).values
|
337 |
-
prob_stats = torch.stack([
|
338 |
-
binary_features.mean(dim=1, keepdim=True),
|
339 |
-
binary_features.std(dim=1, keepdim=True),
|
340 |
-
top3_probs.mean(dim=1, keepdim=True),
|
341 |
-
(top3_probs[:, 0] - top3_probs[:, 2]).unsqueeze(1) # Confidence gap
|
342 |
-
], dim=1).squeeze(-1)
|
343 |
-
|
344 |
-
# Combine all features
|
345 |
-
combined_features = torch.cat([
|
346 |
-
binary_features,
|
347 |
-
image_features,
|
348 |
-
prob_stats
|
349 |
-
], dim=1)
|
350 |
-
|
351 |
-
# Make prediction with meta-model
|
352 |
-
with torch.no_grad():
|
353 |
-
outputs = self.meta_model(combined_features)
|
354 |
-
probabilities = torch.softmax(outputs, dim=1).squeeze().cpu().numpy()
|
355 |
-
|
356 |
-
# Get top predictions
|
357 |
-
top_indices = np.argsort(probabilities)[-top_k:][::-1]
|
358 |
-
top_predictions = [
|
359 |
-
(self.class_names[i], float(probabilities[i]))
|
360 |
-
for i in top_indices
|
361 |
-
]
|
362 |
-
|
363 |
-
return {
|
364 |
-
"top_predictions": top_predictions,
|
365 |
-
"all_probabilities": {name: float(prob) for name, prob in zip(self.class_names, probabilities)}
|
366 |
-
}
|
367 |
-
|
368 |
-
|
369 |
-
@st.cache_resource
|
370 |
-
def initialize_classifier():
|
371 |
-
print("⚙️ Initializing skin disease classifier...")
|
372 |
-
classifier = SkinDiseaseClassifier()
|
373 |
-
classifier.load_models()
|
374 |
-
dummy_img = Image.new('RGB', (224, 224))
|
375 |
-
classifier.predict(dummy_img)
|
376 |
-
|
377 |
-
print("⚙️ Initialization successful")
|
378 |
-
return classifier
|
379 |
|
380 |
-
# print("⚙️ Initializing skin disease classifier...")
|
381 |
-
# classifier = initialize_classifier()
|
382 |
-
# print("⚙️ Initializing classifier done...")
|
383 |
with st.spinner("Loading AI models (one-time operation)..."):
|
384 |
classifier = initialize_classifier()
|
385 |
st.success("Models loaded successfully!")
|
@@ -441,7 +101,7 @@ if uploaded_file:
|
|
441 |
st.session_state.messages.append({"role": "user", "content": query})
|
442 |
|
443 |
with st.spinner("Analyzing the image and retrieving response..."):
|
444 |
-
response =
|
445 |
st.session_state.messages.append({"role": "assistant", "content": response['result']})
|
446 |
|
447 |
with st.chat_message("assistant"):
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import torchvision.transforms as transforms
|
3 |
import torch
|
|
|
9 |
import io
|
10 |
import os
|
11 |
from fpdf import FPDF
|
|
|
|
|
|
|
12 |
from torchvision.models import resnet50
|
13 |
import nest_asyncio
|
|
|
14 |
from huggingface_hub import hf_hub_download
|
15 |
+
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
16 |
+
from SkinCancerDiagnosis import initialize_classifier
|
17 |
+
from rag_pipeline import invoke_rag_chain
|
18 |
|
19 |
nest_asyncio.apply()
|
20 |
device='cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
21 |
st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
|
22 |
|
|
|
|
|
23 |
# === Model Selection ===
|
24 |
available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"]
|
25 |
st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models)
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
# Dynamically initialize LLM based on selection
|
28 |
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
|
29 |
selected_model = st.session_state["selected_model"]
|
|
|
39 |
st.error("Unsupported model selected.")
|
40 |
st.stop()
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
|
|
|
|
|
|
43 |
with st.spinner("Loading AI models (one-time operation)..."):
|
44 |
classifier = initialize_classifier()
|
45 |
st.success("Models loaded successfully!")
|
|
|
101 |
st.session_state.messages.append({"role": "user", "content": query})
|
102 |
|
103 |
with st.spinner("Analyzing the image and retrieving response..."):
|
104 |
+
response = invoke_rag_chain(llm).invoke(query)
|
105 |
st.session_state.messages.append({"role": "assistant", "content": response['result']})
|
106 |
|
107 |
with st.chat_message("assistant"):
|
rag_pipeline.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.chains import RetrievalQA
|
2 |
+
from langchain.prompts import PromptTemplate
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
from qdrant_client import QdrantClient
|
5 |
+
from langchain_qdrant import Qdrant
|
6 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
7 |
+
from langchain_community.embeddings import SentenceTransformerEmbeddings
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
|
11 |
+
def invoke_rag_chain(llm):
|
12 |
+
# === Qdrant DB Setup ===
|
13 |
+
qdrant_client = QdrantClient(
|
14 |
+
url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
|
15 |
+
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
|
16 |
+
)
|
17 |
+
collection_name = "ks_collection_1.5BE"
|
18 |
+
|
19 |
+
model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
|
20 |
+
|
21 |
+
|
22 |
+
local_embedding = HuggingFaceEmbeddings(
|
23 |
+
model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
24 |
+
model_kwargs={"trust_remote_code": True, "device": "cuda" if torch.cuda.is_available() else "cpu"}
|
25 |
+
)
|
26 |
+
print(" Qwen2-1.5B local embedding model loaded.")
|
27 |
+
|
28 |
+
vector_store = Qdrant(
|
29 |
+
client=qdrant_client,
|
30 |
+
collection_name=collection_name,
|
31 |
+
embeddings=local_embedding
|
32 |
+
)
|
33 |
+
retriever = vector_store.as_retriever()
|
34 |
+
|
35 |
+
|
36 |
+
AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
|
37 |
+
You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
|
38 |
+
|
39 |
+
Guidelines:
|
40 |
+
1. Symptoms - Explain in simple terms with proper medical definitions.
|
41 |
+
2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
|
42 |
+
3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
|
43 |
+
4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
|
44 |
+
5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
|
45 |
+
|
46 |
+
Query: {question}
|
47 |
+
Relevant Information: {context}
|
48 |
+
Answer:
|
49 |
+
"""
|
50 |
+
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
|
51 |
+
|
52 |
+
rag_chain = RetrievalQA.from_chain_type(
|
53 |
+
llm=llm,
|
54 |
+
retriever=retriever,
|
55 |
+
chain_type="stuff",
|
56 |
+
chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
|
57 |
+
)
|
58 |
+
|
59 |
+
return rag_chain
|