yakine commited on
Commit
9191425
·
verified ·
1 Parent(s): daac335

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +3 -2
main.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
5
  import logging
6
  import re
 
7
 
8
  app = FastAPI()
9
 
@@ -23,8 +24,8 @@ logging.basicConfig(level=logging.INFO)
23
  ####################################
24
  # Text Generation Endpoint
25
  ####################################
26
-
27
- TEXT_MODEL_NAME = "aubmindlab/aragpt2-medium"
28
  text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
29
  text_model = AutoModelForCausalLM.from_pretrained(TEXT_MODEL_NAME)
30
 
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
5
  import logging
6
  import re
7
+ import os
8
 
9
  app = FastAPI()
10
 
 
24
  ####################################
25
  # Text Generation Endpoint
26
  ####################################
27
+ os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
28
+ TEXT_MODEL_NAME = "aubmindlab/aragpt2-base"
29
  text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
30
  text_model = AutoModelForCausalLM.from_pretrained(TEXT_MODEL_NAME)
31