Canstralian commited on
Commit
5d7c58c
·
verified ·
1 Parent(s): 6a46dba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -19,11 +19,11 @@ model_mapping = {
19
 
20
  model_name = model_mapping.get(model_choice, "Canstralian/CyberAttackDetection")
21
 
22
- # Load model and tokenizer on demand
23
  @st.cache_resource
24
  def load_model(model_name):
 
25
  try:
26
- # Load the model and tokenizer
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
  if model_name == "Canstralian/text2shellcommands":
29
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
@@ -34,8 +34,9 @@ def load_model(model_name):
34
  st.error(f"Error loading model: {e}")
35
  return None, None
36
 
37
- # Load the model and tokenizer
38
- tokenizer, model = load_model(model_name)
 
39
 
40
  # Input text box in the main panel
41
  st.title(f"{model_choice} Model")
@@ -58,7 +59,9 @@ if user_input and model and tokenizer:
58
  outputs = model(**inputs)
59
  logits = outputs.logits
60
  predicted_class = torch.argmax(logits, dim=-1).item()
 
61
  st.write(f"Predicted Class: {predicted_class}")
 
62
  st.write(f"Logits: {logits}")
63
 
64
  else:
 
19
 
20
  model_name = model_mapping.get(model_choice, "Canstralian/CyberAttackDetection")
21
 
22
+ # Cache model and tokenizer to optimize load time
23
  @st.cache_resource
24
  def load_model(model_name):
25
+ """Load the model and tokenizer."""
26
  try:
 
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
  if model_name == "Canstralian/text2shellcommands":
29
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
34
  st.error(f"Error loading model: {e}")
35
  return None, None
36
 
37
+ # Display progress spinner while loading model
38
+ with st.spinner("Loading model..."):
39
+ tokenizer, model = load_model(model_name)
40
 
41
  # Input text box in the main panel
42
  st.title(f"{model_choice} Model")
 
59
  outputs = model(**inputs)
60
  logits = outputs.logits
61
  predicted_class = torch.argmax(logits, dim=-1).item()
62
+ confidence = torch.softmax(logits, dim=-1).max().item() # Calculate confidence score
63
  st.write(f"Predicted Class: {predicted_class}")
64
+ st.write(f"Confidence: {confidence:.2f}")
65
  st.write(f"Logits: {logits}")
66
 
67
  else: