IS361Group4 commited on
Commit
b68a3c9
·
verified ·
1 Parent(s): d03ac60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -137
app.py CHANGED
@@ -1,189 +1,170 @@
 
 
1
  import os
2
  import gradio as gr
3
  import pandas as pd
4
  import numpy as np
5
  import joblib
6
  import spacy
7
- from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
8
  from langchain_core.pydantic_v1 import BaseModel, Field
9
  from langchain.prompts import HumanMessagePromptTemplate, ChatPromptTemplate
10
  from langchain.output_parsers import PydanticOutputParser
11
  from langchain_openai import ChatOpenAI
 
12
 
13
- # --- Translator App ---
14
  chat = ChatOpenAI()
 
15
  class TextTranslator(BaseModel):
16
- output: str = Field(description="Python string containing the output text translated in the desired language")
17
 
18
  output_parser = PydanticOutputParser(pydantic_object=TextTranslator)
19
  format_instructions = output_parser.get_format_instructions()
20
 
21
- def text_translator(input_text : str, language : str) -> str:
22
- human_template = """Enter the text that you want to translate:
23
- {input_text}, and enter the language that you want it to translate to {language}. {format_instructions}"""
24
- human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
25
- chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt])
26
- prompt = chat_prompt.format_prompt(input_text = input_text, language = language, format_instructions = format_instructions)
27
- messages = prompt.to_messages()
28
- response = chat(messages = messages)
29
- output = output_parser.parse(response.content)
30
- return output.output
31
-
32
- translator_tab = gr.Interface(fn=text_translator,
33
- inputs=[gr.Textbox(label="Text to translate"), gr.Textbox(label="Target Language")],
34
- outputs=[gr.Textbox(label="Translated Text")],
35
- title="Text Translator")
36
-
37
- # --- Sentiment Analysis App ---
38
  sentiment_model = pipeline("sentiment-analysis", model="cardiffnlp/twitter-xlm-roberta-base-sentiment")
39
  def sentiment_analysis(message, history):
40
  result = sentiment_model(message)
41
- return f"Sentiment: {result[0]['label']} (Probability: {result[0]['score']:.2f})"
42
-
43
- sentiment_tab = gr.ChatInterface(fn=sentiment_analysis, title="Sentiment Analysis")
44
-
45
- # --- Financial Analyst ---
46
- spacy_model = spacy.load('en_core_web_sm')
47
- spacy_model.add_pipe('sentencizer')
48
- auth_token = os.environ.get("HF_Token")
49
- asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h")
50
- summarizer = pipeline("summarization", model="knkarthick/MEETING_SUMMARY")
51
- fin_model = pipeline("sentiment-analysis", model='yiyanghkust/finbert-tone', tokenizer='yiyanghkust/finbert-tone')
52
 
 
 
 
53
  def split_in_sentences(text):
54
- doc = spacy_model(text)
55
- return [str(sent).strip() for sent in doc.sents]
56
 
57
  def make_spans(text, results):
58
- return list(zip(split_in_sentences(text), [r["label"] for r in results]))
 
59
 
60
- def speech_to_text(speech):
61
- return asr(speech)["text"]
62
 
 
 
 
 
 
63
  def summarize_text(text):
64
  return summarizer(text)[0]['summary_text']
65
 
 
66
  def text_to_sentiment(text):
67
  return fin_model(text)[0]["label"]
68
 
 
 
 
69
  def fin_ext(text):
70
- results = fin_model(split_in_sentences(text))
71
- return make_spans(text, results)
72
 
73
  def fls(text):
74
- fls_model = pipeline("text-classification", model="demo-org/finbert_fls", tokenizer="demo-org/finbert_fls", use_auth_token=auth_token)
75
- results = fls_model(split_in_sentences(text))
76
- return make_spans(text, results)
77
 
