File size: 7,573 Bytes
e1cc9d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
# import streamlit as st
# import faiss
# import pickle
# import numpy as np
# import torch
# from transformers import T5Tokenizer, T5ForConditionalGeneration
# from sentence_transformers import SentenceTransformer
# # Load LLM model (local folder)
# @st.cache_resource
# def load_llm():
# model_path = "./Generator_Model"
# tokenizer = T5Tokenizer.from_pretrained(model_path)
# model = T5ForConditionalGeneration.from_pretrained(model_path)
# return tokenizer, model
# # Load embedding model (local folder)
# @st.cache_resource
# def load_embedding_model():
# embed_model_path = "./Embedding_Model1"
# return SentenceTransformer(embed_model_path)
# # Load FAISS index and embeddings
# @st.cache_resource
# def load_faiss():
# # Load FAISS index
# faiss_index = faiss.read_index("faiss_index_file.index")
# # Load the texts (raw data)
# with open("texts.pkl", "rb") as f:
# data = pickle.load(f)
# # Load the embeddings
# embeddings = np.load("embeddings_file.npy", allow_pickle=True)
# return faiss_index, data, embeddings
# # Search function to find top-k contexts based on query
# def search(query, embed_model, index, data, k=5):
# # Generate query embedding
# query_embedding = embed_model.encode([query]).astype('float32')
# # Perform FAISS search
# _, I = index.search(query_embedding, k) # Top-k results
# results = [data[i] for i in I[0] if i != -1]
# return results
# # Generate response using the LLM model (T5 model)
# def generate_response(context, query, tokenizer, model):
# input_text = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
# inputs = tokenizer.encode(input_text, return_tensors="pt")
# outputs = model.generate(inputs, max_length=512, do_sample=True, temperature=0.7)
# response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# return response
# # Streamlit App
# def main():
# st.title("Local LLM + FAISS + Embedding Search App")
# st.markdown("π Ask a question, and get context-aware answers!")
# # Load everything once
# tokenizer, llm_model = load_llm()
# embed_model = load_embedding_model()
# faiss_index, data, embeddings = load_faiss()
# query = st.text_input("Enter your query:")
# if query:
# with st.spinner("Processing..."):
# # Search for relevant contexts based on the query
# contexts = search(query, embed_model, faiss_index, data)
# combined_context = " ".join(contexts)
# # Generate an answer using the LLM model
# response = generate_response(combined_context, query, tokenizer, llm_model)
# st.subheader("Response:")
# st.write(response)
# # st.subheader("Top Retrieved Contexts:")
# # for idx, ctx in enumerate(contexts, 1):
# # st.markdown(f"**{idx}.** {ctx}")
# if __name__ == "__main__":
# main()
###########################################
import os
import streamlit as st
import faiss
import pickle
import numpy as np
import torch
import gdown
from transformers import T5Tokenizer, T5ForConditionalGeneration
from sentence_transformers import SentenceTransformer
# Function to download a full folder from Google Drive
def download_folder_from_google_drive(folder_url, output_path):
if not os.path.exists(output_path):
gdown.download_folder(url=folder_url, output=output_path, quiet=False, use_cookies=False)
# Download individual files
def download_file_from_google_drive(file_id, destination):
if not os.path.exists(destination):
url = f"https://drive.google.com/uc?id={file_id}"
gdown.download(url, destination, quiet=False)
# Setup models and files
@st.cache_resource
def setup_files():
os.makedirs("models/embedding_model", exist_ok=True)
os.makedirs("models/generator_model", exist_ok=True)
os.makedirs("models/files", exist_ok=True)
# Download embedding model (folder)
download_folder_from_google_drive(
"https://drive.google.com/drive/folders/1GzPk2ehr7rzOr65Am1Hg3A87FOTNHLAM?usp=sharing",
"models/embedding_model"
)
# Download generator model (folder)
download_folder_from_google_drive(
"https://drive.google.com/drive/folders/1338KWiBE-6sWsTO2iH7Pgu8eRI7EE7Vr?usp=sharing",
"models/generator_model"
)
# Download FAISS index, texts.pkl, embeddings.npy
download_file_from_google_drive("11J_VI1buTgnvhoP3z2HM6X5aPzbBO2ed", "models/files/faiss_index_file.index")
download_file_from_google_drive("1RTEwp8xDgxLnRUiy7ClTskFuTu0GtWBT", "models/files/texts.pkl")
download_file_from_google_drive("1N54imsqJIJGeqM3buiRzp1ivK_BtC7rR", "models/files/embeddings.npy")
# Paths
EMBEDDING_MODEL_PATH = "models/embedding_model"
GENERATOR_MODEL_PATH = "models/generator_model"
FAISS_INDEX_PATH = "models/files/faiss_index_file.index"
TEXTS_PATH = "models/files/texts.pkl"
EMBEDDINGS_PATH = "models/files/embeddings.npy"
# Load LLM model (Generator model)
@st.cache_resource
def load_llm():
tokenizer = T5Tokenizer.from_pretrained(GENERATOR_MODEL_PATH)
model = T5ForConditionalGeneration.from_pretrained(GENERATOR_MODEL_PATH)
return tokenizer, model
# Load embedding model
@st.cache_resource
def load_embedding_model():
return SentenceTransformer(EMBEDDING_MODEL_PATH)
# Load FAISS index and embeddings
@st.cache_resource
def load_faiss():
faiss_index = faiss.read_index(FAISS_INDEX_PATH)
with open(TEXTS_PATH, "rb") as f:
data = pickle.load(f)
embeddings = np.load(EMBEDDINGS_PATH, allow_pickle=True)
return faiss_index, data, embeddings
# Search top-k contexts
def search(query, embed_model, index, data, k=5):
query_embedding = embed_model.encode([query]).astype('float32')
_, I = index.search(query_embedding, k)
results = [data[i] for i in I[0] if i != -1]
return results
# Generate response
def generate_response(context, query, tokenizer, model):
input_text = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
inputs = tokenizer.encode(input_text, return_tensors="pt")
outputs = model.generate(inputs, max_length=512, do_sample=True, temperature=0.7)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Streamlit app
def main():
st.set_page_config(page_title="Clinical QA with RAG", page_icon="π©Ί")
st.title("π Clinical QA System (RAG + FAISS + T5)")
st.markdown(
"""
Enter your **clinical question** below.
The system will retrieve relevant context and generate an informed answer using a local model. π
"""
)
# Download + Load everything
setup_files()
tokenizer, llm_model = load_llm()
embed_model = load_embedding_model()
faiss_index, data, embeddings = load_faiss()
query = st.text_input("π¬ Your Question:")
if query:
with st.spinner("π Retrieving and Generating..."):
contexts = search(query, embed_model, faiss_index, data)
combined_context = " ".join(contexts)
response = generate_response(combined_context, query, tokenizer, llm_model)
st.success("β
Answer Ready!")
st.subheader("π Response:")
st.write(response)
if __name__ == "__main__":
main()
|