OSINT_Tool / app.py
Canstralian's picture
Update app.py
b8e3be9 verified
raw
history blame
2.69 kB
import streamlit as st
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForSeq2SeqLM,
)
import torch
# Define the model names and mappings
MODEL_MAPPING = {
"text2shellcommands": "t5-small", # Example seq2seq model
"pentest_ai": "bert-base-uncased", # Example sequence classification model
}
# Sidebar for model selection
def select_model():
st.sidebar.header("Model Configuration")
return st.sidebar.selectbox("Select a model", list(MODEL_MAPPING.keys()))
# Load model and tokenizer with caching
@st.cache_resource
def load_model_and_tokenizer(model_name):
try:
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
if "t5" in model_name or "seq2seq" in model_name:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
else:
model = AutoModelForSequenceClassification.from_pretrained(model_name)
return tokenizer, model
except Exception as e:
st.error(f"Error loading model: {e}")
return None, None
# Handle predictions
def predict_with_model(user_input, model, tokenizer, model_choice):
if model_choice == "text2shellcommands":
# Generate shell commands (seq2seq task)
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model.generate(**inputs)
generated_command = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"Generated Shell Command": generated_command}
else:
# Perform classification
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=-1).item()
return {
"Predicted Class": predicted_class,
"Logits": logits.tolist(),
}
# Main Streamlit app
def main():
st.title("AI Model Inference Dashboard")
# Model selection
model_choice = select_model()
model_name = MODEL_MAPPING.get(model_choice)
tokenizer, model = load_model_and_tokenizer(model_name)
# Input text box
user_input = st.text_area("Enter text:")
# Perform prediction if input and models are available
if user_input and model and tokenizer:
result = predict_with_model(user_input, model, tokenizer, model_choice)
for key, value in result.items():
st.write(f"{key}: {value}")
else:
st.info("Please enter some text for prediction.")
if __name__ == "__main__":
main()