78
- def fin_ner(text):
79
- api = gr.Interface.load("dslim/bert-base-NER", src='models', use_auth_token=auth_token)
80
- return api(text)
81
-
82
- financial_tab = gr.Blocks()
83
- with financial_tab:
84
- gr.Markdown("## Financial Analyst AI")
85
- audio_file = gr.Audio(source="microphone", type="filepath")
86
- text = gr.Textbox(label="Recognized Text")
87
- summary = gr.Textbox(label="Summary")
88
- tone = gr.Label(label="Financial Tone")
89
- spans = gr.HighlightedText()
90
- fls_spans = gr.HighlightedText()
91
- ner_spans = gr.HighlightedText()
92
- with gr.Row():
93
- gr.Button("Recognize Speech").click(speech_to_text, inputs=audio_file, outputs=text)
94
- gr.Button("Summarize Text").click(summarize_text, inputs=text, outputs=summary)
95
- gr.Button("Classify Tone").click(text_to_sentiment, inputs=summary, outputs=tone)
96
- with gr.Row():
97
- gr.Button("Financial Sentiment").click(fin_ext, inputs=text, outputs=spans)
98
- gr.Button("Forward Looking").click(fls, inputs=text, outputs=fls_spans)
99
- gr.Button("NER Companies").click(fin_ner, inputs=text, outputs=ner_spans)
100
-
101
- # --- Personal Information Detection ---
102
- pii_tab = gr.load("models/iiiorg/piiranha-v1-detect-personal-information")
103
-
104
- # --- Customer Churn ---
105
  script_dir = os.path.dirname(os.path.abspath(__file__))
106
- pipeline = joblib.load(os.path.join(script_dir, 'toolkit', 'pipeline.joblib'))
107
- model = joblib.load(os.path.join(script_dir, 'toolkit', 'Random Forest Classifier.joblib'))
 
 
108
 
109
  def calculate_total_charges(tenure, monthly_charges):
110
  return tenure * monthly_charges
111
 
112
- def predict_churn(SeniorCitizen, Partner, Dependents, tenure,
113
  InternetService, OnlineSecurity, OnlineBackup, DeviceProtection, TechSupport,
114
  StreamingTV, StreamingMovies, Contract, PaperlessBilling, PaymentMethod,
115
  MonthlyCharges):
 
116
  TotalCharges = calculate_total_charges(tenure, MonthlyCharges)
117
  input_df = pd.DataFrame({
118
- 'SeniorCitizen': [SeniorCitizen],
119
- 'Partner': [Partner],
120
- 'Dependents': [Dependents],
121
- 'tenure': [tenure],
122
- 'InternetService': [InternetService],
123
- 'OnlineSecurity': [OnlineSecurity],
124
- 'OnlineBackup': [OnlineBackup],
125
- 'DeviceProtection': [DeviceProtection],
126
- 'TechSupport': [TechSupport],
127
- 'StreamingTV': [StreamingTV],
128
- 'StreamingMovies': [StreamingMovies],
129
- 'Contract': [Contract],
130
- 'PaperlessBilling': [PaperlessBilling],
131
- 'PaymentMethod': [PaymentMethod],
132
- 'MonthlyCharges': [MonthlyCharges],
133
- 'TotalCharges': [TotalCharges]
134
  })
135
- X_processed = pipeline.transform(input_df)
136
- cat_encoder = pipeline.named_steps['preprocessor'].named_transformers_['cat'].named_steps['onehot']
137
- cat_cols = [col for col in input_df.columns if input_df[col].dtype == 'object']
138
- feature_names = [col for col in input_df.columns if input_df[col].dtype != 'object'] + list(cat_encoder.get_feature_names_out(cat_cols))
139
  final_df = pd.DataFrame(X_processed, columns=feature_names)
140
- final_df = pd.concat([final_df.iloc[:, 3:], final_df.iloc[:, :3]], axis=1)
141
- prediction_probs = model.predict_proba(final_df)[0]
142
  return {
143
- "Prediction: CHURN 🔴": prediction_probs[1],
144
- "Prediction: STAY ✅": prediction_probs[0]
145
  }
