Spaces:
Running
Running
File size: 2,691 Bytes
c954503 b8e3be9 c954503 151aa67 c0298f8 b8e3be9 6a46dba c0298f8 c954503 c0298f8 b38a095 b8e3be9 6a46dba c0298f8 c954503 69f088a c954503 c0298f8 b8e3be9 c0298f8 6a46dba c0298f8 9bc591d c0298f8 9bc591d c0298f8 9bc591d c0298f8 9bc591d c0298f8 69f088a 9bc591d b8e3be9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import streamlit as st
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForSeq2SeqLM,
)
import torch
# Define the model names and mappings
MODEL_MAPPING = {
"text2shellcommands": "t5-small", # Example seq2seq model
"pentest_ai": "bert-base-uncased", # Example sequence classification model
}
# Sidebar for model selection
def select_model():
st.sidebar.header("Model Configuration")
return st.sidebar.selectbox("Select a model", list(MODEL_MAPPING.keys()))
# Load model and tokenizer with caching
@st.cache_resource
def load_model_and_tokenizer(model_name):
try:
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
if "t5" in model_name or "seq2seq" in model_name:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
else:
model = AutoModelForSequenceClassification.from_pretrained(model_name)
return tokenizer, model
except Exception as e:
st.error(f"Error loading model: {e}")
return None, None
# Handle predictions
def predict_with_model(user_input, model, tokenizer, model_choice):
if model_choice == "text2shellcommands":
# Generate shell commands (seq2seq task)
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model.generate(**inputs)
generated_command = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"Generated Shell Command": generated_command}
else:
# Perform classification
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=-1).item()
return {
"Predicted Class": predicted_class,
"Logits": logits.tolist(),
}
# Main Streamlit app
def main():
st.title("AI Model Inference Dashboard")
# Model selection
model_choice = select_model()
model_name = MODEL_MAPPING.get(model_choice)
tokenizer, model = load_model_and_tokenizer(model_name)
# Input text box
user_input = st.text_area("Enter text:")
# Perform prediction if input and models are available
if user_input and model and tokenizer:
result = predict_with_model(user_input, model, tokenizer, model_choice)
for key, value in result.items():
st.write(f"{key}: {value}")
else:
st.info("Please enter some text for prediction.")
if __name__ == "__main__":
main()
|