Canstralian commited on
Commit
b38a095
·
verified ·
1 Parent(s): 5cda0af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -19,24 +19,28 @@ model_mapping = {
19
 
20
  model_name = model_mapping.get(model_choice, "Canstralian/CyberAttackDetection")
21
 
22
- # Cache model and tokenizer to optimize load time
23
  @st.cache_resource
24
  def load_model(model_name):
25
- """Load the model and tokenizer."""
26
  try:
27
- tokenizer = AutoTokenizer.from_pretrained(model_name)
28
  if model_name == "Canstralian/text2shellcommands":
 
 
 
 
 
29
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
30
  else:
31
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
32
  return tokenizer, model
33
  except Exception as e:
34
  st.error(f"Error loading model: {e}")
35
  return None, None
36
 
37
- # Display progress spinner while loading model
38
- with st.spinner("Loading model..."):
39
- tokenizer, model = load_model(model_name)
40
 
41
  # Input text box in the main panel
42
  st.title(f"{model_choice} Model")
@@ -59,9 +63,7 @@ if user_input and model and tokenizer:
59
  outputs = model(**inputs)
60
  logits = outputs.logits
61
  predicted_class = torch.argmax(logits, dim=-1).item()
62
- confidence = torch.softmax(logits, dim=-1).max().item() # Calculate confidence score
63
  st.write(f"Predicted Class: {predicted_class}")
64
- st.write(f"Confidence: {confidence:.2f}")
65
  st.write(f"Logits: {logits}")
66
 
67
  else:
 
19
 
20
  model_name = model_mapping.get(model_choice, "Canstralian/CyberAttackDetection")
21
 
22
+ # Load model and tokenizer on demand
23
  @st.cache_resource
24
  def load_model(model_name):
 
25
  try:
26
+ # Fallback to a known model for debugging
27
  if model_name == "Canstralian/text2shellcommands":
28
+ model_name = "t5-small" # Use a known model like T5 for testing
29
+
30
+ # Load the model and tokenizer
31
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
32
+ if "seq2seq" in model_name.lower():
33
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
34
  else:
35
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
36
+
37
  return tokenizer, model
38
  except Exception as e:
39
  st.error(f"Error loading model: {e}")
40
  return None, None
41
 
42
+ # Load the model and tokenizer
43
+ tokenizer, model = load_model(model_name)
 
44
 
45
  # Input text box in the main panel
46
  st.title(f"{model_choice} Model")
 
63
  outputs = model(**inputs)
64
  logits = outputs.logits
65
  predicted_class = torch.argmax(logits, dim=-1).item()
 
66
  st.write(f"Predicted Class: {predicted_class}")
 
67
  st.write(f"Logits: {logits}")
68
 
69
  else: