omri374 commited on
Commit
8f78c8f
·
verified ·
1 Parent(s): c39b19c

Upload 5 files

Browse files
src/data_handler.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import io
3
+ from typing import List
4
+
5
+ import pandas as pd
6
+ import base64
7
+
8
+
9
+ def generate_excel_base64(dataframe: pd.DataFrame) -> str:
10
+ """Generates an Excel file from the provided data frame and returns it as a base64 string."""
11
+ output_stream = io.BytesIO() # Create in-memory buffer
12
+
13
+ # Ensure `xlsxwriter` writes to the buffer
14
+ with pd.ExcelWriter(output_stream, engine="xlsxwriter") as writer:
15
+ dataframe.to_excel(writer, index=False, sheet_name="Data")
16
+
17
+ output_stream.seek(0) # Move to the beginning for reading
18
+ base64_data = base64.b64encode(output_stream.getvalue()).decode(
19
+ "utf-8"
20
+ ) # Encode to base64
21
+
22
+ return base64_data # Return base64 string directly
23
+
24
+
25
+ def generate_excel(dataframe: pd.DataFrame) -> str:
26
+ """Generates an Excel file from the provided data frame."""
27
+ output_stream = io.BytesIO() # Create in-memory buffer
28
+
29
+ # Ensure `xlsxwriter` writes to the buffer
30
+ with pd.ExcelWriter(output_stream, engine="xlsxwriter") as writer:
31
+ dataframe.to_excel(writer, index=False, sheet_name="Data")
32
+
33
+ output_stream.seek(0) # Move to the beginning for reading
34
+
35
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx") as tmp_file:
36
+ tmp_file.write(output_stream.getvalue()) # Write bytes to temp file
37
+ tmp_path = tmp_file.name # Get temp file path
38
+
39
+ return tmp_path # ✅ Return file path directly
src/gradio_utils.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ from src.llm_calls import llm_extract_table
4
+ from src.parse_response import extract_and_return_data_table
5
+ import gradio as gr
6
+
7
+
8
+ def __update_df_state(df_before, df_state, updated_df):
9
+ new_df = pd.DataFrame(updated_df)
10
+
11
+ if df_before is not None:
12
+ new_df_before = df_before + [df_state]
13
+ else:
14
+ new_df_before = [df_state]
15
+
16
+ new_df_after = [] # Clear redo history
17
+ new_df_state = new_df.copy()
18
+
19
+ return new_df_before, new_df_state, new_df_after
20
+
21
+
22
+ def extract_table_from_chat(
23
+ chat_output, df_before, df_state, df_after, llm_type, api_key, key="Medications"
24
+ ):
25
+ try:
26
+ updated_df = extract_and_return_data_table(chat_output=chat_output, key=key)
27
+ except ValueError:
28
+ try:
29
+ json_str = llm_extract_table(chat_output, llm_type, api_key)
30
+ updated_df = extract_and_return_data_table(chat_output=json_str, key=key)
31
+ except KeyError:
32
+ gr.Error(
33
+ "Cannot extract table information from chat. "
34
+ "Please ask the LLM to provide the dataset in JSON format.",
35
+ duration=None,
36
+ )
37
+ updated_df = df_before
38
+ except ValueError:
39
+ gr.Error(
40
+ "Cannot extract table information from chat. "
41
+ "Please ask the LLM to provide the dataset in JSON format.",
42
+ duration=None,
43
+ )
44
+ updated_df = df_before
45
+
46
+ new_df_before, new_df_state, new_df_after = __update_df_state(
47
+ df_before, df_state, updated_df
48
+ )
49
+ return (
50
+ new_df_state,
51
+ new_df_before,
52
+ new_df_state,
53
+ new_df_after,
54
+ gr.update(interactive=True),
55
+ gr.update(interactive=False),
56
+ )
57
+
58
+
59
+ def update_llm_selection(selected_llm):
60
+ if selected_llm == "OpenAI":
61
+ return gr.update(label="OpenAI API Key", placeholder="Enter OpenAI API Key")
62
+ elif selected_llm == "Perplexity":
63
+ return gr.update(
64
+ label="Perplexity API Key", placeholder="Enter Perplexity API Key"
65
+ )
66
+ else:
67
+ raise ValueError("Invalid LLM type selected.")
68
+
69
+
70
+ def edit_or_save_changes(updated_df, df_before, df_state, df_after, current_edit_mode):
71
+ """Save user changes, update undo history."""
72
+
73
+ new_df = pd.DataFrame(updated_df)
74
+ new_df_before = df_before + [df_state.copy()]
75
+ new_df_after = [] # Clear redo history
76
+ new_df_state = new_df.copy()
77
+
78
+ if current_edit_mode == "Save":
79
+ # User wants to move from save to edit
80
+ return (
81
+ new_df,
82
+ new_df_before,
83
+ new_df_state,
84
+ new_df_after,
85
+ gr.update(
86
+ interactive=True
87
+ ), # prev button is now enabled as there was a change
88
+ gr.update(interactive=False), # next button
89
+ gr.update(interactive=True), # df display
90
+ gr.update(value="Edit"), # edit button
91
+ "Edit",
92
+ )
93
+ elif current_edit_mode == "Edit":
94
+ return (
95
+ new_df_state,
96
+ new_df_before,
97
+ new_df_state,
98
+ new_df_after,
99
+ gr.update(interactive=False), # prev button
100
+ gr.update(interactive=False), # next button
101
+ gr.update(interactive=True), # df display
102
+ gr.update(value="Save"), # edit button
103
+ "Save",
104
+ )
105
+
106
+ else:
107
+ raise ValueError(f"Wrong edit mode selected: {current_edit_mode}. ")
108
+
109
+
110
+ def undo(df_before, df_state, df_after):
111
+ """Undo user change without enabling Save button."""
112
+ if not df_before:
113
+ return (
114
+ df_state,
115
+ df_before,
116
+ df_state,
117
+ df_after,
118
+ gr.update(interactive=False),
119
+ gr.update(interactive=(len(df_after) > 0)),
120
+ )
121
+
122
+ new_df_after = df_after + [df_state.copy()]
123
+ new_df_state = df_before[-1]
124
+ new_df_before = df_before[:-1]
125
+
126
+ return (
127
+ new_df_state,
128
+ new_df_before,
129
+ new_df_state,
130
+ new_df_after,
131
+ gr.update(interactive=(len(new_df_before) > 0)), # prev button
132
+ gr.update(interactive=(len(new_df_after) > 0)), # next button
133
+ )
134
+
135
+
136
+ def redo(df_before, df_state, df_after):
137
+ """Redo user change without enabling Save button."""
138
+ if not df_after:
139
+ return (
140
+ df_state,
141
+ df_before,
142
+ df_state,
143
+ df_after,
144
+ gr.update(interactive=(len(df_before) > 0)),
145
+ gr.update(interactive=False),
146
+ )
147
+
148
+ if df_state is None:
149
+ df_state = df_after
150
+
151
+ new_df_before = df_before + [df_state.copy()]
152
+ new_df_state = df_after[-1]
153
+ new_df_after = df_after[:-1]
154
+
155
+ return (
156
+ new_df_state,
157
+ new_df_before,
158
+ new_df_state,
159
+ new_df_after,
160
+ gr.update(interactive=(len(new_df_before) > 0)),
161
+ gr.update(interactive=(len(new_df_after) > 0)),
162
+ )
163
+
164
+
165
+ # def toggle_save_edit(button_state, dataframe_display):
166
+ # if button_state == "Edit":
167
+ # return "Save", gr.update(interactive=True), gr.update(interactive=True) # Enable DataFrame editing
168
+ # else:
169
+ # return "Edit", gr.update(interactive=True), gr.update(interactive=False) # Disable DataFrame editing
170
+ #
171
+
172
+
173
+ def upload_file(file, df_before, df_state, df_after):
174
+ if file is None:
175
+ return gr.update()
176
+
177
+ df = pd.read_excel(file.name, engine="openpyxl")
178
+
179
+ new_df_before, new_df_state, new_df_after = __update_df_state(
180
+ df_before, df_state, df
181
+ )
182
+
183
+ # print("Uploaded DataFrame:\n", df) # Print DataFrame to console
184
+ return (
185
+ df,
186
+ new_df_before,
187
+ df,
188
+ new_df_after,
189
+ gr.update(interactive=False),
190
+ gr.update(interactive=False),
191
+ )
src/llm_calls.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Generator, List, Optional
4
+
5
+ import pandas as pd
6
+ import requests
7
+ from dotenv import load_dotenv
8
+ from openai import OpenAI
9
+
10
+ load_dotenv()
11
+
12
+
13
+ def query_llm(
14
+ messages,
15
+ history: List,
16
+ df: Optional[pd.DataFrame],
17
+ llm_type: str,
18
+ api_key: str,
19
+ system_prompt: str,
20
+ ) -> Generator[str, None, None]:
21
+ """Chat function that streams responses using an LLM API.
22
+
23
+ Args:
24
+ messages (str or list): User input message(s).
25
+ history (list): Conversation history.
26
+ df (pd.DataFrame): a representation of the data already obtained
27
+ system_prompt (str): The syste prompt
28
+ api_key (str): The OpenAI api key
29
+ Returns:
30
+ str: The assistant's response.
31
+ """
32
+
33
+ if not api_key:
34
+ if llm_type == "OpenAI":
35
+ api_key = os.environ.get("OPENAI_API_KEY")
36
+ elif llm_type == "Perplexity":
37
+ api_key = os.environ.get("PERPLEXITY_API_KEY")
38
+ else:
39
+ yield "No API key provided for the selected LLM type."
40
+
41
+ print(f"LLM Type: {llm_type}, API Key len: {len(api_key)}") # Debugging
42
+
43
+ if isinstance(messages, str):
44
+ messages = [{"role": "user", "content": messages}]
45
+
46
+ # Extract last 2 messages from history (if available)
47
+ history = history[-2:] if history else []
48
+
49
+ # Build message history (prepend system prompt)
50
+ full_messages = [
51
+ {"role": "system", "content": system_prompt},
52
+ {"role": "user", "content": f"Past interactions: {history}"},
53
+ {
54
+ "role": "assistant",
55
+ "content": f"Dataset: {df.to_json() if df is not None else {}}",
56
+ },
57
+ ] + messages
58
+
59
+ if llm_type == "Perplexity":
60
+ yield from query_perplexity(full_messages, api_key=api_key)
61
+ elif llm_type == "OpenAI":
62
+ yield from query_openai(full_messages, api_key=api_key)
63
+ else:
64
+ yield "Unsupported LLM type. Please choose either 'OpenAI' or 'Perplexity'."
65
+
66
+
67
+ def query_perplexity(
68
+ full_messages,
69
+ api_key: str,
70
+ url="https://api.perplexity.ai/chat/completions",
71
+ model="sonar-pro",
72
+ ):
73
+ """Query Perplexity AI API for a response.
74
+
75
+ Args:
76
+ full_messages (list): List of messages in the conversation.
77
+ api_key (str): Perplexity API key.
78
+ url (str): API endpoint URL.
79
+ model (str): Model to use for the query.
80
+
81
+ Returns:
82
+ str: Parsed JSON response from Perplexity AI API.
83
+ """
84
+
85
+ payload = {
86
+ "model": model,
87
+ "messages": full_messages,
88
+ "stream": True,
89
+ }
90
+
91
+ headers = {
92
+ "Authorization": f"Bearer {api_key}",
93
+ "Content-Type": "application/json",
94
+ }
95
+
96
+ with requests.post(url, json=payload, headers=headers, stream=True) as response:
97
+ if response.status_code == 200:
98
+ for line in response.iter_lines():
99
+ if line:
100
+ try:
101
+ line = line.decode("utf-8").strip()
102
+ if line.startswith("data: "):
103
+ line = line[len("data: ") :] # Remove "data: " prefix
104
+
105
+ data = json.loads(line)
106
+ if "choices" in data and len(data["choices"]) > 0:
107
+ yield data["choices"][0]["message"]["content"]
108
+ except json.JSONDecodeError:
109
+ yield f"Error decoding JSON: {line}"
110
+ else:
111
+ yield f"API request failed with status code {response.status_code}, details: {response.text}"
112
+
113
+
114
+ def query_openai(full_messages, api_key: str) -> Generator[str, None, None]:
115
+ """Chat function that streams responses using OpenAI API.
116
+
117
+ Args:
118
+ full_messages (list): List of messages in the conversation.
119
+ api_key (str): OpenAI API key.
120
+ """
121
+ openai_client = OpenAI(api_key=api_key)
122
+
123
+ response = openai_client.chat.completions.create(
124
+ model="gpt-4-turbo",
125
+ messages=full_messages,
126
+ stream=True, # Enable streaming
127
+ )
128
+
129
+ llm_response = ""
130
+ for chunk in response:
131
+ if chunk.choices[0].delta.content:
132
+ llm_response += chunk.choices[0].delta.content
133
+ yield llm_response
134
+
135
+
136
+ def llm_extract_table(chat_output, llm_type, api_key) -> str:
137
+ system_prompt = """
138
+ You are a pharmacology assistant specialized in analyzing and structuring medical data.
139
+ Your role is to extract information in either markdown, JSON or text, and turn it structured information.
140
+ You will be given output from a conversation with an LLM. This conversation should have a dataset formatted
141
+ as either json or markdown. Extract the dataset and return a JSON object.
142
+ The dataset should be a JSON object with a dict per medication, with the following format:
143
+ ```json
144
+ {
145
+ "Medications": [
146
+ {"Name": "Medication Name", "key1": "value1", "key2": "value2",..},
147
+ {"Name": "Medication Name", "key1": "value1", "key2": "value2",..}
148
+ ]
149
+ }
150
+
151
+ Guidelines:
152
+ - Make sure the response contains only a valid JSON
153
+ - Avoid adding text before or after
154
+ """
155
+
156
+ response = query_llm(
157
+ messages=chat_output,
158
+ history=None,
159
+ df=None,
160
+ llm_type=llm_type,
161
+ api_key=api_key,
162
+ system_prompt=system_prompt,
163
+ )
164
+ json_str = "".join(response).strip()
165
+ return json_str
src/mocks.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import random
4
+ from typing import List, Optional
5
+
6
+ import pandas as pd
7
+
8
+
9
+ def get_current_df(dfs: List[pd.DataFrame], current: int) -> pd.DataFrame:
10
+ if len(dfs) == 0:
11
+ return pd.DataFrame()
12
+ else:
13
+ return dfs[current]
14
+
15
+
16
+ def query_llm_mock(
17
+ messages,
18
+ history: List,
19
+ df: pd.DataFrame,
20
+ llm_type: str,
21
+ api_key: str,
22
+ system_prompt: str,
23
+ ):
24
+ """Chat function that streams responses using mock llm.
25
+
26
+ Args:
27
+ messages (str or list): User input message(s).
28
+ history (list): Conversation history.
29
+ dfs (List[pd.DataFrame): a representation of the data already obtained
30
+ system_prompt (str): The syste prompt
31
+ openai_client (OpenAI): The OpenAI client
32
+ Returns:
33
+ str: The assistant's response.
34
+
35
+ """
36
+
37
+ mock_json = json.dumps(
38
+ {
39
+ "Medications": [
40
+ {
41
+ "Medication name": "Tamsulosin",
42
+ "Passes_RBB": "Yes",
43
+ "Random": random.random(),
44
+ },
45
+ {
46
+ "Medication name": "Metoprolol",
47
+ "Passes_RBB": "Yes",
48
+ "Random": random.random(),
49
+ },
50
+ {
51
+ "Medication name": "Bromocriptine",
52
+ "Passes_RBB": "Yes",
53
+ "Random": random.random(),
54
+ },
55
+ {
56
+ "Medication name": "Reserpine",
57
+ "Passes_RBB": "Yes",
58
+ "Random": random.random(),
59
+ },
60
+ {
61
+ "Medication name": "Rasagiline",
62
+ "Passes_RBB": "Yes",
63
+ "Random": random.random(),
64
+ },
65
+ ]
66
+ }
67
+ )
68
+ yield (
69
+ f"Good question!\n"
70
+ f"Here's the data frame in JSON format:\n"
71
+ f"```json\n"
72
+ f"{mock_json if random.random() > 0.5 else ''}\n"
73
+ f"```\n\n"
74
+ f"Hope this is useful."
75
+ )
76
+
77
+
78
+ def llm_extract_table_mock(chat_output, llm_type, api_key) -> str:
79
+ dic = {
80
+ "Medications": [
81
+ {
82
+ "Name": "Medication Name",
83
+ "key1": "value1",
84
+ "key2": "value2",
85
+ "Random": random.random(),
86
+ },
87
+ {
88
+ "Name": "Medication Name",
89
+ "key1": "value1",
90
+ "key2": "value2",
91
+ "Random": random.random(),
92
+ },
93
+ ]
94
+ }
95
+
96
+ return json.dumps(dic)
src/parse_response.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ from typing import Optional
4
+ import pandas as pd
5
+ import simplejson
6
+
7
+
8
+ def json_to_dict(response: str) -> dict:
9
+ """Convert a JSON string to a Python dictionary.
10
+ Args:
11
+ response (str): JSON string to convert.
12
+ Returns:
13
+ dict: Parsed JSON as a dictionary.
14
+ Raises:
15
+ ValueError: If the JSON string is invalid.
16
+ """
17
+
18
+ # extract dict from json
19
+ try:
20
+ match = re.search(r"\{.*}", response, re.DOTALL)
21
+ if match is None:
22
+ raise ValueError("No valid JSON found in the response.")
23
+ match_response = match.group()
24
+ match_response = json.loads(match_response)
25
+ except json.JSONDecodeError:
26
+ try:
27
+ match_response = simplejson.loads(response) # More forgiving JSON parser
28
+ except simplejson.JSONDecodeError as e:
29
+ raise ValueError(f"Invalid JSON response: {e}")
30
+
31
+ return match_response # Return as a structured dictionary
32
+
33
+
34
+ def json_to_pandas(json_data: str, key: Optional[str] = None) -> pd.DataFrame:
35
+ """Convert JSON data to a pandas DataFrame."""
36
+ try:
37
+ dic = json_to_dict(json_data)
38
+ if key:
39
+ dic = dic[key]
40
+ df = pd.DataFrame(dic)
41
+ return df
42
+ except ValueError as e:
43
+ raise ValueError(f"Invalid JSON data: {e}")
44
+
45
+
46
+ def extract_and_return_data_table(chat_output, key="Medications"):
47
+ """Extract a pandas data frame out of the chat.
48
+ Try rule-based first and use LLM if it fails."""
49
+
50
+ if chat_output:
51
+ if "content" in chat_output[-1]:
52
+ chat_output = chat_output[-1]["content"]
53
+
54
+ print(f"chat output: {chat_output}type: {type(chat_output)}")
55
+
56
+ try:
57
+ df = json_to_pandas(chat_output, key=key)
58
+ except ValueError as e:
59
+ raise ValueError(e)
60
+
61
+ return df