fatmata commited on
Commit
6f871b6
·
verified ·
1 Parent(s): 61e08a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -30
app.py CHANGED
@@ -1,34 +1,16 @@
1
- from flask import Flask, request, jsonify
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import torch
4
 
5
- app = Flask(__name__)
6
 
7
- # Charger le modèle depuis Hugging Face
8
- MODEL_NAME = "fatmata/psybot" # Remplace avec le vrai nom de ton modèle
9
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
11
 
12
- @app.route("/chat", methods=["POST"])
13
- def chat():
14
- data = request.json
15
- user_input = data.get("message", "")
16
 
17
- if not user_input:
18
- return jsonify({"error": "Message vide"}), 400
19
-
20
- # Génération de la réponse
21
- prompt = f"<|startoftext|><|user|> {user_input} <|bot|>"
22
- inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
23
-
24
- with torch.no_grad():
25
- output = model.generate(inputs, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id)
26
-
27
- response = tokenizer.decode(output[0], skip_special_tokens=True)
28
- if "<|bot|>" in response:
29
- response = response.split("<|bot|>")[-1].strip()
30
-
31
- return jsonify({"response": response})
32
-
33
- if __name__ == "__main__":
34
- app.run(host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import pipeline
4
 
5
+ app = FastAPI()
6
 
7
+ # Charger le modèle hébergé sur Hugging Face
8
+ chatbot = pipeline("text-generation", model="fatmata/psybot") # Remplace par ton modèle
 
 
9
 
10
+ class UserInput(BaseModel):
11
+ text: str
 
 
12
 
13
+ @app.post("/chatbot/")
14
+ async def generate_response(user_input: UserInput):
15
+ response = chatbot(user_input.text, max_length=100, do_sample=True)[0]["generated_text"]
16
+ return {"response": response}