mwaliahmad commited on
Commit
6337bb1
·
1 Parent(s): 2e85e33

Initial commit

Browse files
Files changed (2) hide show
  1. app.py +48 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoModelForSequenceClassification
5
+
6
+ # Load ONLY the model, NOT the tokenizer
7
+ model = AutoModelForSequenceClassification.from_pretrained(
8
+ "Kevintu/Engessay_grading_ML")
9
+
10
+
11
+ def process_embeddings(embeddings_array):
12
+ # Convert the received embeddings to the format expected by the model
13
+ embeddings_tensor = torch.tensor(embeddings_array)
14
+
15
+ # Process embeddings with the model
16
+ model.eval()
17
+ with torch.no_grad():
18
+ # Create a dict with the expected input format
19
+ model_inputs = {
20
+ 'input_ids': None, # Not needed since we're using embeddings directly
21
+ 'attention_mask': None, # Not needed for this use case
22
+ 'inputs_embeds': embeddings_tensor # Pass embeddings directly
23
+ }
24
+ outputs = model(**model_inputs)
25
+
26
+ predictions = outputs.logits.squeeze()
27
+
28
+ item_names = ["cohesion", "syntax", "vocabulary",
29
+ "phraseology", "grammar", "conventions"]
30
+ scaled_scores = 2.25 * predictions.numpy() - 1.25
31
+ rounded_scores = [round(score * 2) / 2 for score in scaled_scores]
32
+
33
+ results = {item: f"{score:.1f}" for item,
34
+ score in zip(item_names, rounded_scores)}
35
+ return results
36
+
37
+
38
+ # Create Gradio interface for embeddings input
39
+ demo = gr.Interface(
40
+ fn=process_embeddings,
41
+ inputs=gr.JSON(label="Embeddings"),
42
+ outputs=gr.JSON(label="Scores"),
43
+ title="Essay Grading API (Embeddings Only)",
44
+ description="Grade essays based on precomputed embeddings"
45
+ )
46
+
47
+ demo.queue()
48
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio>=3.50.2
2
+ torch>=2.0.0
3
+ transformers>=4.30.0
4
+ numpy>=1.24.0