chipling-api / app.py
Maouu's picture
Update app.py
79f51b5 verified
raw
history blame
8.2 kB
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Any, Optional
from pydantic import BaseModel
import asyncio
import httpx
import random
from config import cookies, headers, groqapi
from prompts import ChiplingPrompts
from groq import Groq
import json
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define request model
class ChatRequest(BaseModel):
message: str
messages: List[Dict[Any, Any]]
model: Optional[str] = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
client = Groq(api_key=groqapi)
async def generate(json_data: Dict[str, Any]):
max_retries = 5
for attempt in range(max_retries):
async with httpx.AsyncClient(timeout=None) as client:
try:
request_ctx = client.stream(
"POST",
"https://api.together.ai/inference",
cookies=cookies,
headers=headers,
json=json_data
)
async with request_ctx as response:
if response.status_code == 200:
async for line in response.aiter_lines():
if line:
yield f"{line}\n"
return
elif response.status_code == 429:
if attempt < max_retries - 1:
await asyncio.sleep(0.5)
continue
yield "data: [Rate limited, max retries]\n\n"
return
else:
yield f"data: [Unexpected status code: {response.status_code}]\n\n"
return
except Exception as e:
yield f"data: [Connection error: {str(e)}]\n\n"
return
yield "data: [Max retries reached]\n\n"
def convert_to_groq_schema(messages: List[Dict[str, Any]]) -> List[Dict[str, str]]:
converted = []
for message in messages:
role = message.get("role", "user")
content = message.get("content")
if isinstance(content, list):
flattened = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
flattened.append(item.get("text", ""))
content = "\n".join(flattened)
elif not isinstance(content, str):
content = str(content)
converted.append({"role": role, "content": content})
return converted
async def groqgenerate(json_data: Dict[str, Any]):
try:
messages = convert_to_groq_schema(json_data["messages"])
chunk_id = "groq-" + "".join(random.choices("0123456789abcdef", k=32))
created = int(asyncio.get_event_loop().time())
# Create streaming response
stream = client.chat.completions.create(
messages=messages,
model="meta-llama/llama-4-scout-17b-16e-instruct",
temperature=json_data.get("temperature", 0.7),
max_completion_tokens=json_data.get("max_tokens", 1024),
top_p=json_data.get("top_p", 1),
stop=json_data.get("stop", None),
stream=True,
)
total_tokens = 0
# Use normal for-loop since stream is not async
for chunk in stream:
content = chunk.choices[0].delta.content
if content:
response = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": created,
"model": json_data.get("model", "llama-3.3-70b-versatile"),
"choices": [{
"index": 0,
"text": content,
"logprobs": None,
"finish_reason": None
}],
"usage": None
}
yield f"data: {json.dumps(response)}\n\n"
total_tokens += 1
final = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": created,
"model": json_data.get("model", "llama-3.3-70b-versatile"),
"choices": [],
"usage": {
"prompt_tokens": len(messages),
"completion_tokens": total_tokens,
"total_tokens": len(messages) + total_tokens,
}
}
yield f"data: {json.dumps(final)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
generate(json_data)
@app.get("/")
async def index():
return {"status": "ok"}
@app.post("/chat")
async def chat(request: ChatRequest):
current_messages = request.messages.copy()
# Handle both single text or list content
if request.messages and isinstance(request.messages[-1].get('content'), list):
current_messages = request.messages
else:
current_messages.append({
'content': [{
'type': 'text',
'text': request.message
}],
'role': 'user'
})
json_data = {
'model': request.model,
'max_tokens': None,
'temperature': 0.7,
'top_p': 0.7,
'top_k': 50,
'repetition_penalty': 1,
'stream_tokens': True,
'stop': ['<|eot_id|>', '<|eom_id|>'],
'messages': current_messages,
'stream': True,
}
selected_generator = random.choice([groqgenerate, generate])
return StreamingResponse(selected_generator(json_data), media_type='text/event-stream')
@app.post("/generate-modules")
async def generate_modules(request: Request):
data = await request.json()
search_query = data.get("searchQuery")
if not search_query:
return {"error": "searchQuery is required"}
system_prompt = ChiplingPrompts.generateModules(search_query)
current_messages = [
{
'role': 'system',
'content': [{
'type': 'text',
'text': system_prompt
}]
},
{
'role': 'user',
'content': [{
'type': 'text',
'text': search_query
}]
}
]
json_data = {
'model': "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
'max_tokens': None,
'temperature': 0.7,
'top_p': 0.7,
'top_k': 50,
'repetition_penalty': 1,
'stream_tokens': True,
'stop': ['<|eot_id|>', '<|eom_id|>'],
'messages': current_messages,
'stream': True,
}
selected_generator = random.choice([groqgenerate])
return StreamingResponse(selected_generator(json_data), media_type='text/event-stream')
@app.post("/generate-topics")
async def generate_topics(request: Request):
data = await request.json()
search_query = data.get("searchQuery")
if not search_query:
return {"error": "searchQuery is required"}
system_prompt = ChiplingPrompts.generateTopics(search_query)
current_messages = [
{
'role': 'system',
'content': [{
'type': 'text',
'text': system_prompt
}]
},
{
'role': 'user',
'content': [{
'type': 'text',
'text': search_query
}]
}
]
json_data = {
'model': "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
'max_tokens': None,
'temperature': 0.7,
'top_p': 0.7,
'top_k': 50,
'repetition_penalty': 1,
'stream_tokens': True,
'stop': ['<|eot_id|>', '<|eom_id|>'],
'messages': current_messages,
'stream': True,
}
selected_generator = random.choice([groqgenerate, generate])
return StreamingResponse(selected_generator(json_data), media_type='text/event-stream')