import streamlit as st from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from sentence_transformers import SentenceTransformer from qdrant_client import QdrantClient from langchain_qdrant import Qdrant from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.embeddings import SentenceTransformerEmbeddings from transformers import pipeline import os import torch from groq import Groq import google.generativeai as genai from langchain_openai import OpenAIEmbeddings, ChatOpenAI import cohere from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.messages import HumanMessage, AIMessage from transformers import AutoTokenizer from llm_blender.blender.blender import Blender st.cache_resource.clear() os.environ["TOKENIZERS_PARALLELISM"] = "false" available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro", "Ensemble"] # AI_PROMPT_TEMPLATE = """You are DermBOT, an experienced dermatologist. Your job is to provide clear, medically accurate explanations in a natural, reassuring tone — just like how a doctor would speak to a patient in everyday language. # Users will often describe symptoms or visible features seen in a skin image. Your job is to interpret those findings, suggest the most likely condition(s), and guide the patient on what steps to take next. # **Guidelines:** # 1. Speak clearly and briefly — no long paragraphs or medical jargon unless defined simply. # 2. Begin by acknowledging the observed features and connect them to a likely diagnosis. # Example: “Plaque and erythema are common features in chronic plaque psoriasis.” # 3. Keep responses short and easy to follow — use simple terms. # 4. If you mention treatments or medical terms (e.g., PUVA, biologics), add a brief explanation in parentheses. # Example: “PUVA (a type of light therapy using UVA and a sensitizing medication called psoralen)” # 5. Use transitions like: # - “Let’s look at what this could mean...” # - “Here’s what you can do next...” # - “Treatment usually starts with...” # 6. Avoid repeating the same list of treatments unless it adds new value. If a treatment has already been mentioned, summarize or refer back briefly. # 7. End with a polite and clear reminder that this is educational and not a substitute for seeing a licensed dermatologist. # **Tone:** # - Friendly and professional # - Reassuring, especially when symptoms may sound serious # - Easy for both medical and non-medical users to understand # Query: {question} # Relevant Information: {context} # Answer:""" # AI_PROMPT_TEMPLATE = """You are DermBOT, an experienced dermatologist. Your role is to explain possible skin conditions based on features detected in a clinical image. Your responses should be medically accurate, friendly, and easy to understand — like how a doctor would talk to a patient. # **Observed Image Features:** # These are the skin features identified in the uploaded image: {question} # **Reference Medical Context (not directly from the patient):** # {context} # **Guidelines:** # 1. Start by connecting the **observed image features** to the most likely diagnosis (based on the classifier’s label). # Example: “Plaque and erythema are classic signs of chronic plaque psoriasis...” # 2. Avoid phrasing like "you mentioned" — the patient has not provided any description; features are detected from the image. # 3. Provide a short explanation of the condition, its symptoms, and typical causes (if relevant). # 4. When listing treatments or terms like “biologics” or “PUVA,” give a simple explanation in parentheses. # 5. Use friendly transitions, like: # - “Here’s what this could mean…” # - “Let’s go over what you can do next…” # - “Treatment usually starts with…” # 6. Don’t repeat the same treatment list unless necessary. If already described, refer back or summarize. # 7. End with a kind reminder that this is educational information and not a replacement for a consultation with a licensed dermatologist. # **Tone:** # - Calm, supportive, and professional # - Reassuring when symptoms sound concerning # - Clear enough for anyone without a medical background # Answer:""" AI_PROMPT_TEMPLATE = """You are an experienced dermatologist. Your job is to help patients understand possible skin conditions based on features detected from a clinical image.Your answers should be medically accurate explanations that are friendly and easy to understand — like how a doctor would talk to a patient. Start your response with a brief overview (no more than 300 words) that summarizes what the identified features most likely indicate. Mention the likely condition(s) clearly in plain language. After the summary, continue with a detailed explanation covering: - What the condition is and how it typically presents - Common symptoms and possible causes - Treatment options, explained simply - Recommended next steps, if appropriate If you use medical terms (e.g., biologics, PUVA), include a short explanation in parentheses. For example: “PUVA (a light-based treatment using UVA and a medication called psoralen)” Use friendly transitions like: - “Here’s what this could mean…” - “Let’s go over what you can do next…” - “Treatment usually begins with…” Do not repeat the same treatments if already mentioned earlier—summarize or refer back instead. End with a gentle reminder that this explanation is for educational purposes and not a substitute for seeing a licensed dermatologist. --- **Identified Image Features:** {question} **Relevant Medical Information:** {context} Response:""" ENSEMBLE_PROMPT_TEMPLATE = """You are DermBOT, a highly experienced dermatologist. You are reviewing responses from multiple medical AI systems. Based on the image analysis and clinical features, provide a detailed, comprehensive explanation. Start with a **summary (around 250–300 words)** of what the identified image features suggest. Then continue with: - A detailed explanation of the condition - Typical symptoms and known causes - Recommended treatment options (explain briefly) - Suggested next steps for patients Use friendly, plain language. Include brief definitions for medical terms (e.g., “PUVA (a type of light therapy using UVA and a sensitizing drug)”). **Important:** - Do NOT stop early. Continue until all sections are covered. - Assume this answer will be fused with other responses — make it as rich and complete as possible. --- **Detected Skin Features:** {question} **Medical Reference Context:** {context} Answer:""" @st.cache_resource(show_spinner=False) def initialize_rag_components(): device= "cuda" if torch.cuda.is_available() else "cpu" blender = Blender() blender.loadranker("llm-blender/PairRM", device = device, truncation=True,max_length=512,) # Ranker blender.loadfuser("llm-blender/gen_fuser_3b", device = device, max_length=4096, length_penalty=1.2) # Fuser components = { 'cohere_client': cohere.Client(st.secrets["COHERE_API_KEY"]), 'blender': blender, # 'pair_ranker': pipeline("text-classification", # model="llm-blender/PairRM", # tokenizer="llm-blender/PairRM", # return_all_scores=True # ), # 'gen_fuser': pipeline("text-generation", # model="llm-blender/gen_fuser_3b", # tokenizer="llm-blender/gen_fuser_3b", # max_length=4096, # do_sample=False # ), 'retriever': get_retriever() } return components # class AllModelsWrapper(Runnable): # def __init__(self): # super().__init__() # def invoke(self, input, config=None, **kwargs): # current_query = st.session_state.get("current_query", "") # print(f"Inside invoke current query : {current_query}") # if isinstance(input, dict): # prompt = input.get("query", "") # context = input.get("context", "") # prompt = f"Context: {context}\n\nQuery: {prompt}" # elif isinstance(input, list): # prompt = input[0]["content"] # else: # prompt = str(input) # rag_components = st.session_state.app_models['rag_components'] # responses = get_all_responses(prompt) # fused = rank_and_fuse_caller(prompt, responses, rag_components, current_query) # return fused # def __call__(self, input, config=None, **kwargs): # return self.invoke(input, config, **kwargs) class AllModelsWrapper(Runnable): def __init__(self): super().__init__() def invoke(self, input, config=None, **kwargs): current_query = st.session_state.get("current_query", "") print(f"Inside invoke current query : {current_query}") if isinstance(input, dict): prompt = input.get("query", "") context = input.get("context", "") prompt = f"""Please explain this comprehensively in 2–3 paragraphs.\n\n""" + ENSEMBLE_PROMPT_TEMPLATE.format(question=prompt, context=context) #prompt = AI_PROMPT_TEMPLATE.format(question=prompt, context=context) # question=query #prompt = ENSEMBLE_PROMPT_TEMPLATE.format(question=prompt, context=context) # prompt = f"Context: {context}\n\nQuery: {prompt}" elif isinstance(input, list): prompt = input[0]["content"] else: prompt = str(input) rag_components = st.session_state.app_models['rag_components'] responses = get_all_responses(prompt, max_tokens=2048) # Use longer responses for ensemble fused = rank_and_fuse_caller(prompt, responses, rag_components, current_query) return fused def __call__(self, input, config=None, **kwargs): return self.invoke(input, config, **kwargs) # def get_all_responses(prompt): # openai_resp = ChatOpenAI(model="gpt-4o", temperature=0.2, # api_key=st.secrets["OPENAI_API_KEY"]).invoke( # [{"role": "user", "content": prompt}]).content # gemini = genai.GenerativeModel("gemini-2.5-pro-preview-05-06") # gemini_resp = gemini.generate_content(prompt).text # llama = Groq(api_key=st.secrets["GROQ_API_KEY"]) # llama_resp = llama.chat.completions.create( # model="meta-llama/llama-4-maverick-17b-128e-instruct", # messages=[{"role": "user", "content": prompt}], # temperature=1, max_completion_tokens=1024, top_p=1, stream=False # ).choices[0].message.content # return [openai_resp, gemini_resp, llama_resp def get_all_responses(prompt, max_tokens=2048): #512 responses = [] # OpenAI GPT-4o try: openai_resp = ChatOpenAI( model="gpt-4o", temperature=0.5, max_tokens=max_tokens, api_key=st.secrets["OPENAI_API_KEY"] ).invoke([{"role": "user", "content": prompt}]).content responses.append(openai_resp) except Exception as e: print(f"[OpenAI GPT-4o] Error: {e}") # Gemini 2.5 Pro try: genai.configure(api_key=st.secrets["GEMINI_API_KEY"]) gemini = genai.GenerativeModel("gemini-2.5-pro-preview-05-06") generation_config = genai.types.GenerationConfig( temperature=0.5, top_p=0.9, max_output_tokens=2048 ) gemini_resp = gemini.generate_content(prompt, generation_config=generation_config).text responses.append(gemini_resp) except Exception as e: print(f"[Gemini Pro] Error: {e}") # LLaMA 4 Maverick try: llama = Groq(api_key=st.secrets["GROQ_API_KEY"]) llama_resp = llama.chat.completions.create( model="meta-llama/llama-4-maverick-17b-128e-instruct", messages=[{"role": "user", "content": prompt}], temperature=0.5, top_p=0.9, max_completion_tokens=max_tokens, stream=False ).choices[0].message.content responses.append(llama_resp) except Exception as e: print(f"[LLaMA 4 Maverick] Error: {e}") return responses # def rank_and_fuse_caller(prompt, responses, rag_components, current_query, top_k=3): # score_tokenizer = AutoTokenizer.from_pretrained("llm-blender/PairRM") # def truncate_for_scoring(text, max_length=500): # return score_tokenizer.decode( # score_tokenizer.encode(text, truncation=True, max_length=max_length), # skip_special_tokens=True # ) # ranked = [] # for resp in responses: # try: # truncated = truncate_for_scoring(f"{current_query}\n\n{resp}") # # print(f"\n\n\nOriginal response : {resp}") # # print(f"\n\n\nTruncated response : {truncated}") # score = rag_components['pair_ranker'](truncated)[0][1]['score'] # ranked.append((resp, score)) # except Exception as e: # print(f"Scoring error, using fallback: {str(e)}") # ranked.append((resp, 0.5)) # Mid-range fallback score # # # print(f"\n\n\nAfter ranked : {ranked}") # ranked.sort(key=lambda x: x[1], reverse=True) # fusion_input = "\n\n".join([f"[Answer {i+1}]: {ans}" for i, (ans, _) in enumerate(ranked)]) # print(f"After fusion input : {fusion_input}") # fused = rag_components['gen_fuser']( # f"Combine these dermatology responses:\n{fusion_input}", # max_new_tokens=4096, # truncation=True, # return_full_text=False # ) # print(f"Fused text : {fused}") # if fused and fused[0]['generated_text'].strip(): # return fused[0]['generated_text'] # return responses[0] # def rank_and_fuse_caller(prompt, responses, rag_components, current_query, top_k=2): # print(f"Current query : {current_query}") # if not isinstance(responses, list): # responses = [str(responses)] # responses = [str(r) for r in responses if r is not None] # print(f"\n\n\nresponse from models : {responses}") # if not responses: # return "No valid responses to fuse" # def truncate_text(text, max_tokens=150): #256 # tokenizer = AutoTokenizer.from_pretrained("llm-blender/PairRM") # tokens = tokenizer.encode(text, truncation=True, max_length=max_tokens) # return tokenizer.decode(tokens, skip_special_tokens=True) # truncated_prompt = truncate_text(current_query, max_tokens=518) #256 # # truncated_prompt = truncate_text(prompt) # truncated_responses = responses # No manual truncation # #truncated_responses = [truncate_text(r, max_tokens=2048) for r in responses] #512 # print(f"\n\n\nTruncated responses : {truncated_responses}") # inputs = [str(truncated_prompt)] # candidates = [truncated_responses] # print(f"\n\n\ninputs : {inputs}") # print(f"\n\n\ncandidates : {candidates}") # # fused_text, _ = rag_components['blender'].rank_and_fuse( # # inputs=inputs, # # candidates=candidates, # # return_scores=True, # Set True if you need scores # # batch_size=2, # Adjust based on GPU memory # # top_k=top_k, # Fuse top 3 responses # # max_length=4096, # # temperature=0.7 # # ) # # fused_text, _ = rag_components['blender'].rank_and_fuse( # # inputs=inputs, # # candidates=candidates, # # return_scores=True, # # batch_size=2, # # top_k=2, # Fuse from top 2 model responses # # max_length=4096, # # max_new_tokens=2048, # # # to Increase new output length # # do_sample=True, # temperature effective # # temperature=0.5, # creativity # # length_penalty=1.3, # # ) # fused_text, _ = rag_components['blender'].rank_and_fuse( # inputs=inputs, # candidates=candidates, # return_scores=True, # batch_size=1, # top_k=3, # max_length=4096, # max_new_tokens=4096, # do_sample=True, #True # temperature=0.8, # length_penalty=0.8, # ) # print(f"\n\n\nFused text : {fused_text}") # return fused_text[0] def rank_and_fuse_caller(prompt, responses, rag_components, current_query, top_k=3): print(f"Current query : {current_query}") if not isinstance(responses, list): responses = [str(responses)] responses = [str(r) for r in responses if r is not None] print(f"\n\n\nresponse from models : {responses}") if not responses: return "No valid responses to fuse" tokenizer = AutoTokenizer.from_pretrained("llm-blender/PairRM") # 🔧 Adjust token truncation based on PairRM specs def truncate_text(text, max_tokens=1024): tokens = tokenizer.encode(text, truncation=True, max_length=max_tokens) return tokenizer.decode(tokens, skip_special_tokens=True) # # Apply truncation # truncated_prompt = truncate_text(current_query, max_tokens=1000) # truncated_responses = [truncate_text(r, max_tokens=1000) for r in responses] truncated_prompt = truncate_text(current_query) truncated_responses = [truncate_text(r) for r in responses] # print(f"\n\n\nTruncated responses : {truncated_responses}") inputs = [str(truncated_prompt)] candidates = [truncated_responses] # ✅ Use all 3 responses # # 🔁 Optional: Add prompt nudge to encourage long responses # fused_prompt = f"""Please provide a comprehensive, medically accurate explanation in 3–4 paragraphs.\n\n{ENSEMBLE_PROMPT_TEMPLATE.format(question=current_query, context=prompt)}""" print(f"\n\n\nInputs: {inputs}") print(f"\n\n\nCandidates: {candidates}") fused_text, _ = fused_text, _ = rag_components['blender'].rank_and_fuse( inputs=inputs, candidates=candidates, return_scores=True, batch_size=1, top_k=3, max_length=4096, # prompt + output limit max_new_tokens=3000, # most important min_new_tokens=800, # optional: ensures verbose start do_sample=True, temperature=0.7, top_p=0.9, length_penalty=1.3, repetition_penalty=1.1, no_repeat_ngram_size=4 ) print(f"\n\n\nFused text : {fused_text}") return fused_text[0] def get_retriever(): qdrant_client = QdrantClient( url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io", api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q" ) collection_name = "ks_collection_1.5BE" model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True) local_embedding = HuggingFaceEmbeddings( model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct", model_kwargs={"trust_remote_code": True, "device": "cuda" if torch.cuda.is_available() else "cpu"} ) print(" Qwen2-1.5B local embedding model loaded.") vector_store = Qdrant( client=qdrant_client, collection_name=collection_name, embeddings=local_embedding ) return vector_store.as_retriever() def initialize_llm(_model_name): """Initialize the LLM based on selection""" print(f"Model name : {_model_name}") if "OpenAI" in _model_name: return ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"]) elif "LLaMA" in _model_name: client = Groq(api_key=st.secrets["GROQ_API_KEY"]) class LLaMAWrapper(Runnable): def __init__(self, client): self.client = client def invoke(self, input, config=None, **kwargs): prompt = str(input) completion = self.client.chat.completions.create( model="meta-llama/llama-4-maverick-17b-128e-instruct", messages=[{"role": "user", "content": prompt}], temperature=1, max_completion_tokens=1024, top_p=1, stream=False ) return completion.choices[0].message.content def __call__(self, input, config=None, **kwargs): return self.invoke(input, config, **kwargs) return LLaMAWrapper(client) elif "Gemini" in _model_name: genai.configure(api_key=st.secrets["GEMINI_API_KEY"]) gemini_model = genai.GenerativeModel("gemini-2.5-pro-preview-05-06") class GeminiWrapper(Runnable): def __init__(self, model): self.model = model def invoke(self, input, config=None, **kwargs): prompt = str(input) response = self.model.generate_content(prompt) return response.text def __call__(self, input, config=None, **kwargs): return self.invoke(input, config, **kwargs) return GeminiWrapper(gemini_model) elif "Ensemble" in _model_name: return AllModelsWrapper() else: raise ValueError("Unsupported model selected") def load_rag_chain(llm): prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"]) rag_chain = RetrievalQA.from_chain_type( llm=llm, retriever=get_retriever(), chain_type="stuff", chain_type_kwargs={ "prompt": prompt_template, "document_variable_name": "context" } ) return rag_chain def rerank_with_cohere(query, documents, co, top_n=5): if not documents: return [] raw_texts = [doc.page_content for doc in documents] results = co.rerank(query=query, documents=raw_texts, top_n=min(top_n, len(raw_texts)), model="rerank-v3.5") return [documents[result.index] for result in results.results] # def get_reranked_response(query, llm, rag_components, selected_model, current_query): # """Get response with reranking""" # try: # st.session_state.current_query = current_query # print(f"Current query in get_reranked_response: {current_query}") # docs = rag_components['retriever'].get_relevant_documents(query) # reranked_docs = rerank_with_cohere(query, docs, rag_components['cohere_client']) # context = "\n\n".join([doc.page_content for doc in reranked_docs]) # rag_chain = load_rag_chain(llm) # response = rag_chain.invoke( # {"query": query, "context": context} # ) # if selected_model == "Ensemble": # return response.get('result', response.get('answer', str(response))) # else: # if isinstance(response, dict): # return response.get('content', # response.get('result', # response.get('answer', str(response)))) # return str(response) if response else "No response generated" # except Exception as e: # print(f"Error in get_reranked_response: {str(e)}") # print(e) # return "I apologize, but I'm having trouble processing your request right now." def get_reranked_response(query, llm, rag_components, selected_model, current_query): try: # st.session_state.current_query = current_query st.session_state.current_query = current_query if st.session_state.get("last_model_output") and selected_model == "Ensemble": query = f"The earlier diagnosis was: {st.session_state['last_model_output']}\nNow the user asks: {query}" docs = rag_components['retriever'].get_relevant_documents(query) reranked_docs = rerank_with_cohere(query, docs, rag_components['cohere_client']) context = "\n\n".join([doc.page_content for doc in reranked_docs]) rag_chain = load_rag_chain(llm) if selected_model == "Ensemble": response = rag_chain.invoke({"query": query, "context": context, "current_query": current_query}) result = response.get('result', response.get('answer', str(response))) else: response = rag_chain.invoke({"query": query, "context": context}) result = response.get('content', response.get('result', response.get('answer', str(response)))) st.session_state["last_model_output"] = result return result except Exception as e: print(f"Error in get_reranked_response: {str(e)}") return "I apologize, but I'm having trouble processing your request right now." if __name__ == "__main__": print("This is a module - import it instead of running directly")