Lord-Raven commited on
Commit
fd25b82
·
1 Parent(s): a822923

Blocking requests from other origins.

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -18,9 +18,11 @@ app.add_middleware(
18
  allow_headers=["*"],
19
  )
20
 
21
- model_name = "xenova/mobilebert-uncased-mnli"
22
- model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name="onnx/model_quantized.onnx")
23
- tokenizer = AutoTokenizer.from_pretrained("typeform/mobilebert-uncased-mnli", model_max_length=512)
 
 
24
 
25
  # file = cached_download("https://huggingface.co/" + model_name + "")
26
  # sess = InferenceSession(file)
@@ -28,13 +30,11 @@ tokenizer = AutoTokenizer.from_pretrained("typeform/mobilebert-uncased-mnli", mo
28
  classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
29
 
30
  def zero_shot_classification(data_string, request: gradio.Request):
31
- print(data_string)
32
  if request:
33
  print("Request headers dictionary:", request.headers)
34
- print("IP address:", request.client.host)
35
- print("Query parameters:", dict(request.query_params))
36
- else:
37
- print("No request")
38
  data = json.loads(data_string)
39
  print(data)
40
  results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
 
18
  allow_headers=["*"],
19
  )
20
 
21
+ model_name = "xenova/deberta-v3-base-tasksource-nli" # "xenova/mobilebert-uncased-mnli"
22
+ file_name = "onnx/model_quantized.onnx"
23
+ tokenizer_name = "sileod/deberta-v3-base-tasksource-nli" # "typeform/mobilebert-uncased-mnli"
24
+ model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=file_name)
25
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
26
 
27
  # file = cached_download("https://huggingface.co/" + model_name + "")
28
  # sess = InferenceSession(file)
 
30
  classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
31
 
32
  def zero_shot_classification(data_string, request: gradio.Request):
 
33
  if request:
34
  print("Request headers dictionary:", request.headers)
35
+ if !(request.origin in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://jhuhman-statosphere-backend.hf.space"])
36
+ return ""
37
+ print(data_string)
 
38
  data = json.loads(data_string)
39
  print(data)
40
  results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])