Spaces:
Running
Running
import streamlit as st | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
AutoModelForSeq2SeqLM, | |
) | |
import torch | |
# Define the model names and mappings | |
MODEL_MAPPING = { | |
"text2shellcommands": "t5-small", # Example seq2seq model | |
"pentest_ai": "bert-base-uncased", # Example sequence classification model | |
} | |
# Sidebar for model selection | |
def select_model(): | |
st.sidebar.header("Model Configuration") | |
return st.sidebar.selectbox("Select a model", list(MODEL_MAPPING.keys())) | |
# Load model and tokenizer with caching | |
def load_model_and_tokenizer(model_name): | |
try: | |
# Load the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
if "t5" in model_name or "seq2seq" in model_name: | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
else: | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
return tokenizer, model | |
except Exception as e: | |
st.error(f"Error loading model: {e}") | |
return None, None | |
# Handle predictions | |
def predict_with_model(user_input, model, tokenizer, model_choice): | |
if model_choice == "text2shellcommands": | |
# Generate shell commands (seq2seq task) | |
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
outputs = model.generate(**inputs) | |
generated_command = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return {"Generated Shell Command": generated_command} | |
else: | |
# Perform classification | |
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class = torch.argmax(logits, dim=-1).item() | |
return { | |
"Predicted Class": predicted_class, | |
"Logits": logits.tolist(), | |
} | |
# Main Streamlit app | |
def main(): | |
st.title("AI Model Inference Dashboard") | |
# Model selection | |
model_choice = select_model() | |
model_name = MODEL_MAPPING.get(model_choice) | |
tokenizer, model = load_model_and_tokenizer(model_name) | |
# Input text box | |
user_input = st.text_area("Enter text:") | |
# Perform prediction if input and models are available | |
if user_input and model and tokenizer: | |
result = predict_with_model(user_input, model, tokenizer, model_choice) | |
for key, value in result.items(): | |
st.write(f"{key}: {value}") | |
else: | |
st.info("Please enter some text for prediction.") | |
if __name__ == "__main__": | |
main() | |