fatmata commited on
Commit
7090deb
·
verified ·
1 Parent(s): 877b12a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -0
app.py CHANGED
@@ -1,7 +1,25 @@
1
  from fastapi import FastAPI
 
 
2
 
3
  app = FastAPI()
4
 
 
 
 
 
 
 
 
5
  @app.get("/")
6
  def read_root():
7
  return {"message": "Hello from PsyBot API!"}
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
  app = FastAPI()
6
 
7
+ # Définir le chemin correct vers ton modèle
8
+ MODEL_PATH = "fatmata/psyboy/psybot_model" # Remplace par le bon chemin
9
+
10
+ # Charger le modèle et le tokenizer
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
12
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
13
+
14
  @app.get("/")
15
  def read_root():
16
  return {"message": "Hello from PsyBot API!"}
17
+
18
+ @app.post("/generate/")
19
+ def generate_response(prompt: str):
20
+ inputs = tokenizer(prompt, return_tensors="pt")
21
+ with torch.no_grad():
22
+ output = model.generate(**inputs, max_length=150)
23
+
24
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
25
+ return {"response": response}