BaRiDo commited on
Commit
66b3608
Β·
verified Β·
1 Parent(s): a0e73a7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -0
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import getpass
3
+
4
+ import streamlit as st
5
+
6
+ def get_credentials():
7
+ return {
8
+ "url" : "https://us-south.ml.cloud.ibm.com",
9
+ "apikey" : os.getenv("IBM_API_KEY")
10
+ }
11
+
12
+ model_id = "ibm/granite-3-8b-instruct"
13
+
14
+ parameters = {
15
+ "decoding_method": "greedy",
16
+ "max_new_tokens": 900,
17
+ "min_new_tokens": 0,
18
+ "repetition_penalty": 1
19
+ }
20
+
21
+ project_id = os.getenv("IBM_PROJECT_ID")
22
+ space_id = os.getenv("IBM_SPACE_ID")
23
+
24
+ from ibm_watsonx_ai.foundation_models import ModelInference
25
+
26
+ model = ModelInference(
27
+ model_id = model_id,
28
+ params = parameters,
29
+ credentials = get_credentials(),
30
+ project_id = project_id,
31
+ space_id = space_id
32
+ )
33
+
34
+ from ibm_watsonx_ai.client import APIClient
35
+
36
+ wml_credentials = get_credentials()
37
+ client = APIClient(credentials=wml_credentials, project_id=project_id, space_id=space_id)
38
+
39
+ vector_index_id = "14c14504-5f45-4e6c-8f0f-25f2378a1d99"
40
+ vector_index_details = client.data_assets.get_details(vector_index_id)
41
+ vector_index_properties = vector_index_details["entity"]["vector_index"]
42
+
43
+ top_n = 20 if vector_index_properties["settings"].get("rerank") else int(vector_index_properties["settings"]["top_k"])
44
+
45
+ def rerank( client, documents, query, top_n ):
46
+ from ibm_watsonx_ai.foundation_models import Rerank
47
+
48
+ reranker = Rerank(
49
+ model_id="cross-encoder/ms-marco-minilm-l-12-v2",
50
+ api_client=client,
51
+ params={
52
+ "return_options": {
53
+ "top_n": top_n
54
+ },
55
+ "truncate_input_tokens": 512
56
+ }
57
+ )
58
+
59
+ reranked_results = reranker.generate(query=query, inputs=documents)["results"]
60
+
61
+ new_documents = []
62
+
63
+ for result in reranked_results:
64
+ result_index = result["index"]
65
+ new_documents.append(documents[result_index])
66
+
67
+ return new_documents
68
+
69
+ from ibm_watsonx_ai.foundation_models.embeddings.sentence_transformer_embeddings import SentenceTransformerEmbeddings
70
+
71
+ emb = SentenceTransformerEmbeddings('sentence-transformers/all-MiniLM-L6-v2')
72
+
73
+ import subprocess
74
+ import gzip
75
+ import json
76
+ import chromadb
77
+ import random
78
+ import string
79
+
80
+ def hydrate_chromadb():
81
+ data = client.data_assets.get_content(vector_index_id)
82
+ content = gzip.decompress(data)
83
+ stringified_vectors = str(content, "utf-8")
84
+ vectors = json.loads(stringified_vectors)
85
+
86
+ chroma_client = chromadb.Client()
87
+
88
+ # make sure collection is empty if it already existed
89
+ collection_name = "my_collection"
90
+ try:
91
+ collection = chroma_client.delete_collection(name=collection_name)
92
+ except:
93
+ print("Collection didn't exist - nothing to do.")
94
+ collection = chroma_client.create_collection(name=collection_name)
95
+
96
+ vector_embeddings = []
97
+ vector_documents = []
98
+ vector_metadatas = []
99
+ vector_ids = []
100
+
101
+ for vector in vectors:
102
+ vector_embeddings.append(vector["embedding"])
103
+ vector_documents.append(vector["content"])
104
+ metadata = vector["metadata"]
105
+ lines = metadata["loc"]["lines"]
106
+ clean_metadata = {}
107
+ clean_metadata["asset_id"] = metadata["asset_id"]
108
+ clean_metadata["asset_name"] = metadata["asset_name"]
109
+ clean_metadata["url"] = metadata["url"]
110
+ clean_metadata["from"] = lines["from"]
111
+ clean_metadata["to"] = lines["to"]
112
+ vector_metadatas.append(clean_metadata)
113
+ asset_id = vector["metadata"]["asset_id"]
114
+ random_string = ''.join(random.choices(string.ascii_uppercase + string.digits, k=10))
115
+ id = "{}:{}-{}-{}".format(asset_id, lines["from"], lines["to"], random_string)
116
+ vector_ids.append(id)
117
+
118
+ collection.add(
119
+ embeddings=vector_embeddings,
120
+ documents=vector_documents,
121
+ metadatas=vector_metadatas,
122
+ ids=vector_ids
123
+ )
124
+ return collection
125
+
126
+ chroma_collection = hydrate_chromadb()
127
+
128
+ def proximity_search( question ):
129
+ query_vectors = emb.embed_query(question)
130
+ query_result = chroma_collection.query(
131
+ query_embeddings=query_vectors,
132
+ n_results=top_n,
133
+ include=["documents", "metadatas", "distances"]
134
+ )
135
+
136
+ documents = list(reversed(query_result["documents"][0]))
137
+
138
+ if vector_index_properties["settings"].get("rerank"):
139
+ documents = rerank(client, documents, question, vector_index_properties["settings"]["top_k"])
140
+
141
+ return "\n".join(documents)
142
+
143
+ # Streamlit UI
144
+ st.title("πŸ” IBM Watson RAG Chatbot")
145
+
146
+ # User input in Streamlit
147
+ question = st.text_input("Enter your question:")
148
+
149
+ if question:
150
+ # Retrieve relevant grounding context
151
+ grounding = proximity_search(question)
152
+
153
+ # Format the question with retrieved context
154
+ formatted_question = f"""<|start_of_role|>user<|end_of_role|>Use the following pieces of context to answer the question.
155
+
156
+ {grounding}
157
+
158
+ Question: {question}<|end_of_text|>
159
+ <|start_of_role|>assistant<|end_of_role|>"""
160
+
161
+ # Placeholder for a prompt input (Optional)
162
+ prompt_input = "" # Set this dynamically if needed
163
+ prompt = f"""{prompt_input}{formatted_question}"""
164
+
165
+ # Simulated AI response (Replace with actual model call)
166
+ generated_response = f"AI Response based on: {prompt}"
167
+
168
+ # Display results
169
+ st.subheader("πŸ“Œ Retrieved Context")
170
+ st.write(grounding)
171
+
172
+ st.subheader("πŸ€– AI Response")
173
+ st.write(generated_response)