priya2k commited on
Commit
f41a38c
·
verified ·
1 Parent(s): f6216bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -26
app.py CHANGED
@@ -1,10 +1,7 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModel
4
  import torch
5
  import os
6
-
7
- app = FastAPI()
8
 
9
  # Load Hugging Face Token
10
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -13,32 +10,35 @@ if not HF_TOKEN:
13
 
14
  # Load tokenizer and model
15
  tokenizer = AutoTokenizer.from_pretrained("mental/mental-bert-base-uncased", use_auth_token=HF_TOKEN)
16
- model = AutoModel.from_pretrained("mental/mental-bert-base-uncased", use_auth_token=HF_TOKEN)
17
 
18
  model.eval() # Set model to evaluation mode
19
 
20
- # Request body schema
21
- class TextRequest(BaseModel):
22
- text: str
23
 
24
- # Helper function to compute embedding
25
- def compute_embedding(text: str) -> list[float]:
26
- """Generate a sentence embedding using mean pooling on MentalBERT output."""
27
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
28
  with torch.no_grad():
29
  outputs = model(**inputs)
30
- embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
31
- return embedding.tolist()
32
-
33
- # POST endpoint to return embedding
34
- @app.post("/embed")
35
- def get_embedding(request: TextRequest):
36
- text = request.text.strip()
37
- if not text:
38
- raise HTTPException(status_code=400, detail="Input text cannot be empty.")
39
 
40
- try:
41
- embedding = compute_embedding(text)
42
- return {"embedding": embedding}
43
- except Exception as e:
44
- raise HTTPException(status_code=500, detail=f"Error computing embedding: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModel
2
  import torch
3
  import os
4
+ import gradio as gr
 
5
 
6
  # Load Hugging Face Token
7
  HF_TOKEN = os.getenv("HF_TOKEN")
 
10
 
11
  # Load tokenizer and model
12
  tokenizer = AutoTokenizer.from_pretrained("mental/mental-bert-base-uncased", use_auth_token=HF_TOKEN)
13
+ model = AutoModel.from_pretrained("mental/mental-bert-base-uncased", use_auth_token=HF_TOKEN,output_hidden_states=True)
14
 
15
  model.eval() # Set model to evaluation mode
16
 
 
 
 
17
 
18
+
19
+ def infer(text):
20
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
 
21
  with torch.no_grad():
22
  outputs = model(**inputs)
 
 
 
 
 
 
 
 
 
23
 
24
+ last_hidden_state = outputs.last_hidden_state # (1, seq_len, hidden_size)
25
+ mask = inputs['attention_mask'].unsqueeze(-1).expand(last_hidden_state.size()).float()
26
+
27
+ masked_embeddings = last_hidden_state * mask
28
+ summed = torch.sum(masked_embeddings, dim=1)
29
+ counts = torch.clamp(mask.sum(dim=1), min=1e-9)
30
+ mean_pooled = summed / counts
31
+
32
+ return mean_pooled.squeeze().tolist()
33
+
34
+
35
+ # Gradio interface
36
+ iface = gr.Interface(
37
+ fn=infer,
38
+ inputs=[
39
+ gr.Textbox(label="text"),
40
+ ],
41
+ outputs="text"
42
+ )
43
+ iface.launch()
44
+