NanLi2021 commited on
Commit
87e2d22
·
1 Parent(s): 592cf19
Files changed (1) hide show
  1. app.py +431 -0
app.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ from pathlib import Path
4
+ import string
5
+ import random
6
+ import torch
7
+ import numpy as np
8
+ import pickle
9
+
10
+ import gradio as gr
11
+ import pandas as pd
12
+ from scipy.special import softmax
13
+ import numpy as np
14
+ import seaborn as sns
15
+ import matplotlib.pyplot as plt
16
+ import hydra
17
+ from omegaconf import open_dict, DictConfig
18
+ import matplotlib.pyplot as plt
19
+ import matplotlib
20
+ from matplotlib.patches import Patch
21
+ sns.set()
22
+ sns.set_style("darkgrid")
23
+
24
+ from utils.data import *
25
+ from utils.metrics import *
26
+
27
+
28
+
29
+ def user_interface(Ufile, Pfile, Sfile=None, job_meta_file=None, user_meta_file=None, user_groups=None):
30
+ recdata = Data(Ufile, Pfile, Sfile, job_meta_file, user_meta_file, user_groups)
31
+
32
+
33
+ def calculate_user_item_metrics(res, S, U, k=10):
34
+ # get rec
35
+ m, n = res.shape
36
+ if not torch.is_tensor(res):
37
+ res = torch.from_numpy(res)
38
+ if not torch.is_tensor(U):
39
+ U = torch.from_numpy(U)
40
+ _, rec = torch.topk(res, k, dim=1)
41
+ rec_onehot = slow_onehot(rec, res)
42
+ # rec_onehot = F.one_hot(rec, num_classes=n).sum(1).float()
43
+ try:
44
+ rec_per_job = rec_onehot.sum(axis=0).numpy()
45
+ except:
46
+ rec_per_job = rec_onehot.sum(axis=0).cpu().numpy()
47
+ rec = rec.cpu()
48
+ S = S.cpu()
49
+ # envy
50
+ envy = expected_envy_torch_vec(U, rec_onehot, k=1).numpy()
51
+
52
+ # competitors for each rec job
53
+ competitors = get_competitors(rec_per_job, rec)
54
+
55
+ # rank
56
+ better_competitors = get_num_better_competitors(rec, S)
57
+
58
+ # scores per job for later zoom in scores
59
+ scores = get_scores_per_job(rec, S)
60
+
61
+ return {'rec': rec, 'envy': envy, 'competitors': competitors, 'ranks': better_competitors, 'scores_job': scores}
62
+
63
+
64
+ def plot_user_envy(user=0, k=2):
65
+ plt.close('all')
66
+ user = int(user)
67
+ if k in recdata.lookup_dict:
68
+ ret_dict = recdata.lookup_dict[k]
69
+ else:
70
+ ret_dict = calculate_user_item_metrics(recdata.P_sub, recdata.S_sub, recdata.U_sub, k=k)
71
+ recdata.lookup_dict[k] = ret_dict
72
+ # user's recommended jobs
73
+ users_rec = ret_dict['rec'][user].numpy()
74
+ # Plot
75
+ fig, ax1 = plt.subplots(figsize=(10, 5))
76
+ # fig.tight_layout()
77
+ fig.subplots_adjust(bottom=0.2)
78
+
79
+ envy = ret_dict['envy'].sum(-1)
80
+ envy_user = envy[user]
81
+ # plot envy histogram
82
+ n, bins, patches = ax1.hist(envy, bins=50, color='grey', alpha=0.5)
83
+ ax1.set_yscale('symlog')
84
+ sns.kdeplot(envy, color='grey', bw_adjust=0.3, cut=0, ax=ax1)
85
+ # mark this user's envy
86
+ # index of the bin that contains this user's envy
87
+ idx = np.digitize(envy_user, bins)
88
+ # print(envy_user, idx)
89
+ patches[idx-1].set_fc('r')
90
+ ax1.legend(handles=[Patch(facecolor='r', edgecolor='r', alpha=0.5,
91
+ label='Your envy group')])
92
+ ax1.set_xlabel('Envy')
93
+ ax1.set_ylabel('Number of users (log scale)')
94
+
95
+ return fig
96
+
97
+ def plot_user_scores(user=0, k=2):
98
+ user = int(user)
99
+ if k in recdata.lookup_dict:
100
+ ret_dict = recdata.lookup_dict[k]
101
+ else:
102
+ ret_dict = calculate_user_item_metrics(recdata.P_sub, recdata.S_sub, recdata.U_sub, k=k)
103
+ recdata.lookup_dict[k] = ret_dict
104
+ users_rec = ret_dict['rec'][user].numpy()
105
+ scores = ret_dict['scores_job']
106
+
107
+ # scores = [softmax(np.array(scores[jb])*0.5) for jb in users_rec]
108
+ scores = [scores[jb] for jb in users_rec]
109
+
110
+ rank_xs = [list(range(1, len(s)+1)) for s in scores]
111
+ my_ranks = [1+int(i) for i in ret_dict['ranks'][user]]
112
+ # my scores are the scores of the recommended jobs with rank
113
+ # my_scores = [scores[i][j] for i, j in enumerate(my_ranks)]
114
+ my_scores = [recdata.S_sub[user, job_id].item() for job_id in users_rec]
115
+ # my_scores_log = np.log(np.array(my_scores).astype(float))
116
+ ys = np.arange(len(users_rec))
117
+ # user's recommended jobs
118
+ if (user, k) in recdata.user_temp_data:
119
+ df = recdata.user_temp_data[(user, k)]
120
+ else:
121
+ df = pd.DataFrame({'x': rank_xs, 's': scores, 'y': ys})
122
+ df = df.explode(list('xs'))
123
+ recdata.user_temp_data[(user, k)] = df
124
+
125
+ # df['log_scores'] = np.log(df['s'].values.astype(float))
126
+ fig, ax = plt.subplots(figsize=(10, 5))
127
+ # fig.tight_layout()
128
+ fig.subplots_adjust(bottom=0.3)
129
+
130
+ def sub_cmap(cmap, vmin, vmax):
131
+ return lambda v: cmap(vmin + (vmax - vmin) * v)
132
+
133
+ # palette=matplotlib.cm.get_cmap('Greens').reversed()
134
+ # palette = sub_cmap(palette,0.2, 0.8)
135
+
136
+ sns.scatterplot(data=df, x="y", y="s", ax=ax, alpha=0.6,
137
+ legend=False, s=100, hue='y', palette="summer") #monotone color palette
138
+ sns.scatterplot(y=my_scores, x=range(k), ax=ax,
139
+ alpha=0.8, s=200, ec='r', fc='none', label='Your rank')
140
+ # add ranking of this user's score for each job
141
+ # find score gaps
142
+ gaps = np.diff(np.sort(scores[0])).mean()
143
+ for i, (y, x) in enumerate(zip(my_scores, range(k))):
144
+ ax.text(x-0.3, y+gaps, my_ranks[i], color='r', fontsize=15)
145
+ # add notation for 'rank'
146
+ # ax.text(-0.8, 1.12, 'Your rank', color='r', fontsize=12)
147
+ ax.set_xticks(range(k))
148
+ # shorten the job title
149
+ titles = [recdata.job_metadata[jb] for jb in users_rec]
150
+ titles = [t[:20] + '...' if len(t) > 20 else t for t in titles]
151
+ ax.set_xticklabels(titles, rotation=30, ha='right')
152
+ ax.set_xlabel('')
153
+ ax.set_xlim(-1, k)
154
+ # ax.grid(False)
155
+ ax.set_ylabel('Score')
156
+ # ax.set_ylim(-0.09, 1.2)
157
+ ax.legend()
158
+ return fig
159
+
160
+
161
+ # demo = gr.Blocks(gr.themes.Base.from_hub('finlaymacklon/smooth_slate'))
162
+ demo = gr.Blocks(gr.themes.Soft())
163
+ with demo:
164
+ def submit0(user, k):
165
+ fig = plot_user_envy(user, k)
166
+ return {
167
+ hist_plot: gr.update(value=fig, visible=True),
168
+ }
169
+
170
+
171
+ def submit2(user, k):
172
+ bar = plot_user_scores(user, k)
173
+ return {
174
+ bar_plot2: gr.update(value=bar, visible=True)
175
+ }
176
+
177
+ def submit(user):
178
+ new_job_num = random.randint(1,6)
179
+ # if new_job_num == 0, do nothing but clear the plots
180
+ if new_job_num > 0:
181
+ print(f'adding {new_job_num} new jobs')
182
+ recdata.update(new_user_num=0, new_job_num=new_job_num)
183
+ recdata.tweak_P(user)
184
+
185
+ return {
186
+ hist_plot: gr.update(visible=False),
187
+ bar_plot2: gr.update(visible=False)
188
+ }
189
+
190
+ # def submit_login(user):
191
+ # return {
192
+ # k: gr.update(visible=True),
193
+ # btn: gr.update(visible=True),
194
+ # btn0: gr.update(visible=True),
195
+ # btn2: gr.update(visible=True),
196
+ # pswd: gr.update(visible=False),
197
+ # lgbtn: gr.update(visible=False),
198
+ # }
199
+
200
+
201
+ # layout
202
+ gr.Markdown("## Job Recommendation Inferiority and Envy Monitor Demo")
203
+
204
+ with gr.Row():
205
+ with gr.Column(scale=1):
206
+ user = gr.Textbox(label='User ID',default='0', placeholder='Enter a random integer user ID')
207
+ # with gr.Column(scale=1):
208
+ # pswd = gr.Textbox(label='Password',default='********')
209
+ # with gr.Column(scale=1):
210
+ # lgbtn = gr.Button("Login")
211
+ # with gr.Row():
212
+ with gr.Column(scale=1):
213
+ k = gr.Slider(minimum=1, maximum=20,
214
+ default=4, step=1, label='Number of Jobs', visible=True)
215
+ with gr.Column(scale=1):
216
+ btn = gr.Button("Refresh to see new jobs", visible=True)
217
+
218
+ with gr.Tab('Envy'):
219
+ btn0 = gr.Button("User envy distribution", visible=True)
220
+ hist_plot = gr.Plot(visible=False)
221
+
222
+ with gr.Tab('Inferiority'):
223
+ with gr.Row():
224
+ # btn1 = gr.Button("User ranks for the recommended jobs")
225
+ btn2 = gr.Button("User scores/ranks for the recommended jobs", visible=True)
226
+
227
+ # bar_plot = gr.Plot()
228
+ bar_plot2 = gr.Plot(visible=False)
229
+
230
+ # lgbtn.click(submit_login, inputs=[user], outputs=[k, btn, btn0, btn2, pswd, lgbtn])
231
+ btn.click(submit, inputs=[user], outputs=[hist_plot, bar_plot2])
232
+ btn0.click(submit0, inputs=[user, k], outputs=[hist_plot])
233
+ # btn1.click(submit1, inputs=[user, k], outputs=[bar_plot])
234
+ btn2.click(submit2, inputs=[user, k], outputs=[bar_plot2])
235
+
236
+ return demo
237
+
238
+
239
+ def developer_interface(Ufile, Pfile, Sfile=None, job_meta_file=None, user_meta_file=None, user_groups=None):
240
+
241
+ recdata = Data(Ufile, Pfile, Sfile, job_meta_file, user_meta_file, user_groups, sub_sample_size=500)
242
+
243
+ def calculate_all_metrics(k, S_sub, U_sub, P_sub):
244
+ print('calculating all metrics')
245
+ if k in recdata.lookup_dict:
246
+ print('Found in lookup dict')
247
+ return recdata.lookup_dict[k]
248
+ else:
249
+ if not torch.is_tensor(P_sub):
250
+ P_sub = torch.from_numpy(P_sub)
251
+ envy, inferiority, utility = eiu_cut_off2(
252
+ (S_sub, U_sub), P_sub, k=k, agg=False)
253
+ envy = envy.sum(-1)
254
+ inferiority = inferiority.sum(-1)
255
+
256
+ _, rec = torch.topk(P_sub, k=k, dim=1)
257
+ rec_onehot = slow_onehot(rec, P_sub)
258
+ try:
259
+ rec_per_job = rec_onehot.sum(axis=0).numpy()
260
+ except:
261
+ rec_per_job = rec_onehot.sum(axis=0).cpu().numpy()
262
+ rec = rec.cpu()
263
+ metrics_at_k = {'rec': rec, 'envy': envy, 'inferiority': inferiority, 'utility': utility,
264
+ 'rec_per_job': rec_per_job}
265
+ print('Finished calculating all metrics')
266
+ return metrics_at_k
267
+
268
+ def plot_user_box(metrics_dict):
269
+ print('plotting user box')
270
+ plt.close('all')
271
+ envy = metrics_dict['envy'].numpy()
272
+ inferiority = metrics_dict['inferiority'].numpy()
273
+ fig, (ax1, ax2) = plt.subplots(ncols=2)
274
+ fig.tight_layout()
275
+ ax1.boxplot(envy)
276
+ ax1.set_ylabel('envy')
277
+ ax1.set_title('Envy')
278
+ ax1.set_xticks([])
279
+ ax2.boxplot(inferiority)
280
+ ax2.set_ylabel('inferiority')
281
+ ax2.set_title('Inferiority')
282
+ ax2.set_xticks([])
283
+ return fig
284
+
285
+ def plot_scatter(k, group=None):
286
+ print('plotting scatter')
287
+ plt.close('all')
288
+ if group == 'None':
289
+ group = None
290
+ if k in recdata.lookup_dict:
291
+ metrics_dict = recdata.lookup_dict[k]
292
+ else:
293
+ metrics_dict = calculate_all_metrics(k, recdata.S_sub, recdata.U_sub, recdata.P_sub)
294
+ recdata.lookup_dict[k] = metrics_dict
295
+
296
+ data = {'log(envy+1)': np.log(metrics_dict['envy']+1),
297
+ 'inferiority': metrics_dict['inferiority']}
298
+ data = pd.DataFrame(data)
299
+ data = data.join(recdata.user_metadata)
300
+ fig, ax = plt.subplots()
301
+ sns.scatterplot(data=data, x='log(envy+1)', y='inferiority', hue=group, ax=ax)
302
+ return fig
303
+
304
+ def lorenz_curve(X, ax, label):
305
+ # ref: https://zhiyzuo.github.io/Plot-Lorenz/
306
+ X.sort()
307
+ X_lorenz = X.cumsum() / X.sum()
308
+ X_lorenz = np.insert(X_lorenz, 0, 0)
309
+ X_lorenz[0], X_lorenz[-1]
310
+
311
+ ax.plot(np.arange(X_lorenz.size) / (X_lorenz.size - 1), X_lorenz, label=label)
312
+ ## line plot of equality
313
+ ax.plot([0, 1], [0, 1], linestyle='dashed', color='k')
314
+ return ax
315
+
316
+ def plot_item(rec_per_job):
317
+ print('plotting item')
318
+ plt.close('all')
319
+ fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(10, 10))
320
+ fig.tight_layout(pad=5.0)
321
+ labels, counts = np.unique(rec_per_job, return_counts=True)
322
+ ax1.bar(labels, counts, align='center')
323
+
324
+ ax1.set_xlabel('Number of times a job is recommended')
325
+ ax1.set_ylabel('Number of jobs')
326
+ ax1.set_title('Distribution of job exposure')
327
+ ax2 = lorenz_curve(rec_per_job, ax2,'')
328
+ ax2.set_title('Lorenz Curve')
329
+ return fig
330
+
331
+
332
+ # build the interface
333
+ demo = gr.Blocks(gr.themes.Soft())
334
+ with demo:
335
+ # callbacks
336
+ def submit_u():
337
+ # generate two random integers including 0 representing user num and job num
338
+ user_num = np.random.randint(0, 5)
339
+ job_num = np.random.randint(0, 5)
340
+
341
+ if user_num > 0 or job_num > 0:
342
+ recdata.update(user_num, job_num)
343
+
344
+ return{
345
+ info: gr.update(value='New {} users and {} jobs'.format(user_num, job_num),visible=True),
346
+ }
347
+
348
+
349
+ def submit1(k):
350
+ metrics_dict = calculate_all_metrics(k, recdata.S_sub, recdata.U_sub, recdata.P_sub)
351
+ return {
352
+ user_box_plot: plot_user_box(metrics_dict),
353
+ scatter_plot: plot_scatter(k),
354
+ btn2: gr.update(visible=True)
355
+ }
356
+
357
+ def submit2():
358
+ return {
359
+ radio: gr.update(visible=True)
360
+ }
361
+
362
+ def submit3(k):
363
+ metrics_dict = calculate_all_metrics(k, recdata.S_sub, recdata.U_sub, recdata.P_sub)
364
+ return {
365
+ item_plots: plot_item(metrics_dict['rec_per_job'])
366
+ }
367
+
368
+ # layout
369
+ gr.Markdown("## Envy & Inferiority Monitor for Developers Demo")
370
+ # 1. accept k
371
+ with gr.Row():
372
+ with gr.Column(scale=1):
373
+ k = gr.inputs.Slider(minimum=1, maximum=min(30,len(
374
+ recdata.P[0])), default=1, step=1, label='Number of Jobs')
375
+ with gr.Column(scale=1):
376
+ btn = gr.Button('Refresh')
377
+ with gr.Column(scale=1):
378
+ info = gr.Textbox('', label='Updated info', visible=False)
379
+ btn.click(submit_u, inputs=[], outputs=[info])
380
+
381
+
382
+ with gr.Tab('User'):
383
+ plt.close('all')
384
+ btn1 = gr.Button('Visualize user-side fairness')
385
+ user_box_plot = gr.Plot()
386
+ scatter_plot = gr.Plot()
387
+
388
+ btn2 = gr.Button('Visualize intra-group fairness', visible=False)
389
+
390
+ radio = gr.Radio(choices=user_groups, value=user_groups[0] if len(user_groups) > 0 else "",
391
+ interactive=True, label="User group", visible=False)
392
+
393
+ btn1.click(submit1, inputs=[k], outputs=[
394
+ user_box_plot, scatter_plot, btn2])
395
+ btn2.click(submit2, inputs=[], outputs=[radio])
396
+ radio.change(fn=plot_scatter, inputs=[
397
+ k, radio], outputs=[scatter_plot])
398
+
399
+ with gr.Tab('Item'):
400
+ plt.close('all')
401
+ btn3 = gr.Button('Visualize item-side fairness')
402
+ item_plots = gr.Plot()
403
+ btn3.click(submit3, inputs=[k], outputs=[item_plots])
404
+
405
+ return demo
406
+
407
+
408
+ @hydra.main(version_base=None, config_path='./utils', config_name='monitor')
409
+ def main(config: DictConfig):
410
+ print(config)
411
+ Ufile = config.Ufile
412
+ Sfile = config.Sfile
413
+ Pfile = config.Pfile
414
+ user_meta_file = config.user_meta_file
415
+ job_meta_file = config.job_meta_file
416
+ user_groups = ['None'] + \
417
+ list(config.user_groups) if config.user_groups else ['None']
418
+ server_name = config.server_name
419
+ role = config.role
420
+ if role == 'user':
421
+ demo = user_interface(Ufile, Pfile, Sfile,
422
+ job_meta_file, user_meta_file, user_groups)
423
+ elif role == 'developer':
424
+ demo = developer_interface(
425
+ Ufile, Pfile, Sfile, job_meta_file, user_meta_file, user_groups)
426
+ demo.launch(server_name=server_name, server_port=config.server_port)
427
+ # demo.launch()
428
+
429
+
430
+ if __name__ == "__main__":
431
+ main()