Hjgugugjhuhjggg commited on
Commit
fcc4055
·
verified ·
1 Parent(s): 399f6a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -178
app.py CHANGED
@@ -2,23 +2,18 @@ import os
2
  from fastapi import FastAPI, HTTPException
3
  from fastapi.responses import StreamingResponse
4
  from pydantic import BaseModel, field_validator
5
- from transformers import (
6
- AutoConfig,
7
- GenerationConfig,
8
- AutoTokenizer,
9
- AutoModelForCausalLM,
10
- )
11
  from google.cloud import storage
12
  from google.auth.exceptions import DefaultCredentialsError
13
  import uvicorn
14
  import asyncio
15
  import json
16
- import logging
17
  from huggingface_hub import login
18
  from dotenv import load_dotenv
19
  import huggingface_hub
20
- import torch
21
- from safetensors.torch import load_file as safe_load
22
 
23
  load_dotenv()
24
 
@@ -33,8 +28,8 @@ os.system("git config --global credential.helper store")
33
  if HUGGINGFACE_HUB_TOKEN:
34
  huggingface_hub.login(token=HUGGINGFACE_HUB_TOKEN, add_to_git_credential=True)
35
 
36
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
37
- logger = logging.getLogger(__name__)
38
 
39
  try:
40
  credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
@@ -53,8 +48,7 @@ class GenerateRequest(BaseModel):
53
  input_text: str
54
  task_type: str
55
  temperature: float = 1.0
56
- max_new_tokens: int = 20
57
- stream: bool = False
58
  top_p: float = 1.0
59
  top_k: int = 50
60
  repetition_penalty: float = 1.0
@@ -71,7 +65,7 @@ class GenerateRequest(BaseModel):
71
 
72
  @field_validator("task_type")
73
  def task_type_must_be_valid(cls, v):
74
- valid_types = ["text-to-text"]
75
  if v not in valid_types:
76
  raise ValueError(f"task_type must be one of: {valid_types}")
77
  return v
@@ -87,18 +81,6 @@ class GCSModelLoader:
87
  blob = self.bucket.blob(blob_path)
88
  return blob.exists()
89
 
90
- def _download_content(self, blob_path):
91
- blob = self.bucket.blob(blob_path)
92
- try:
93
- return blob.download_as_bytes()
94
- except Exception as e:
95
- logger.error(f"Error downloading {blob_path}: {e}")
96
- return None
97
-
98
- def _upload_content(self, content, blob_path):
99
- blob = self.bucket.blob(blob_path)
100
- blob.upload_from_string(content)
101
-
102
  def _create_model_folder(self, model_name):
103
  gcs_model_folder = self._get_gcs_uri(model_name)
104
  if not self._blob_exists(f"{gcs_model_folder}/.touch"):
@@ -106,133 +88,56 @@ class GCSModelLoader:
106
  blob.upload_from_string("")
107
  logger.info(f"Created folder '{gcs_model_folder}' in GCS.")
108
 
