Spaces:
Sleeping
Sleeping
File size: 7,559 Bytes
dc9062b 79cab30 dc9062b 79cab30 dc9062b 79cab30 dc9062b 79cab30 dc9062b 79cab30 dc9062b 79cab30 dc9062b 79cab30 dc9062b 154407c dc9062b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import streamlit as st
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from langchain_qdrant import Qdrant
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings import SentenceTransformerEmbeddings
from transformers import pipeline
import os
import torch
from groq import Groq
import google.generativeai as genai
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
import cohere
available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro", "Ensemble"]
AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
Guidelines:
1. Symptoms - Explain in simple terms with proper medical definitions.
2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
Query: {question}
Relevant Information: {context}
Answer:
"""
@st.cache_resource(show_spinner=False)
def initialize_rag_components():
components = {
'cohere_client': cohere.Client(st.secrets["COHERE_API_KEY"]),
'pair_ranker': pipeline("text-classification",
model="llm-blender/PairRM",
tokenizer="llm-blender/PairRM",
return_all_scores=True
),
'gen_fuser': pipeline("text-generation",
model="llm-blender/gen_fuser_3b",
tokenizer="llm-blender/gen_fuser_3b",
max_length=2048,
do_sample=False
),
'retriever': get_retriever()
}
return components
class AllModelsWrapper:
def invoke(self, messages):
prompt = messages[0]["content"]
rag_components = st.session_state.app_models['rag_components'] # Get components
responses = get_all_responses(prompt)
fused = rank_and_fuse(prompt, responses, rag_components)
return type('obj', (object,), {'content': fused})()
def get_all_responses(prompt):
# Get responses from all models
openai_resp = ChatOpenAI(model="gpt-4o", temperature=0.2,
api_key=st.secrets["OPENAI_API_KEY"]).invoke(
[{"role": "user", "content": prompt}]).content
gemini = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")
gemini_resp = gemini.generate_content(prompt).text
llama = Groq(api_key=st.secrets["GROQ_API_KEY"])
llama_resp = llama.chat.completions.create(
model="meta-llama/llama-4-maverick-17b-128e-instruct",
messages=[{"role": "user", "content": prompt}],
temperature=1, max_completion_tokens=1024, top_p=1, stream=False
).choices[0].message.content
return [openai_resp, gemini_resp, llama_resp]
def rank_and_fuse(prompt, responses, rag_components):
ranked = [(resp, rag_components['pair_ranker'](f"{prompt}\n\n{resp}")[0][1]['score'])
for resp in responses]
ranked.sort(key=lambda x: x[1], reverse=True)
# Fuse top responses
fusion_input = "\n\n".join([f"[Answer {i + 1}]: {ans}" for i, (ans, _) in enumerate(ranked[:2])])
return rag_components['gen_fuser'](f"Fuse these responses:\n{fusion_input}",
return_full_text=False)[0]['generated_text']
def get_retriever():
# === Qdrant DB Setup ===
qdrant_client = QdrantClient(
url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
)
collection_name = "ks_collection_1.5BE"
model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
local_embedding = HuggingFaceEmbeddings(
model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
model_kwargs={"trust_remote_code": True, "device": "cuda" if torch.cuda.is_available() else "cpu"}
)
print(" Qwen2-1.5B local embedding model loaded.")
vector_store = Qdrant(
client=qdrant_client,
collection_name=collection_name,
embeddings=local_embedding
)
return vector_store.as_retriever()
def initialize_llm(_model_name):
"""Initialize the LLM based on selection"""
print(f"Model name : {_model_name}")
if "OpenAI" in _model_name:
return ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"])
elif "LLaMA" in _model_name:
client = Groq(api_key=st.secrets["GROQ_API_KEY"])
def get_llama_response(prompt):
completion = client.chat.completions.create(
model="meta-llama/llama-4-maverick-17b-128e-instruct",
messages=[{"role": "user", "content": prompt}],
temperature=1,
max_completion_tokens=1024,
top_p=1,
stream=False
)
return completion.choices[0].message.content
return type('obj', (object,), {'invoke': lambda self, x: get_llama_response(x[0]["content"])})()
elif "Gemini" in _model_name:
genai.configure(api_key=st.secrets["GEMINI_API_KEY"])
gemini_model = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")
def get_gemini_response(prompt):
response = gemini_model.generate_content(prompt)
return response.text
return type('obj', (object,), {'invoke': lambda self, x: get_gemini_response(x[0]["content"])})()
elif "Ensemble" in _model_name:
return AllModelsWrapper()
else:
raise ValueError("Unsupported model selected")
def load_rag_chain(llm):
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=get_retriever(),
chain_type="stuff",
chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
)
return rag_chain
def rerank_with_cohere(query, documents, co, top_n=5):
if not documents:
return []
raw_texts = [doc.page_content for doc in documents]
results = co.rerank(query=query, documents=raw_texts, top_n=min(top_n, len(raw_texts)), model="rerank-v3.5")
return [documents[result.index] for result in results.results]
def get_reranked_response(query, llm, rag_components):
"""Get response with reranking"""
docs = rag_components['retriever'].get_relevant_documents(query)
reranked_docs = rerank_with_cohere(query, docs, rag_components['cohere_client'])
context = "\n\n".join([doc.page_content for doc in reranked_docs])
if isinstance(llm, (ChatOpenAI, AllModelsWrapper)):
return load_rag_chain(llm).invoke({"query": query, "context": context})['result']
else:
prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context)
return llm.invoke([{"role": "user", "content": prompt}]).content
if __name__ == "__main__":
print("This is a module - import it instead of running directly") |