|
from fastapi import FastAPI, Request |
|
from fastapi.responses import StreamingResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from models.text.together.main import TogetherAPI |
|
from models.text.vercel.main import XaiAPI, GroqAPI, DeepinfraAPI |
|
from models.image.vercel.main import FalAPI |
|
from models.image.together.main import TogetherImageAPI |
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"status":"ok", "routes":{"/":"GET", "/api/v1/generate":"POST", "/api/v1/models":"GET", "/api/v1/generate-images":"POST"}, "models": ["text", "image"]} |
|
|
|
@app.post("/api/v1/generate") |
|
async def generate(request: Request): |
|
data = await request.json() |
|
messages = data['messages'] |
|
model = data['model'] |
|
|
|
if not messages or not model: |
|
return {"error": "Invalid request. 'messages' and 'model' are required."} |
|
|
|
try: |
|
query = { |
|
'model': model, |
|
'max_tokens': None, |
|
'temperature': 0.7, |
|
'top_p': 0.7, |
|
'top_k': 50, |
|
'repetition_penalty': 1, |
|
'stream_tokens': True, |
|
'stop': ['<|eot_id|>', '<|eom_id|>'], |
|
'messages': messages, |
|
'stream': True, |
|
} |
|
|
|
together_models = TogetherAPI().get_model_list() |
|
xai_models = XaiAPI().get_model_list() |
|
groq_models = GroqAPI().get_model_list() |
|
deepinfra_models = DeepinfraAPI().get_model_list() |
|
|
|
if model in together_models: |
|
streamModel = TogetherAPI() |
|
elif model in xai_models: |
|
streamModel = XaiAPI() |
|
elif model in groq_models: |
|
streamModel = GroqAPI() |
|
elif model in deepinfra_models: |
|
streamModel = DeepinfraAPI() |
|
else: |
|
return {"error": f"Model '{model}' is not supported."} |
|
|
|
response = streamModel.generate(query) |
|
|
|
return StreamingResponse(response, media_type="text/event-stream") |
|
|
|
except Exception as e: |
|
return {"error": f"An error occurred: {str(e)}"} |
|
|
|
@app.get("/api/v1/models") |
|
async def get_models(): |
|
try: |
|
models = { |
|
'text': { |
|
'together': TogetherAPI().get_model_list(), |
|
'xai': XaiAPI().get_model_list(), |
|
'groq': GroqAPI().get_model_list(), |
|
'deepinfra': DeepinfraAPI().get_model_list() |
|
}, |
|
'image': { |
|
'fal': FalAPI().get_model_list(), |
|
'together': TogetherImageAPI().get_model_list() |
|
} |
|
} |
|
return {"models": models} |
|
except Exception as e: |
|
return {"error": f"An error occurred: {str(e)}"} |
|
|
|
@app.post('/api/v1/generate-images') |
|
async def generate_images(request: Request): |
|
data = await request.json() |
|
prompt = data['prompt'] |
|
model = data['model'] |
|
print(model) |
|
|
|
fal_models = FalAPI().get_model_list() |
|
together_models = TogetherImageAPI().get_model_list() |
|
if not prompt or not model: |
|
return {"error": "Invalid request. 'prompt' and 'model' are required."} |
|
if model in fal_models: |
|
streamModel = FalAPI() |
|
elif model in together_models: |
|
streamModel = TogetherImageAPI() |
|
else: |
|
return {"error": f"Model '{model}' is not supported."} |
|
try: |
|
query = { |
|
'prompt': prompt, |
|
'modelId': model, |
|
} |
|
response = await streamModel.generate(query) |
|
return response |
|
|
|
except Exception as e: |
|
return {"error": f"An error occurred: {str(e)}"} |