chipling-api / app.py
Maouu's picture
Create app.py
9cb3fae verified
raw
history blame
5.18 kB
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Any, Optional
from pydantic import BaseModel
import asyncio
import httpx
from config import cookies, headers
from prompts import ChiplingPrompts
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define request model
class ChatRequest(BaseModel):
message: str
messages: List[Dict[Any, Any]]
model: Optional[str] = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
async def generate(json_data: Dict[str, Any]):
max_retries = 5
for attempt in range(max_retries):
async with httpx.AsyncClient(timeout=None) as client:
try:
request_ctx = client.stream(
"POST",
"https://api.together.ai/inference",
cookies=cookies,
headers=headers,
json=json_data
)
async with request_ctx as response:
if response.status_code == 200:
async for line in response.aiter_lines():
if line:
yield f"{line}\n"
return
elif response.status_code == 429:
if attempt < max_retries - 1:
await asyncio.sleep(0.5)
continue
yield "data: [Rate limited, max retries]\n\n"
return
else:
yield f"data: [Unexpected status code: {response.status_code}]\n\n"
return
except Exception as e:
yield f"data: [Connection error: {str(e)}]\n\n"
return
yield "data: [Max retries reached]\n\n"
@app.get("/")
async def index():
return {"status": "ok"}
@app.post("/chat")
async def chat(request: ChatRequest):
current_messages = request.messages.copy()
# Handle both single text or list content
if request.messages and isinstance(request.messages[-1].get('content'), list):
current_messages = request.messages
else:
current_messages.append({
'content': [{
'type': 'text',
'text': request.message
}],
'role': 'user'
})
json_data = {
'model': request.model,
'max_tokens': None,
'temperature': 0.7,
'top_p': 0.7,
'top_k': 50,
'repetition_penalty': 1,
'stream_tokens': True,
'stop': ['<|eot_id|>', '<|eom_id|>'],
'messages': current_messages,
'stream': True,
}
return StreamingResponse(generate(json_data), media_type='text/event-stream')
@app.post("/generate-modules")
async def generate_modules(request: Request):
data = await request.json()
search_query = data.get("searchQuery")
if not search_query:
return {"error": "searchQuery is required"}
system_prompt = ChiplingPrompts.generateModules(search_query)
current_messages = [
{
'role': 'system',
'content': [{
'type': 'text',
'text': system_prompt
}]
},
{
'role': 'user',
'content': [{
'type': 'text',
'text': search_query
}]
}
]
json_data = {
'model': "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
'max_tokens': None,
'temperature': 0.7,
'top_p': 0.7,
'top_k': 50,
'repetition_penalty': 1,
'stream_tokens': True,
'stop': ['<|eot_id|>', '<|eom_id|>'],
'messages': current_messages,
'stream': True,
}
return StreamingResponse(generate(json_data), media_type='text/event-stream')
@app.post("/generate-topics")
async def generate_topics(request: Request):
data = await request.json()
search_query = data.get("searchQuery")
if not search_query:
return {"error": "searchQuery is required"}
system_prompt = ChiplingPrompts.generateTopics(search_query)
current_messages = [
{
'role': 'system',
'content': [{
'type': 'text',
'text': system_prompt
}]
},
{
'role': 'user',
'content': [{
'type': 'text',
'text': search_query
}]
}
]
json_data = {
'model': "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
'max_tokens': None,
'temperature': 0.7,
'top_p': 0.7,
'top_k': 50,
'repetition_penalty': 1,
'stream_tokens': True,
'stop': ['<|eot_id|>', '<|eom_id|>'],
'messages': current_messages,
'stream': True,
}
return StreamingResponse(generate(json_data), media_type='text/event-stream')