import spaces import torch import gradio import json import onnxruntime import time from datetime import datetime from transformers import pipeline from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware # CORS Config - This isn't actually working; instead, I am taking a gross approach to origin whitelisting within the service. app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win","https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win","https://crunchatize-2-2b4f5b1479a6.c5v4v4jx6pq5.win","https://tamabotchi-2dba63df3bf1.c5v4v4jx6pq5.win"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) print(f"Is CUDA available: {torch.cuda.is_available()}") print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") # "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere model_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0" tokenizer_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0" classifier_cpu = pipeline(task="zero-shot-classification", model=model_name, tokenizer=tokenizer_name) classifier_gpu = pipeline(task="zero-shot-classification", model=model_name, tokenizer=tokenizer_name, device="cuda:0") def classify(data_string, request: gradio.Request): if request: if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://crunchatize-2-2b4f5b1479a6.c5v4v4jx6pq5.win", "https://tamabotchi-2dba63df3bf1.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space", "https://lord-raven.github.io"]: return "{}" data = json.loads(data_string) # Try to prevent batch suggestion warning in log. classifier_cpu.call_count = 0 classifier_gpu.call_count = 0 start_time = time.time() result = {} try: if 'cpu' not in data: result = zero_shot_classification_gpu(data) print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - GPU Classification took {time.time() - start_time}.") except Exception as e: print(f"GPU classification failed: {e}\nFall back to CPU.") if not result: result = zero_shot_classification_cpu(data) print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - CPU Classification took {time.time() - start_time}.") return json.dumps(result) def zero_shot_classification_cpu(data): return classifier_cpu(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label']) @spaces.GPU(duration=3) def zero_shot_classification_gpu(data): return classifier_gpu(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label']) def create_sequences(data): return [data['sequence'] + '\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels']] gradio_interface = gradio.Interface( fn = classify, inputs = gradio.Textbox(label="JSON Input"), outputs = gradio.Textbox(label="JSON Output"), title = "Statosphere Backend", description = "This Space is a classification service for a set of chub.ai stages and not really intended for use through this UI." ) app.mount("/gradio", gradio_interface) gradio_interface.launch()