fatmata commited on
Commit
61e08a1
·
verified ·
1 Parent(s): 42dbf5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -17
app.py CHANGED
@@ -1,22 +1,34 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from transformers import pipeline
4
 
5
- app = FastAPI()
6
 
7
- # Charger le modèle Hugging Face
8
- model_name = "fatmata/psybot"
9
- generator = pipeline("text-generation", model=model_name)
 
10
 
11
- # Définition du format des requêtes
12
- class TextRequest(BaseModel):
13
- prompt: str
 
14
 
15
- @app.get("/")
16
- def home():
17
- return {"message": "Bienvenue sur l'API PsyBot !"}
18
 
19
- @app.post("/generate/")
20
- def generate_text(request: TextRequest):
21
- response = generator(request.prompt, max_length=100, do_sample=True)
22
- return {"response": response[0]["generated_text"]}
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
+ app = Flask(__name__)
6
 
7
+ # Charger le modèle depuis Hugging Face
8
+ MODEL_NAME = "fatmata/psybot" # Remplace avec le vrai nom de ton modèle
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
11
 
12
+ @app.route("/chat", methods=["POST"])
13
+ def chat():
14
+ data = request.json
15
+ user_input = data.get("message", "")
16
 
17
+ if not user_input:
18
+ return jsonify({"error": "Message vide"}), 400
 
19
 
20
+ # Génération de la réponse
21
+ prompt = f"<|startoftext|><|user|> {user_input} <|bot|>"
22
+ inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
23
+
24
+ with torch.no_grad():
25
+ output = model.generate(inputs, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id)
26
+
27
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
28
+ if "<|bot|>" in response:
29
+ response = response.split("<|bot|>")[-1].strip()
30
+
31
+ return jsonify({"response": response})
32
+
33
+ if __name__ == "__main__":
34
+ app.run(host="0.0.0.0", port=7860)