maiarhSLM / app.py
Brave1's picture
Update app.py
ea153db verified
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from dotenv import load_dotenv
from fastapi import FastAPI, Query
import os
load_dotenv()
if torch.cuda.is_available():
if hasattr(torch.backends.cuda, "enable_mem_efficient_sdp"):
torch.backends.cuda.enable_mem_efficient_sdp(False)
if hasattr(torch.backends.cuda, "enable_flash_sdp"):
torch.backends.cuda.enable_flash_sdp(False)
if hasattr(torch.backends.cuda, "enable_math_sdp"):
torch.backends.cuda.enable_math_sdp(True)
# Définir vos tokens
HUGGINGFACE_TOKEN =os.getenv("HUGGINGFACE_TOKEN")
MODEL_NAME = "Qwen/Qwen1.5-1.8B-Chat"
# Initialiser l'app FastAPI
app = FastAPI()
# Charger le tokenizer et le modèle
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
token=HUGGINGFACE_TOKEN
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
token=HUGGINGFACE_TOKEN,
device_map="auto",
torch_dtype=torch.float16
).eval()
@app.get("/")
def home():
return {"message": "API Qwen avec prompt + system"}
@app.get("/generate")
def generate(
prompt: str = Query(..., description="Message de l'utilisateur"),
system: str = Query(..., description="Instruction système obligatoire")
):
try:
full_prompt = f"<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=700,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
final_answer = response.split("<|im_start|>assistant\n")[-1].strip()
return {"response": final_answer}
except Exception as e:
return {"error": str(e)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)