File size: 5,098 Bytes
e2b8671
58c74fe
49add26
 
 
a073b25
e2b8671
49add26
a073b25
 
e2b8671
 
49add26
 
 
 
 
a073b25
 
 
 
 
 
 
 
e2b8671
a073b25
 
 
 
 
 
e2b8671
a073b25
 
 
 
 
 
49add26
a073b25
 
e2b8671
58c74fe
a073b25
 
 
 
 
 
e2b8671
49add26
a073b25
49add26
 
a073b25
49add26
a073b25
 
49add26
 
a073b25
49add26
a073b25
 
 
e2b8671
49add26
a073b25
 
49add26
a073b25
49add26
a073b25
 
 
 
 
 
 
 
 
58c74fe
49add26
a073b25
49add26
58c74fe
 
 
e2b8671
49add26
a073b25
49add26
a073b25
58c74fe
49add26
 
 
f8d190b
49add26
a073b25
49add26
 
a073b25
 
49add26
a073b25
 
49add26
a073b25
 
 
 
 
 
49add26
 
a073b25
49add26
a073b25
49add26
a073b25
e2b8671
58c74fe
 
49add26
a073b25
58c74fe
a073b25
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import faiss
import pickle
import threading
import time
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

class FinancialChatbot:
    def __init__(self):
        # Load FAISS index
        self.faiss_index = faiss.read_index("financial_faiss.index")
        with open("index_map.pkl", "rb") as f:
            self.index_map = pickle.load(f)
        
        # Load BM25 keyword-based search
        with open("bm25_corpus.pkl", "rb") as f:
            self.bm25_corpus = pickle.load(f)
        self.bm25 = BM25Okapi(self.bm25_corpus)

        # Load SentenceTransformer for embedding-based retrieval
        self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2")

        # Load Qwen Model
        model_name = "Qwen/Qwen2.5-1.5b"
        self.qwen_model = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True
        )
        self.qwen_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

        # Guardrail: Blocked Words
        self.BLOCKED_WORDS = [
            "hack", "bypass", "illegal", "exploit", "scam", "kill", "laundering", 
            "murder", "suicide", "self-harm", "assault", "bomb", "terrorism", 
            "attack", "genocide", "mass shooting", "credit card number"
        ]

        # Relevance threshold
        self.min_similarity_threshold = 0.2

    def moderate_query(self, query):
        """Check if the query contains inappropriate words."""
        query_lower = query.lower()
        for word in self.BLOCKED_WORDS:
            if word in query_lower:
                return False  # Block query
        return True  # Allow query

    def query_faiss(self, query, top_k=5):
        """Retrieve relevant documents using FAISS and compute confidence scores."""
        query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
        distances, indices = self.faiss_index.search(query_embedding, top_k)

        results = []
        confidence_scores = []
        
        for idx, dist in zip(indices[0], distances[0]):
            if idx in self.index_map:
                similarity = 1 / (1 + dist)  # Convert L2 distance to similarity
                results.append(self.index_map[idx])
                confidence_scores.append(similarity)

        return results, confidence_scores

    def query_bm25(self, query, top_k=5):
        """Retrieve relevant documents using BM25 keyword-based search."""
        tokenized_query = query.lower().split()
        scores = self.bm25.get_scores(tokenized_query)
        top_indices = np.argsort(scores)[::-1][:top_k]
        
        results = []
        confidence_scores = []

        for idx in top_indices:
            if scores[idx] > 0:  # Ignore zero-score matches
                results.append(self.bm25_corpus[idx])
                confidence_scores.append(scores[idx])

        return results, confidence_scores

    def generate_answer(self, context, question):
        """Generate answer using the Qwen model."""
        input_text = f"Context: {context}\nQuestion: {question}\nAnswer:"
        inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
        outputs = self.qwen_model.generate(inputs, max_length=100)
        return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)

    def get_answer(self, query, timeout=200):
        """Fetch an answer from FAISS and Qwen model while handling timeouts."""
        result = ["No relevant information found", 0.0]  # Default response

        def task():
            if query.lower() in ["hi", "hello", "hey"]:
                result[:] = ["Hi, how can I help you?", 1.0]
                return

            if not self.moderate_query(query):
                result[:] = ["I'm unable to process your request due to inappropriate language.", 1.0]
                return

            faiss_results, faiss_conf = self.query_faiss(query)
            bm25_results, bm25_conf = self.query_bm25(query)

            all_results = faiss_results + bm25_results
            all_conf = faiss_conf + bm25_conf

            # Check relevance
            if not all_results or max(all_conf, default=0) < self.min_similarity_threshold:
                result[:] = ["No relevant information found", 1.0]
                return

            context = " ".join(all_results)
            answer = self.generate_answer(context, query)
            
            last_index = answer.rfind("Answer")
            if answer[last_index+9:11] == "--":
                result[:] = ["No relevant information found", 1.0]
            else:
                result[:] = [answer[last_index:], max(all_conf, default=0.9)]

        thread = threading.Thread(target=task)
        thread.start()
        thread.join(timeout)

        if thread.is_alive():
            return "No relevant information found", 1.0  # Timeout case

        return tuple(result)