rongguangw commited on
Commit
5ad4396
·
verified ·
1 Parent(s): df36903

add training script

Browse files
Files changed (1) hide show
  1. synthetic_data_generation.ipynb +193 -0
synthetic_data_generation.ipynb ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "eaa22304",
6
+ "metadata": {
7
+ "id": "eaa22304"
8
+ },
9
+ "source": [
10
+ "### Kernel Density Estimation \n",
11
+ "Given n data points, X$\\in R^{n\\times m}$, estimate the probability density function of the data i.e. Prob(x).\n",
12
+ "\n",
13
+ "In KDE, the pdf is given by $P(x) = \\frac{1}{nh}\\sum_{i=1}^{N}K(\\frac{X_i-x}{h})$,\n",
14
+ "where K is the kernel function, h is smoothing bandwidth (small h undersmoothing, large h oversmoothing)."
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "id": "e139aff4",
21
+ "metadata": {
22
+ "id": "e139aff4"
23
+ },
24
+ "outputs": [],
25
+ "source": [
26
+ "import sklearn\n",
27
+ "import fnmatch\n",
28
+ "import numpy as np\n",
29
+ "import pandas as pd\n",
30
+ "import seaborn as sns\n",
31
+ "import statsmodels.api as sm\n",
32
+ "import matplotlib.pyplot as plt\n",
33
+ "from sklearn.decomposition import PCA\n",
34
+ "from sklearn.neighbors import KernelDensity\n",
35
+ "from sklearn.model_selection import GridSearchCV"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "markdown",
40
+ "id": "14a48fb0",
41
+ "metadata": {
42
+ "id": "14a48fb0"
43
+ },
44
+ "source": [
45
+ "#### Load the real data and select samples for a specific race and sex"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "id": "c784a28b",
52
+ "metadata": {
53
+ "id": "c784a28b"
54
+ },
55
+ "outputs": [],
56
+ "source": [
57
+ "df = pd.read_csv('istaging_all.csv') # load istaging data"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "id": "0159223c",
64
+ "metadata": {
65
+ "id": "0159223c"
66
+ },
67
+ "outputs": [],
68
+ "source": [
69
+ "# select black females\n",
70
+ "df = df[((df.Race == 'Black') & (df.Sex == 'F'))].reset_index(drop=True)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "0758d447",
77
+ "metadata": {
78
+ "id": "0758d447"
79
+ },
80
+ "outputs": [],
81
+ "source": [
82
+ "# select baseline data for each subject\n",
83
+ "df.Date = pd.to_datetime(df.Date)\n",
84
+ "df_tp1 = df.loc[df.groupby('PTID')['Date'].idxmin()].reset_index(drop=True)"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "id": "9611e735",
91
+ "metadata": {
92
+ "scrolled": true,
93
+ "id": "9611e735"
94
+ },
95
+ "outputs": [],
96
+ "source": [
97
+ "# split the data to train and test set, train set will be used to learn the probablity distribtuion of the real data\n",
98
+ "df_train, df_test = sklearn.model_selection.train_test_split(df_tp1, test_size=0.3, random_state=40)"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "markdown",
103
+ "id": "0800ceda",
104
+ "metadata": {
105
+ "id": "0800ceda"
106
+ },
107
+ "source": [
108
+ "#### Fit a KDE model to estimate the joint probability density of Age and ROI volumes."
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "id": "b37250f8",
115
+ "metadata": {
116
+ "id": "b37250f8"
117
+ },
118
+ "outputs": [],
119
+ "source": [
120
+ "## standardized ROI grid search\n",
121
+ "# use grid search to select the bandwidth\n",
122
+ "cols = ['Age']\n",
123
+ "roi_cols = [] #fill in with the roi column names\n",
124
+ "cols.extend(fnmatch.filter(df_train.columns, roi_cols)) # select the ROI volumes\n",
125
+ "data = df_train.loc[:, cols].to_numpy()\n",
126
+ "data_standard = pd.DataFrame()\n",
127
+ "# standardize the data\n",
128
+ "data_standard['Age'] = (df_train['Age'] - df_train.loc[:, 'Age'].mean()) / df_train.loc[:, 'Age'].std()\n",
129
+ "data_standard[cols[1:]] = ((df_train.loc[:, cols[1:]] - df_train.loc[:, cols[1:]].mean()) / df_train.loc[:, cols[1:]].std())\n",
130
+ "data_standard = data_standard.to_numpy()\n",
131
+ "\n",
132
+ "# Use a Gaussian kernel\n",
133
+ "kde = GridSearchCV(KernelDensity(kernel='gaussian'),{'bandwidth': np.linspace(0, 3, 100)}, cv=5)\n",
134
+ "kde.fit(data_standard)\n",
135
+ "kde = kde.best_estimator_\n",
136
+ "print(f'optimal bandwidth of kernel estimated via grid search is {kde.bandwidth_} ')"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "id": "32c78445",
142
+ "metadata": {
143
+ "id": "32c78445"
144
+ },
145
+ "source": [
146
+ "#### Generate synthetic data using a KDE model for the specified category of race and sex"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "id": "e06523c2",
153
+ "metadata": {
154
+ "id": "e06523c2"
155
+ },
156
+ "outputs": [],
157
+ "source": [
158
+ "# sample 3000 data points\n",
159
+ "sample = kde.sample(3000, random_state=0)\n",
160
+ "sample[:, :] = np.multiply(sample[:, :], df_train.loc[:, cols[:]].std().tolist()) + df_train.loc[:, cols[:]].mean().tolist()\n",
161
+ "cov_list = np.array([[f'Synth_{i+1}', 'F', 'Black'] for i in range(3000)])\n",
162
+ "synthetic_data = np.concatenate([cov_list, sample], axis=1)\n",
163
+ "cols=['PTID', 'Sex', 'Race', 'Age']\n",
164
+ "cols.extend(roi_cols)\n",
165
+ "df_kde_synth = pd.DataFrame(synthetic_data, columns=cols)"
166
+ ]
167
+ }
168
+ ],
169
+ "metadata": {
170
+ "kernelspec": {
171
+ "display_name": "Python 3",
172
+ "language": "python",
173
+ "name": "python3"
174
+ },
175
+ "language_info": {
176
+ "codemirror_mode": {
177
+ "name": "ipython",
178
+ "version": 3
179
+ },
180
+ "file_extension": ".py",
181
+ "mimetype": "text/x-python",
182
+ "name": "python",
183
+ "nbconvert_exporter": "python",
184
+ "pygments_lexer": "ipython3",
185
+ "version": "3.8.8"
186
+ },
187
+ "colab": {
188
+ "provenance": []
189
+ }
190
+ },
191
+ "nbformat": 4,
192
+ "nbformat_minor": 5
193
+ }