gleisonnanet commited on
Commit
fb2eca4
·
1 Parent(s): 12883ff

start teste

Browse files
Files changed (3) hide show
  1. Dockerfile +11 -0
  2. main.py +50 -0
  3. requirements.txt +7 -0
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM tiangolo/uvicorn-gunicorn-fastapi:python3.9
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+
7
+ RUN pip install -r requirements.txt
8
+
9
+ COPY . .
10
+
11
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"]
main.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import time
3
+ from typing import List, Literal
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+ from enum import Enum
7
+ from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
8
+ import torch
9
+
10
+ app = FastAPI()
11
+ device = torch.device("cpu")
12
+
13
+
14
+ class TranslationRequest(BaseModel):
15
+ user_input: str
16
+ source_lang: str
17
+ target_lang: str
18
+
19
+
20
+ def load_model(pretrained_model: str = "facebook/m2m100_1.2B", cache_dir: str = "models/"):
21
+ tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
22
+ model = M2M100ForConditionalGeneration.from_pretrained(pretrained_model, cache_dir=cache_dir).to(device)
23
+ model.eval()
24
+ return tokenizer, model
25
+
26
+
27
+ @app.post("/translate")
28
+ async def translate(request: TranslationRequest):
29
+ time_start = time.time()
30
+ tokenizer, model = load_model()
31
+ src_lang = request.source_lang
32
+ trg_lang = request.target_lang
33
+ tokenizer.src_lang = src_lang
34
+ with torch.no_grad():
35
+ encoded_input = tokenizer(request.user_input, return_tensors="pt").to(device)
36
+ generated_tokens = model.generate(
37
+ **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
38
+ )
39
+ translated_text = tokenizer.batch_decode(
40
+ generated_tokens, skip_special_tokens=True
41
+ )[0]
42
+ time_end = time.time()
43
+ response = {"translation": translated_text, "computation_time": round((time_end - time_start), 3)}
44
+ return response
45
+
46
+
47
+ if __name__ == "__main__":
48
+ import uvicorn
49
+
50
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ wheel
2
+ fastapi[all]
3
+ gunicorn
4
+ streamlit
5
+ torch
6
+ transformers
7
+ transformers[sentencepiece]