Canstralian commited on
Commit
c0298f8
·
verified ·
1 Parent(s): 9bc591d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -42
app.py CHANGED
@@ -2,73 +2,81 @@ import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
- # Define the model names
6
- model_mapping = {
7
- "CyberAttackDetection": "Canstralian/CyberAttackDetection",
8
  "text2shellcommands": "Canstralian/text2shellcommands",
9
- "pentest_ai": "Canstralian/pentest_ai"
10
  }
11
 
12
- def load_model(model_name):
 
 
 
 
 
 
 
 
13
  try:
14
- # Fallback to a known model for debugging
15
  if model_name == "Canstralian/text2shellcommands":
16
- model_name = "t5-small" # Use a known model like T5 for testing
17
 
18
- # Load the model and tokenizer
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
  if "seq2seq" in model_name.lower():
21
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
22
  else:
23
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
24
-
25
  return tokenizer, model
26
  except Exception as e:
27
  st.error(f"Error loading model: {e}")
28
  return None, None
29
 
30
- def validate_input(user_input):
31
- if not user_input:
32
- st.error("Please enter some text for prediction.")
33
- return False
34
- return True
35
 
36
- def make_prediction(model, tokenizer, user_input):
37
- try:
 
 
 
 
 
 
 
 
 
38
  inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
39
  with torch.no_grad():
40
  outputs = model(**inputs)
41
- return outputs
42
- except Exception as e:
43
- st.error(f"Error making prediction: {e}")
44
- return None
 
 
45
 
46
- def main():
47
- st.sidebar.header("Model Configuration")
48
- model_choice = st.sidebar.selectbox("Select a model", [
49
- "CyberAttackDetection",
50
- "text2shellcommands",
51
- "pentest_ai"
52
- ])
53
 
54
- model_name = model_mapping.get(model_choice, "Canstralian/CyberAttackDetection")
 
 
55
 
56
- tokenizer, model = load_model(model_name)
 
 
 
57
 
58
- st.title(f"{model_choice} Model")
59
  user_input = st.text_area("Enter text:")
60
 
61
- if validate_input(user_input) and model is not None and tokenizer is not None:
62
- outputs = make_prediction(model, tokenizer, user_input)
63
- if outputs is not None:
64
- if model_choice == "text2shellcommands":
65
- generated_command = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
- st.write(f"Generated Shell Command: {generated_command}")
67
- else:
68
- logits = outputs.logits
69
- predicted_class = torch.argmax(logits, dim=-1).item()
70
- st.write(f"Predicted Class: {predicted_class}")
71
- st.write(f"Logits: {logits}")
72
 
73
  if __name__ == "__main__":
74
- main()
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
+ # Define the model names and mappings
6
+ MODEL_MAPPING = {
 
7
  "text2shellcommands": "Canstralian/text2shellcommands",
8
+ "pentest_ai": "Canstralian/pentest_ai",
9
  }
10
 
11
+ # Sidebar for model selection
12
+ def select_model():
13
+ st.sidebar.header("Model Configuration")
14
+ return st.sidebar.selectbox("Select a model", list(MODEL_MAPPING.keys()))
15
+
16
+
17
+ # Load model and tokenizer with caching
18
+ @st.cache_resource
19
+ def load_model_and_tokenizer(model_name):
20
  try:
21
+ # Use a fallback model for testing
22
  if model_name == "Canstralian/text2shellcommands":
23
+ model_name = "t5-small"
24
 
25
+ # Load the tokenizer and model
26
  tokenizer = AutoTokenizer.from_pretrained(model_name)
27
  if "seq2seq" in model_name.lower():
28
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
29
  else:
30
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
31
+
32
  return tokenizer, model
33
  except Exception as e:
34
  st.error(f"Error loading model: {e}")
35
  return None, None
36
 
 
 
 
 
 
37
 
38
+ # Handle predictions
39
+ def predict_with_model(user_input, model, tokenizer, model_choice):
40
+ if model_choice == "text2shellcommands":
41
+ # Generate shell commands
42
+ inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
43
+ with torch.no_grad():
44
+ outputs = model.generate(**inputs)
45
+ generated_command = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
+ return {"Generated Shell Command": generated_command}
47
+ else:
48
+ # Perform classification
49
  inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
50
  with torch.no_grad():
51
  outputs = model(**inputs)
52
+ logits = outputs.logits
53
+ predicted_class = torch.argmax(logits, dim=-1).item()
54
+ return {
55
+ "Predicted Class": predicted_class,
56
+ "Logits": logits.tolist(),
57
+ }
58
 
 
 
 
 
 
 
 
59
 
60
+ # Main Streamlit app
61
+ def main():
62
+ st.title("AI Model Inference Dashboard")
63
 
64
+ # Model selection
65
+ model_choice = select_model()
66
+ model_name = MODEL_MAPPING.get(model_choice)
67
+ tokenizer, model = load_model_and_tokenizer(model_name)
68
 
69
+ # Input text box
70
  user_input = st.text_area("Enter text:")
71
 
72
+ # Perform prediction if input and models are available
73
+ if user_input and model and tokenizer:
74
+ result = predict_with_model(user_input, model, tokenizer, model_choice)
75
+ for key, value in result.items():
76
+ st.write(f"{key}: {value}")
77
+ else:
78
+ st.info("Please enter some text for prediction.")
79
+
 
 
 
80
 
81
  if __name__ == "__main__":
82
+ main()