Rivalcoder commited on
Commit
bda5a7d
·
1 Parent(s): 45a75d8
Files changed (2) hide show
  1. app.py +125 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ from transformers import pipeline, RobertaForSequenceClassification, RobertaTokenizer
3
+ import gradio as gr
4
+ from fastapi import FastAPI, UploadFile, File, Request, HTTPException
5
+ import os
6
+ import json
7
+ from typing import Optional, Dict, List
8
+ import torch
9
+
10
+ # Initialize models
11
+ model_name = "cardiffnlp/twitter-roberta-base-emotion"
12
+ tokenizer = RobertaTokenizer.from_pretrained(model_name)
13
+ model = RobertaForSequenceClassification.from_pretrained(model_name)
14
+ emotion_analysis = pipeline("text-classification",
15
+ model=model,
16
+ tokenizer=tokenizer,
17
+ return_all_scores=True)
18
+
19
+ # Create FastAPI app
20
+ app = FastAPI()
21
+
22
+ def save_upload_file(upload_file: UploadFile) -> str:
23
+ """Save uploaded file to temporary location"""
24
+ try:
25
+ suffix = os.path.splitext(upload_file.filename)[1]
26
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
27
+ content = upload_file.file.read()
28
+ if suffix == '.json':
29
+ content = content.decode('utf-8') # Decode JSON files
30
+ tmp.write(content if isinstance(content, bytes) else content.encode())
31
+ return tmp.name
32
+ finally:
33
+ upload_file.file.close()
34
+
35
+ @app.post("/api/predict")
36
+ async def predict_from_upload(file: UploadFile = File(...)):
37
+ """API endpoint for file uploads"""
38
+ try:
39
+ # Save the uploaded file temporarily
40
+ temp_path = save_upload_file(file)
41
+
42
+ # Process based on file type
43
+ if temp_path.endswith('.json'):
44
+ with open(temp_path, 'r') as f:
45
+ data = json.load(f)
46
+ text = data.get('description', '')
47
+ else: # Assume text file
48
+ with open(temp_path, 'r') as f:
49
+ text = f.read()
50
+
51
+ if not text.strip():
52
+ raise HTTPException(status_code=400, detail="No text content found")
53
+
54
+ # Analyze text
55
+ result = emotion_analysis(text)
56
+ emotions = [{'label': e['label'], 'score': float(e['score'])}
57
+ for e in sorted(result[0], key=lambda x: x['score'], reverse=True)]
58
+
59
+ # Clean up
60
+ os.unlink(temp_path)
61
+
62
+ return {
63
+ "success": True,
64
+ "results": emotions
65
+ }
66
+
67
+ except Exception as e:
68
+ if 'temp_path' in locals() and os.path.exists(temp_path):
69
+ os.unlink(temp_path)
70
+ raise HTTPException(status_code=500, detail=str(e))
71
+
72
+ # Gradio interface
73
+ def gradio_predict(input_data):
74
+ """Handle both direct text and file uploads"""
75
+ try:
76
+ if isinstance(input_data, str): # Direct text input
77
+ text = input_data
78
+ else: # File upload
79
+ temp_path = save_upload_file(input_data)
80
+ if temp_path.endswith('.json'):
81
+ with open(temp_path, 'r') as f:
82
+ data = json.load(f)
83
+ text = data.get('description', '')
84
+ else:
85
+ with open(temp_path, 'r') as f:
86
+ text = f.read()
87
+ os.unlink(temp_path)
88
+
89
+ if not text.strip():
90
+ return {"error": "No text content found"}
91
+
92
+ result = emotion_analysis(text)
93
+ return {
94
+ "emotions": [
95
+ {e['label']: float(e['score'])}
96
+ for e in sorted(result[0], key=lambda x: x['score'], reverse=True)
97
+ ]
98
+ }
99
+
100
+ except Exception as e:
101
+ return {"error": str(e)}
102
+
103
+ # Create Gradio interface
104
+ demo = gr.Interface(
105
+ fn=gradio_predict,
106
+ inputs=[
107
+ gr.Textbox(label="Enter text directly", lines=5),
108
+ gr.File(label="Or upload text/JSON file", file_types=[".txt", ".json"])
109
+ ],
110
+ outputs=gr.JSON(label="Emotion Analysis"),
111
+ title="Text Emotion Analysis",
112
+ description="Analyze emotion in text using RoBERTa model",
113
+ examples=[
114
+ ["I'm feeling absolutely thrilled about this new project!"],
115
+ ["This situation is making me extremely anxious and worried."]
116
+ ]
117
+ )
118
+
119
+ # Mount Gradio app
120
+ app = gr.mount_gradio_app(app, demo, path="/")
121
+
122
+ # For running locally
123
+ if __name__ == "__main__":
124
+ import uvicorn
125
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ gradio
4
+ fastapi
5
+ uvicorn
6
+ python-multipart