from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer import uvicorn import os app = FastAPI() API_KEY = os.environ.get("API_KEY") try: model = AutoModelForCausalLM.from_pretrained("petertill/cordia-a6") tokenizer = AutoTokenizer.from_pretrained("petertill/cordia-a6") pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) print("Model and tokenizer loaded successfully!") class Message(BaseModel): role: str # "system", "user", or "assistant" content: str class GenerateRequest(BaseModel): system_prompt : str messages: list[Message] key: str max_length: int = 1024 temperature: float = 0.7 class TokenUsage(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int class GenerateResponse(BaseModel): generated_text: str usage: TokenUsage @app.post("/generate", response_model=GenerateResponse) async def generate(request: GenerateRequest): if request.key != API_KEY: raise HTTPException(status_code=401, detail="Unauthorized") try: # Format messages into a prompt format the model expects formatted_prompt = "" formatted_prompt += f"<|system|>\n{request.system_prompt}\n" for message in request.messages: if message.role == "system": formatted_prompt += f"\n{message.content}\n\n" elif message.role == "user": formatted_prompt += f"\n{message.content}\n\n" elif message.role == "assistant": formatted_prompt += f"\n{message.content}\n\n" # Add final assistant prefix for generation formatted_prompt += "\n" # Count tokens in the prompt prompt_tokens = len(tokenizer.encode(formatted_prompt)) output = pipe( formatted_prompt, #max_length=request.max_length, #temperature=request.temperature, do_sample=True, return_full_text=True # Make sure we get the full text )[0]['generated_text'] # Extract only the newly generated assistant response response_text = output.split("\n")[-1].split("")[0] # Count tokens in the completion full_output_tokens = len(tokenizer.encode(output)) completion_tokens = full_output_tokens - prompt_tokens usage = TokenUsage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens ) return GenerateResponse(generated_text=response_text,usage=usage) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) #try: #output = pipe(request.prompt)[0]['generated_text'] #return GenerateResponse(generated_text=output) #except Exception as e: # except Exception as e: print(f"Error: {e}") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)