File size: 5,334 Bytes
cc65a27
 
7a146f7
ee8c883
 
 
 
 
 
 
 
7a146f7
 
 
 
cc65a27
e543406
cc65a27
26cb36a
cc65a27
 
 
 
7a146f7
 
cc65a27
 
 
 
 
26cb36a
cc65a27
 
 
 
 
 
ee8c883
cc65a27
e543406
cc65a27
26cb36a
cc65a27
7a146f7
 
26cb36a
14eb06f
7a146f7
 
 
 
cc65a27
 
 
 
 
26cb36a
7a146f7
cc65a27
 
7a146f7
26cb36a
cc65a27
 
669346a
26cb36a
 
cc65a27
7a146f7
 
 
 
 
cc65a27
26cb36a
 
 
cc65a27
26cb36a
 
 
 
 
 
 
 
bcc3582
94b061e
26cb36a
 
 
ee8c883
cc65a27
 
26cb36a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc65a27
ee8c883
 
26cb36a
c67895f
26cb36a
3dc3add
26cb36a
 
 
 
 
7a146f7
 
26cb36a
bcc3582
26cb36a
 
 
 
 
7a146f7
 
 
 
26cb36a
7a146f7
26cb36a
7a146f7
 
 
26cb36a
 
7a146f7
26cb36a
 
 
 
 
 
3dc3add
26cb36a
 
d3ebd22
26cb36a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import torch
import torch.backends.cudnn as cudnn
import streamlit as st
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

# Enable CUDA optimizations if available
if torch.cuda.is_available():
    cudnn.benchmark = True

# Step 1: Load the PDF and create a vector store
@st.cache_resource
def load_pdf_to_vectorstore(pdf_path):
    # Load and split PDF
    loader = PyPDFLoader(pdf_path)
    documents = loader.load()
    
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=20,
        separators=["\n\n", "\n", ".", " ", ""]
    )
    
    chunks = text_splitter.split_documents(documents)
    
    # Create embeddings and vector store
    embeddings = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-MiniLM-L6-v2"
    )
    vectorstore = FAISS.from_documents(chunks, embeddings)
    
    return vectorstore

# Step 2: Initialize the LaMini model
@st.cache_resource
def setup_model():
    model_id = "MBZUAI/LaMini-Flan-T5-248M"  # Using smaller model for faster inference
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_id,
        # Removed low_cpu_mem_usage parameter
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
    )
    
    if torch.cuda.is_available():
        model = model.cuda()
    
    pipe = pipeline(
        "text2text-generation",
        model=model,
        tokenizer=tokenizer,
        max_length=256,
        do_sample=False,
        temperature=0.3,
        top_p=0.95,
        device=0 if torch.cuda.is_available() else -1,
        batch_size=1
    )
    return pipe

# Step 3: Generate a response using the model and vector store
def generate_response(pipe, vectorstore, user_input):
    # Get relevant context
    docs = vectorstore.similarity_search(user_input, k=2)
    context = "\n".join([
        f"Page {doc.metadata.get('page', 'unknown')}: {doc.page_content}"
        for doc in docs
    ])
    
    # Create prompt
    prompt = PromptTemplate(
        input_variables=["context", "question"],
        template="""
        Using the following medical text excerpts, answer the question.
        If the information isn't clearly provided in the context, or if you're unsure, please say so and recommend consulting a healthcare professional.
        
        Context: {context}
        
        Question: {question}
        
        Answer (citing relevant page numbers when possible):"""
    )
    
    # Generate response using the new method
    prompt_text = prompt.format(context=context, question=user_input)
    response = pipe(prompt_text)[0]['generated_text']
    
    return response

# Cache responses for repeated questions
@st.cache_data
def cached_generate_response(user_input, _pipe, _vectorstore):
    return generate_response(_pipe, _vectorstore, user_input)

# Batch processing for multiple questions
def batch_generate_responses(pipe, vectorstore, questions, batch_size=4):
    responses = []
    for i in range(0, len(questions), batch_size):
        batch = questions[i:i + batch_size]
        batch_responses = [generate_response(pipe, vectorstore, q) for q in batch]
        responses.extend(batch_responses)
    return responses

# Streamlit UI
def main():
    st.title("Medical Chatbot Assistant 🏥")
    
    # Use the PDF file from the root directory
    pdf_path = "Medical_book.pdf"
    
    if os.path.exists(pdf_path):
        # Initialize progress
        progress_text = "Operation in progress. Please wait."
        
        # Load vector store and model with progress indication
        with st.spinner("Loading PDF and initializing model..."):
            vectorstore = load_pdf_to_vectorstore(pdf_path)
            pipe = setup_model()
            st.success("Ready to answer questions!")
        
        # Create a chat-like interface
        if "messages" not in st.session_state:
            st.session_state.messages = []

        # Display chat history
        for message in st.session_state.messages:
            with st.chat_message(message["role"]):
                st.markdown(message["content"])

        # User input
        if prompt := st.chat_input("Ask your medical question:"):
            # Add user message to chat history
            st.session_state.messages.append({"role": "user", "content": prompt})
            with st.chat_message("user"):
                st.markdown(prompt)

            # Generate and display response
            with st.chat_message("assistant"):
                with st.spinner("Generating response..."):
                    response = cached_generate_response(prompt, pipe, vectorstore)
                    st.markdown(response)
                    # Add assistant message to chat history
                    st.session_state.messages.append({"role": "assistant", "content": response})
                    
    else:
        st.error("The file 'Medical_book.pdf' was not found in the root directory.")


main()