Spaces:
Running
Running
import io | |
import re | |
import time | |
import os | |
from typing import List, Literal | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from enum import Enum | |
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration | |
import torch | |
import uvicorn | |
from fastapi.responses import HTMLResponse, FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from enum import Enum | |
from fastapi.staticfiles import StaticFiles | |
os.makedirs("static", exist_ok=True) | |
app = FastAPI(docs_url="/docs", redoc_url=None) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
class TranslationRequest(BaseModel): | |
user_input: str | |
source_lang: str | |
target_lang: str | |
model:str = 'facebook/m2m100_418M' | |
# facebook/m2m100_418M | |
# facebook/m2m100_1.2B | |
def load_model(model: str = 'facebook/m2m100_418M' , cache_dir: str = "models/"): | |
model_dir = os.path.join(os.getcwd(), cache_dir) | |
tokenizer = M2M100Tokenizer.from_pretrained(model, cache_dir=model_dir) | |
model = M2M100ForConditionalGeneration.from_pretrained(model, cache_dir=model_dir).to(device) | |
model.eval() | |
return tokenizer, model | |
# aparentemente temos um problema ao carregar o modelo então vou tentar carregar no start da aplicação para não dar time-out na request | |
load_model() | |
async def read_index(): | |
""" | |
Serve o arquivo index.html da pasta static | |
""" | |
return FileResponse("static/index.html") | |
async def translate(request: TranslationRequest): | |
""" | |
models: facebook/m2m100_418M | facebook/m2m100_1.2B | |
language support | |
Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu) | |
""" | |
try: | |
tokenizer, model = load_model(model=request.model) | |
except Exception as E: | |
return{"error": str(E)} | |
src_lang = request.source_lang | |
trg_lang = request.target_lang | |
tokenizer.src_lang = src_lang | |
with torch.no_grad(): | |
encoded_input = tokenizer(request.user_input, return_tensors="pt").to(device) | |
generated_tokens = model.generate( | |
**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang) | |
) | |
translated_text = tokenizer.batch_decode( | |
generated_tokens, skip_special_tokens=True | |
)[0] | |
try: | |
response = {"translation": translated_text} | |
except Exception as E: | |
return {"error": str(E)} | |
return response | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |