inaamullah-younas's picture
Update app.py
17e49a6 verified
import gradio as gr
import chromadb
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
import torch
# πŸ”Ή Load Mistral-7B for LLM Responses
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
# πŸ”Ή Load API Token from Hugging Face Secrets
HF_TOKEN = os.getenv("api_key") # βœ… Securely load API key
# πŸ”Ή Ensure API Token is Loaded
if HF_TOKEN is None:
raise ValueError("❌ Hugging Face API token not found. Add `HF_TOKEN` in Hugging Face Secrets.")
# πŸ”Ή Load Mistral-7B-Instruct with Authentication
llm_name = "mistralai/Mistral-7B-Instruct-v0.1"
llm_tokenizer = AutoTokenizer.from_pretrained(llm_name, use_auth_token=HF_TOKEN)
llm_model = AutoModelForCausalLM.from_pretrained(
llm_name,
torch_dtype=torch.float16,
device_map="auto",
use_auth_token=HF_TOKEN
)
# πŸ”Ή Optimize Mistral for Faster Inference
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
llm_model = torch.compile(llm_model)
# πŸ”Ή Initialize ChromaDB
import os
import zipfile
# πŸ”Ή Unzip ChromaDB database if not extracted
if not os.path.exists("./chroma_db"):
with zipfile.ZipFile("chroma_db.zip", 'r') as zip_ref:
zip_ref.extractall("./")
print("βœ… ChromaDB database loaded!")
import chromadb
# πŸ”Ή Load ChromaDB from local storage
chroma_client = chromadb.PersistentClient(path="./chroma_db")
collection = chroma_client.get_or_create_collection(name="hepB_knowledge")
print("βœ… ChromaDB initialized!")
# πŸ”Ή Function to Generate LLM Responses
import torch
# πŸ”Ή Detect Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"βœ… Using device: {device}")
def generate_humanized_response(query, retrieved_text):
"""Passes retrieved chunks through Mistral-7B to improve readability."""
# πŸ”Ή Truncate retrieved text to avoid long input errors
retrieved_text = retrieved_text[:500]
prompt = f"""You are a medical assistant. Answer the following question based on retrieved text:
Retrieved Text:
{retrieved_text}
User Query: {query}
Provide a well-structured, human-like response:
"""
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) # βœ… Uses GPU if available, otherwise CPU
output = llm_model.generate(**inputs, max_new_tokens=150, do_sample=True)
response = llm_tokenizer.decode(output[0], skip_special_tokens=True)
return response
# πŸ”Ή Load BioMedBERT for Embeddings
embed_model_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract"
embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
embed_model = AutoModel.from_pretrained(embed_model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
embed_model.to(device)
# πŸ”Ή Function to Generate Text Embeddings
def get_embedding(text):
"""Generates BioMedBERT embeddings using the CLS token (max 512 tokens)."""
inputs = embed_tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=512
).to(device)
with torch.no_grad():
outputs = embed_model(**inputs)
cls_embedding = outputs.last_hidden_state[:, 0, :].cpu() # Move back to CPU
return cls_embedding.squeeze().numpy().tolist()
# πŸ”Ή Function to Retrieve Similar Chunks
def retrieve_similar_chunks(query, top_k=5, similarity_threshold=0.5):
"""Finds top-k similar chunks from ChromaDB using cosine similarity."""
print("πŸ”Ή Generating embedding for query...")
query_embedding = get_embedding(query)
print("πŸ”Ή Querying ChromaDB...")
results = collection.query(
query_embeddings=[query_embedding],
n_results=top_k
)
# βœ… Check if results are empty before accessing scores
if not results["documents"] or not results["distances"]:
print("❌ No relevant chunks found in ChromaDB.")
return ["Sorry, I couldn't find relevant information."]
print(f"πŸ”Ή Retrieved {len(results['documents'])} chunks from ChromaDB.")
# πŸ” Print similarity scores
for i, score in enumerate(results["distances"]):
print(f"Chunk {i+1} Score: {score}")
# πŸ” Filter out low-score chunks
filtered_results = []
for doc, scores in zip(results["documents"], results["distances"]):
if scores and scores[0] >= similarity_threshold: # βœ… Avoid IndexError
filtered_results.append(doc)
print("βœ… Retrieval completed.")
return filtered_results if filtered_results else ["Sorry, I couldn't find relevant information."]
# πŸ”Ή Chatbot Function
def chatbot(query):
"""Returns a structured and human-like answer using Mistral-7B."""
retrieved_chunks = retrieve_similar_chunks(query)
if not retrieved_chunks or retrieved_chunks == ["No relevant information found."]:
return "Sorry, I couldn't find relevant information."
retrieved_texts = [chunk if isinstance(chunk, str) else " ".join(chunk) for chunk in retrieved_chunks]
retrieved_text = "\n\n".join(retrieved_texts)[:500]
response = generate_humanized_response(query, retrieved_text)
return response
# πŸ”Ή Gradio Chat Interface
ui = gr.Interface(
fn=chatbot,
inputs=gr.Textbox(lines=2, placeholder="Ask about Hepatitis B..."),
outputs=gr.Textbox(),
title="πŸ’‘ Hepatitis B Chatbot",
description="βš•οΈ Ask questions based on WHO Hepatitis B guidelines (2024). Uses ChromaDB & Mistral-7B for responses.",
)
# πŸ”₯ Run the Chatbot
if __name__ == "__main__":
ui.launch()