Canstralian commited on
Commit
6a09dd7
·
verified ·
1 Parent(s): b8e3be9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -20
app.py CHANGED
@@ -5,40 +5,74 @@ from transformers import (
5
  AutoModelForSeq2SeqLM,
6
  )
7
  import torch
 
8
 
9
- # Define the model names and mappings
10
  MODEL_MAPPING = {
11
- "text2shellcommands": "t5-small", # Example seq2seq model
12
- "pentest_ai": "bert-base-uncased", # Example sequence classification model
13
  }
14
 
15
- # Sidebar for model selection
16
  def select_model():
 
 
 
 
 
17
  st.sidebar.header("Model Configuration")
18
- return st.sidebar.selectbox("Select a model", list(MODEL_MAPPING.keys()))
 
19
 
20
 
21
- # Load model and tokenizer with caching
22
  @st.cache_resource
23
  def load_model_and_tokenizer(model_name):
 
 
 
 
 
 
 
 
 
 
24
  try:
25
- # Load the tokenizer and model
26
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
27
  if "t5" in model_name or "seq2seq" in model_name:
 
28
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
29
  else:
 
30
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
31
 
32
  return tokenizer, model
33
  except Exception as e:
34
- st.error(f"Error loading model: {e}")
 
35
  return None, None
36
 
37
 
38
- # Handle predictions
39
  def predict_with_model(user_input, model, tokenizer, model_choice):
 
 
 
 
 
 
 
 
 
 
 
 
40
  if model_choice == "text2shellcommands":
41
- # Generate shell commands (seq2seq task)
42
  inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
43
  with torch.no_grad():
44
  outputs = model.generate(**inputs)
@@ -57,25 +91,85 @@ def predict_with_model(user_input, model, tokenizer, model_choice):
57
  }
58
 
59
 
60
- # Main Streamlit app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def main():
62
  st.title("AI Model Inference Dashboard")
 
 
 
 
 
 
63
 
64
  # Model selection
65
  model_choice = select_model()
66
  model_name = MODEL_MAPPING.get(model_choice)
67
  tokenizer, model = load_model_and_tokenizer(model_name)
68
 
69
- # Input text box
70
- user_input = st.text_area("Enter text:")
71
 
72
- # Perform prediction if input and models are available
73
- if user_input and model and tokenizer:
74
- result = predict_with_model(user_input, model, tokenizer, model_choice)
75
- for key, value in result.items():
76
- st.write(f"{key}: {value}")
77
- else:
78
- st.info("Please enter some text for prediction.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
 
81
  if __name__ == "__main__":
 
5
  AutoModelForSeq2SeqLM,
6
  )
7
  import torch
8
+ import os
9
 
10
+ # Define the model names and their corresponding Hugging Face models
11
  MODEL_MAPPING = {
12
+ "text2shellcommands": "t5-small", # Example seq2seq model for generating shell commands
13
+ "pentest_ai": "bert-base-uncased", # Example classification model for pentesting tasks
14
  }
15
 
16
+ # Function to create a sidebar for model selection
17
  def select_model():
18
+ """
19
+ Adds a dropdown to the Streamlit sidebar for selecting a model.
20
+ Returns:
21
+ str: The selected model key from MODEL_MAPPING.
22
+ """
23
  st.sidebar.header("Model Configuration")
24
+ selected_model = st.sidebar.selectbox("Select a model", list(MODEL_MAPPING.keys()))
25
+ return selected_model
26
 
27
 
28
+ # Function to load the model and tokenizer with caching
29
  @st.cache_resource
30
  def load_model_and_tokenizer(model_name):
31
+ """
32
+ Loads the tokenizer and model for the specified Hugging Face model name.
33
+ Uses caching to optimize performance.
34
+
35
+ Args:
36
+ model_name (str): The name of the Hugging Face model to load.
37
+
38
+ Returns:
39
+ tuple: A tokenizer and model instance.
40
+ """
41
  try:
42
+ # Load the tokenizer
43
  tokenizer = AutoTokenizer.from_pretrained(model_name)
44
+
45
+ # Determine the correct model class to use
46
  if "t5" in model_name or "seq2seq" in model_name:
