File size: 5,581 Bytes
08c64f2
 
17e49a6
 
e9ae124
 
 
71a9549
e774835
 
 
71a9549
fa48186
e774835
 
 
 
 
e9ae124
e774835
e9ae124
 
e774835
 
 
e9ae124
 
e774835
e9ae124
 
 
 
08c64f2
e9ae124
b9d32b7
 
 
 
 
 
 
 
 
 
 
 
 
08c64f2
 
 
b9d32b7
 
 
e9ae124
a6faf11
 
 
 
 
 
e9ae124
 
a6faf11
 
 
 
e9ae124
08c64f2
e9ae124
 
08c64f2
e9ae124
 
 
 
 
a6faf11
 
e9ae124
 
 
a6faf11
8250433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9ae124
 
 
 
7f0a887
e9ae124
7f0a887
 
 
 
 
 
 
 
 
 
 
 
 
e9ae124
7f0a887
 
 
 
 
e9ae124
 
7f0a887
e9ae124
 
7f0a887
 
 
e9ae124
 
 
 
 
08c64f2
e9ae124
 
08c64f2
e9ae124
 
08c64f2
e9ae124
08c64f2
 
e9ae124
08c64f2
 
 
 
e9ae124
 
08c64f2
 
e9ae124
08c64f2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import gradio as gr
import chromadb
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM

import torch

# πŸ”Ή Load Mistral-7B for LLM Responses
import os
from transformers import AutoModelForCausalLM, AutoTokenizer

# πŸ”Ή Load API Token from Hugging Face Secrets
HF_TOKEN = os.getenv("api_key")  # βœ… Securely load API key

# πŸ”Ή Ensure API Token is Loaded
if HF_TOKEN is None:
    raise ValueError("❌ Hugging Face API token not found. Add `HF_TOKEN` in Hugging Face Secrets.")

# πŸ”Ή Load Mistral-7B-Instruct with Authentication
llm_name = "mistralai/Mistral-7B-Instruct-v0.1"
llm_tokenizer = AutoTokenizer.from_pretrained(llm_name, use_auth_token=HF_TOKEN)
llm_model = AutoModelForCausalLM.from_pretrained(
    llm_name, 
    torch_dtype=torch.float16, 
    device_map="auto", 
    use_auth_token=HF_TOKEN
)


# πŸ”Ή Optimize Mistral for Faster Inference
torch.backends.cuda.matmul.allow_tf32 = True  
torch.backends.cudnn.benchmark = True  
llm_model = torch.compile(llm_model)  

# πŸ”Ή Initialize ChromaDB
import os
import zipfile

# πŸ”Ή Unzip ChromaDB database if not extracted
if not os.path.exists("./chroma_db"):
    with zipfile.ZipFile("chroma_db.zip", 'r') as zip_ref:
        zip_ref.extractall("./")

print("βœ… ChromaDB database loaded!")

import chromadb

# πŸ”Ή Load ChromaDB from local storage
chroma_client = chromadb.PersistentClient(path="./chroma_db")
collection = chroma_client.get_or_create_collection(name="hepB_knowledge")

print("βœ… ChromaDB initialized!")


# πŸ”Ή Function to Generate LLM Responses
import torch

# πŸ”Ή Detect Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"βœ… Using device: {device}")

def generate_humanized_response(query, retrieved_text):
    """Passes retrieved chunks through Mistral-7B to improve readability."""
    
    # πŸ”Ή Truncate retrieved text to avoid long input errors
    retrieved_text = retrieved_text[:500]  

    prompt = f"""You are a medical assistant. Answer the following question based on retrieved text:

    Retrieved Text:
    {retrieved_text}

    User Query: {query}
    
    Provide a well-structured, human-like response:
    """
    
    inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)  # βœ… Uses GPU if available, otherwise CPU
    output = llm_model.generate(**inputs, max_new_tokens=150, do_sample=True)  
    response = llm_tokenizer.decode(output[0], skip_special_tokens=True)
    
    return response


# πŸ”Ή Load BioMedBERT for Embeddings
embed_model_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract"
embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
embed_model = AutoModel.from_pretrained(embed_model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
embed_model.to(device)

# πŸ”Ή Function to Generate Text Embeddings
def get_embedding(text):
    """Generates BioMedBERT embeddings using the CLS token (max 512 tokens)."""
    inputs = embed_tokenizer(
        text, 
        return_tensors="pt", 
        truncation=True,  
        padding="max_length",  
        max_length=512  
    ).to(device)
    
    with torch.no_grad():
        outputs = embed_model(**inputs)
    
    cls_embedding = outputs.last_hidden_state[:, 0, :].cpu()  # Move back to CPU
    return cls_embedding.squeeze().numpy().tolist()

# πŸ”Ή Function to Retrieve Similar Chunks
def retrieve_similar_chunks(query, top_k=5, similarity_threshold=0.5):
    """Finds top-k similar chunks from ChromaDB using cosine similarity."""
    print("πŸ”Ή Generating embedding for query...")
    query_embedding = get_embedding(query)  
    
    print("πŸ”Ή Querying ChromaDB...")
    results = collection.query(
        query_embeddings=[query_embedding],  
        n_results=top_k
    )

    # βœ… Check if results are empty before accessing scores
    if not results["documents"] or not results["distances"]:
        print("❌ No relevant chunks found in ChromaDB.")
        return ["Sorry, I couldn't find relevant information."]

    print(f"πŸ”Ή Retrieved {len(results['documents'])} chunks from ChromaDB.")

    # πŸ” Print similarity scores
    for i, score in enumerate(results["distances"]):
        print(f"Chunk {i+1} Score: {score}")

    # πŸ” Filter out low-score chunks
    filtered_results = []
    for doc, scores in zip(results["documents"], results["distances"]):
        if scores and scores[0] >= similarity_threshold:  # βœ… Avoid IndexError
            filtered_results.append(doc)

    print("βœ… Retrieval completed.")
    return filtered_results if filtered_results else ["Sorry, I couldn't find relevant information."]


# πŸ”Ή Chatbot Function
def chatbot(query):
    """Returns a structured and human-like answer using Mistral-7B."""
    retrieved_chunks = retrieve_similar_chunks(query)

    if not retrieved_chunks or retrieved_chunks == ["No relevant information found."]:
        return "Sorry, I couldn't find relevant information."

    retrieved_texts = [chunk if isinstance(chunk, str) else " ".join(chunk) for chunk in retrieved_chunks]
    retrieved_text = "\n\n".join(retrieved_texts)[:500]  

    response = generate_humanized_response(query, retrieved_text)
    return response

# πŸ”Ή Gradio Chat Interface
ui = gr.Interface(
    fn=chatbot,
    inputs=gr.Textbox(lines=2, placeholder="Ask about Hepatitis B..."),
    outputs=gr.Textbox(),
    title="πŸ’‘ Hepatitis B Chatbot",
    description="βš•οΈ Ask questions based on WHO Hepatitis B guidelines (2024). Uses ChromaDB & Mistral-7B for responses.",
)

# πŸ”₯ Run the Chatbot
if __name__ == "__main__":
    ui.launch()