Spaces:
Running
Running
Commit
·
10cb7da
1
Parent(s):
953ee36
adicionado chat
Browse files- main.py +40 -1
- requirements.txt +8 -1
main.py
CHANGED
@@ -5,7 +5,7 @@ from typing import List, Literal
|
|
5 |
from fastapi import FastAPI
|
6 |
from pydantic import BaseModel
|
7 |
from enum import Enum
|
8 |
-
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
|
9 |
import torch
|
10 |
|
11 |
app = FastAPI(docs_url="/")
|
@@ -57,6 +57,45 @@ async def translate(request: TranslationRequest):
|
|
57 |
return response
|
58 |
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
if __name__ == "__main__":
|
61 |
import uvicorn
|
62 |
|
|
|
5 |
from fastapi import FastAPI
|
6 |
from pydantic import BaseModel
|
7 |
from enum import Enum
|
8 |
+
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
|
9 |
import torch
|
10 |
|
11 |
app = FastAPI(docs_url="/")
|
|
|
57 |
return response
|
58 |
|
59 |
|
60 |
+
|
61 |
+
# chat
|
62 |
+
WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
|
63 |
+
|
64 |
+
chat_model_name = "csebuetnlp/mT5_multilingual_XLSum"
|
65 |
+
tokenizer = AutoTokenizer.from_pretrained(chat_model_name)
|
66 |
+
modelchat = AutoModelForSeq2SeqLM.from_pretrained(chat_model_name)
|
67 |
+
|
68 |
+
@app.get("/chat")
|
69 |
+
async def read_root(text: str, ):
|
70 |
+
input_ids = tokenizer(
|
71 |
+
[WHITESPACE_HANDLER(text)],
|
72 |
+
return_tensors="pt",
|
73 |
+
padding="max_length",
|
74 |
+
truncation=True,
|
75 |
+
max_length=512
|
76 |
+
)["input_ids"]
|
77 |
+
|
78 |
+
# max_length=84,
|
79 |
+
output_ids = modelchat.generate(
|
80 |
+
input_ids=input_ids,
|
81 |
+
max_length=500,
|
82 |
+
no_repeat_ngram_size=2,
|
83 |
+
num_beams=4
|
84 |
+
)[0]
|
85 |
+
|
86 |
+
summary = tokenizer.decode(
|
87 |
+
output_ids,
|
88 |
+
skip_special_tokens=True,
|
89 |
+
clean_up_tokenization_spaces=True
|
90 |
+
)
|
91 |
+
|
92 |
+
return {"summary": summary}
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
if __name__ == "__main__":
|
100 |
import uvicorn
|
101 |
|
requirements.txt
CHANGED
@@ -4,4 +4,11 @@ requests==2.27.*
|
|
4 |
sentencepiece==0.1.*
|
5 |
torch==1.11.*
|
6 |
transformers==4.*
|
7 |
-
uvicorn[standard]==0.17.*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
sentencepiece==0.1.*
|
5 |
torch==1.11.*
|
6 |
transformers==4.*
|
7 |
+
uvicorn[standard]==0.17.*
|
8 |
+
|
9 |
+
tensorboard
|
10 |
+
scikit-learn
|
11 |
+
seqeval
|
12 |
+
psutil
|
13 |
+
sacrebleu
|
14 |
+
protobuf
|