47
+ # Load a sequence-to-sequence model
48
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
49
  else:
50
+ # Load a sequence classification model
51
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
52
 
53
  return tokenizer, model
54
  except Exception as e:
55
+ # Display an error message in the Streamlit app
56
+ st.error(f"An error occurred while loading the model or tokenizer: {str(e)}")
57
  return None, None
58
 
59
 
60
+ # Function to handle predictions based on the selected model
61
  def predict_with_model(user_input, model, tokenizer, model_choice):
62
+ """
63
+ Handles predictions using the loaded model and tokenizer.
64
+
65
+ Args:
66
+ user_input (str): Text input from the user.
67
+ model: Loaded Hugging Face model.
68
+ tokenizer: Loaded Hugging Face tokenizer.
69
+ model_choice (str): Selected model key from MODEL_MAPPING.
70
+
71
+ Returns:
72
+ dict: A dictionary containing the prediction results.
73
+ """
74
  if model_choice == "text2shellcommands":
75
+ # Generate shell commands (Seq2Seq task)
76
  inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
77
  with torch.no_grad():
78
  outputs = model.generate(**inputs)
 
91
  }
92
 
93
 
94
+ # Function to process uploaded files
95
+ def process_uploaded_file(uploaded_file):
96
+ """
97
+ Reads and processes the uploaded file. Supports text and CSV files.
98
+
99
+ Args:
100
+ uploaded_file: The uploaded file.
101
+
102
+ Returns:
103
+ str: The content of the file as a string.
104
+ """
105
+ try:
106
+ if uploaded_file is not None:
107
+ file_type = uploaded_file.type
108
+
109
+ # Text file processing
110
+ if "text" in file_type:
111
+ content = uploaded_file.read().decode("utf-8")
112
+ return content
113
+ # CSV file processing
114
+ elif "csv" in file_type:
115
+ import pandas as pd
116
+ df = pd.read_csv(uploaded_file)
117
+ return df.to_string() # Convert the dataframe to string
118
+ else:
119
+ st.error("Unsupported file type. Please upload a text or CSV file.")
120
+ return None
121
+ except Exception as e:
122
+ st.error(f"Error processing file: {e}")
123
+ return None
124
+
125
+
126
+ # Main function to define the Streamlit app
127
  def main():
128
  st.title("AI Model Inference Dashboard")
129
+ st.markdown(
130
+ """
131
+ This dashboard allows you to interact with different AI models for inference tasks,
132
+ such as generating shell commands or performing text classification.
133
+ """
134
+ )
135
 
136
  # Model selection
137
  model_choice = select_model()
138
  model_name = MODEL_MAPPING.get(model_choice)
139
  tokenizer, model = load_model_and_tokenizer(model_name)
140
 
141
+ # Input text area or file upload
142
+ input_choice = st.radio("Choose Input Method", ("Text Input", "Upload File"))
143
 
144
+ if input_choice == "Text Input":
145
+ user_input = st.text_area("Enter your text input:", placeholder="Type your text here...")
146
+
147
+ # Handle prediction after submit
148
+ submit_button = st.button("Submit")
149
+
150
+ if submit_button and user_input:
151
+ st.write("### Prediction Results:")
152
+ result = predict_with_model(user_input, model, tokenizer, model_choice)
153
+ for key, value in result.items():
154
+ st.write(f"**{key}:** {value}")
155
+
156
+ elif input_choice == "Upload File":
157
+ uploaded_file = st.file_uploader("Choose a text or CSV file", type=["txt", "csv"])
158
+
159
+ # Handle prediction after submit
160
+ submit_button = st.button("Submit")
161
+
162
+ if submit_button and uploaded_file:
163
+ file_content = process_uploaded_file(uploaded_file)
164
+ if file_content:
165
+ st.write("### File Content:")
166
+ st.write(file_content)
167
+ result = predict_with_model(file_content, model, tokenizer, model_choice)
168
+ st.write("### Prediction Results:")
169
+ for key, value in result.items():
170
+ st.write(f"**{key}:** {value}")
171
+ else:
172
+ st.info("No valid content found in the file.")
173
 
174
 
175
  if __name__ == "__main__":