Spaces:
Sleeping
Sleeping
File size: 13,891 Bytes
719d0db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 |
# templates
import numpy as np
import streamlit as st
from typing import Dict, List
from models.prompts.identify_question import Template4IdentifyQuestion
from models.prompts.generate_explanation import Template4GenerateExplanation
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AIMessage
import utils.util_app as util_app
class StreamingChatCallbackHandler(BaseCallbackHandler):
def __init__(self):
pass
def on_llm_start(self, *args, **kwargs):
self.container = st.empty()
self.text = ""
def on_llm_new_token(self, token: str, *args, **kwargs):
self.text += token
self.container.markdown(
body=self.text,
unsafe_allow_html=False,
)
def on_llm_end(self, response: str, *args, **kwargs):
self.container.markdown(
body=response.generations[0][0].text,
unsafe_allow_html=False,
)
class RouteExplainer():
template_identify_question = Template4IdentifyQuestion()
template_generate_explanation = Template4GenerateExplanation()
def __init__(self,
llm,
cf_generator,
classifier) -> None:
assert cf_generator.problem == classifier.problem, "Problem type of cf_generator and predictor should coincide!"
self.coord_dim = 2
self.problem = cf_generator.problem
self.cf_generator = cf_generator
self.classifier = classifier
self.actual_route = None
self.cf_route = None
# templates
self.question_extractor = self.template_identify_question.sandwiches(llm)
self.explanation_generator = self.template_generate_explanation.sandwiches(llm)
#----------------
# whole pipeline
#----------------
def generate_explanation(self,
tour_list,
whynot_question: str,
actual_routes: list,
actual_labels: list,
node_feats: dict,
dist_matrix: np.array) -> str:
#--------------------------------
# define why & why-not questions
#--------------------------------
route_info_text = self.get_route_info_text(tour_list, actual_routes)
inputs = self.question_extractor.invoke({
"whynot_question": whynot_question,
"route_info": route_info_text
})
util_app.stream_words(inputs["summary"] + " " + inputs["intent"])
st.session_state.chat_history.append(AIMessage(content=inputs["summary"] + inputs["intent"]))
if not inputs["success"]:
return ""
#----------------------
# validate the CF edge
#----------------------
is_cf_edge_feasible, reason = self.validate_cf_edge(node_feats,
dist_matrix,
actual_routes[0],
inputs["cf_step"],
inputs["cf_visit"]-1)
# exception
if not is_cf_edge_feasible:
util_app.stream_words(reason)
return reason
#---------------------
# generate a cf route
#---------------------
cf_routes = self.cf_generator(actual_routes,
vehicle_id=0,
cf_step=inputs["cf_step"],
cf_next_node_id=inputs["cf_visit"]-1,
node_feats=node_feats,
dist_matrix=dist_matrix)
st.session_state.generated_cf_route = True
st.session_state.close_chat = True
st.session_state.cf_step = inputs["cf_step"]
#--------------------------------------
# classify the intentions of each edge
#--------------------------------------
cf_labels = self.classifier(self.classifier.get_inputs(cf_routes,
0,
node_feats,
dist_matrix))
st.session_state.cf_routes = cf_routes
st.session_state.cf_labels = cf_labels
#-------------------------------------
# generate a constrastive explanation
#-------------------------------------
comparison_results = self.get_comparison_results(question_summary=inputs["summary"],
tour_list=tour_list,
actual_routes=actual_routes,
actual_labels=actual_labels,
cf_routes=cf_routes,
cf_labels=cf_labels,
cf_step=inputs["cf_step"])
explanation = self.explanation_generator.invoke({
"comparison_results": comparison_results,
"intent": inputs["intent"]
}, config={"callbacks": [StreamingChatCallbackHandler()]})
return explanation
#-------------------------
# for exctracting inputs
#-------------------------
def get_route_info_text(self, tour_list, routes) -> str:
route_info = ""
# nodes
route_info += "Nodes(node id, name): "
for i, destination in enumerate(tour_list):
if i != len(tour_list) - 1:
route_info += f"({i+1}, {destination['name']}), "
else:
route_info += f"({i+1}, {destination['name']})\n"
# routes
route_info += "Route: "
for i, node_id in enumerate(routes[0]):
if i == 0:
route_info += f"{tour_list[node_id]['name']} "
else:
route_info += f"> (step {i}) > {tour_list[node_id]['name']})"
if i == len(routes[0]) - 1:
route_info += "\n"
else:
route_info += " "
return route_info
#--------------------------
# for validating a CF edge
#--------------------------
def validate_cf_edge(self,
node_feats: Dict[str, np.array],
dist_matrix: np.array,
route: List[int],
cf_step: int,
cf_visit: int) -> bool:
# calc current time
curr_time = node_feats["time_window"][route[0]][0] # start point's open time
for step in range(1, cf_step):
curr_node_id = route[step-1]
next_node_id = route[step]
curr_time += node_feats["service_time"][curr_node_id] + dist_matrix[curr_node_id][next_node_id]
curr_time = max(curr_time, node_feats["time_window"][next_node_id][0]) # waiting
# validate the cf edge
curr_node_id = route[cf_step-1]
next_node_id = cf_visit
next_node_close_time = node_feats["time_window"][next_node_id][1]
arrival_time = curr_time + node_feats["service_time"][curr_node_id] + dist_matrix[curr_node_id][next_node_id]
if next_node_close_time < arrival_time:
exceed_time = (arrival_time - next_node_close_time)
return False, f"Oops, your CF edge is infeasible because it does not meet the destination's close time by {util_app.add_time_unit(exceed_time)}."
else:
return True, "The CF edge is feasible!"
#-------------------------------
# for generating an explanation
#-------------------------------
def get_comparison_results(self,
tour_list,
question_summary,
actual_routes: List[List[int]],
actual_labels: List[List[int]],
cf_routes: List[List[int]],
cf_labels: List[List[int]],
cf_step: int) -> str:
comparison_results = "Question:\n" + question_summary + "\n"
comparison_results += "Actual route:\n" + \
self.get_route_info(tour_list, actual_routes[0], actual_labels[0], cf_step-1, "actual") + \
self.get_representative_values(actual_routes[0], actual_labels[0], cf_step-1, "actual")
comparison_results += "CF route:\n" + \
self.get_route_info(tour_list, cf_routes[0], cf_labels[0], cf_step-1, "CF") + \
self.get_representative_values(cf_routes[0], cf_labels[0], cf_step-1, "CF")
comparison_results += "Difference between two routes:\n" + self.get_diff(cf_step-1, actual_routes[0], cf_routes[0])
comparison_results += "Planed desination information:\n" + self.get_node_info()
return comparison_results
def get_route_info(self,
tour_list,
route: List[int],
label: List[int],
ex_step: int,
type: str) -> str:
def get_labelname(label_number):
return "route_len" if label_number == 0 else "time_window"
route_info = "- route: "
for i, node_id in enumerate(route):
if i == ex_step and i != len(route) - 1:
if type == "actual":
edge_label = {get_labelname(label[i])}
else:
edge_label = "user_preference"
route_info += f"{tour_list[node_id]['name']} > ({type} edge: {edge_label}) > "
elif i != len(route) - 1:
route_info += f"{tour_list[node_id]['name']} > ({get_labelname(label[i])}) > "
else:
route_info += f"{tour_list[node_id]['name']}\n"
return route_info
def get_representative_values(self, route, labels, ex_step, type) -> str:
time_window_ratio = self.get_intention_ratio(1, labels, ex_step) * 100
route_len_ratio = self.get_intention_ratio(0, labels, ex_step) * 100
return f"- short-term effect (immediate travel time): {self.get_immediate_state(route, ex_step)//60} minutes\n- long-term effect (total travel time): {self.get_route_length(route)//60} minutes\n- missed nodes: {self.get_infeasible_node_name(route)}\n- edge-intention ratio after the {type} edge: time_window {time_window_ratio: .1f}%, route_len {route_len_ratio: .1f}%"
def get_immediate_state(self, route, ex_step) -> str:
return st.session_state.dist_matrix[route[ex_step]][route[ex_step+1]]
def get_route_length(self, route) -> float:
route_length = 0.0
for i in range(len(route)-1):
route_length += st.session_state.dist_matrix[route[i]][route[i+1]]
return route_length
def get_infeasible_nodes(self, route) -> int:
return len(route) - (len(st.session_state.dist_matrix) - 1)
def get_infeasible_node_name(self, route) -> str:
if len(route) == len(st.session_state.dist_matrix) - 1:
return "none"
else:
num_nodes = np.arange(len(st.session_state.dist_matrix))
for node_id in route:
num_nodes = num_nodes[num_nodes != node_id]
return ",".join([st.session_state.tour_list[node_id]["name"] for node_id in num_nodes])
def get_intention_ratio(self,
intention: int,
labels: List[int],
ex_step: int) -> float:
np_labels = np.array(labels)
return np.sum(np_labels[ex_step:] == intention) / len(labels[ex_step:])
def get_diff(self, ex_step, actual_route, cf_route) -> str:
def get_str(effect: float):
long_effect_str = "The actual route increases it by" if effect > 0 else "The actual route reduces it by"
long_effect_str += util_app.add_time_unit(abs(effect))
return long_effect_str
def get_str2(num_nodes: int, num_missed_nodes):
if num_nodes < 0:
num_nodes_str = f"The actual route visits {abs(num_nodes)} more nodes"
elif num_nodes == 0:
if num_missed_nodes == 0:
num_nodes_str = f"Both routes missed no node,"
else:
num_nodes_str = f"Both routes missed the same number of nodes ({abs(num_missed_nodes)} node(s))"
else:
num_nodes_str = f"The actual route visits {abs(num_nodes)} less nodes"
return num_nodes_str
# short/long-term effects
short_effect = self.get_immediate_state(actual_route, ex_step) - self.get_immediate_state(cf_route, ex_step)
long_effect = self.get_route_length(actual_route) - self.get_route_length(cf_route)
short_effect_str = get_str(short_effect)
long_effect_str = get_str(long_effect)
# missed nodes
missed_nodes = self.get_infeasible_nodes(actual_route) - self.get_infeasible_nodes(cf_route)
missed_nodes_str = get_str2(missed_nodes, self.get_infeasible_nodes(actual_route))
return f"- short-term effect: {short_effect_str}\n - long-term effect: {long_effect_str}\n- missed nodes: {missed_nodes_str}\n"
def get_node_info(self) -> str:
node_info = ""
for i in range(len(st.session_state.df_tour)):
node_info += f"- {st.session_state.df_tour['destination'][i]}: {st.session_state.df_tour['remarks'][i]}\n"
return node_info |