File size: 4,229 Bytes
bda5a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import tempfile
from transformers import pipeline, RobertaForSequenceClassification, RobertaTokenizer
import gradio as gr
from fastapi import FastAPI, UploadFile, File, Request, HTTPException
import os
import json
from typing import Optional, Dict, List
import torch

# Initialize models
model_name = "cardiffnlp/twitter-roberta-base-emotion"
tokenizer = RobertaTokenizer.from_pretrained(model_name)
model = RobertaForSequenceClassification.from_pretrained(model_name)
emotion_analysis = pipeline("text-classification",
                          model=model, 
                          tokenizer=tokenizer,
                          return_all_scores=True)

# Create FastAPI app
app = FastAPI()

def save_upload_file(upload_file: UploadFile) -> str:
    """Save uploaded file to temporary location"""
    try:
        suffix = os.path.splitext(upload_file.filename)[1]
        with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
            content = upload_file.file.read()
            if suffix == '.json':
                content = content.decode('utf-8')  # Decode JSON files
            tmp.write(content if isinstance(content, bytes) else content.encode())
            return tmp.name
    finally:
        upload_file.file.close()

@app.post("/api/predict")
async def predict_from_upload(file: UploadFile = File(...)):
    """API endpoint for file uploads"""
    try:
        # Save the uploaded file temporarily
        temp_path = save_upload_file(file)
        
        # Process based on file type
        if temp_path.endswith('.json'):
            with open(temp_path, 'r') as f:
                data = json.load(f)
            text = data.get('description', '')
        else:  # Assume text file
            with open(temp_path, 'r') as f:
                text = f.read()
        
        if not text.strip():
            raise HTTPException(status_code=400, detail="No text content found")
        
        # Analyze text
        result = emotion_analysis(text)
        emotions = [{'label': e['label'], 'score': float(e['score'])} 
                   for e in sorted(result[0], key=lambda x: x['score'], reverse=True)]
        
        # Clean up
        os.unlink(temp_path)
        
        return {
            "success": True,
            "results": emotions
        }
        
    except Exception as e:
        if 'temp_path' in locals() and os.path.exists(temp_path):
            os.unlink(temp_path)
        raise HTTPException(status_code=500, detail=str(e))

# Gradio interface
def gradio_predict(input_data):
    """Handle both direct text and file uploads"""
    try:
        if isinstance(input_data, str):  # Direct text input
            text = input_data
        else:  # File upload
            temp_path = save_upload_file(input_data)
            if temp_path.endswith('.json'):
                with open(temp_path, 'r') as f:
                    data = json.load(f)
                text = data.get('description', '')
            else:
                with open(temp_path, 'r') as f:
                    text = f.read()
            os.unlink(temp_path)
        
        if not text.strip():
            return {"error": "No text content found"}
        
        result = emotion_analysis(text)
        return {
            "emotions": [
                {e['label']: float(e['score'])} 
                for e in sorted(result[0], key=lambda x: x['score'], reverse=True)
            ]
        }
    
    except Exception as e:
        return {"error": str(e)}

# Create Gradio interface
demo = gr.Interface(
    fn=gradio_predict,
    inputs=[
        gr.Textbox(label="Enter text directly", lines=5),
        gr.File(label="Or upload text/JSON file", file_types=[".txt", ".json"])
    ],
    outputs=gr.JSON(label="Emotion Analysis"),
    title="Text Emotion Analysis",
    description="Analyze emotion in text using RoBERTa model",
    examples=[
        ["I'm feeling absolutely thrilled about this new project!"],
        ["This situation is making me extremely anxious and worried."]
    ]
)

# Mount Gradio app
app = gr.mount_gradio_app(app, demo, path="/")

# For running locally
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)