Update app
Browse files
app.py
CHANGED
@@ -5,7 +5,11 @@ import os
|
|
5 |
from dotenv import load_dotenv
|
6 |
from huggingface_hub import login
|
7 |
from transformers import BitsAndBytesConfig
|
8 |
-
|
|
|
|
|
|
|
|
|
9 |
|
10 |
load_dotenv()
|
11 |
|
@@ -16,38 +20,56 @@ login(hf_token)
|
|
16 |
# Configuration du modèle
|
17 |
model_path = "mistralai/Mistral-Large-Instruct-2411"
|
18 |
|
|
|
|
|
|
|
|
|
19 |
# Configuration de la quantification 4-bits
|
|
|
20 |
quantization_config = BitsAndBytesConfig(
|
21 |
load_in_4bit=True,
|
22 |
-
bnb_4bit_compute_dtype=
|
23 |
bnb_4bit_quant_type="nf4",
|
24 |
bnb_4bit_use_double_quant=True
|
25 |
)
|
26 |
|
27 |
-
# Initialisation du modèle
|
|
|
28 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
|
|
|
|
|
29 |
model = AutoModelForCausalLM.from_pretrained(
|
30 |
model_path,
|
31 |
device_map="auto",
|
32 |
quantization_config=quantization_config
|
33 |
)
|
|
|
|
|
|
|
34 |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
|
|
35 |
|
36 |
def generate_response(message, temperature=0.7, max_new_tokens=256):
|
37 |
try:
|
38 |
-
response
|
39 |
-
|
40 |
-
temperature
|
41 |
-
max_new_tokens
|
42 |
-
do_sample
|
43 |
-
top_k
|
44 |
-
top_p
|
45 |
-
pad_token_id
|
46 |
-
eos_token_id
|
47 |
-
batch_size
|
48 |
-
|
|
|
|
|
|
|
|
|
49 |
return response[0]['generated_text']
|
50 |
except Exception as e:
|
|
|
51 |
return f"Une erreur s'est produite : {str(e)}"
|
52 |
|
53 |
# Interface Gradio
|
@@ -64,4 +86,5 @@ demo = gr.Interface(
|
|
64 |
)
|
65 |
|
66 |
if __name__ == "__main__":
|
|
|
67 |
demo.launch(share=True)
|
|
|
5 |
from dotenv import load_dotenv
|
6 |
from huggingface_hub import login
|
7 |
from transformers import BitsAndBytesConfig
|
8 |
+
import logging
|
9 |
+
|
10 |
+
# Configuration du logging
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
|
14 |
load_dotenv()
|
15 |
|
|
|
20 |
# Configuration du modèle
|
21 |
model_path = "mistralai/Mistral-Large-Instruct-2411"
|
22 |
|
23 |
+
# Détermination automatique du dtype optimal
|
24 |
+
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
|
25 |
+
logger.info(f"Using dtype: {dtype}")
|
26 |
+
|
27 |
# Configuration de la quantification 4-bits
|
28 |
+
logger.info("Configuring 4-bit quantization")
|
29 |
quantization_config = BitsAndBytesConfig(
|
30 |
load_in_4bit=True,
|
31 |
+
bnb_4bit_compute_dtype=dtype, # Utilisation du dtype optimal
|
32 |
bnb_4bit_quant_type="nf4",
|
33 |
bnb_4bit_use_double_quant=True
|
34 |
)
|
35 |
|
36 |
+
# Initialisation du modèle
|
37 |
+
logger.info(f"Loading tokenizer from {model_path}")
|
38 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
39 |
+
logger.info("Tokenizer loaded successfully")
|
40 |
+
|
41 |
+
logger.info(f"Loading model from {model_path} with 4-bit quantization")
|
42 |
model = AutoModelForCausalLM.from_pretrained(
|
43 |
model_path,
|
44 |
device_map="auto",
|
45 |
quantization_config=quantization_config
|
46 |
)
|
47 |
+
logger.info("Model loaded successfully")
|
48 |
+
|
49 |
+
logger.info("Creating inference pipeline")
|
50 |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
51 |
+
logger.info("Inference pipeline created successfully")
|
52 |
|
53 |
def generate_response(message, temperature=0.7, max_new_tokens=256):
|
54 |
try:
|
55 |
+
logger.info(f"Generating response for message: {message[:50]}...")
|
56 |
+
parameters = {
|
57 |
+
"temperature": temperature,
|
58 |
+
"max_new_tokens": max_new_tokens,
|
59 |
+
# "do_sample": True,
|
60 |
+
# "top_k": 50,
|
61 |
+
# "top_p": 0.9,
|
62 |
+
# "pad_token_id": tokenizer.pad_token_id,
|
63 |
+
# "eos_token_id": tokenizer.eos_token_id,
|
64 |
+
# "batch_size": 1
|
65 |
+
}
|
66 |
+
logger.info(f"Parameters: {parameters}")
|
67 |
+
|
68 |
+
response = pipe(message, **parameters)
|
69 |
+
logger.info("Response generated successfully")
|
70 |
return response[0]['generated_text']
|
71 |
except Exception as e:
|
72 |
+
logger.error(f"Error during generation: {str(e)}")
|
73 |
return f"Une erreur s'est produite : {str(e)}"
|
74 |
|
75 |
# Interface Gradio
|
|
|
86 |
)
|
87 |
|
88 |
if __name__ == "__main__":
|
89 |
+
logger.info("Starting Gradio interface")
|
90 |
demo.launch(share=True)
|