Nac31 commited on
Commit
d3dcfb8
·
1 Parent(s): 498dadd

Update app

Browse files
Files changed (1) hide show
  1. app.py +37 -14
app.py CHANGED
@@ -5,7 +5,11 @@ import os
5
  from dotenv import load_dotenv
6
  from huggingface_hub import login
7
  from transformers import BitsAndBytesConfig
8
- from functools import lru_cache
 
 
 
 
9
 
10
  load_dotenv()
11
 
@@ -16,38 +20,56 @@ login(hf_token)
16
  # Configuration du modèle
17
  model_path = "mistralai/Mistral-Large-Instruct-2411"
18
 
 
 
 
 
19
  # Configuration de la quantification 4-bits
 
20
  quantization_config = BitsAndBytesConfig(
21
  load_in_4bit=True,
22
- bnb_4bit_compute_dtype=torch.float16,
23
  bnb_4bit_quant_type="nf4",
24
  bnb_4bit_use_double_quant=True
25
  )
26
 
27
- # Initialisation du modèle avec quantification
 
28
  tokenizer = AutoTokenizer.from_pretrained(model_path)
 
 
 
29
  model = AutoModelForCausalLM.from_pretrained(
30
  model_path,
31
  device_map="auto",
32
  quantization_config=quantization_config
33
  )
 
 
 
34
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
35
 
36
  def generate_response(message, temperature=0.7, max_new_tokens=256):
37
  try:
38
- response = pipe(
39
- message,
40
- temperature=temperature,
41
- max_new_tokens=max_new_tokens,
42
- do_sample=True,
43
- top_k=50,
44
- top_p=0.9,
45
- pad_token_id=tokenizer.pad_token_id,
46
- eos_token_id=tokenizer.eos_token_id,
47
- batch_size=1
48
- )
 
 
 
 
49
  return response[0]['generated_text']
50
  except Exception as e:
 
51
  return f"Une erreur s'est produite : {str(e)}"
52
 
53
  # Interface Gradio
@@ -64,4 +86,5 @@ demo = gr.Interface(
64
  )
65
 
66
  if __name__ == "__main__":
 
67
  demo.launch(share=True)
 
5
  from dotenv import load_dotenv
6
  from huggingface_hub import login
7
  from transformers import BitsAndBytesConfig
8
+ import logging
9
+
10
+ # Configuration du logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
 
14
  load_dotenv()
15
 
 
20
  # Configuration du modèle
21
  model_path = "mistralai/Mistral-Large-Instruct-2411"
22
 
23
+ # Détermination automatique du dtype optimal
24
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
25
+ logger.info(f"Using dtype: {dtype}")
26
+
27
  # Configuration de la quantification 4-bits
28
+ logger.info("Configuring 4-bit quantization")
29
  quantization_config = BitsAndBytesConfig(
30
  load_in_4bit=True,
31
+ bnb_4bit_compute_dtype=dtype, # Utilisation du dtype optimal
32
  bnb_4bit_quant_type="nf4",
33
  bnb_4bit_use_double_quant=True
34
  )
35
 
36
+ # Initialisation du modèle
37
+ logger.info(f"Loading tokenizer from {model_path}")
38
  tokenizer = AutoTokenizer.from_pretrained(model_path)
39
+ logger.info("Tokenizer loaded successfully")
40
+
41
+ logger.info(f"Loading model from {model_path} with 4-bit quantization")
42
  model = AutoModelForCausalLM.from_pretrained(
43
  model_path,
44
  device_map="auto",
45
  quantization_config=quantization_config
46
  )
47
+ logger.info("Model loaded successfully")
48
+
49
+ logger.info("Creating inference pipeline")
50
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
51
+ logger.info("Inference pipeline created successfully")
52
 
53
  def generate_response(message, temperature=0.7, max_new_tokens=256):
54
  try:
55
+ logger.info(f"Generating response for message: {message[:50]}...")
56
+ parameters = {
57
+ "temperature": temperature,
58
+ "max_new_tokens": max_new_tokens,
59
+ # "do_sample": True,
60
+ # "top_k": 50,
61
+ # "top_p": 0.9,
62
+ # "pad_token_id": tokenizer.pad_token_id,
63
+ # "eos_token_id": tokenizer.eos_token_id,
64
+ # "batch_size": 1
65
+ }
66
+ logger.info(f"Parameters: {parameters}")
67
+
68
+ response = pipe(message, **parameters)
69
+ logger.info("Response generated successfully")
70
  return response[0]['generated_text']
71
  except Exception as e:
72
+ logger.error(f"Error during generation: {str(e)}")
73
  return f"Une erreur s'est produite : {str(e)}"
74
 
75
  # Interface Gradio
 
86
  )
87
 
88
  if __name__ == "__main__":
89
+ logger.info("Starting Gradio interface")
90
  demo.launch(share=True)