Hjgugugjhuhjggg commited on
Commit
0154ba4
·
verified ·
1 Parent(s): 9f54418

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -28
app.py CHANGED
@@ -7,7 +7,6 @@ from transformers import (
7
  GenerationConfig,
8
  AutoTokenizer,
9
  AutoModelForCausalLM,
10
- TextIteratorStreamer
11
  )
12
  from google.cloud import storage
13
  from google.auth.exceptions import DefaultCredentialsError
@@ -80,7 +79,9 @@ class GCSModelLoader:
80
 
81
  def _download_content(self, blob_path):
82
  blob = self.bucket.blob(blob_path)
83
- return blob.download_as_bytes() if blob.exists() else None
 
 
84
 
85
  def _upload_content(self, content, blob_path):
86
  blob = self.bucket.blob(blob_path)
@@ -99,8 +100,7 @@ class GCSModelLoader:
99
  try:
100
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
101
  gcs_model_folder = self._get_gcs_uri(model_name)
102
- config_json = json.dumps(config.to_dict())
103
- self._upload_content(config_json.encode('utf-8'), f"{gcs_model_folder}/config.json")
104
  return config
105
  except Exception as e:
106
  logger.error(f"Error loading config from Hugging Face and saving to GCS: {e}")
@@ -121,9 +121,10 @@ class GCSModelLoader:
121
  try:
122
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
123
  gcs_model_folder = self._get_gcs_uri(model_name)
124
- for file in tokenizer.save_pretrained(None):
125
- with open(file, 'rb') as f:
126
- self._upload_content(f.read(), f"{gcs_model_folder}/{os.path.basename(file)}")
 
127
  return tokenizer
128
  except Exception as e:
129
  logger.error(f"Error loading tokenizer from Hugging Face and saving to GCS: {e}")
@@ -144,6 +145,7 @@ class GCSModelLoader:
144
  try:
145
  model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
146
  gcs_model_folder = self._get_gcs_uri(model_name)
 
147
  for filename in os.listdir(model.save_pretrained(None)):
148
  with open(filename, 'rb') as f:
149
  self._upload_content(f.read(), f"{gcs_model_folder}/{filename}")
@@ -154,17 +156,20 @@ class GCSModelLoader:
154
 
155
  model_loader = GCSModelLoader(bucket)
156
 
157
- async def generate_stream(model, tokenizer, input_text, generation_config, chunk_delay):
158
  inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
159
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
160
- generation_kwargs = dict(inputs, generation_config=generation_config, streamer=streamer)
161
- asyncio.create_task(model.generate(**generation_kwargs))
162
-
163
- async def event_stream():
164
- for token in streamer:
 
 
 
165
  yield {"token": token}
166
- await asyncio.sleep(chunk_delay)
167
- return event_stream()
168
 
169
  def generate_non_stream(model, tokenizer, input_text, generation_config):
170
  inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
@@ -178,19 +183,15 @@ async def generate(request: GenerateRequest):
178
  task_type = request.task_type
179
  stream = request.stream
180
 
181
- generation_params = {}
182
- for key, value in request.model_dump(exclude_none=True, exclude={'model_name', 'input_text', 'task_type', 'stream', 'chunk_delay'}).items():
183
- generation_params[key] = value
 
184
 
185
  try:
186
  gcs_model_folder_uri = model_loader._get_gcs_uri(model_name)
187
- if not bucket.blob(f"{gcs_model_folder_uri}/config.json").exists():
188
- logger.info(f"Model '{model_name}' not found in GCS, creating placeholder and downloading.")
189
- bucket.blob(f"{gcs_model_folder_uri}/.placeholder").upload_from_string("")
190
- for folder in gcs_model_folder_uri.split('/'):
191
- prefix = "/".join(gcs_model_folder_uri.split('/')[:gcs_model_folder_uri.split('/').index(folder)+1])
192
- if not any(bucket.list_blobs(prefix=f"{prefix}/config.json")):
193
- bucket.blob(f"{prefix}/.placeholder").upload_from_string("")
194
 
195
  config = model_loader.load_config(model_name)
196
  if not config:
@@ -222,12 +223,12 @@ async def generate(request: GenerateRequest):
222
 
