File size: 6,318 Bytes
4e2451c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1cc9d4
 
 
 
 
06bf607
 
4e2451c
6869969
 
 
17e35db
4e2451c
4cc245d
 
4e2451c
6869969
e1cc9d4
 
d50b52d
 
 
9045ce9
d50b52d
9045ce9
d50b52d
 
 
e1cc9d4
6869969
e1cc9d4
 
b41cc2a
 
6869969
e1cc9d4
6869969
e1cc9d4
 
4e2451c
 
 
 
 
e1cc9d4
6869969
 
 
 
 
 
 
4e2451c
6869969
 
4e2451c
 
e1cc9d4
 
6869969
e1cc9d4
 
6869969
e1cc9d4
 
 
 
4e2451c
e1cc9d4
4e2451c
 
e1cc9d4
4e2451c
 
 
 
 
 
 
6869969
 
 
4e2451c
e1cc9d4
4e2451c
e1cc9d4
 
4e2451c
6869969
e1cc9d4
6869969
e1cc9d4
4e2451c
 
e1cc9d4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# import streamlit as st
# import faiss
# import pickle
# import numpy as np
# import torch
# from transformers import T5Tokenizer, T5ForConditionalGeneration
# from sentence_transformers import SentenceTransformer

# # Load LLM model (local folder)
# @st.cache_resource
# def load_llm():
#     model_path = "./Generator_Model"
#     tokenizer = T5Tokenizer.from_pretrained(model_path)
#     model = T5ForConditionalGeneration.from_pretrained(model_path)
#     return tokenizer, model

# # Load embedding model (local folder)
# @st.cache_resource
# def load_embedding_model():
#     embed_model_path = "./Embedding_Model1"
#     return SentenceTransformer(embed_model_path)

# # Load FAISS index and embeddings
# @st.cache_resource
# def load_faiss():
#     faiss_index = faiss.read_index("faiss_index_file.index")
#     data = np.load("embeddings_file.npy", allow_pickle=True)
#     return faiss_index, data


# # Search function
# def search(query, embed_model, index, data):
#     query_embedding = embed_model.encode([query]).astype('float32')
#     _, I = index.search(query_embedding, k=5)  # Top 5 results
#     results = [data['texts'][i] for i in I[0] if i != -1]
#     return results

# # Generate response using LLM
# def generate_response(context, query, tokenizer, model):
#     input_text = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
#     inputs = tokenizer.encode(input_text, return_tensors="pt")
#     outputs = model.generate(inputs, max_length=512, do_sample=True, temperature=0.7)
#     response = tokenizer.decode(outputs[0], skip_special_tokens=True)
#     return response

# # Streamlit App
# def main():
#     st.title("Local LLM + FAISS + Embedding Search App")
#     st.markdown("πŸ” Ask a question, and get context-aware answers!")

#     # Load everything once
#     tokenizer, llm_model = load_llm()
#     embed_model = load_embedding_model()
#     faiss_index, data = load_faiss()

#     query = st.text_input("Enter your query:")

#     if query:
#         with st.spinner("Processing..."):
#             # Search relevant contexts
#             contexts = search(query, embed_model, faiss_index, data)
#             combined_context = " ".join(contexts)

#             # Generate answer
#             response = generate_response(combined_context, query, tokenizer, llm_model)

#             st.subheader("Response:")
#             st.write(response)

#             st.subheader("Top Retrieved Contexts:")
#             for idx, ctx in enumerate(contexts, 1):
#                 st.markdown(f"**{idx}.** {ctx}")

# if __name__ == "__main__":
#     main()



###########################
import os
import streamlit as st
import faiss
import pickle
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5ForConditionalGeneration,AutoModelForSeq2SeqLM
 

# Paths (everything is local now)
FAISS_INDEX_PATH = "faiss_index_file.index"
TEXTS_PATH = "texts.pkl"
EMBEDDINGS_PATH = "embeddings_file.npy"

# EMBEDDING_MODEL_NAME = "Ah1111/Embedding_Model"
# GENERATOR_MODEL_NAME = "Ah1111/Generator_Model"

# Load generator model (T5)
@st.cache_resource
def load_llm():
    tokenizer = T5Tokenizer.from_pretrained("Ah1111/Generator_Model")
    model = T5ForConditionalGeneration.from_pretrained("Ah1111/Generator_Model")
    return tokenizer, model

    # model_name = "google/flan-t5-base"

    # tokenizer = AutoTokenizer.from_pretrained(model_name)
    # model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    # return tokenizer, model

# Load embedding model (custom Hugging Face model)
@st.cache_resource
def load_embedding_model():
    tokenizer = AutoTokenizer.from_pretrained("Ah1111/Embedding_Model")
    model = AutoModel.from_pretrained("Ah1111/Embedding_Model")
    return tokenizer, model

# Load FAISS index and texts
@st.cache_resource
def load_faiss():
    faiss_index = faiss.read_index(FAISS_INDEX_PATH)
    with open(TEXTS_PATH, "rb") as f:
        data = pickle.load(f)
    embeddings = np.load(EMBEDDINGS_PATH, allow_pickle=True)
    return faiss_index, data, embeddings

# Function to encode query using the embedding model
def encode_query(query, tokenizer, model):
    inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        embeddings = model(**inputs).last_hidden_state.mean(dim=1)
    return embeddings.cpu().numpy()

# Search top-k contexts
def search(query, tokenizer, model, index, data, k=5):
    query_embedding = encode_query(query, tokenizer, model).astype('float32')
    _, I = index.search(query_embedding, k)
    results = [data[i] for i in I[0] if i != -1]
    return results

# Generate response using generator model
def generate_response(context, query, tokenizer, model):
    input_text = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
    inputs = tokenizer.encode(input_text, return_tensors="pt", truncation=True)
    outputs = model.generate(inputs, max_length=512, do_sample=True, temperature=0.7)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

# Streamlit app
def main():
    st.set_page_config(page_title="Clinical QA with RAG", page_icon="🩺")
    st.title("πŸ”Ž Clinical QA System (RAG + FAISS + T5)")

    st.markdown(
        """

        Enter your **clinical question** below.  

        The system will retrieve relevant context and generate an informed answer using a local model. πŸš€

        """
    )

    # Load models and files
    embed_tokenizer, embed_model = load_embedding_model()
    gen_tokenizer, gen_model = load_llm()
    faiss_index, data, embeddings = load_faiss()

    query = st.text_input("πŸ’¬ Your Question:")

    if query:
        with st.spinner("πŸ” Retrieving and Generating..."):
            contexts = search(query, embed_tokenizer, embed_model, faiss_index, data)
            combined_context = " ".join(contexts)
            response = generate_response(combined_context, query, gen_tokenizer, gen_model)

            st.success("βœ… Answer Ready!")
            st.subheader("πŸ“„ Response:")
            st.write(response)

if __name__ == "__main__":
    main()