Spaces:
Sleeping
Sleeping
import os | |
import logging | |
from fastapi import FastAPI, HTTPException, Query | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel | |
from openai import AsyncOpenAI | |
from typing import Optional | |
# Configure logging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
app = FastAPI( | |
title="Orion AI API", | |
description="API for streaming AI responses with model selection and publisher via URL", | |
version="1.0.0" | |
) | |
# Define valid models (replace with actual models supported by https://models.github.ai/inference) | |
VALID_MODELS = [ | |
"deepseek/DeepSeek-V3-0324", # Added based on your request | |
"gpt-3.5-turbo", # Common model (placeholder) | |
"llama-3", # Common model (placeholder) | |
"mistral-7b" # Common model (placeholder) | |
] | |
class GenerateRequest(BaseModel): | |
prompt: str | |
publisher: Optional[str] = None # Allow publisher in the body if needed | |
async def generate_ai_response(prompt: str, model: str, publisher: Optional[str]): | |
logger.debug(f"Received prompt: {prompt}, model: {model}, publisher: {publisher}") | |
# Configuration for AI endpoint | |
token = os.getenv("GITHUB_TOKEN") | |
endpoint = os.getenv("AI_SERVER_URL", "https://models.github.ai/inference") | |
default_publisher = os.getenv("DEFAULT_PUBLISHER", "abdullahalioo") # Fallback publisher | |
if not token: | |
logger.error("GitHub token not configured") | |
raise HTTPException(status_code=500, detail="GitHub token not configured") | |
# Use provided publisher or fallback to environment variable | |
final_publisher = publisher or default_publisher | |
if not final_publisher: | |
logger.error("Publisher is required") | |
raise HTTPException(status_code=400, detail="Publisher is required") | |
# Validate model | |
if model not in VALID_MODELS: | |
logger.error(f"Invalid model: {model}. Valid models: {VALID_MODELS}") | |
raise HTTPException(status_code=400, detail=f"Invalid model. Valid models: {VALID_MODELS}") | |
logger.debug(f"Using endpoint: {endpoint}, publisher: {final_publisher}") | |
client = AsyncOpenAI(base_url=endpoint, api_key=token) | |
try: | |
# Include publisher in the request payload | |
stream = await client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant named Orion, created by Abdullah Ali"}, | |
{"role": "user", "content": prompt} | |
], | |
model=model, | |
temperature=1.0, | |
top_p=1.0, | |
stream=True, | |
extra_body={"publisher": final_publisher} # Add publisher to extra_body | |
) | |
async for chunk in stream: | |
if chunk.choices and chunk.choices[0].delta.content: | |
yield chunk.choices[0].delta.content | |
except Exception as err: | |
logger.error(f"AI generation failed: {str(err)}") | |
yield f"Error: {str(err)}" | |
raise HTTPException(status_code=500, detail=f"AI generation failed: {str(err)}") | |
async def generate_response( | |
model: str = Query("deepseek/DeepSeek-V3-0324", description="The AI model to use"), | |
prompt: Optional[str] = Query(None, description="The input text prompt for the AI"), | |
publisher: Optional[str] = Query(None, description="Publisher identifier (optional, defaults to DEFAULT_PUBLISHER env var)"), | |
request: Optional[GenerateRequest] = None | |
): | |
""" | |
Generate a streaming AI response based on the provided prompt, model, and publisher. | |
- **model**: The AI model to use (e.g., deepseek/DeepSeek-V3-0324) | |
- **prompt**: The input text prompt for the AI (query param or body) | |
- **publisher**: The publisher identifier (optional, defaults to DEFAULT_PUBLISHER env var) | |
""" | |
logger.debug(f"Request received - model: {model}, prompt: {prompt}, publisher: {publisher}, body: {request}") | |
# Determine prompt source: query parameter or request body | |
final_prompt = prompt if prompt is not None else (request.prompt if request is not None else None) | |
# Determine publisher source: query parameter or request body | |
final_publisher = publisher if publisher is not None else (request.publisher if request is not None else None) | |
if not final_prompt or not final_prompt.strip(): | |
logger.error("Prompt cannot be empty") | |
raise HTTPException(status_code=400, detail="Prompt cannot be empty") | |
if not model or not model.strip(): | |
logger.error("Model cannot be empty") | |
raise HTTPException(status_code=400, detail="Model cannot be empty") | |
return StreamingResponse( | |
generate_ai_response(final_prompt, model, final_publisher), | |
media_type="text/event-stream" | |
) | |
async def list_models(): | |
""" | |
List all available models supported by the AI server. | |
""" | |
return {"models": VALID_MODELS} | |
def get_app(): | |
return app | |