Lord-Raven
Messing with fastAPI.
573c10d
raw
history blame
4.32 kB
import spaces
import torch
import gradio
import json
import onnxruntime
import time
from datetime import datetime
from transformers import pipeline
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
# CORS Config
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win","https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win","https://crunchatize-2-2b4f5b1479a6.c5v4v4jx6pq5.win","https://tamabotchi-2dba63df3bf1.c5v4v4jx6pq5.win"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
# "Xenova/bart-large-mnli" A bit slow
# "Xenova/distilbert-base-uncased-mnli" "typeform/distilbert-base-uncased-mnli" Bad answers
# "Xenova/deBERTa-v3-base-mnli" "MoritzLaurer/DeBERTa-v3-base-mnli" Still a bit slow and not great answers
# "xenova/nli-deberta-v3-small" "cross-encoder/nli-deberta-v3-small" Was using this for a good while and it was...okay
model_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
tokenizer_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
classifier_cpu = pipeline(task="zero-shot-classification", model=model_name, tokenizer=tokenizer_name)
classifier_gpu = pipeline(task="zero-shot-classification", model=model_name, tokenizer=tokenizer_name, device="cuda:0")
# classifier = pipeline(task="zero-shot-classification", model=model_name, tokenizer=tokenizer_name)
def classify(data_string, request: gradio.Request):
if request:
if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://crunchatize-2-2b4f5b1479a6.c5v4v4jx6pq5.win", "https://tamabotchi-2dba63df3bf1.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space", "https://lord-raven.github.io"]:
return "{}"
data = json.loads(data_string)
# Prevent batch suggestion warning in log.
classifier_cpu.call_count = 0
classifier_gpu.call_count = 0
# if 'task' in data and data['task'] == 'few_shot_classification':
# return few_shot_classification(data)
# else:
start_time = time.time()
result = {}
if ('cpu' in data):
result = zero_shot_classification_cpu(data)
else:
result = zero_shot_classification_gpu(data)
print(f"Classification @ [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] took {time.time() - start_time}.")
return json.dumps(result)
def zero_shot_classification_cpu(data):
return classifier_cpu(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
@spaces.GPU(duration=3)
def zero_shot_classification_gpu(data):
return classifier_gpu(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
def create_sequences(data):
return [data['sequence'] + '\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels']]
# def few_shot_classification(data):
# sequences = create_sequences(data)
# print(sequences)
# # results = onnx_few_shot_model(sequences)
# probs = onnx_few_shot_model.predict_proba(sequences)
# scores = [true[0] for true in probs]
# composite = list(zip(scores, data['candidate_labels']))
# composite = sorted(composite, key=lambda x: x[0], reverse=True)
# labels, scores = zip(*composite)
# response_dict = {'scores': scores, 'labels': labels}
# print(response_dict)
# response_string = json.dumps(response_dict)
# return response_strin
gradio_interface = gradio.Interface(
fn = classify,
inputs = gradio.Textbox(label="JSON Input"),
outputs = gradio.Textbox()
)
app.mount("/gradio", gradio_interface)
# app = gradio.mount_gradio_app(app, gradio_interface, path="/gradio")
gradio_interface.launch()
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=8000)