File size: 2,032 Bytes
f846adc
 
 
ea153db
f846adc
 
 
0370e8f
 
 
 
 
 
 
f846adc
 
 
 
 
 
 
 
 
 
 
0370e8f
f846adc
 
 
 
0370e8f
f846adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ceda346
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from dotenv import load_dotenv
from fastapi import FastAPI, Query
import os

load_dotenv()
if torch.cuda.is_available():
    if hasattr(torch.backends.cuda, "enable_mem_efficient_sdp"):
        torch.backends.cuda.enable_mem_efficient_sdp(False)
    if hasattr(torch.backends.cuda, "enable_flash_sdp"):
        torch.backends.cuda.enable_flash_sdp(False)
    if hasattr(torch.backends.cuda, "enable_math_sdp"):
        torch.backends.cuda.enable_math_sdp(True)
# Définir vos tokens
HUGGINGFACE_TOKEN =os.getenv("HUGGINGFACE_TOKEN")
MODEL_NAME = "Qwen/Qwen1.5-1.8B-Chat"

# Initialiser l'app FastAPI
app = FastAPI()

# Charger le tokenizer et le modèle
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    token=HUGGINGFACE_TOKEN
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    token=HUGGINGFACE_TOKEN,
    device_map="auto",
    torch_dtype=torch.float16
).eval()

@app.get("/")
def home():
    return {"message": "API Qwen avec prompt + system"}

@app.get("/generate")
def generate(
    prompt: str = Query(..., description="Message de l'utilisateur"),
    system: str = Query(..., description="Instruction système obligatoire")
):
    try:
        full_prompt = f"<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
        inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=700,
                do_sample=True
            )
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        final_answer = response.split("<|im_start|>assistant\n")[-1].strip()
        return {"response": final_answer}
    except Exception as e:
        return {"error": str(e)}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)