NanLi2021 commited on
Commit
616f9a5
·
1 Parent(s): 87e2d22
Files changed (1) hide show
  1. app.py +0 -431
app.py DELETED
@@ -1,431 +0,0 @@
1
- import sys
2
-
3
- from pathlib import Path
4
- import string
5
- import random
6
- import torch
7
- import numpy as np
8
- import pickle
9
-
10
- import gradio as gr
11
- import pandas as pd
12
- from scipy.special import softmax
13
- import numpy as np
14
- import seaborn as sns
15
- import matplotlib.pyplot as plt
16
- import hydra
17
- from omegaconf import open_dict, DictConfig
18
- import matplotlib.pyplot as plt
19
- import matplotlib
20
- from matplotlib.patches import Patch
21
- sns.set()
22
- sns.set_style("darkgrid")
23
-
24
- from utils.data import *
25
- from utils.metrics import *
26
-
27
-
28
-
29
- def user_interface(Ufile, Pfile, Sfile=None, job_meta_file=None, user_meta_file=None, user_groups=None):
30
- recdata = Data(Ufile, Pfile, Sfile, job_meta_file, user_meta_file, user_groups)
31
-
32
-
33
- def calculate_user_item_metrics(res, S, U, k=10):
34
- # get rec
35
- m, n = res.shape
36
- if not torch.is_tensor(res):
37
- res = torch.from_numpy(res)
38
- if not torch.is_tensor(U):
39
- U = torch.from_numpy(U)
40
- _, rec = torch.topk(res, k, dim=1)
41
- rec_onehot = slow_onehot(rec, res)
42
- # rec_onehot = F.one_hot(rec, num_classes=n).sum(1).float()
43
- try:
44
- rec_per_job = rec_onehot.sum(axis=0).numpy()
45
- except:
46
- rec_per_job = rec_onehot.sum(axis=0).cpu().numpy()
47
- rec = rec.cpu()
48
- S = S.cpu()
49
- # envy
50
- envy = expected_envy_torch_vec(U, rec_onehot, k=1).numpy()
51
-
52
- # competitors for each rec job
53
- competitors = get_competitors(rec_per_job, rec)
54
-
55
- # rank
56
- better_competitors = get_num_better_competitors(rec, S)
57
-
58
- # scores per job for later zoom in scores
59
- scores = get_scores_per_job(rec, S)
60
-
61
- return {'rec': rec, 'envy': envy, 'competitors': competitors, 'ranks': better_competitors, 'scores_job': scores}
62
-
63
-
64
- def plot_user_envy(user=0, k=2):
65
- plt.close('all')
66
- user = int(user)
67
- if k in recdata.lookup_dict:
68
- ret_dict = recdata.lookup_dict[k]
69
- else:
70
- ret_dict = calculate_user_item_metrics(recdata.P_sub, recdata.S_sub, recdata.U_sub, k=k)
71
- recdata.lookup_dict[k] = ret_dict
72
- # user's recommended jobs
73
- users_rec = ret_dict['rec'][user].numpy()
74
- # Plot
75
- fig, ax1 = plt.subplots(figsize=(10, 5))
76
- # fig.tight_layout()
77
- fig.subplots_adjust(bottom=0.2)
78
-
79
- envy = ret_dict['envy'].sum(-1)
80
- envy_user = envy[user]
81
- # plot envy histogram
82
- n, bins, patches = ax1.hist(envy, bins=50, color='grey', alpha=0.5)
83
- ax1.set_yscale('symlog')
84
- sns.kdeplot(envy, color='grey', bw_adjust=0.3, cut=0, ax=ax1)
85
- # mark this user's envy
86
- # index of the bin that contains this user's envy
87
- idx = np.digitize(envy_user, bins)
88
- # print(envy_user, idx)
89
- patches[idx-1].set_fc('r')
90
- ax1.legend(handles=[Patch(facecolor='r', edgecolor='r', alpha=0.5,
91
- label='Your envy group')])
92
- ax1.set_xlabel('Envy')
93
- ax1.set_ylabel('Number of users (log scale)')
94
-
95
- return fig
96
-
97
- def plot_user_scores(user=0, k=2):
98
- user = int(user)
99
- if k in recdata.lookup_dict:
100
- ret_dict = recdata.lookup_dict[k]
101
- else:
102
- ret_dict = calculate_user_item_metrics(recdata.P_sub, recdata.S_sub, recdata.U_sub, k=k)
103
- recdata.lookup_dict[k] = ret_dict
104
- users_rec = ret_dict['rec'][user].numpy()
105
- scores = ret_dict['scores_job']
106
-
107
- # scores = [softmax(np.array(scores[jb])*0.5) for jb in users_rec]
108
- scores = [scores[jb] for jb in users_rec]
109
-
110
- rank_xs = [list(range(1, len(s)+1)) for s in scores]
111
- my_ranks = [1+int(i) for i in ret_dict['ranks'][user]]
112
- # my scores are the scores of the recommended jobs with rank
113
- # my_scores = [scores[i][j] for i, j in enumerate(my_ranks)]
114
- my_scores = [recdata.S_sub[user, job_id].item() for job_id in users_rec]
115
- # my_scores_log = np.log(np.array(my_scores).astype(float))
116
- ys = np.arange(len(users_rec))
117
- # user's recommended jobs
118
- if (user, k) in recdata.user_temp_data:
119
- df = recdata.user_temp_data[(user, k)]
120
- else:
121
- df = pd.DataFrame({'x': rank_xs, 's': scores, 'y': ys})
122
- df = df.explode(list('xs'))
123
- recdata.user_temp_data[(user, k)] = df
124
-
125
- # df['log_scores'] = np.log(df['s'].values.astype(float))
126
- fig, ax = plt.subplots(figsize=(10, 5))
127
- # fig.tight_layout()
128
- fig.subplots_adjust(bottom=0.3)
129
-
130
- def sub_cmap(cmap, vmin, vmax):
131
- return lambda v: cmap(vmin + (vmax - vmin) * v)
132
-
133
- # palette=matplotlib.cm.get_cmap('Greens').reversed()
134
- # palette = sub_cmap(palette,0.2, 0.8)
135
-
136
- sns.scatterplot(data=df, x="y", y="s", ax=ax, alpha=0.6,
137
- legend=False, s=100, hue='y', palette="summer") #monotone color palette
138
- sns.scatterplot(y=my_scores, x=range(k), ax=ax,
139
- alpha=0.8, s=200, ec='r', fc='none', label='Your rank')
140
- # add ranking of this user's score for each job
141
- # find score gaps
142
- gaps = np.diff(np.sort(scores[0])).mean()
143
- for i, (y, x) in enumerate(zip(my_scores, range(k))):
144
- ax.text(x-0.3, y+gaps, my_ranks[i], color='r', fontsize=15)
145
- # add notation for 'rank'
146
- # ax.text(-0.8, 1.12, 'Your rank', color='r', fontsize=12)
147
- ax.set_xticks(range(k))
148
- # shorten the job title
149
- titles = [recdata.job_metadata[jb] for jb in users_rec]
150
- titles = [t[:20] + '...' if len(t) > 20 else t for t in titles]
151
- ax.set_xticklabels(titles, rotation=30, ha='right')
152
- ax.set_xlabel('')
153
- ax.set_xlim(-1, k)
154
- # ax.grid(False)
155
- ax.set_ylabel('Score')
156
- # ax.set_ylim(-0.09, 1.2)
157
- ax.legend()
158
- return fig
159
-
160
-
161
- # demo = gr.Blocks(gr.themes.Base.from_hub('finlaymacklon/smooth_slate'))
162
- demo = gr.Blocks(gr.themes.Soft())
163
- with demo:
164
- def submit0(user, k):
165
- fig = plot_user_envy(user, k)
166
- return {
167
- hist_plot: gr.update(value=fig, visible=True),
168
- }
169
-
170
-
171
- def submit2(user, k):
172
- bar = plot_user_scores(user, k)
173
- return {
174
- bar_plot2: gr.update(value=bar, visible=True)
175
- }
176
-
177
- def submit(user):
178
- new_job_num = random.randint(1,6)
179
- # if new_job_num == 0, do nothing but clear the plots
180
- if new_job_num > 0:
181
- print(f'adding {new_job_num} new jobs')
182
- recdata.update(new_user_num=0, new_job_num=new_job_num)
183
- recdata.tweak_P(user)
184
-
185
- return {
186
- hist_plot: gr.update(visible=False),
187
- bar_plot2: gr.update(visible=False)
188
- }
189
-
190
- # def submit_login(user):
191
- # return {
192
- # k: gr.update(visible=True),
193
- # btn: gr.update(visible=True),
194
- # btn0: gr.update(visible=True),
195
- # btn2: gr.update(visible=True),
196
- # pswd: gr.update(visible=False),
197
- # lgbtn: gr.update(visible=False),
198
- # }
199
-
200
-
201
- # layout
202
- gr.Markdown("## Job Recommendation Inferiority and Envy Monitor Demo")
203
-
204
- with gr.Row():
205
- with gr.Column(scale=1):
206
- user = gr.Textbox(label='User ID',default='0', placeholder='Enter a random integer user ID')
207
- # with gr.Column(scale=1):
208
- # pswd = gr.Textbox(label='Password',default='********')
209
- # with gr.Column(scale=1):
210
- # lgbtn = gr.Button("Login")
211
- # with gr.Row():
212
- with gr.Column(scale=1):
213
- k = gr.Slider(minimum=1, maximum=20,
214
- default=4, step=1, label='Number of Jobs', visible=True)
215
- with gr.Column(scale=1):
216
- btn = gr.Button("Refresh to see new jobs", visible=True)
217
-
218
- with gr.Tab('Envy'):
219
- btn0 = gr.Button("User envy distribution", visible=True)
220
- hist_plot = gr.Plot(visible=False)
221
-
222
- with gr.Tab('Inferiority'):
223
- with gr.Row():
224
- # btn1 = gr.Button("User ranks for the recommended jobs")
225
- btn2 = gr.Button("User scores/ranks for the recommended jobs", visible=True)
226
-
227
- # bar_plot = gr.Plot()
228
- bar_plot2 = gr.Plot(visible=False)
229
-
230
- # lgbtn.click(submit_login, inputs=[user], outputs=[k, btn, btn0, btn2, pswd, lgbtn])
231
- btn.click(submit, inputs=[user], outputs=[hist_plot, bar_plot2])
232
- btn0.click(submit0, inputs=[user, k], outputs=[hist_plot])
233
- # btn1.click(submit1, inputs=[user, k], outputs=[bar_plot])
234
- btn2.click(submit2, inputs=[user, k], outputs=[bar_plot2])
235
-
236
- return demo
237
-
238
-
239
- def developer_interface(Ufile, Pfile, Sfile=None, job_meta_file=None, user_meta_file=None, user_groups=None):
240
-
241
- recdata = Data(Ufile, Pfile, Sfile, job_meta_file, user_meta_file, user_groups, sub_sample_size=500)
242
-
243
- def calculate_all_metrics(k, S_sub, U_sub, P_sub):
244
- print('calculating all metrics')
245
- if k in recdata.lookup_dict:
246
- print('Found in lookup dict')
247
- return recdata.lookup_dict[k]
248
- else:
249
- if not torch.is_tensor(P_sub):
250
- P_sub = torch.from_numpy(P_sub)
251
- envy, inferiority, utility = eiu_cut_off2(
252
- (S_sub, U_sub), P_sub, k=k, agg=False)
253
- envy = envy.sum(-1)
254
- inferiority = inferiority.sum(-1)
255
-
256
- _, rec = torch.topk(P_sub, k=k, dim=1)
257
- rec_onehot = slow_onehot(rec, P_sub)
258
- try:
259
- rec_per_job = rec_onehot.sum(axis=0).numpy()
260
- except:
261
- rec_per_job = rec_onehot.sum(axis=0).cpu().numpy()
262
- rec = rec.cpu()
263
- metrics_at_k = {'rec': rec, 'envy': envy, 'inferiority': inferiority, 'utility': utility,
264
- 'rec_per_job': rec_per_job}
265
- print('Finished calculating all metrics')
266
- return metrics_at_k
267
-
268
- def plot_user_box(metrics_dict):
269
- print('plotting user box')
270
- plt.close('all')
271
- envy = metrics_dict['envy'].numpy()
272
- inferiority = metrics_dict['inferiority'].numpy()
273
- fig, (ax1, ax2) = plt.subplots(ncols=2)
274
- fig.tight_layout()
275
- ax1.boxplot(envy)
276
- ax1.set_ylabel('envy')
277
- ax1.set_title('Envy')
278
- ax1.set_xticks([])
279
- ax2.boxplot(inferiority)
280
- ax2.set_ylabel('inferiority')
281
- ax2.set_title('Inferiority')
282
- ax2.set_xticks([])
283
- return fig
284
-
285
- def plot_scatter(k, group=None):
286
- print('plotting scatter')
287
- plt.close('all')
288
- if group == 'None':
289
- group = None
290
- if k in recdata.lookup_dict:
291
- metrics_dict = recdata.lookup_dict[k]
292
- else:
293
- metrics_dict = calculate_all_metrics(k, recdata.S_sub, recdata.U_sub, recdata.P_sub)
294
- recdata.lookup_dict[k] = metrics_dict
295
-
296
- data = {'log(envy+1)': np.log(metrics_dict['envy']+1),
297
- 'inferiority': metrics_dict['inferiority']}
298
- data = pd.DataFrame(data)
299
- data = data.join(recdata.user_metadata)
300
- fig, ax = plt.subplots()
301
- sns.scatterplot(data=data, x='log(envy+1)', y='inferiority', hue=group, ax=ax)
302
- return fig
303
-
304
- def lorenz_curve(X, ax, label):
305
- # ref: https://zhiyzuo.github.io/Plot-Lorenz/
306
- X.sort()
307
- X_lorenz = X.cumsum() / X.sum()
308
- X_lorenz = np.insert(X_lorenz, 0, 0)
309
- X_lorenz[0], X_lorenz[-1]
310
-
311
- ax.plot(np.arange(X_lorenz.size) / (X_lorenz.size - 1), X_lorenz, label=label)
312
- ## line plot of equality
313
- ax.plot([0, 1], [0, 1], linestyle='dashed', color='k')
314
- return ax
315
-
316
- def plot_item(rec_per_job):
317
- print('plotting item')
318
- plt.close('all')
319
- fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(10, 10))
320
- fig.tight_layout(pad=5.0)
321
- labels, counts = np.unique(rec_per_job, return_counts=True)
322
- ax1.bar(labels, counts, align='center')
323
-
324
- ax1.set_xlabel('Number of times a job is recommended')
325
- ax1.set_ylabel('Number of jobs')
326
- ax1.set_title('Distribution of job exposure')
327
- ax2 = lorenz_curve(rec_per_job, ax2,'')
328
- ax2.set_title('Lorenz Curve')
329
- return fig
330
-
331
-
332
- # build the interface
333
- demo = gr.Blocks(gr.themes.Soft())
334
- with demo:
335
- # callbacks
336
- def submit_u():
337
- # generate two random integers including 0 representing user num and job num
338
- user_num = np.random.randint(0, 5)
339
- job_num = np.random.randint(0, 5)
340
-
341
- if user_num > 0 or job_num > 0:
342
- recdata.update(user_num, job_num)
343
-
344
- return{
345
- info: gr.update(value='New {} users and {} jobs'.format(user_num, job_num),visible=True),
346
- }
347
-
348
-
349
- def submit1(k):
350
- metrics_dict = calculate_all_metrics(k, recdata.S_sub, recdata.U_sub, recdata.P_sub)
351
- return {
352
- user_box_plot: plot_user_box(metrics_dict),
353
- scatter_plot: plot_scatter(k),
354
- btn2: gr.update(visible=True)
355
- }
356
-
357
- def submit2():
358
- return {
359
- radio: gr.update(visible=True)
360
- }
361
-
362
- def submit3(k):
363
- metrics_dict = calculate_all_metrics(k, recdata.S_sub, recdata.U_sub, recdata.P_sub)
364
- return {
365
- item_plots: plot_item(metrics_dict['rec_per_job'])
366
- }
367
-
368
- # layout
369
- gr.Markdown("## Envy & Inferiority Monitor for Developers Demo")
370
- # 1. accept k
371
- with gr.Row():
372
- with gr.Column(scale=1):
373
- k = gr.inputs.Slider(minimum=1, maximum=min(30,len(
374
- recdata.P[0])), default=1, step=1, label='Number of Jobs')
375
- with gr.Column(scale=1):
376
- btn = gr.Button('Refresh')
377
- with gr.Column(scale=1):
378
- info = gr.Textbox('', label='Updated info', visible=False)
379
- btn.click(submit_u, inputs=[], outputs=[info])
380
-
381
-
382
- with gr.Tab('User'):
383
- plt.close('all')
384
- btn1 = gr.Button('Visualize user-side fairness')
385
- user_box_plot = gr.Plot()
386
- scatter_plot = gr.Plot()
387
-
388
- btn2 = gr.Button('Visualize intra-group fairness', visible=False)
389
-
390
- radio = gr.Radio(choices=user_groups, value=user_groups[0] if len(user_groups) > 0 else "",
391
- interactive=True, label="User group", visible=False)
392
-
393
- btn1.click(submit1, inputs=[k], outputs=[
394
- user_box_plot, scatter_plot, btn2])
395
- btn2.click(submit2, inputs=[], outputs=[radio])
396
- radio.change(fn=plot_scatter, inputs=[
397
- k, radio], outputs=[scatter_plot])
398
-
399
- with gr.Tab('Item'):
400
- plt.close('all')
401
- btn3 = gr.Button('Visualize item-side fairness')
402
- item_plots = gr.Plot()
403
- btn3.click(submit3, inputs=[k], outputs=[item_plots])
404
-
405
- return demo
406
-
407
-
408
- @hydra.main(version_base=None, config_path='./utils', config_name='monitor')
409
- def main(config: DictConfig):
410
- print(config)
411
- Ufile = config.Ufile
412
- Sfile = config.Sfile
413
- Pfile = config.Pfile
414
- user_meta_file = config.user_meta_file
415
- job_meta_file = config.job_meta_file
416
- user_groups = ['None'] + \
417
- list(config.user_groups) if config.user_groups else ['None']
418
- server_name = config.server_name
419
- role = config.role
420
- if role == 'user':
421
- demo = user_interface(Ufile, Pfile, Sfile,
422
- job_meta_file, user_meta_file, user_groups)
423
- elif role == 'developer':
424
- demo = developer_interface(
425
- Ufile, Pfile, Sfile, job_meta_file, user_meta_file, user_groups)
426
- demo.launch(server_name=server_name, server_port=config.server_port)
427
- # demo.launch()
428
-
429
-
430
- if __name__ == "__main__":
431
- main()