Canstralian commited on
Commit
b8e3be9
·
verified ·
1 Parent(s): c0298f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -1,11 +1,15 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
 
 
 
 
3
  import torch
4
 
5
  # Define the model names and mappings
6
  MODEL_MAPPING = {
7
- "text2shellcommands": "Canstralian/text2shellcommands",
8
- "pentest_ai": "Canstralian/pentest_ai",
9
  }
10
 
11
  # Sidebar for model selection
@@ -18,13 +22,9 @@ def select_model():
18
  @st.cache_resource
19
  def load_model_and_tokenizer(model_name):
20
  try:
21
- # Use a fallback model for testing
22
- if model_name == "Canstralian/text2shellcommands":
23
- model_name = "t5-small"
24
-
25
  # Load the tokenizer and model
26
  tokenizer = AutoTokenizer.from_pretrained(model_name)
27
- if "seq2seq" in model_name.lower():
28
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
29
  else:
30
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
@@ -38,7 +38,7 @@ def load_model_and_tokenizer(model_name):
38
  # Handle predictions
39
  def predict_with_model(user_input, model, tokenizer, model_choice):
40
  if model_choice == "text2shellcommands":
41
- # Generate shell commands
42
  inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
43
  with torch.no_grad():
44
  outputs = model.generate(**inputs)
@@ -79,4 +79,4 @@ def main():
79
 
80
 
81
  if __name__ == "__main__":
82
- main()
 
1
  import streamlit as st
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForSequenceClassification,
5
+ AutoModelForSeq2SeqLM,
6
+ )
7
  import torch
8
 
9
  # Define the model names and mappings
10
  MODEL_MAPPING = {
11
+ "text2shellcommands": "t5-small", # Example seq2seq model
12
+ "pentest_ai": "bert-base-uncased", # Example sequence classification model
13
  }
14
 
15
  # Sidebar for model selection
 
22
  @st.cache_resource
23
  def load_model_and_tokenizer(model_name):
24
  try:
 
 
 
 
25
  # Load the tokenizer and model
26
  tokenizer = AutoTokenizer.from_pretrained(model_name)
27
+ if "t5" in model_name or "seq2seq" in model_name:
28
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
29
  else:
30
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
38
  # Handle predictions
39
  def predict_with_model(user_input, model, tokenizer, model_choice):
40
  if model_choice == "text2shellcommands":
41
+ # Generate shell commands (seq2seq task)
42
  inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
43
  with torch.no_grad():
44
  outputs = model.generate(**inputs)
 
79
 
80
 
81
  if __name__ == "__main__":
82
+ main()