Spaces:
Running
Running
from fastapi import FastAPI, Request | |
from fastapi.responses import StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from typing import List, Dict, Any, Optional | |
from pydantic import BaseModel | |
import asyncio | |
import httpx | |
from config import cookies, headers | |
from prompts import ChiplingPrompts | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Define request model | |
class ChatRequest(BaseModel): | |
message: str | |
messages: List[Dict[Any, Any]] | |
model: Optional[str] = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8" | |
async def generate(json_data: Dict[str, Any]): | |
max_retries = 5 | |
for attempt in range(max_retries): | |
async with httpx.AsyncClient(timeout=None) as client: | |
try: | |
request_ctx = client.stream( | |
"POST", | |
"https://api.together.ai/inference", | |
cookies=cookies, | |
headers=headers, | |
json=json_data | |
) | |
async with request_ctx as response: | |
if response.status_code == 200: | |
async for line in response.aiter_lines(): | |
if line: | |
yield f"{line}\n" | |
return | |
elif response.status_code == 429: | |
if attempt < max_retries - 1: | |
await asyncio.sleep(0.5) | |
continue | |
yield "data: [Rate limited, max retries]\n\n" | |
return | |
else: | |
yield f"data: [Unexpected status code: {response.status_code}]\n\n" | |
return | |
except Exception as e: | |
yield f"data: [Connection error: {str(e)}]\n\n" | |
return | |
yield "data: [Max retries reached]\n\n" | |
async def index(): | |
return {"status": "ok"} | |
async def chat(request: ChatRequest): | |
current_messages = request.messages.copy() | |
# Handle both single text or list content | |
if request.messages and isinstance(request.messages[-1].get('content'), list): | |
current_messages = request.messages | |
else: | |
current_messages.append({ | |
'content': [{ | |
'type': 'text', | |
'text': request.message | |
}], | |
'role': 'user' | |
}) | |
json_data = { | |
'model': request.model, | |
'max_tokens': None, | |
'temperature': 0.7, | |
'top_p': 0.7, | |
'top_k': 50, | |
'repetition_penalty': 1, | |
'stream_tokens': True, | |
'stop': ['<|eot_id|>', '<|eom_id|>'], | |
'messages': current_messages, | |
'stream': True, | |
} | |
return StreamingResponse(generate(json_data), media_type='text/event-stream') | |
async def generate_modules(request: Request): | |
data = await request.json() | |
search_query = data.get("searchQuery") | |
if not search_query: | |
return {"error": "searchQuery is required"} | |
system_prompt = ChiplingPrompts.generateModules(search_query) | |
current_messages = [ | |
{ | |
'role': 'system', | |
'content': [{ | |
'type': 'text', | |
'text': system_prompt | |
}] | |
}, | |
{ | |
'role': 'user', | |
'content': [{ | |
'type': 'text', | |
'text': search_query | |
}] | |
} | |
] | |
json_data = { | |
'model': "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", | |
'max_tokens': None, | |
'temperature': 0.7, | |
'top_p': 0.7, | |
'top_k': 50, | |
'repetition_penalty': 1, | |
'stream_tokens': True, | |
'stop': ['<|eot_id|>', '<|eom_id|>'], | |
'messages': current_messages, | |
'stream': True, | |
} | |
return StreamingResponse(generate(json_data), media_type='text/event-stream') | |
async def generate_topics(request: Request): | |
data = await request.json() | |
search_query = data.get("searchQuery") | |
if not search_query: | |
return {"error": "searchQuery is required"} | |
system_prompt = ChiplingPrompts.generateTopics(search_query) | |
current_messages = [ | |
{ | |
'role': 'system', | |
'content': [{ | |
'type': 'text', | |
'text': system_prompt | |
}] | |
}, | |
{ | |
'role': 'user', | |
'content': [{ | |
'type': 'text', | |
'text': search_query | |
}] | |
} | |
] | |
json_data = { | |
'model': "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", | |
'max_tokens': None, | |
'temperature': 0.7, | |
'top_p': 0.7, | |
'top_k': 50, | |
'repetition_penalty': 1, | |
'stream_tokens': True, | |
'stop': ['<|eot_id|>', '<|eom_id|>'], | |
'messages': current_messages, | |
'stream': True, | |
} | |
return StreamingResponse(generate(json_data), media_type='text/event-stream') |