hiyata commited on
Commit
555d484
·
verified ·
1 Parent(s): d55c2b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -11
app.py CHANGED
@@ -93,7 +93,11 @@ def create_importance_plot(shap_values, kmers, top_k=10):
93
  """
94
  Create horizontal bar plot of feature importance.
95
  """
96
- plt.style.use('seaborn-v0_8-whitegrid')
 
 
 
 
97
  fig = plt.figure(figsize=(10, 8))
98
 
99
  # Sort by absolute importance
@@ -115,8 +119,13 @@ def create_contribution_plot(important_kmers, final_prob):
115
  """
116
  Create waterfall plot showing cumulative feature contributions.
117
  """
118
- plt.style.use('seaborn-v0_8-whitegrid')
119
- fig = plt.figure(figsize=(12, 6))
 
 
 
 
 
120
 
121
  base_prob = 0.5
122
  cumulative = [base_prob]
@@ -126,15 +135,36 @@ def create_contribution_plot(important_kmers, final_prob):
126
  cumulative.append(cumulative[-1] + kmer_info['impact'])
127
  labels.append(kmer_info['kmer'])
128
 
129
- plt.plot(range(len(cumulative)), cumulative, 'b-o', linewidth=2)
130
- plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- plt.xticks(range(len(labels)), labels, rotation=45)
133
- plt.ylim(0, 1)
134
- plt.grid(True, alpha=0.3)
135
- plt.title('Cumulative Feature Contributions')
136
- plt.ylabel('Probability of Human Origin')
 
 
 
137
 
 
 
138
  return fig
139
 
140
  def predict(file_obj, top_kmers=10, fasta_text=""):
@@ -165,7 +195,8 @@ def predict(file_obj, top_kmers=10, fasta_text=""):
165
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
166
  try:
167
  model = VirusClassifier(256).to(device)
168
- model.load_state_dict(torch.load('model.pt', map_location=device))
 
169
  scaler = joblib.load('scaler.pkl')
170
  except Exception as e:
171
  return f"Error loading model: {str(e)}", None, None
 
93
  """
94
  Create horizontal bar plot of feature importance.
95
  """
96
+ # Set style directly instead of using seaborn
97
+ plt.rcParams['figure.facecolor'] = '#ffffff'
98
+ plt.rcParams['axes.facecolor'] = '#ffffff'
99
+ plt.rcParams['axes.grid'] = True
100
+ plt.rcParams['grid.alpha'] = 0.3
101
  fig = plt.figure(figsize=(10, 8))
102
 
103
  # Sort by absolute importance
 
119
  """
120
  Create waterfall plot showing cumulative feature contributions.
121
  """
122
+ # Set style parameters
123
+ plt.rcParams['figure.facecolor'] = '#ffffff'
124
+ plt.rcParams['axes.facecolor'] = '#ffffff'
125
+ plt.rcParams['axes.grid'] = True
126
+ plt.rcParams['grid.alpha'] = 0.3
127
+
128
+ fig, ax = plt.subplots(figsize=(12, 6))
129
 
130
  base_prob = 0.5
131
  cumulative = [base_prob]
 
135
  cumulative.append(cumulative[-1] + kmer_info['impact'])
136
  labels.append(kmer_info['kmer'])
137
 
138
+ # Plot cumulative line with markers
139
+ line = ax.plot(range(len(cumulative)), cumulative, '-o',
140
+ color='#3498db', linewidth=2,
141
+ marker='o', markersize=8,
142
+ markerfacecolor='white',
143
+ markeredgecolor='#3498db',
144
+ markeredgewidth=2)
145
+
146
+ # Add reference line at 0.5
147
+ ax.axhline(y=0.5, color='#95a5a6', linestyle='--', alpha=0.5)
148
+
149
+ # Customize plot
150
+ ax.set_xticks(range(len(labels)))
151
+ ax.set_xticklabels(labels, rotation=45, ha='right')
152
+ ax.set_ylim(0, 1)
153
+ ax.grid(True, axis='y', linestyle='--', alpha=0.3)
154
+ ax.set_title('Cumulative Feature Contributions')
155
+ ax.set_ylabel('Probability of Human Origin')
156
 
157
+ # Add value labels
158
+ for i, prob in enumerate(cumulative):
159
+ ax.annotate(f'{prob:.3f}',
160
+ (i, prob),
161
+ xytext=(0, 10),
162
+ textcoords='offset points',
163
+ ha='center',
164
+ va='bottom')
165
 
166
+ # Adjust layout to prevent label cutoff
167
+ plt.tight_layout()
168
  return fig
169
 
170
  def predict(file_obj, top_kmers=10, fasta_text=""):
 
195
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
196
  try:
197
  model = VirusClassifier(256).to(device)
198
+ # Load model weights safely
199
+ model.load_state_dict(torch.load('model.pt', map_location=device, weights_only=True))
200
  scaler = joblib.load('scaler.pkl')
201
  except Exception as e:
202
  return f"Error loading model: {str(e)}", None, None