Spaces:
Sleeping
Sleeping
import gradio as gr | |
import joblib | |
import numpy as np | |
import pandas as pd | |
from openai import OpenAI | |
from huggingface_hub import hf_hub_download | |
# Load your pre-trained model and label names | |
model_path = hf_hub_download(repo_id="govtech/zoo-entry-001", filename="model.joblib") | |
model_data = joblib.load(model_path) | |
model = model_data['model'] | |
label_names = model_data['label_names'] | |
# Initialize OpenAI client | |
client = OpenAI() | |
def get_embedding(text, embedding_model="text-embedding-3-large"): | |
""" | |
Get embedding for the input text from OpenAI. | |
Replace newlines in the text, then call the API. | |
""" | |
text = text.replace("\n", " ") | |
response = client.embeddings.create( | |
input=[text], | |
model=embedding_model | |
) | |
# Extract embedding vector from response | |
embedding = response.data[0].embedding | |
return np.array(embedding) | |
def classify_text(text): | |
""" | |
Get the OpenAI embedding for the provided text, classify it using your model, | |
and return an updated DataFrame component with the predictions and probabilities. | |
""" | |
embedding = get_embedding(text) | |
# Add batch dimension | |
X = np.array(embedding)[None, :] | |
# Get probabilities from the model | |
probabilities = model.predict(X) | |
# Create a DataFrame with probabilities, labels, and binary predictions | |
df = pd.DataFrame({ | |
'Label': label_names, | |
'Probability': probabilities[0], | |
'Prediction': (probabilities[0] > 0.5).astype(int) | |
}) | |
# Return an update to the DataFrame component to make it visible with the results | |
return gr.update(value=df, visible=True) | |
with gr.Blocks(title="Zoo Entry 001") as iface: | |
with gr.Row(): | |
input_text = gr.Textbox(lines=5, label="Input Text") | |
with gr.Row(): | |
submit_btn = gr.Button("Submit") | |
# Initialize the table as hidden | |
with gr.Row(): | |
output_table = gr.DataFrame(label="Classification Results", visible=False) | |
submit_btn.click(fn=classify_text, inputs=input_text, outputs=output_table) | |
if __name__ == "__main__": | |
iface.launch() | |