petertill commited on
Commit
3e90d44
·
verified ·
1 Parent(s): 14cdd4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -8
app.py CHANGED
@@ -13,24 +13,56 @@ try:
13
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
14
  print("Model and tokenizer loaded successfully!")
15
 
 
 
 
 
16
  class GenerateRequest(BaseModel):
17
- prompt: str
 
18
  key: str
 
 
19
 
20
  class GenerateResponse(BaseModel):
21
  generated_text: str
22
 
23
  @app.post("/generate", response_model=GenerateResponse)
24
  async def generate(request: GenerateRequest):
25
- authorization = request.key
26
- #token = authorization.split('Bearer ')[1]
27
- if authorization != API_KEY:
28
  raise HTTPException(status_code=401, detail="Unauthorized")
 
29
  try:
30
- output = pipe(request.prompt)[0]['generated_text']
31
- return GenerateResponse(generated_text=output)
32
- except Exception as e:
33
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  except Exception as e:
36
  print(f"Error: {e}")
 
13
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
14
  print("Model and tokenizer loaded successfully!")
15
 
16
+ class Message(BaseModel):
17
+ role: str # "system", "user", or "assistant"
18
+ content: str
19
+
20
  class GenerateRequest(BaseModel):
21
+ system_prompt : str
22
+ messages: list[Message]
23
  key: str
24
+ max_length: int = 1024
25
+ temperature: float = 0.7
26
 
27
  class GenerateResponse(BaseModel):
28
  generated_text: str
29
 
30
  @app.post("/generate", response_model=GenerateResponse)
31
  async def generate(request: GenerateRequest):
32
+ if request.key != API_KEY:
 
 
33
  raise HTTPException(status_code=401, detail="Unauthorized")
34
+
35
  try:
36
+ # Format messages into a prompt format the model expects
37
+ formatted_prompt = ""
38
+ formatted_prompt += f"<|system|>\n{request.system_prompt}</s>\n"
39
+ for message in request.messages:
40
+ if message.role == "system":
41
+ formatted_prompt += f"<|system|>\n{message.content}</s>\n"
42
+ elif message.role == "user":
43
+ formatted_prompt += f"<|user|>\n{message.content}</s>\n"
44
+ elif message.role == "assistant":
45
+ formatted_prompt += f"<|assistant|>\n{message.content}</s>\n"
46
+
47
+ # Add final assistant prefix for generation
48
+ formatted_prompt += "<|assistant|>\n"
49
+
50
+ output = pipe(
51
+ formatted_prompt,
52
+ max_length=request.max_length,
53
+ temperature=request.temperature,
54
+ do_sample=True
55
+ )[0]['generated_text']
56
+
57
+ # Extract only the newly generated assistant response
58
+ response_text = output.split("<|assistant|>\n")[-1].split("</s>")[0]
59
+
60
+ return GenerateResponse(generated_text=response_text)
61
+ #try:
62
+ #output = pipe(request.prompt)[0]['generated_text']
63
+ #return GenerateResponse(generated_text=output)
64
+ #except Exception as e:
65
+ #raise HTTPException(status_code=500, detail=str(e))
66
 
67
  except Exception as e:
68
  print(f"Error: {e}")