Spaces:
Sleeping
Sleeping
File size: 3,400 Bytes
e18a7f6 02ac23d e18a7f6 02ac23d 11b89be 78da8bb e18a7f6 78da8bb 11b89be 3e90d44 e18a7f6 3e90d44 fd54971 3e90d44 11b89be 12e43e4 e18a7f6 491f951 11b89be e18a7f6 fd54971 3e90d44 02ac23d 3e90d44 e18a7f6 3e90d44 12e43e4 3e90d44 12e43e4 3e90d44 12e43e4 3e90d44 12e43e4 3e90d44 2fa652d 3e90d44 12e43e4 3e90d44 12e43e4 3e90d44 12e43e4 11b89be e18a7f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import uvicorn
import os
app = FastAPI()
API_KEY = os.environ.get("API_KEY")
try:
model = AutoModelForCausalLM.from_pretrained("petertill/cordia-a6")
tokenizer = AutoTokenizer.from_pretrained("petertill/cordia-a6")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
print("Model and tokenizer loaded successfully!")
class Message(BaseModel):
role: str # "system", "user", or "assistant"
content: str
class GenerateRequest(BaseModel):
system_prompt : str
messages: list[Message]
key: str
max_length: int = 1024
temperature: float = 0.7
class TokenUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class GenerateResponse(BaseModel):
generated_text: str
usage: TokenUsage
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
if request.key != API_KEY:
raise HTTPException(status_code=401, detail="Unauthorized")
try:
# Format messages into a prompt format the model expects
formatted_prompt = ""
formatted_prompt += f"<|system|>\n{request.system_prompt}</s>\n"
for message in request.messages:
if message.role == "system":
formatted_prompt += f"<system>\n{message.content}\n</system>\n"
elif message.role == "user":
formatted_prompt += f"<user>\n{message.content}\n</user>\n"
elif message.role == "assistant":
formatted_prompt += f"<assistant>\n{message.content}\n</assistant>\n"
# Add final assistant prefix for generation
formatted_prompt += "<assistant>\n"
# Count tokens in the prompt
prompt_tokens = len(tokenizer.encode(formatted_prompt))
output = pipe(
formatted_prompt,
#max_length=request.max_length,
#temperature=request.temperature,
do_sample=True,
return_full_text=True # Make sure we get the full text
)[0]['generated_text']
# Extract only the newly generated assistant response
response_text = output.split("<assistant>\n")[-1].split("</assistant>")[0]
# Count tokens in the completion
full_output_tokens = len(tokenizer.encode(output))
completion_tokens = full_output_tokens - prompt_tokens
usage = TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
return GenerateResponse(generated_text=response_text,usage=usage)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
#try:
#output = pipe(request.prompt)[0]['generated_text']
#return GenerateResponse(generated_text=output)
#except Exception as e:
#
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |