Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import os
|
2 |
from fastapi import FastAPI, HTTPException
|
3 |
from fastapi.responses import StreamingResponse
|
4 |
-
from pydantic
|
5 |
from transformers import pipeline, AutoConfig, AutoTokenizer
|
6 |
from transformers.utils import logging
|
7 |
from google.cloud import storage
|
@@ -55,7 +55,7 @@ class GenerateRequest(BaseModel):
|
|
55 |
num_return_sequences: int = 1
|
56 |
do_sample: bool = False
|
57 |
chunk_delay: float = 0.0
|
58 |
-
max_new_tokens: int = 512
|
59 |
|
60 |
@field_validator("model_name")
|
61 |
def model_name_cannot_be_empty(cls, v):
|
@@ -114,25 +114,6 @@ class GCSModelLoader:
|
|
114 |
|
115 |
model_loader = GCSModelLoader(bucket)
|
116 |
|
117 |
-
class TokenIteratorStreamer:
|
118 |
-
def __init__(self):
|
119 |
-
self.queue = asyncio.Queue()
|
120 |
-
|
121 |
-
def put(self, value):
|
122 |
-
self.queue.put_nowait(value)
|
123 |
-
|
124 |
-
def end(self):
|
125 |
-
self.queue.put_nowait(None)
|
126 |
-
|
127 |
-
def __aiter__(self):
|
128 |
-
return self
|
129 |
-
|
130 |
-
async def __anext__(self):
|
131 |
-
value = await self.queue.get()
|
132 |
-
if value is None:
|
133 |
-
raise StopAsyncIteration
|
134 |
-
return value
|
135 |
-
|
136 |
@app.post("/generate")
|
137 |
async def generate(request: GenerateRequest):
|
138 |
model_name = request.model_name
|
@@ -151,34 +132,19 @@ async def generate(request: GenerateRequest):
|
|
151 |
|
152 |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
|
153 |
|
154 |
-
async def generate_responses() -> AsyncIterator[Dict[str, str]]:
|
155 |
-
|
156 |
-
|
157 |
-
text_pipeline = pipeline(task_type, model=model_name, tokenizer=tokenizer, token=HUGGINGFACE_HUB_TOKEN)
|
158 |
-
token_streamer = TokenIteratorStreamer()
|
159 |
-
generated_text = ""
|
160 |
-
|
161 |
-
def generate_on_thread(pipeline, input_text, streamer, generation_params, max_new_tokens):
|
162 |
-
try:
|
163 |
-
for output in pipeline(input_text,
|
164 |
-
max_new_tokens=max_new_tokens,
|
165 |
-
return_full_text=False,
|
166 |
-
streamer=streamer,
|
167 |
-
**generation_params):
|
168 |
-
streamer.put(output) # Put the output dictionary into the queue
|
169 |
-
finally:
|
170 |
-
streamer.end()
|
171 |
-
|
172 |
-
thread = Thread(target=generate_on_thread, args=(text_pipeline, current_input_text, token_streamer, generation_params, initial_max_new_tokens))
|
173 |
-
thread.start()
|
174 |
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
generated_text += token
|
179 |
-
yield {"token": token, "generated_text": generated_text}
|
180 |
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
async def text_stream():
|
184 |
async for data in generate_responses():
|
|
|
1 |
import os
|
2 |
from fastapi import FastAPI, HTTPException
|
3 |
from fastapi.responses import StreamingResponse
|
4 |
+
from pydantic importBaseModel, field_validator
|
5 |
from transformers import pipeline, AutoConfig, AutoTokenizer
|
6 |
from transformers.utils import logging
|
7 |
from google.cloud import storage
|
|
|
55 |
num_return_sequences: int = 1
|
56 |
do_sample: bool = False
|
57 |
chunk_delay: float = 0.0
|
58 |
+
max_new_tokens: int = 512
|
59 |
|
60 |
@field_validator("model_name")
|
61 |
def model_name_cannot_be_empty(cls, v):
|
|
|
114 |
|
115 |
model_loader = GCSModelLoader(bucket)
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
@app.post("/generate")
|
118 |
async def generate(request: GenerateRequest):
|
119 |
model_name = request.model_name
|
|
|
132 |
|
133 |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
|
134 |
|
135 |
+
async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
|
136 |
+
text_pipeline = pipeline(task_type, model=model_name, tokenizer=tokenizer, token=HUGGINGFACE_HUB_TOKEN, **generation_params, max_new_tokens=initial_max_new_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
+
def generate_on_thread(pipeline, input_text, output_queue):
|
139 |
+
result = pipeline(input_text)
|
140 |
+
output_queue.put_nowait(result)
|
|
|
|
|
141 |
|
142 |
+
output_queue = asyncio.Queue()
|
143 |
+
thread = Thread(target=generate_on_thread, args=(text_pipeline, input_text, output_queue))
|
144 |
+
thread.start()
|
145 |
+
result = await output_queue.get()
|
146 |
+
thread.join()
|
147 |
+
yield {"response": result}
|
148 |
|
149 |
async def text_stream():
|
150 |
async for data in generate_responses():
|