gleisonnanet commited on
Commit
10cb7da
·
1 Parent(s): 953ee36

adicionado chat

Browse files
Files changed (2) hide show
  1. main.py +40 -1
  2. requirements.txt +8 -1
main.py CHANGED
@@ -5,7 +5,7 @@ from typing import List, Literal
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
7
  from enum import Enum
8
- from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
9
  import torch
10
 
11
  app = FastAPI(docs_url="/")
@@ -57,6 +57,45 @@ async def translate(request: TranslationRequest):
57
  return response
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  if __name__ == "__main__":
61
  import uvicorn
62
 
 
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
7
  from enum import Enum
8
+ from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
9
  import torch
10
 
11
  app = FastAPI(docs_url="/")
 
57
  return response
58
 
59
 
60
+
61
+ # chat
62
+ WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
63
+
64
+ chat_model_name = "csebuetnlp/mT5_multilingual_XLSum"
65
+ tokenizer = AutoTokenizer.from_pretrained(chat_model_name)
66
+ modelchat = AutoModelForSeq2SeqLM.from_pretrained(chat_model_name)
67
+
68
+ @app.get("/chat")
69
+ async def read_root(text: str, ):
70
+ input_ids = tokenizer(
71
+ [WHITESPACE_HANDLER(text)],
72
+ return_tensors="pt",
73
+ padding="max_length",
74
+ truncation=True,
75
+ max_length=512
76
+ )["input_ids"]
77
+
78
+ # max_length=84,
79
+ output_ids = modelchat.generate(
80
+ input_ids=input_ids,
81
+ max_length=500,
82
+ no_repeat_ngram_size=2,
83
+ num_beams=4
84
+ )[0]
85
+
86
+ summary = tokenizer.decode(
87
+ output_ids,
88
+ skip_special_tokens=True,
89
+ clean_up_tokenization_spaces=True
90
+ )
91
+
92
+ return {"summary": summary}
93
+
94
+
95
+
96
+
97
+
98
+
99
  if __name__ == "__main__":
100
  import uvicorn
101
 
requirements.txt CHANGED
@@ -4,4 +4,11 @@ requests==2.27.*
4
  sentencepiece==0.1.*
5
  torch==1.11.*
6
  transformers==4.*
7
- uvicorn[standard]==0.17.*
 
 
 
 
 
 
 
 
4
  sentencepiece==0.1.*
5
  torch==1.11.*
6
  transformers==4.*
7
+ uvicorn[standard]==0.17.*
8
+
9
+ tensorboard
10
+ scikit-learn
11
+ seqeval
12
+ psutil
13
+ sacrebleu
14
+ protobuf