Hjgugugjhuhjggg commited on
Commit
053347d
·
verified ·
1 Parent(s): 00ee742

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -8
app.py CHANGED
@@ -24,6 +24,7 @@ import logging
24
 
25
  GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
26
  GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
 
27
 
28
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
29
  logger = logging.getLogger(__name__)
@@ -88,15 +89,36 @@ class GCSModelLoader:
88
  return json.loads(blob.download_as_bytes())
89
  except Exception as e:
90
  raise Exception(f"Error downloading config: {e}")
 
 
 
 
 
 
 
 
 
91
 
92
  async def load_config(self, model_name):
93
- gcs_uri = self._get_gcs_uri(model_name)
94
- try:
95
- config_data = await self._download_config_from_gcs(gcs_uri)
96
- config = AutoConfig.from_dict(config_data)
97
- return config
98
- except Exception as e:
99
- raise HTTPException(status_code=500, detail=f"Error loading config: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  model_loader = GCSModelLoader(bucket, client)
102
 
@@ -130,7 +152,7 @@ async def stream_from_gcs(model_name, input_text, generation_config, stop_sequen
130
  yield str(token)
131
  del model
132
 
133
- class CustomTextIteratorStreamer:
134
  def __init__(self, gcs_uri, chunk_delay):
135
  self.chunk_delay = chunk_delay
136
  self.queue = asyncio.Queue()
 
24
 
25
  GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
26
  GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
27
+ HUGGINGFACE_HUB_TOKEN = os.getenv("HF_API_TOKEN")
28
 
29
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
30
  logger = logging.getLogger(__name__)
 
89
  return json.loads(blob.download_as_bytes())
90
  except Exception as e:
91
  raise Exception(f"Error downloading config: {e}")
92
+
93
+ async def _upload_to_gcs(self, gcs_uri, model_files):
94
+ try:
95
+ for file_name, file_content in model_files.items():
96
+ blob_path = gcs_uri.replace(f'gs://{self.bucket.name}/', '') + '/' + file_name
97
+ blob = self.bucket.blob(blob_path)
98
+ blob.upload_from_string(file_content)
99
+ except Exception as e:
100
+ raise Exception(f"Error uploading to GCS: {e}")
101
 
102
  async def load_config(self, model_name):
103
+ gcs_uri = self._get_gcs_uri(model_name)
104
+ try:
105
+ config_data = await self._download_config_from_gcs(gcs_uri)
106
+ config = AutoConfig.from_pretrained(gcs_uri, config = config_data)
107
+ return config
108
+ except Exception as e:
109
+ try:
110
+ config = AutoConfig.from_pretrained(model_name, use_auth_token=HUGGINGFACE_HUB_TOKEN)
111
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=HUGGINGFACE_HUB_TOKEN)
112
+ model_files = {}
113
+ config_file = f"{model_name}/config.json"
114
+ tokenizer_file = f"{model_name}/tokenizer.json"
115
+ model_files[config_file] = json.dumps(config.to_dict())
116
+ model_files[tokenizer_file] = json.dumps(tokenizer.to_dict())
117
+ await self._upload_to_gcs(gcs_uri, model_files)
118
+ return config
119
+ except Exception as e:
120
+ raise HTTPException(status_code=500, detail=f"Error loading config: {e}")
121
+
122
 
123
  model_loader = GCSModelLoader(bucket, client)
124
 
 
152
  yield str(token)
153
  del model
154
 
155
+ class CustomTextIteratorStreamer(TextIteratorStreamer):
156
  def __init__(self, gcs_uri, chunk_delay):
157
  self.chunk_delay = chunk_delay
158
  self.queue = asyncio.Queue()