146
 
147
- churn_tab = gr.Interface(
148
- fn=predict_churn,
149
- inputs=[
150
- gr.Radio(['Yes', 'No'], label="Senior Citizen"),
151
- gr.Radio(['Yes', 'No'], label="Partner"),
152
- gr.Radio(['No', 'Yes'], label="Dependents"),
153
- gr.Slider(1, 73, step=1, label="Tenure (months)"),
154
- gr.Radio(['DSL', 'Fiber optic', 'No Internet'], label="Internet Service"),
155
- gr.Radio(['No', 'Yes'], label="Online Security"),
156
- gr.Radio(['No', 'Yes'], label="Online Backup"),
157
- gr.Radio(['No', 'Yes'], label="Device Protection"),
158
- gr.Radio(['No', 'Yes'], label="Tech Support"),
159
- gr.Radio(['No', 'Yes'], label="Streaming TV"),
160
- gr.Radio(['No', 'Yes'], label="Streaming Movies"),
161
- gr.Radio(['Month-to-month', 'One year', 'Two year'], label="Contract"),
162
- gr.Radio(['Yes', 'No'], label="Paperless Billing"),
163
- gr.Radio(['Electronic check', 'Mailed check', 'Bank transfer (automatic)', 'Credit card (automatic)'], label="Payment Method"),
164
- gr.Slider(18.4, 118.65, label="Monthly Charges")
165
- ],
166
- outputs=gr.Label(label="Prediction"),
167
- title="Customer Churn Prediction"
168
- )
169
-
170
- # --- Launching All Tabs ---
171
- demo = gr.TabbedInterface(
172
- interface_list=[
173
- translator_tab,
174
- sentiment_tab,
175
- financial_tab,
176
- pii_tab,
177
- churn_tab
178
- ],
179
- tab_names=[
180
- "Translator",
181
- "Sentiment Analysis",
182
- "Financial Analyst",
183
- "Personal Info Detection",
184
- "Customer Churn"
185
- ]
186
- )
187
-
188
- if __name__ == '__main__':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  demo.launch()
 
1
+ # app.py
2
+
3
  import os
4
  import gradio as gr
5
  import pandas as pd
6
  import numpy as np
7
  import joblib
8
  import spacy
9
+
10
  from langchain_core.pydantic_v1 import BaseModel, Field
11
  from langchain.prompts import HumanMessagePromptTemplate, ChatPromptTemplate
12
  from langchain.output_parsers import PydanticOutputParser
13
  from langchain_openai import ChatOpenAI
14
+ from transformers import pipeline
15
 
16
+ ### 1. Translator ###
17
  chat = ChatOpenAI()
18
+
19
  class TextTranslator(BaseModel):
20
+ output: str = Field(description="Translated output")
21
 
22
  output_parser = PydanticOutputParser(pydantic_object=TextTranslator)
23
  format_instructions = output_parser.get_format_instructions()
24
 
25
+ def text_translator(input_text: str, language: str) -> str:
26
+ template = """Enter the text that you want to translate:
27
+ {input_text}, and enter the language that you want it to translate to {language}. {format_instructions}"""
28
+ human_prompt = HumanMessagePromptTemplate.from_template(template)
29
+ prompt = ChatPromptTemplate.from_messages([human_prompt]).format_prompt(
30
+ input_text=input_text, language=language, format_instructions=format_instructions)
31
+ response = chat(messages=prompt.to_messages())
32
+ return output_parser.parse(response.content).output
33
+
34
+ ### 2. Sentiment ###
 
 
 
 
 
 
 
35
  sentiment_model = pipeline("sentiment-analysis", model="cardiffnlp/twitter-xlm-roberta-base-sentiment")
36
  def sentiment_analysis(message, history):
37
  result = sentiment_model(message)
38
+ return f"Sentimiento : {result[0]['label']} (Probabilidad: {result[0]['score']:.2f})"
 
 
 
 
 
 
 
 
 
 
39
 
