Spaces:
Sleeping
Sleeping
import gradio as gr | |
import chromadb | |
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# πΉ Load Mistral-7B for LLM Responses | |
import os | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# πΉ Load API Token from Hugging Face Secrets | |
HF_TOKEN = os.getenv("api_key") # β Securely load API key | |
# πΉ Ensure API Token is Loaded | |
if HF_TOKEN is None: | |
raise ValueError("β Hugging Face API token not found. Add `HF_TOKEN` in Hugging Face Secrets.") | |
# πΉ Load Mistral-7B-Instruct with Authentication | |
llm_name = "mistralai/Mistral-7B-Instruct-v0.1" | |
llm_tokenizer = AutoTokenizer.from_pretrained(llm_name, use_auth_token=HF_TOKEN) | |
llm_model = AutoModelForCausalLM.from_pretrained( | |
llm_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
use_auth_token=HF_TOKEN | |
) | |
# πΉ Optimize Mistral for Faster Inference | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.benchmark = True | |
llm_model = torch.compile(llm_model) | |
# πΉ Initialize ChromaDB | |
import os | |
import zipfile | |
# πΉ Unzip ChromaDB database if not extracted | |
if not os.path.exists("./chroma_db"): | |
with zipfile.ZipFile("chroma_db.zip", 'r') as zip_ref: | |
zip_ref.extractall("./") | |
print("β ChromaDB database loaded!") | |
import chromadb | |
# πΉ Load ChromaDB from local storage | |
chroma_client = chromadb.PersistentClient(path="./chroma_db") | |
collection = chroma_client.get_or_create_collection(name="hepB_knowledge") | |
print("β ChromaDB initialized!") | |
# πΉ Function to Generate LLM Responses | |
import torch | |
# πΉ Detect Device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"β Using device: {device}") | |
def generate_humanized_response(query, retrieved_text): | |
"""Passes retrieved chunks through Mistral-7B to improve readability.""" | |
# πΉ Truncate retrieved text to avoid long input errors | |
retrieved_text = retrieved_text[:500] | |
prompt = f"""You are a medical assistant. Answer the following question based on retrieved text: | |
Retrieved Text: | |
{retrieved_text} | |
User Query: {query} | |
Provide a well-structured, human-like response: | |
""" | |
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) # β Uses GPU if available, otherwise CPU | |
output = llm_model.generate(**inputs, max_new_tokens=150, do_sample=True) | |
response = llm_tokenizer.decode(output[0], skip_special_tokens=True) | |
return response | |
# πΉ Load BioMedBERT for Embeddings | |
embed_model_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract" | |
embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name) | |
embed_model = AutoModel.from_pretrained(embed_model_name) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
embed_model.to(device) | |
# πΉ Function to Generate Text Embeddings | |
def get_embedding(text): | |
"""Generates BioMedBERT embeddings using the CLS token (max 512 tokens).""" | |
inputs = embed_tokenizer( | |
text, | |
return_tensors="pt", | |
truncation=True, | |
padding="max_length", | |
max_length=512 | |
).to(device) | |
with torch.no_grad(): | |
outputs = embed_model(**inputs) | |
cls_embedding = outputs.last_hidden_state[:, 0, :].cpu() # Move back to CPU | |
return cls_embedding.squeeze().numpy().tolist() | |
# πΉ Function to Retrieve Similar Chunks | |
def retrieve_similar_chunks(query, top_k=5, similarity_threshold=0.5): | |
"""Finds top-k similar chunks from ChromaDB using cosine similarity.""" | |
print("πΉ Generating embedding for query...") | |
query_embedding = get_embedding(query) | |
print("πΉ Querying ChromaDB...") | |
results = collection.query( | |
query_embeddings=[query_embedding], | |
n_results=top_k | |
) | |
# β Check if results are empty before accessing scores | |
if not results["documents"] or not results["distances"]: | |
print("β No relevant chunks found in ChromaDB.") | |
return ["Sorry, I couldn't find relevant information."] | |
print(f"πΉ Retrieved {len(results['documents'])} chunks from ChromaDB.") | |
# π Print similarity scores | |
for i, score in enumerate(results["distances"]): | |
print(f"Chunk {i+1} Score: {score}") | |
# π Filter out low-score chunks | |
filtered_results = [] | |
for doc, scores in zip(results["documents"], results["distances"]): | |
if scores and scores[0] >= similarity_threshold: # β Avoid IndexError | |
filtered_results.append(doc) | |
print("β Retrieval completed.") | |
return filtered_results if filtered_results else ["Sorry, I couldn't find relevant information."] | |
# πΉ Chatbot Function | |
def chatbot(query): | |
"""Returns a structured and human-like answer using Mistral-7B.""" | |
retrieved_chunks = retrieve_similar_chunks(query) | |
if not retrieved_chunks or retrieved_chunks == ["No relevant information found."]: | |
return "Sorry, I couldn't find relevant information." | |
retrieved_texts = [chunk if isinstance(chunk, str) else " ".join(chunk) for chunk in retrieved_chunks] | |
retrieved_text = "\n\n".join(retrieved_texts)[:500] | |
response = generate_humanized_response(query, retrieved_text) | |
return response | |
# πΉ Gradio Chat Interface | |
ui = gr.Interface( | |
fn=chatbot, | |
inputs=gr.Textbox(lines=2, placeholder="Ask about Hepatitis B..."), | |
outputs=gr.Textbox(), | |
title="π‘ Hepatitis B Chatbot", | |
description="βοΈ Ask questions based on WHO Hepatitis B guidelines (2024). Uses ChromaDB & Mistral-7B for responses.", | |
) | |
# π₯ Run the Chatbot | |
if __name__ == "__main__": | |
ui.launch() | |