File size: 2,611 Bytes
e18a7f6
 
 
 
02ac23d
e18a7f6
 
02ac23d
11b89be
 
78da8bb
 
e18a7f6
78da8bb
11b89be
3e90d44
 
 
 
e18a7f6
3e90d44
 
fd54971
3e90d44
 
11b89be
e18a7f6
 
11b89be
e18a7f6
fd54971
3e90d44
02ac23d
3e90d44
e18a7f6
3e90d44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11b89be
 
e18a7f6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import uvicorn
import os

app = FastAPI()
API_KEY = os.environ.get("API_KEY")

try:
    model = AutoModelForCausalLM.from_pretrained("petertill/cordia-a6")
    tokenizer = AutoTokenizer.from_pretrained("petertill/cordia-a6")
    pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
    print("Model and tokenizer loaded successfully!")

    class Message(BaseModel):
        role: str  # "system", "user", or "assistant"
        content: str

    class GenerateRequest(BaseModel):
        system_prompt : str
        messages: list[Message]
        key: str
        max_length: int = 1024
        temperature: float = 0.7

    class GenerateResponse(BaseModel):
        generated_text: str

    @app.post("/generate", response_model=GenerateResponse)
    async def generate(request: GenerateRequest):
        if request.key != API_KEY:
            raise HTTPException(status_code=401, detail="Unauthorized")

        try:
            # Format messages into a prompt format the model expects
            formatted_prompt = ""
            formatted_prompt += f"<|system|>\n{request.system_prompt}</s>\n"
            for message in request.messages:
                if message.role == "system":
                    formatted_prompt += f"<|system|>\n{message.content}</s>\n"
                elif message.role == "user":
                    formatted_prompt += f"<|user|>\n{message.content}</s>\n"
                elif message.role == "assistant":
                    formatted_prompt += f"<|assistant|>\n{message.content}</s>\n"
            
            # Add final assistant prefix for generation
            formatted_prompt += "<|assistant|>\n"

            output = pipe(
                formatted_prompt,
                max_length=request.max_length,
                temperature=request.temperature,
                do_sample=True
            )[0]['generated_text']

            # Extract only the newly generated assistant response
            response_text = output.split("<|assistant|>\n")[-1].split("</s>")[0]
        
            return GenerateResponse(generated_text=response_text)
        #try:
            #output = pipe(request.prompt)[0]['generated_text']
            #return GenerateResponse(generated_text=output)
        #except Exception as e:
            #raise HTTPException(status_code=500, detail=str(e))

except Exception as e:
    print(f"Error: {e}")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)