40
+ ### 3. Financial Analyst ###
41
+ nlp = spacy.load('en_core_web_sm')
42
+ nlp.add_pipe('sentencizer')
43
  def split_in_sentences(text):
44
+ return [str(sent).strip() for sent in nlp(text).sents]
 
45
 
46
  def make_spans(text, results):
47
+ labels = [r['label'] for r in results]
48
+ return list(zip(split_in_sentences(text), labels))
49
 
50
+ auth_token = os.environ.get("HF_Token")
 
51
 
52
+ asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h")
53
+ def speech_to_text(audio):
54
+ return asr(audio)["text"]
55
+
56
+ summarizer = pipeline("summarization", model="knkarthick/MEETING_SUMMARY")
57
  def summarize_text(text):
58
  return summarizer(text)[0]['summary_text']
59
 
60
+ fin_model = pipeline("sentiment-analysis", model='yiyanghkust/finbert-tone')
61
  def text_to_sentiment(text):
62
  return fin_model(text)[0]["label"]
63
 
64
+ def fin_ner(text):
65
+ return gr.Interface.load("dslim/bert-base-NER", src='models', use_auth_token=auth_token)(text)
66
+
67
  def fin_ext(text):
68
+ return make_spans(text, fin_model(split_in_sentences(text)))
 
69
 
70
  def fls(text):
71
+ model = pipeline("text-classification", model="demo-org/finbert_fls", tokenizer="demo-org/finbert_fls", use_auth_token=auth_token)
72
+ return make_spans(text, model(split_in_sentences(text)))
 
73
 
74
+ ### 4. Personal Info Detection ###
75
+ def detect_personal_info(text):
76
+ model = gr.Interface.load("iiiorg/piiranha-v1-detect-personal-information")
77
+ return model(text)
78
+
79
+ ### 5. Customer Churn ###
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  script_dir = os.path.dirname(os.path.abspath(__file__))
81
+ pipeline_path = os.path.join(script_dir, 'toolkit', 'pipeline.joblib')
82
+ model_path = os.path.join(script_dir, 'toolkit', 'Random Forest Classifier.joblib')
83
+ pipeline_model = joblib.load(pipeline_path)
84
+ model = joblib.load(model_path)
85
 
86
  def calculate_total_charges(tenure, monthly_charges):
87
  return tenure * monthly_charges
88
 
89
+ def predict(SeniorCitizen, Partner, Dependents, tenure,
90
  InternetService, OnlineSecurity, OnlineBackup, DeviceProtection, TechSupport,
91
  StreamingTV, StreamingMovies, Contract, PaperlessBilling, PaymentMethod,
92
  MonthlyCharges):
93
+
94
  TotalCharges = calculate_total_charges(tenure, MonthlyCharges)
95
  input_df = pd.DataFrame({
96
+ 'SeniorCitizen': [SeniorCitizen], 'Partner': [Partner], 'Dependents': [Dependents],
97
+ 'tenure': [tenure], 'InternetService': [InternetService], 'OnlineSecurity': [OnlineSecurity],
98
+ 'OnlineBackup': [OnlineBackup], 'DeviceProtection': [DeviceProtection], 'TechSupport': [TechSupport],
99
+ 'StreamingTV': [StreamingTV], 'StreamingMovies': [StreamingMovies], 'Contract': [Contract],
100
+ 'PaperlessBilling': [PaperlessBilling], 'PaymentMethod': [PaymentMethod],
101
+ 'MonthlyCharges': [MonthlyCharges], 'TotalCharges': [TotalCharges]
 
 
 
 
 
 
 
 
 
 
102
  })
103
+
104
+ X_processed = pipeline_model.transform(input_df)
105
+ cat_encoder = pipeline_model.named_steps['preprocessor'].named_transformers_['cat'].named_steps['onehot']
106
+ feature_names = [*input_df.select_dtypes(exclude='object').columns, *cat_encoder.get_feature_names_out()]
107
  final_df = pd.DataFrame(X_processed, columns=feature_names)
