File size: 3,400 Bytes
e18a7f6
 
 
 
02ac23d
e18a7f6
 
02ac23d
11b89be
 
78da8bb
 
e18a7f6
78da8bb
11b89be
3e90d44
 
 
 
e18a7f6
3e90d44
 
fd54971
3e90d44
 
11b89be
12e43e4
 
 
 
 
e18a7f6
 
491f951
11b89be
e18a7f6
fd54971
3e90d44
02ac23d
3e90d44
e18a7f6
3e90d44
 
 
 
 
12e43e4
3e90d44
12e43e4
3e90d44
12e43e4
3e90d44
 
12e43e4
 
 
 
3e90d44
 
 
2fa652d
 
 
 
3e90d44
 
 
12e43e4
 
 
 
 
 
 
 
 
 
 
3e90d44
12e43e4
 
 
3e90d44
 
 
 
12e43e4
11b89be
 
e18a7f6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import uvicorn
import os

app = FastAPI()
API_KEY = os.environ.get("API_KEY")

try:
    model = AutoModelForCausalLM.from_pretrained("petertill/cordia-a6")
    tokenizer = AutoTokenizer.from_pretrained("petertill/cordia-a6")
    pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
    print("Model and tokenizer loaded successfully!")

    class Message(BaseModel):
        role: str  # "system", "user", or "assistant"
        content: str

    class GenerateRequest(BaseModel):
        system_prompt : str
        messages: list[Message]
        key: str
        max_length: int = 1024
        temperature: float = 0.7

    class TokenUsage(BaseModel):
        prompt_tokens: int
        completion_tokens: int
        total_tokens: int

    class GenerateResponse(BaseModel):
        generated_text: str
        usage: TokenUsage

    @app.post("/generate", response_model=GenerateResponse)
    async def generate(request: GenerateRequest):
        if request.key != API_KEY:
            raise HTTPException(status_code=401, detail="Unauthorized")

        try:
            # Format messages into a prompt format the model expects
            formatted_prompt = ""
            formatted_prompt += f"<|system|>\n{request.system_prompt}</s>\n"
            for message in request.messages:
                if message.role == "system":
                    formatted_prompt += f"<system>\n{message.content}\n</system>\n"
                elif message.role == "user":
                    formatted_prompt += f"<user>\n{message.content}\n</user>\n"
                elif message.role == "assistant":
                    formatted_prompt += f"<assistant>\n{message.content}\n</assistant>\n"
            
            # Add final assistant prefix for generation
            formatted_prompt += "<assistant>\n"

            # Count tokens in the prompt
            prompt_tokens = len(tokenizer.encode(formatted_prompt))

            output = pipe(
                formatted_prompt,
                #max_length=request.max_length,
                #temperature=request.temperature,
                do_sample=True,
                return_full_text=True  # Make sure we get the full text
            )[0]['generated_text']

            # Extract only the newly generated assistant response
            response_text = output.split("<assistant>\n")[-1].split("</assistant>")[0]

            # Count tokens in the completion
            full_output_tokens = len(tokenizer.encode(output))
            completion_tokens = full_output_tokens - prompt_tokens

            usage = TokenUsage(
                prompt_tokens=prompt_tokens,
                completion_tokens=completion_tokens,
                total_tokens=prompt_tokens + completion_tokens
            )
        
            return GenerateResponse(generated_text=response_text,usage=usage)
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))
        #try:
            #output = pipe(request.prompt)[0]['generated_text']
            #return GenerateResponse(generated_text=output)
        #except Exception as e:
            #

except Exception as e:
    print(f"Error: {e}")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)