File size: 5,655 Bytes
7b666bb
c0a164f
5f45885
911335e
a0f23a4
911335e
4f13fd4
911335e
a0f23a4
911335e
28202fc
 
a0f23a4
911335e
28202fc
 
 
 
 
 
 
 
 
911335e
 
28202fc
 
 
 
a0f23a4
 
 
 
 
 
 
 
28202fc
911335e
 
28202fc
 
 
 
 
 
 
 
 
 
911335e
 
 
28202fc
911335e
28202fc
911335e
 
 
 
 
28202fc
 
 
911335e
28202fc
5f45885
783a14e
28202fc
 
 
 
 
 
 
783a14e
911335e
28202fc
 
 
 
 
911335e
28202fc
 
 
 
 
 
 
 
911335e
 
28202fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
911335e
 
 
 
 
4f13fd4
911335e
28202fc
783a14e
 
 
 
 
28202fc
783a14e
28202fc
783a14e
911335e
28202fc
783a14e
28202fc
 
783a14e
28202fc
783a14e
28202fc
 
783a14e
28202fc
 
 
 
783a14e
 
5f45885
28202fc
5f45885
911335e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import streamlit as st
import os
import requests
import faiss
import numpy as np
from pdfminer.high_level import extract_text
from sentence_transformers import SentenceTransformer
from langdetect import detect

# Load the Hugging Face token
HUGGINGFACE_TOKEN = os.environ.get("Key2")
HF_MODEL = "HuggingFaceH4/zephyr-7b-alpha"

# Load Sentence Transformer Model
EMBEDDER = SentenceTransformer("all-MiniLM-L6-v2")

# Default system prompts
SYSTEM_PROMPTS = {
    "Multi-Query": "Generate five alternative versions of the user question: {question}",
    "RAG Fusion": "Combine multiple queries into a single, refined query: {question}",
    "Decomposition": "Break down the user question into simpler sub-questions: {question}",
    "Step Back": "Refine the user question by asking a more general question: {question}",
    "HyDE": "Generate a hypothetical document relevant to the user question: {question}",
}


# Helper function to interact with Hugging Face API
def query_hf(prompt, max_new_tokens=1000, temperature=0.7, top_k=50):
    headers = {"Authorization": f"Bearer {HUGGINGFACE_TOKEN}"}
    payload = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
            "top_k": top_k,
        },
    }
    response = requests.post(f"https://api-inference.huggingface.co/models/{HF_MODEL}", headers=headers, json=payload)
    if response.status_code == 200:
        return response.json()[0]["generated_text"]
    st.error(f"Error: {response.status_code} - {response.text}")


# Extract text from PDF
def extract_pdf_text(pdf_file):
    return extract_text(pdf_file).split("\n")


# Chunk text into segments
def chunk_text(text_lines, chunk_size=500):
    words = " ".join(text_lines).split()
    return [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]


# Build FAISS Index
def build_index(embeddings):
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings)
    return index


# Search FAISS Index
def search_index(query_embedding, index, top_k=5):
    distances, indices = index.search(query_embedding, top_k)
    return indices[0]


# Embed PDF content and build FAISS index
def process_pdf(pdf_file):
    text_lines = extract_pdf_text(pdf_file)
    chunks = chunk_text(text_lines)
    embeddings = EMBEDDER.encode(chunks, convert_to_tensor=False)
    faiss_index = build_index(np.array(embeddings))
    return chunks, faiss_index


# Generate query translations
def translate_query(prompt, method, max_new_tokens, temperature, top_k):
    formatted_prompt = SYSTEM_PROMPTS[method].format(question=prompt)
    return query_hf(formatted_prompt, max_new_tokens, temperature, top_k).split("\n")


# Retrieve relevant chunks from FAISS index
def retrieve_chunks(translated_queries, faiss_index, chunks, top_k=5):
    relevant_chunks = []
    for query in translated_queries:
        query_embedding = EMBEDDER.encode([query], convert_to_tensor=False)
        indices = search_index(np.array(query_embedding), faiss_index, top_k)
        relevant_chunks.extend([chunks[i] for i in indices])
    return relevant_chunks


# Generate final response using RAG approach
def generate_final_response(prompt, context, max_new_tokens, temperature, top_k):
    input_text = f"Context: {context}\n\nAnswer this question: {prompt}"
    return query_hf(input_text, max_new_tokens, temperature, top_k)


# Streamlit UI
def main():
    st.title("Enhanced RAG Model with FAISS Indexing")

    # Sidebar Inputs
    pdf_file = st.sidebar.file_uploader("Upload PDF", type="pdf")
    query_translation = st.sidebar.selectbox("Query Translation Method", list(SYSTEM_PROMPTS.keys()))
    similarity_method = st.sidebar.selectbox("Similarity Search Method", ["Cosine Similarity", "KNN"])
    k_value = st.sidebar.slider("K Value (for KNN)", 1, 10, 5) if similarity_method == "KNN" else 5
    max_new_tokens = st.sidebar.slider("Max New Tokens", 10, 1000, 500)
    temperature = st.sidebar.slider("Temperature", 0.1, 1.0, 0.7)
    top_k = st.sidebar.slider("Top K", 1, 100, 50)

    # Input Prompt
    prompt = st.text_input("Enter your query:")

    # State Management
    if 'chunks' not in st.session_state:
        st.session_state.chunks = []
    if 'faiss_index' not in st.session_state:
        st.session_state.faiss_index = None

    # Step 1: Process PDF
    if st.button("1. Embed PDF") and pdf_file:
        st.session_state.chunks, st.session_state.faiss_index = process_pdf(pdf_file)
        st.success("PDF Embedded Successfully")

    # Step 2: Generate Translated Queries
    if st.button("2. Query Translation") and prompt:
        st.session_state.translated_queries = translate_query(prompt, query_translation, max_new_tokens, temperature, top_k)
        st.write("**Generated Queries:**", st.session_state.translated_queries)

    # Step 3: Retrieve Relevant Chunks
    if st.button("3. Retrieve Documents") and st.session_state.translated_queries:
        st.session_state.relevant_chunks = retrieve_chunks(st.session_state.translated_queries, st.session_state.faiss_index, st.session_state.chunks, top_k=k_value)
        st.write("**Retrieved Chunks:**", st.session_state.relevant_chunks)

    # Step 4: Generate Final Response
    if st.button("4. Generate Final Response") and st.session_state.relevant_chunks:
        context = "\n".join(st.session_state.relevant_chunks)
        final_response = generate_final_response(prompt, context, max_new_tokens, temperature, top_k)
        st.subheader("Final Response:")
        st.write(final_response)


if __name__ == "__main__":
    main()