108
+ pred_probs = model.predict_proba(final_df)[0]
 
109
  return {
110
+ "Prediction: CHURN 🔴": pred_probs[1],
111
+ "Prediction: STAY ": pred_probs[0]
112
  }
113
 
114
+ ### COMBINED UI ###
115
+ with gr.Blocks() as demo:
116
+ with gr.Tab("Translator"):
117
+ gr.Markdown("## Translator")
118
+ input_text = gr.Textbox(label="Text to Translate")
119
+ language = gr.Textbox(label="Target Language")
120
+ output = gr.Textbox(label="Translated Text")
121
+ gr.Button("Translate").click(text_translator, inputs=[input_text, language], outputs=output)
122
+
123
+ with gr.Tab("Sentiment"):
124
+ gr.Markdown("## Sentiment Analysis")
125
+ gr.ChatInterface(sentiment_analysis, type="messages")
126
+
127
+ with gr.Tab("Financial Analyst"):
128
+ gr.Markdown("## Financial Analyst")
129
+ audio = gr.Audio(source="microphone", type="filepath")
130
+ text_input = gr.Textbox()
131
+ summary = gr.Textbox()
132
+ tone_label = gr.Label()
133
+ gr.Button("Speech to Text").click(speech_to_text, inputs=audio, outputs=text_input)
134
+ gr.Button("Summarize").click(summarize_text, inputs=text_input, outputs=summary)
135
+ gr.Button("Classify Tone").click(text_to_sentiment, inputs=summary, outputs=tone_label)
136
+ gr.HighlightedText(label="Tone").render()
137
+ gr.HighlightedText(label="Forward-Looking").render()
138
+ gr.Button("Analyze All").click(fn=fin_ext, inputs=text_input, outputs=None).click(fls, inputs=text_input, outputs=None)
139
+ gr.Button("Entities").click(fin_ner, inputs=text_input, outputs=None)
140
+
141
+ with gr.Tab("Personal Info Detector"):
142
+ gr.Markdown("## Detect Personal Info")
143
+ pi_input = gr.Textbox()
144
+ pi_output = gr.HighlightedText()
145
+ gr.Button("Detect").click(detect_personal_info, inputs=pi_input, outputs=pi_output)
146
+
147
+ with gr.Tab("Customer Churn"):
148
+ gr.Markdown("## Customer Churn Prediction")
149
+ inputs = [
150
+ gr.Radio(["Yes", "No"], label="SeniorCitizen"),
151
+ gr.Radio(["Yes", "No"], label="Partner"),
152
+ gr.Radio(["No", "Yes"], label="Dependents"),
153
+ gr.Slider(1, 73, step=1, label="Tenure"),
154
+ gr.Radio(["DSL", "Fiber optic", "No Internet"], label="InternetService"),
155
+ gr.Radio(["No", "Yes"], label="OnlineSecurity"),
156
+ gr.Radio(["No", "Yes"], label="OnlineBackup"),
157
+ gr.Radio(["No", "Yes"], label="DeviceProtection"),
158
+ gr.Radio(["No", "Yes"], label="TechSupport"),
159
+ gr.Radio(["No", "Yes"], label="StreamingTV"),
160
+ gr.Radio(["No", "Yes"], label="StreamingMovies"),
161
+ gr.Radio(["Month-to-month", "One year", "Two year"], label="Contract"),
162
+ gr.Radio(["Yes", "No"], label="PaperlessBilling"),
163
+ gr.Radio(["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"], label="PaymentMethod"),
164
+ gr.Slider(18.40, 118.65, label="MonthlyCharges")
165
+ ]
166
+ churn_output = gr.Label()
167
+ gr.Button("Predict").click(predict, inputs=inputs, outputs=churn_output)
168
+
169
+ if __name__ == "__main__":
170
  demo.launch()