223
  if task_type == "text-to-text":
224
  if stream:
225
- return StreamingResponse(generate_stream(model, tokenizer, input_text, generation_config, request.chunk_delay), media_type="text/event-stream")
226
  else:
227
  text_result = generate_non_stream(model, tokenizer, input_text, generation_config)
228
  return {"text": text_result}
229
  else:
230
- raise HTTPException(status_code=400, detail=f"Task type not supported: {task_type}")
231
 
232
  except HTTPException as e:
233
  raise e
 
7
  GenerationConfig,
8
  AutoTokenizer,
9
  AutoModelForCausalLM,
 
10
  )
11
  from google.cloud import storage
12
  from google.auth.exceptions import DefaultCredentialsError
 
79
 
80
  def _download_content(self, blob_path):
81
  blob = self.bucket.blob(blob_path)
82
+ if blob.exists():
83
+ return blob.download_as_bytes()
84
+ return None
85
 
86
  def _upload_content(self, content, blob_path):
87
  blob = self.bucket.blob(blob_path)
 
100
  try:
101
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
102
  gcs_model_folder = self._get_gcs_uri(model_name)
103
+ self._upload_content(json.dumps(config.to_dict()).encode('utf-8'), f"{gcs_model_folder}/config.json")
 
104
  return config
105
  except Exception as e:
106
  logger.error(f"Error loading config from Hugging Face and saving to GCS: {e}")
 
121
  try:
122
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
123
  gcs_model_folder = self._get_gcs_uri(model_name)
124
+ os.makedirs(gcs_model_folder, exist_ok=True)
125
+ for filename in os.listdir(tokenizer.save_pretrained(None)):
126
+ with open(filename, 'rb') as f:
127
+ self._upload_content(f.read(), f"{gcs_model_folder}/{filename}")
128
  return tokenizer
129
  except Exception as e:
130
  logger.error(f"Error loading tokenizer from Hugging Face and saving to GCS: {e}")
 
145
  try:
146
  model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
147
  gcs_model_folder = self._get_gcs_uri(model_name)
148
+ os.makedirs(gcs_model_folder, exist_ok=True)
149
  for filename in os.listdir(model.save_pretrained(None)):
150
  with open(filename, 'rb') as f:
151
  self._upload_content(f.read(), f"{gcs_model_folder}/{filename}")
 
156
 
157
  model_loader = GCSModelLoader(bucket)
158
 
159
+ async def generate_stream(model, tokenizer, input_text, generation_config):
160
  inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
161
+ generation_stream = model.generate(
162
+ **inputs,
163
+ generation_config=generation_config,
164
+ stream=True,
165
+ )
166
+ async def token_stream():
167
+ for output in generation_stream:
168
+ token_id = output[-1]
169
+ token = tokenizer.decode(token_id, skip_special_tokens=True)
170
  yield {"token": token}
171
+ await asyncio.sleep(0.001) # Adjust delay as needed
172
+ return token_stream()
173
 
174
  def generate_non_stream(model, tokenizer, input_text, generation_config):
175
  inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
 
183
  task_type = request.task_type
184
  stream = request.stream
185
 
186
+ generation_params = request.model_dump(
187
+ exclude_none=True,
188
+ exclude={'model_name', 'input_text', 'task_type', 'stream', 'chunk_delay'}
189
+ )
190
 
191
  try:
192
  gcs_model_folder_uri = model_loader._get_gcs_uri(model_name)
193
+ if not model_loader._blob_exists(f"{gcs_model_folder_uri}/config.json"):
194
+ logger.info(f"Model '{model_name}' not found in GCS, downloading from Hugging Face.")
 
 
 
 
 
195
 
196
  config = model_loader.load_config(model_name)
197
  if not config:
 
223
 
224
  if task_type == "text-to-text":
225
  if stream:
226
+ return StreamingResponse(generate_stream(model, tokenizer, input_text, generation_config), media_type="text/event-stream")
227
  else:
228
  text_result = generate_non_stream(model, tokenizer, input_text, generation_config)
229
  return {"text": text_result}
230
  else:
231
+ raise HTTPException(status_code=400, detail=f"Task type not supported: {task_type}")
232
 
233
  except HTTPException as e:
234
  raise e