Spaces:
Sleeping
Sleeping
File size: 5,581 Bytes
08c64f2 17e49a6 e9ae124 71a9549 e774835 71a9549 fa48186 e774835 e9ae124 e774835 e9ae124 e774835 e9ae124 e774835 e9ae124 08c64f2 e9ae124 b9d32b7 08c64f2 b9d32b7 e9ae124 a6faf11 e9ae124 a6faf11 e9ae124 08c64f2 e9ae124 08c64f2 e9ae124 a6faf11 e9ae124 a6faf11 8250433 e9ae124 7f0a887 e9ae124 7f0a887 e9ae124 7f0a887 e9ae124 7f0a887 e9ae124 7f0a887 e9ae124 08c64f2 e9ae124 08c64f2 e9ae124 08c64f2 e9ae124 08c64f2 e9ae124 08c64f2 e9ae124 08c64f2 e9ae124 08c64f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import gradio as gr
import chromadb
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
import torch
# πΉ Load Mistral-7B for LLM Responses
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
# πΉ Load API Token from Hugging Face Secrets
HF_TOKEN = os.getenv("api_key") # β
Securely load API key
# πΉ Ensure API Token is Loaded
if HF_TOKEN is None:
raise ValueError("β Hugging Face API token not found. Add `HF_TOKEN` in Hugging Face Secrets.")
# πΉ Load Mistral-7B-Instruct with Authentication
llm_name = "mistralai/Mistral-7B-Instruct-v0.1"
llm_tokenizer = AutoTokenizer.from_pretrained(llm_name, use_auth_token=HF_TOKEN)
llm_model = AutoModelForCausalLM.from_pretrained(
llm_name,
torch_dtype=torch.float16,
device_map="auto",
use_auth_token=HF_TOKEN
)
# πΉ Optimize Mistral for Faster Inference
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
llm_model = torch.compile(llm_model)
# πΉ Initialize ChromaDB
import os
import zipfile
# πΉ Unzip ChromaDB database if not extracted
if not os.path.exists("./chroma_db"):
with zipfile.ZipFile("chroma_db.zip", 'r') as zip_ref:
zip_ref.extractall("./")
print("β
ChromaDB database loaded!")
import chromadb
# πΉ Load ChromaDB from local storage
chroma_client = chromadb.PersistentClient(path="./chroma_db")
collection = chroma_client.get_or_create_collection(name="hepB_knowledge")
print("β
ChromaDB initialized!")
# πΉ Function to Generate LLM Responses
import torch
# πΉ Detect Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"β
Using device: {device}")
def generate_humanized_response(query, retrieved_text):
"""Passes retrieved chunks through Mistral-7B to improve readability."""
# πΉ Truncate retrieved text to avoid long input errors
retrieved_text = retrieved_text[:500]
prompt = f"""You are a medical assistant. Answer the following question based on retrieved text:
Retrieved Text:
{retrieved_text}
User Query: {query}
Provide a well-structured, human-like response:
"""
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) # β
Uses GPU if available, otherwise CPU
output = llm_model.generate(**inputs, max_new_tokens=150, do_sample=True)
response = llm_tokenizer.decode(output[0], skip_special_tokens=True)
return response
# πΉ Load BioMedBERT for Embeddings
embed_model_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract"
embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
embed_model = AutoModel.from_pretrained(embed_model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
embed_model.to(device)
# πΉ Function to Generate Text Embeddings
def get_embedding(text):
"""Generates BioMedBERT embeddings using the CLS token (max 512 tokens)."""
inputs = embed_tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=512
).to(device)
with torch.no_grad():
outputs = embed_model(**inputs)
cls_embedding = outputs.last_hidden_state[:, 0, :].cpu() # Move back to CPU
return cls_embedding.squeeze().numpy().tolist()
# πΉ Function to Retrieve Similar Chunks
def retrieve_similar_chunks(query, top_k=5, similarity_threshold=0.5):
"""Finds top-k similar chunks from ChromaDB using cosine similarity."""
print("πΉ Generating embedding for query...")
query_embedding = get_embedding(query)
print("πΉ Querying ChromaDB...")
results = collection.query(
query_embeddings=[query_embedding],
n_results=top_k
)
# β
Check if results are empty before accessing scores
if not results["documents"] or not results["distances"]:
print("β No relevant chunks found in ChromaDB.")
return ["Sorry, I couldn't find relevant information."]
print(f"πΉ Retrieved {len(results['documents'])} chunks from ChromaDB.")
# π Print similarity scores
for i, score in enumerate(results["distances"]):
print(f"Chunk {i+1} Score: {score}")
# π Filter out low-score chunks
filtered_results = []
for doc, scores in zip(results["documents"], results["distances"]):
if scores and scores[0] >= similarity_threshold: # β
Avoid IndexError
filtered_results.append(doc)
print("β
Retrieval completed.")
return filtered_results if filtered_results else ["Sorry, I couldn't find relevant information."]
# πΉ Chatbot Function
def chatbot(query):
"""Returns a structured and human-like answer using Mistral-7B."""
retrieved_chunks = retrieve_similar_chunks(query)
if not retrieved_chunks or retrieved_chunks == ["No relevant information found."]:
return "Sorry, I couldn't find relevant information."
retrieved_texts = [chunk if isinstance(chunk, str) else " ".join(chunk) for chunk in retrieved_chunks]
retrieved_text = "\n\n".join(retrieved_texts)[:500]
response = generate_humanized_response(query, retrieved_text)
return response
# πΉ Gradio Chat Interface
ui = gr.Interface(
fn=chatbot,
inputs=gr.Textbox(lines=2, placeholder="Ask about Hepatitis B..."),
outputs=gr.Textbox(),
title="π‘ Hepatitis B Chatbot",
description="βοΈ Ask questions based on WHO Hepatitis B guidelines (2024). Uses ChromaDB & Mistral-7B for responses.",
)
# π₯ Run the Chatbot
if __name__ == "__main__":
ui.launch()
|