109
- def load_config(self, model_name):
110
- gcs_config_path = f"{self._get_gcs_uri(model_name)}/config.json"
111
- if self._blob_exists(gcs_config_path):
112
- try:
113
- config_content = self._download_content(gcs_config_path)
114
- return AutoConfig.from_pretrained(pretrained_model_name_or_path="", _commit_hash=None, config_dict=json.loads(config_content), trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
115
- except Exception as e:
116
- logger.error(f"Error loading config from GCS: {e}")
117
  try:
118
- logger.info(f"Downloading config from Hugging Face for {model_name}")
119
- config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
120
  gcs_model_folder = self._get_gcs_uri(model_name)
121
  self._create_model_folder(model_name)
122
- self._upload_content(json.dumps(config.to_dict()).encode('utf-8'), f"{gcs_model_folder}/config.json")
123
- return config
 
 
 
 
 
 
124
  except Exception as e:
125
- logger.error(f"Error loading config from Hugging Face: {e}")
126
- return None
127
 
128
- def load_tokenizer(self, model_name):
129
- gcs_tokenizer_path = self._get_gcs_uri(model_name)
130
- tokenizer_files = ["tokenizer_config.json", "vocab.json", "merges.txt", "tokenizer.json", "special_tokens_map.json"]
131
- gcs_files_exist = all(self._blob_exists(f"{gcs_tokenizer_path}/{f}") for f in tokenizer_files)
132
-
133
- if gcs_files_exist:
134
- try:
135
- return AutoTokenizer.from_pretrained(gcs_tokenizer_path, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
136
- except Exception as e:
137
- logger.error(f"Error loading tokenizer from GCS: {e}")
138
- return None
139
- else:
140
- try:
141
- logger.info(f"Downloading tokenizer from Hugging Face for {model_name}")
142
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
143
- gcs_model_folder = self._get_gcs_uri(model_name)
144
- self._create_model_folder(model_name)
145
- tokenizer.save_pretrained(gcs_model_folder)
146
- return tokenizer
147
- except Exception as e:
148
- logger.error(f"Error loading tokenizer from Hugging Face: {e}")
149
- return None
150
-
151
- def load_model(self, model_name, config):
152
- gcs_model_path = self._get_gcs_uri(model_name)
153
- logger.info(f"Attempting to load model '{model_name}' from GCS.")
154
- blobs = self.bucket.list_blobs(prefix=gcs_model_path)
155
- weight_files = [blob.name for blob in blobs if blob.name.endswith((".bin", ".safetensors"))]
156
 
157
- if not weight_files:
158
- logger.info(f"No weight files found in GCS for '{model_name}'. Downloading from Hugging Face.")
159
- try:
160
- model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
161
- gcs_model_folder = self._get_gcs_uri(model_name)
162
- self._create_model_folder(model_name)
163
- for filename in os.listdir(model.config.name_or_path):
164
- if filename.endswith((".bin", ".safetensors")):
165
- blob = self.bucket.blob(f"{gcs_model_folder}/{filename}")
166
- blob.upload_from_filename(os.path.join(model.config.name_or_path, filename))
167
- logger.info(f"Model '{model_name}' downloaded from Hugging Face and saved to GCS.")
168
- return model
169
- except Exception as e:
170
- logger.error(f"Error downloading model from Hugging Face: {e}")
171
- raise HTTPException(status_code=500, detail=f"Error downloading model from Hugging Face: {e}")
172
-
173
- logger.info(f"Found weight files in GCS for '{model_name}': {weight_files}")
174
-
175
- loaded_state_dict = {}
176
- error_occurred = False
177
- for weight_file in weight_files:
178
- logger.info(f"Streaming weight file from GCS: {weight_file}")
179
- blob = self.bucket.blob(weight_file)
180
- try:
181
- blob_content = blob.download_as_bytes()
182
- if weight_file.endswith(".safetensors"):
183
- loaded_state_dict.update(safe_load(blob_content))
184
- else:
185
- loaded_state_dict.update(torch.load(io.BytesIO(blob_content), map_location="cpu"))
186
- except Exception as e:
187
- logger.error(f"Error streaming and loading weights from GCS {weight_file}: {e}")
188
- error_occurred = True
189
- break
190
-
191
- if error_occurred:
192
- logger.info(f"Attempting to reload model '{model_name}' from Hugging Face due to loading error.")
193
- try:
194
- model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
195
- gcs_model_folder = self._get_gcs_uri(model_name)
196
- self._create_model_folder(model_name)
197
- for filename in os.listdir(model.config.name_or_path):
198
- if filename.endswith((".bin", ".safetensors")):
199
- upload_blob = self.bucket.blob(f"{gcs_model_folder}/{filename}")
200
- upload_blob.upload_from_filename(os.path.join(model.config.name_or_path, filename))
201
- logger.info(f"Model '{model_name}' reloaded from Hugging Face and saved to GCS.")
202
- return model
203
- except Exception as redownload_error:
204
- logger.error(f"Error redownloading model from Hugging Face: {redownload_error}")
205
- raise HTTPException(status_code=500, detail=f"Failed to load or redownload model: {redownload_error}")
206
 
207
- try:
208
- model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
209
- model.load_state_dict(loaded_state_dict, strict=False)
210
- logger.info(f"Model '{model_name}' successfully loaded from GCS.")
211
- return model
212
- except Exception as e:
213
- logger.error(f"Error loading state dict: {e}")
214
- raise HTTPException(status_code=500, detail=f"Error loading state dict: {e}")
215
 
216
- model_loader = GCSModelLoader(bucket)
 
217
 
218
- async def generate_stream(model, tokenizer, input_text, generation_config):
219
- inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
220
- async for output in model.generate(**inputs, generation_config=generation_config, stream=True, return_dict_in_generate=True):
221
- token_id = output.sequences[0][-1]
222
- token = tokenizer.decode(token_id, skip_special_tokens=True)
223
- yield {"token": token}
224
 
225
- def generate_non_stream(model, tokenizer, input_text, generation_config):
226
- inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
227
- outputs = model.generate(**inputs, generation_config=generation_config)
228
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
229
 
230
  @app.post("/generate")
231
  async def generate(request: GenerateRequest):
232
  model_name = request.model_name
233
  input_text = request.input_text
234
  task_type = request.task_type
235
- stream = request.stream
236
 
237
  generation_params = request.model_dump(
238
  exclude_none=True,
@@ -240,47 +145,41 @@ async def generate(request: GenerateRequest):
240
  )
241
 
242
  try:
243
- config = model_loader.load_config(model_name)
244
- if not config:
245
- raise HTTPException(status_code=400, detail="Model configuration could not be loaded.")
246
-
247
- tokenizer = model_loader.load_tokenizer(model_name)
248
- if not tokenizer:
249
- raise HTTPException(status_code=400, detail="Tokenizer could not be loaded.")
250
-
251
- model = model_loader.load_model(model_name, config)
252
- if not model:
253
- raise HTTPException(status_code=400, detail="Model could not be loaded.")
254
-
255
- generation_config_kwargs = {k: v for k, v in generation_params.items() if k in GenerationConfig.__init__.__code__.co_varnames}
256
- generation_config_kwargs.setdefault('pad_token_id', tokenizer.pad_token_id)
257
- generation_config_kwargs.setdefault('eos_token_id', tokenizer.eos_token_id)
258
- if hasattr(tokenizer, 'sep_token_id') and tokenizer.sep_token_id is not None:
259
- generation_config_kwargs.setdefault('sep_token_id', tokenizer.sep_token_id)
260
- if hasattr(tokenizer, 'unk_token_id') and tokenizer.unk_token_id is not None:
261
- generation_config_kwargs.setdefault('unk_token_id', tokenizer.unk_token_id)
262
-
263
- generation_config = GenerationConfig.from_pretrained(
264
- model_name,
265
- trust_remote_code=True,
266
- token=HUGGINGFACE_HUB_TOKEN,
267
- **generation_config_kwargs
268
- )
269
-
270
- model.eval()
271
-
272
- if task_type == "text-to-text":
273
- if stream:
274
- async def token_streamer():
275
- async for item in generate_stream(model, tokenizer, input_text, generation_config):
276
- yield f"data: {json.dumps(item)}\n\n"
277
- await asyncio.sleep(request.chunk_delay)
278
- return StreamingResponse(token_streamer(), media_type="text/event-stream")
279
- else:
280
- text_result = generate_non_stream(model, tokenizer, input_text, generation_config)
281
- return {"text": text_result}
282
- else:
283
- raise HTTPException(status_code=400, detail=f"Task type not supported: {task_type}")
284
 
285
  except HTTPException as e:
286
  raise e
 
2
  from fastapi import FastAPI, HTTPException
3
  from fastapi.responses import StreamingResponse
4
  from pydantic import BaseModel, field_validator
5
+ from transformers import pipeline, AutoConfig, AutoTokenizer
6
+ from transformers.utils import logging
 
 
 
 
7
  from google.cloud import storage
8
  from google.auth.exceptions import DefaultCredentialsError
9
  import uvicorn
10
  import asyncio
11
  import json
 
12
  from huggingface_hub import login
13
  from dotenv import load_dotenv
14
  import huggingface_hub
15
+ from threading import Thread
16
+ from typing import AsyncIterator
17
 
18
  load_dotenv()
19
 
 
28
  if HUGGINGFACE_HUB_TOKEN:
29
  huggingface_hub.login(token=HUGGINGFACE_HUB_TOKEN, add_to_git_credential=True)
30
 
31
+ logging.set_verbosity_info()
32
+ logger = logging.get_logger(__name__)
33
 
34
  try:
35
  credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
 
48
  input_text: str
49
  task_type: str
50
  temperature: float = 1.0
51
+ stream: bool = True # Enforce stream for this functionality
 
52
  top_p: float = 1.0
53
  top_k: int = 50
54
  repetition_penalty: float = 1.0
 
65
 
66
  @field_validator("task_type")
67
  def task_type_must_be_valid(cls, v):
68
+ valid_types = ["text-generation"]
69
  if v not in valid_types:
70
  raise ValueError(f"task_type must be one of: {valid_types}")
71
  return v
 
81
  blob = self.bucket.blob(blob_path)
82
  return blob.exists()
83
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def _create_model_folder(self, model_name):
85
  gcs_model_folder = self._get_gcs_uri(model_name)
86
  if not self._blob_exists(f"{gcs_model_folder}/.touch"):
 
88
  blob.upload_from_string("")
89
  logger.info(f"Created folder '{gcs_model_folder}' in GCS.")
90
 
91
+ def check_model_exists_locally(self, model_name):
92
+ gcs_model_path = self._get_gcs_uri(model_name)
93
+ blobs = self.bucket.list_blobs(prefix=gcs_model_path)
94
+ return any(blobs)
95
+
96
+ def download_model_from_huggingface(self, model_name):
97
+ logger.info(f"Downloading model '{model_name}' from Hugging Face.")
 
98
  try:
99
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
100
+ config = AutoConfig.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
101
  gcs_model_folder = self._get_gcs_uri(model_name)
102
  self._create_model_folder(model_name)
103
+ tokenizer.save_pretrained(gcs_model_folder)
104
+ config.save_pretrained(gcs_model_folder)
105
+ for filename in os.listdir(config.name_or_path):
106
+ if filename.endswith((".bin", ".safetensors")):
107
+ blob = self.bucket.blob(f"{gcs_model_folder}/{filename}")
108
+ blob.upload_from_filename(os.path.join(config.name_or_path, filename))
109
+ logger.info(f"Model '{model_name}' downloaded and saved to GCS.")
110
+ return True
111
  except Exception as e:
112
+ logger.error(f"Error downloading model from Hugging Face: {e}")
113
+ return False
114
 
115
+ model_loader = GCSModelLoader(bucket)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ class TokenIteratorStreamer:
118
+ def __init__(self):
119
+ self.queue = asyncio.Queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
+ def put(self, value):
122
+ self.queue.put_nowait(value)
 
 
 
 
 
 
123
 
124
+ def end(self):
125
+ self.queue.put_nowait(None)
126
 
127
+ async def __aiter__(self):
128
+ return self
 
 
 
 
129
 
130
+ async def __anext__(self):
131
+ value = await self.queue.get()
132
+ if value is None:
133
+ raise StopAsyncIteration
134
+ return value
135
 
136
  @app.post("/generate")
137
  async def generate(request: GenerateRequest):
138
  model_name = request.model_name
139
  input_text = request.input_text
140
  task_type = request.task_type
 
141
 
142
  generation_params = request.model_dump(
143
  exclude_none=True,
 
145
  )
146
 
147
  try:
148
+ if not model_loader.check_model_exists_locally(model_name):
149
+ if not model_loader.download_model_from_huggingface(model_name):
150
+ raise HTTPException(status_code=500, detail=f"Failed to load model: {model_name}")
151
+
152
+ pipe = pipeline(task_type, model=model_name, token=HUGGINGFACE_HUB_TOKEN, device_map="auto")
153
+ token_streamer = TokenIteratorStreamer()
154
+
155
+ def generate_on_thread(pipe, input_text, token_streamer, generation_params):
156
+ try:
157
+ for output in pipe(input_text,
158
+ max_new_tokens=int(1e9), # Effectively infinite
159
+ return_full_text=False,
160
+ streamer=token_streamer,
161
+ **generation_params):
162
+ pass
163
+ finally:
164
+ token_streamer.end()
165
+
166
+ thread = Thread(target=generate_on_thread, args=(pipe, input_text, token_streamer, generation_params))
167
+ thread.start()
168
+
169
+ async def event_stream() -> AsyncIterator[str]:
170
+ chunk_size = 20
171
+ tokens_buffer = []
172
+ async for token in token_streamer:
173
+ tokens_buffer.append(token)
174
+ if len(tokens_buffer) >= chunk_size:
175
+ yield f"data: {json.dumps({'tokens': tokens_buffer})}\n\n"
176
+ tokens_buffer = []
177
+ await asyncio.sleep(request.chunk_delay)
178
+ if tokens_buffer:
179
+ yield f"data: {json.dumps({'tokens': tokens_buffer})}\n\n"
180
+ yield "\n\n" # Ensure final newline
181
+
182
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
 
 
 
 
 
 
183
 
184
  except HTTPException as e:
185
  raise e