Select_Model / app.py
MartinKosela's picture
Update app.py
817156c
raw
history blame
4.43 kB
import streamlit as st
import openai
import os
# Securely fetch the OpenAI API key
#try:
openai.api_key = ('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')
#except openai.APIError as e:
# Log the error for debugging purposes
#st.error("An error occurred while communicating with the OpenAI API.")
# Optionally, log more detailed information for debugging
#print("Error details:", e)
# You might also want to implement additional logic here, like retrying the request,
# sending a notification to an administrator, or providing a default response.
KNOWN_MODELS = [
# General ML models
"Neural Networks", "Decision Trees", "Support Vector Machines",
"Random Forests", "Linear Regression", "Reinforcement Learning",
"Logistic Regression", "k-Nearest Neighbors", "Naive Bayes",
"Gradient Boosting Machines", "Regularization Techniques",
"Ensemble Methods", "Time Series Analysis",
# Deep Learning models
"Deep Learning", "Convolutional Neural Networks",
"Recurrent Neural Networks", "Transformer Models",
"Generative Adversarial Networks", "Autoencoders",
"Bidirectional LSTM", "Residual Networks (ResNets)",
"Variational Autoencoders",
# Computer Vision models and techniques
"Object Detection (e.g., YOLO, SSD)", "Semantic Segmentation",
"Image Classification", "Face Recognition", "Optical Character Recognition (OCR)",
"Pose Estimation", "Style Transfer", "Image-to-Image Translation",
"Image Generation", "Capsule Networks",
# NLP models and techniques
"BERT", "GPT", "ELMo", "T5", "Word2Vec", "Doc2Vec",
"Topic Modeling", "Sentiment Analysis", "Text Classification",
"Machine Translation", "Speech Recognition", "Sequence-to-Sequence Models",
"Attention Mechanisms", "Named Entity Recognition", "Text Summarization"
]
def recommend_ai_model_via_gpt(description):
messages = [
{"role": "user", "content": description}
]
response = openai.ChatCompletion.create(
model="gpt-4",
messages=messages
)
return response['choices'][0]['message']['content'].strip()
#except openai.APIError as e:
# return f"Error: {e}"
#except openai.RateLimitError as e:
# return f"Rate limit exceeded: {e}"
#except openai.APIConnectionError as e:
# return f"Connection error: {e}"
def explain_recommendation(model_name):
messages = [
{"role": "user", "content": f"Why would {model_name} be a suitable choice for the application?"}
]
response = openai.ChatCompletion.create(
model="gpt-4",
messages=messages
)
return response['choices'][0]['message']['content'].strip()
#except openai.APIError as e:
# return f"Error: {e}"
#except openai.RateLimitError as e:
# return f"Rate limit exceeded: {e}"
#except openai.APIConnectionError as e:
# return f"Connection error: {e}"
# Streamlit UI
st.image("./A8title2.png")
st.title('Find the best AI stack for your app')
description = st.text_area("Describe your application:")
dataset_description = st.text_area("Describe the dataset you want to use for fine-tuning your model:")
recommendation_type = st.radio("What type of recommendation are you looking for?", ["Recommend Open-Source Model", "Recommend API Service"])
if "rec_model_pressed" not in st.session_state:
st.session_state.rec_model_pressed = False
if "feedback_submitted" not in st.session_state:
st.session_state.feedback_submitted = False
if st.button("Recommend AI Model"):
st.session_state.rec_model_pressed = True
if st.session_state.rec_model_pressed:
if description and dataset_description:
combined_query = f"{description} Dataset: {dataset_description}"
recommended_model = recommend_ai_model_via_gpt(combined_query)
st.subheader(f"Recommended: {recommended_model}")
explanation = explain_recommendation(recommended_model)
st.write("Reason:", explanation)
rating = st.slider("Rate the explanation from 1 (worst) to 5 (best):", 1, 5)
feedback = st.text_input("Any additional feedback?")
if st.button("Submit Feedback"):
st.session_state.feedback_submitted = True
if st.session_state.feedback_submitted:
st.success("Thank you for your feedback!")
else:
st.warning("Please provide a description and dataset details.")