Hjgugugjhuhjggg commited on
Commit
2ca418a
·
verified ·
1 Parent(s): 49a991b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -47
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  from fastapi import FastAPI, HTTPException
3
  from fastapi.responses import StreamingResponse
4
- from pydantic import BaseModel, field_validator
5
  from transformers import pipeline, AutoConfig, AutoTokenizer
6
  from transformers.utils import logging
7
  from google.cloud import storage
@@ -55,7 +55,7 @@ class GenerateRequest(BaseModel):
55
  num_return_sequences: int = 1
56
  do_sample: bool = False
57
  chunk_delay: float = 0.0
58
- max_new_tokens: int = 512 # Initial max tokens, can be large
59
 
60
  @field_validator("model_name")
61
  def model_name_cannot_be_empty(cls, v):
@@ -114,25 +114,6 @@ class GCSModelLoader:
114
 
115
  model_loader = GCSModelLoader(bucket)
116
 
117
- class TokenIteratorStreamer:
118
- def __init__(self):
119
- self.queue = asyncio.Queue()
120
-
121
- def put(self, value):
122
- self.queue.put_nowait(value)
123
-
124
- def end(self):
125
- self.queue.put_nowait(None)
126
-
127
- def __aiter__(self):
128
- return self
129
-
130
- async def __anext__(self):
131
- value = await self.queue.get()
132
- if value is None:
133
- raise StopAsyncIteration
134
- return value
135
-
136
  @app.post("/generate")
137
  async def generate(request: GenerateRequest):
138
  model_name = request.model_name
@@ -151,34 +132,19 @@ async def generate(request: GenerateRequest):
151
 
152
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
153
 
154
- async def generate_responses() -> AsyncIterator[Dict[str, str]]:
155
- current_input_text = input_text
156
-
157
- text_pipeline = pipeline(task_type, model=model_name, tokenizer=tokenizer, token=HUGGINGFACE_HUB_TOKEN)
158
- token_streamer = TokenIteratorStreamer()
159
- generated_text = ""
160
-
161
- def generate_on_thread(pipeline, input_text, streamer, generation_params, max_new_tokens):
162
- try:
163
- for output in pipeline(input_text,
164
- max_new_tokens=max_new_tokens,
165
- return_full_text=False,
166
- streamer=streamer,
167
- **generation_params):
168
- streamer.put(output) # Put the output dictionary into the queue
169
- finally:
170
- streamer.end()
171
-
172
- thread = Thread(target=generate_on_thread, args=(text_pipeline, current_input_text, token_streamer, generation_params, initial_max_new_tokens))
173
- thread.start()
174
 
175
- async for output_dict in token_streamer:
176
- if isinstance(output_dict, dict) and "generated_text" in output_dict:
177
- token = output_dict["generated_text"]
178
- generated_text += token
179
- yield {"token": token, "generated_text": generated_text}
180
 
181
- thread.join() # Ensure the thread finishes before exiting the generator
 
 
 
 
 
182
 
183
  async def text_stream():
184
  async for data in generate_responses():
 
1
  import os
2
  from fastapi import FastAPI, HTTPException
3
  from fastapi.responses import StreamingResponse
4
+ from pydantic importBaseModel, field_validator
5
  from transformers import pipeline, AutoConfig, AutoTokenizer
6
  from transformers.utils import logging
7
  from google.cloud import storage
 
55
  num_return_sequences: int = 1
56
  do_sample: bool = False
57
  chunk_delay: float = 0.0
58
+ max_new_tokens: int = 512
59
 
60
  @field_validator("model_name")
61
  def model_name_cannot_be_empty(cls, v):
 
114
 
115
  model_loader = GCSModelLoader(bucket)
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  @app.post("/generate")
118
  async def generate(request: GenerateRequest):
119
  model_name = request.model_name
 
132
 
133
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
134
 
135
+ async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
136
+ text_pipeline = pipeline(task_type, model=model_name, tokenizer=tokenizer, token=HUGGINGFACE_HUB_TOKEN, **generation_params, max_new_tokens=initial_max_new_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ def generate_on_thread(pipeline, input_text, output_queue):
139
+ result = pipeline(input_text)
140
+ output_queue.put_nowait(result)
 
 
141
 
142
+ output_queue = asyncio.Queue()
143
+ thread = Thread(target=generate_on_thread, args=(text_pipeline, input_text, output_queue))
144
+ thread.start()
145
+ result = await output_queue.get()
146
+ thread.join()
147
+ yield {"response": result}
148
 
149
  async def text_stream():
150
  async for data in generate_responses():