pgurazada1 commited on
Commit
0c83a68
·
verified ·
1 Parent(s): bdae1d4

Create chat_interface.py

Browse files
Files changed (1) hide show
  1. chat_interface.py +154 -0
chat_interface.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import json
4
+ import chromadb
5
+
6
+ import gradio as gr
7
+
8
+ from dotenv import load_dotenv
9
+ from openai import OpenAI
10
+
11
+ from langchain_community.embeddings import AnyscaleEmbeddings
12
+ from langchain_community.vectorstores import Chroma
13
+
14
+ from huggingface_hub import CommitScheduler
15
+ from pathlib import Path
16
+
17
+
18
+ load_dotenv()
19
+
20
+ tesla_10k_collection = 'tesla-10k-2019-to-2023'
21
+
22
+ anyscale_api_key = os.environ['ANYSCALE_API_KEY']
23
+
24
+ client = OpenAI(
25
+ base_url="https://api.endpoints.anyscale.com/v1",
26
+ api_key=anyscale_api_key
27
+ )
28
+
29
+ qna_model = 'meta-llama/Meta-Llama-3-8B-Instruct'
30
+
31
+ embedding_model = AnyscaleEmbeddings(
32
+ client=client,
33
+ model='thenlper/gte-large'
34
+ )
35
+
36
+ chromadb_client = chromadb.PersistentClient(path='./tesla_db')
37
+
38
+ vectorstore_persisted = Chroma(
39
+ client=chromadb_client,
40
+ collection_name=tesla_10k_collection,
41
+ embedding_function=embedding_model
42
+ )
43
+
44
+ retriever = vectorstore_persisted.as_retriever(
45
+ search_type='similarity',
46
+ search_kwargs={'k': 5}
47
+ )
48
+
49
+ # Prepare the logging functionality
50
+
51
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
52
+ log_folder = log_file.parent
53
+
54
+ scheduler = CommitScheduler(
55
+ repo_id="document-qna-chroma-anyscale-logs",
56
+ repo_type="dataset",
57
+ folder_path=log_folder,
58
+ path_in_repo="data",
59
+ every=2
60
+ )
61
+
62
+ qna_system_message = """
63
+ You are an assistant to a financial services firm who answers user queries on annual reports.
64
+ Users will ask questions delimited by triple backticks, that is, ```.
65
+ User input will have the context required by you to answer user questions.
66
+ This context will begin with the token: ###Context.
67
+ The context contains references to specific portions of a document relevant to the user query.
68
+ Please answer only using the context provided in the input. However, do not mention anything about the context in your answer.
69
+ If the answer is not found in the context, respond "I don't know".
70
+ """
71
+
72
+ qna_user_message_template = """
73
+ ###Context
74
+ Here are some documents that are relevant to the question.
75
+ {context}
76
+ ```
77
+ {question}
78
+ ```
79
+ """
80
+
81
+ def predict(input: str, history):
82
+
83
+ """
84
+ Predict the response of the chatbot and complete a running list of chat history.
85
+ """
86
+
87
+ relevant_document_chunks = retriever.invoke(input)
88
+ context_list = [d.page_content for d in relevant_document_chunks]
89
+ context_for_query = "\n".join(context_list)
90
+
91
+ user_message = [{
92
+ 'role': 'user',
93
+ 'content': qna_user_message_template.format(
94
+ context=context_for_query,
95
+ question=input
96
+ )
97
+ }]
98
+
99
+ prompt = [{'role':'system', 'content': qna_system_message}]
100
+
101
+ for entry in history:
102
+ prompt += (
103
+ [{'role': 'user', 'content': entry[0]}] +
104
+ [{'role': 'assistant', 'content': entry[1]}]
105
+ )
106
+
107
+ final_prompt = prompt + user_message
108
+
109
+ try:
110
+
111
+ response = client.chat.completions.create(
112
+ model=qna_model,
113
+ messages=final_prompt,
114
+ temperature=0
115
+ )
116
+
117
+ prediction = response.choices[0].message.content.strip()
118
+ except Exception as e:
119
+ prediction = f"Sorry, I cannot answer your question at this point. {e}"
120
+
121
+ # While the prediction is made, log both the inputs and outputs to a local log file
122
+ # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
123
+ # access
124
+
125
+ with scheduler.lock:
126
+ with log_file.open("a") as f:
127
+ f.write(json.dumps(
128
+ {
129
+ 'user_input': user_input,
130
+ 'retrieved_context': context_for_query,
131
+ 'model_response': prediction
132
+ }
133
+ ))
134
+ f.write("\n")
135
+
136
+ return prediction
137
+
138
+ demo = gr.ChatInterface(
139
+ fn=predict,
140
+ title="AMA on Tesla 10-K statements",
141
+ description="This web API presents an interface to ask questions on contents of the Tesla 10-K reports for the period 2019 - 2023.",
142
+ article="Note that questions that are not relevant to the Tesla 10-K report will not be answered.",
143
+ examples=[["What was the total revenue of the company in 2022?"],
144
+ ["Summarize the Management Discussion and Analysis section of the 2021 report in 50 words."],
145
+ ["What was the company's debt level in 2020?"],
146
+ ["Identify 5 key risks identified in the 2019 10k report? Respond with bullet point summaries."],
147
+ ["What is the view of the management on the future of electric vehicle batteries?"]
148
+ ],
149
+ cache_examples=False,
150
+ concurrency_limit=8,
151
+ show_progress="full"
152
+ )
153
+
154
+ demo.launch(auth=("demouser", os.getenv('PASSWD')))