Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
import pandas as pd | |
import torch | |
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer | |
import plotly.graph_objects as go | |
import logging | |
import io | |
from rapidfuzz import fuzz | |
def fuzzy_deduplicate(df, column, threshold=55): | |
"""Deduplicate rows based on fuzzy matching of text content""" | |
seen_texts = [] | |
indices_to_keep = [] | |
for i, text in enumerate(df[column]): | |
if pd.isna(text): | |
indices_to_keep.append(i) | |
continue | |
text = str(text) | |
if not seen_texts or all(fuzz.ratio(text, seen) < threshold for seen in seen_texts): | |
seen_texts.append(text) | |
indices_to_keep.append(i) | |
return df.iloc[indices_to_keep] | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class ProcessControl: | |
def __init__(self): | |
self.stop_requested = False | |
def request_stop(self): | |
self.stop_requested = True | |
def should_stop(self): | |
return self.stop_requested | |
def reset(self): | |
self.stop_requested = False | |
class EventDetector: | |
def __init__(self): | |
self.model_name = "google/mt5-small" | |
# Initialize tokenizer with legacy=True to suppress warning | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
legacy=True | |
) | |
self.model = None | |
self.finbert = None | |
self.roberta = None | |
self.finbert_tone = None | |
self.control = ProcessControl() | |
def initialize_models(self): | |
"""Initialize all models with GPU support""" | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Initializing models on device: {device}") | |
# Initialize MT5 model | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name).to(device) | |
# Initialize sentiment analysis pipelines | |
self.finbert = pipeline( | |
"sentiment-analysis", | |
model="ProsusAI/finbert", | |
device=device, | |
truncation=True, | |
max_length=512 | |
) | |
self.roberta = pipeline( | |
"sentiment-analysis", | |
model="cardiffnlp/twitter-roberta-base-sentiment", | |
device=device, | |
truncation=True, | |
max_length=512 | |
) | |
self.finbert_tone = pipeline( | |
"sentiment-analysis", | |
model="yiyanghkust/finbert-tone", | |
device=device, | |
truncation=True, | |
max_length=512 | |
) | |
logger.info("All models initialized successfully") | |
return True | |
except Exception as e: | |
logger.error(f"Model initialization error: {str(e)}") | |
return False | |
def detect_events(self, text, entity): | |
if not text or not entity: | |
return "Нет", "Invalid input" | |
try: | |
# Check if models are initialized | |
if self.model is None: | |
if not self.initialize_models(): | |
return "Нет", "Model initialization failed" | |
# Truncate input text | |
text = text[:500] | |
prompt = f"""<s>Analyze the following news about {entity}: | |
Text: {text} | |
Task: Identify the main event type and provide a brief summary.</s>""" | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
).to(self.model.device) | |
outputs = self.model.generate( | |
**inputs, | |
max_length=300, | |
num_return_sequences=1, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
event_type = "Нет" | |
if any(term in text.lower() for term in ['отчет', 'выручка', 'прибыль', 'ebitda']): | |
event_type = "Отчетность" | |
elif any(term in text.lower() for term in ['облигаци', 'купон', 'дефолт']): | |
event_type = "РЦБ" | |
elif any(term in text.lower() for term in ['суд', 'иск', 'арбитраж']): | |
event_type = "Суд" | |
return event_type, response | |
except Exception as e: | |
logger.error(f"Event detection error: {str(e)}") | |
return "Нет", f"Error: {str(e)}" | |
def get_sentiment_label(self, result): | |
"""Helper method for sentiment classification""" | |
label = result['label'].lower() | |
if label in ["positive", "label_2", "pos"]: | |
return "Positive" | |
elif label in ["negative", "label_0", "neg"]: | |
return "Negative" | |
return "Neutral" | |
def analyze_sentiment(self, text): | |
try: | |
if self.finbert is None: | |
if not self.initialize_models(): | |
return "Neutral" | |
truncated_text = text[:500] | |
results = [] | |
try: | |
inputs = [truncated_text] | |
finbert_result = self.finbert(inputs)[0] | |
roberta_result = self.roberta(inputs)[0] | |
finbert_tone_result = self.finbert_tone(inputs)[0] | |
results = [ | |
self.get_sentiment_label(finbert_result), | |
self.get_sentiment_label(roberta_result), | |
self.get_sentiment_label(finbert_tone_result) | |
] | |
except Exception as e: | |
logger.error(f"Model inference error: {e}") | |
return "Neutral" | |
sentiment_counts = pd.Series(results).value_counts() | |
return sentiment_counts.index[0] if sentiment_counts.iloc[0] >= 2 else "Neutral" | |
except Exception as e: | |
logger.error(f"Sentiment analysis error: {e}") | |
return "Neutral" | |
def create_visualizations(df): | |
if df is None or df.empty: | |
return None, None | |
try: | |
sentiments = df['Sentiment'].value_counts() | |
fig_sentiment = go.Figure(data=[go.Pie( | |
labels=sentiments.index, | |
values=sentiments.values, | |
marker_colors=['#FF6B6B', '#4ECDC4', '#95A5A6'] | |
)]) | |
fig_sentiment.update_layout(title="Распределение тональности") | |
events = df['Event_Type'].value_counts() | |
fig_events = go.Figure(data=[go.Bar( | |
x=events.index, | |
y=events.values, | |
marker_color='#2196F3' | |
)]) | |
fig_events.update_layout(title="Распределение событий") | |
return fig_sentiment, fig_events | |
except Exception as e: | |
logger.error(f"Visualization error: {e}") | |
return None, None | |
def process_file(file_obj): | |
try: | |
logger.info("Starting to read Excel file...") | |
df = pd.read_excel(file_obj, sheet_name='Публикации') | |
logger.info(f"Successfully read Excel file. Shape: {df.shape}") | |
# Perform deduplication | |
original_count = len(df) | |
df = fuzzy_deduplicate(df, 'Выдержки из текста', threshold=55) | |
logger.info(f"Removed {original_count - len(df)} duplicate entries") | |
detector = EventDetector() | |
processed_rows = [] | |
total = len(df) | |
# Initialize models once for all rows | |
if not detector.initialize_models(): | |
raise Exception("Failed to initialize models") | |
for idx, row in df.iterrows(): | |
try: | |
text = str(row.get('Выдержки из текста', '')) | |
if not text.strip(): | |
continue | |
entity = str(row.get('Объект', '')) | |
if not entity.strip(): | |
continue | |
event_type, event_summary = detector.detect_events(text, entity) | |
sentiment = detector.analyze_sentiment(text) | |
processed_rows.append({ | |
'Объект': entity, | |
'Заголовок': str(row.get('Заголовок', '')), | |
'Sentiment': sentiment, | |
'Event_Type': event_type, | |
'Event_Summary': event_summary, | |
'Текст': text[:1000] # Truncate text for display | |
}) | |
if idx % 5 == 0: | |
logger.info(f"Processed {idx + 1}/{total} rows") | |
except Exception as e: | |
logger.error(f"Error processing row {idx}: {str(e)}") | |
continue | |
result_df = pd.DataFrame(processed_rows) | |
logger.info(f"Processing complete. Final DataFrame shape: {result_df.shape}") | |
return result_df | |
except Exception as e: | |
logger.error(f"File processing error: {str(e)}") | |
raise | |
def create_interface(): | |
control = ProcessControl() | |
with gr.Blocks(theme=gr.themes.Soft()) as app: | |
gr.Markdown("# AI-анализ мониторинга новостей v.1.14") | |
with gr.Row(): | |
file_input = gr.File( | |
label="Загрузите Excel файл", | |
file_types=[".xlsx"], | |
type="binary" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
analyze_btn = gr.Button( | |
"▶️ Начать анализ", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=1): | |
stop_btn = gr.Button( | |
"⏹️ Остановить", | |
variant="stop", | |
size="lg" | |
) | |
with gr.Row(): | |
progress = gr.Textbox( | |
label="Статус обработки", | |
interactive=False, | |
value="Ожидание файла..." | |
) | |
with gr.Row(): | |
stats = gr.DataFrame( | |
label="Результаты анализа", | |
interactive=False, | |
wrap=True | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
sentiment_plot = gr.Plot(label="Распределение тональности") | |
with gr.Column(scale=1): | |
events_plot = gr.Plot(label="Распределение событий") | |
def stop_processing(): | |
control.request_stop() | |
return "Остановка обработки..." | |
def analyze(file_bytes): | |
if file_bytes is None: | |
gr.Warning("Пожалуйста, загрузите файл") | |
return None, None, None, "Ожидание файла..." | |
try: | |
# Reset stop flag | |
control.reset() | |
file_obj = io.BytesIO(file_bytes) | |
logger.info("File loaded into BytesIO successfully") | |
progress_status = "Начинаем обработку файла..." | |
yield None, None, None, progress_status | |
# Process file | |
df = pd.read_excel(file_obj, sheet_name='Публикации') | |
logger.info(f"Successfully read Excel file. Shape: {df.shape}") | |
# Deduplication | |
original_count = len(df) | |
df = fuzzy_deduplicate(df, 'Выдержки из текста', threshold=55) | |
logger.info(f"Removed {original_count - len(df)} duplicate entries") | |
detector = EventDetector() | |
detector.control = control # Pass control object | |
processed_rows = [] | |
total = len(df) | |
# Initialize models | |
if not detector.initialize_models(): | |
raise Exception("Failed to initialize models") | |
for idx, row in df.iterrows(): | |
if control.should_stop(): | |
yield ( | |
pd.DataFrame(processed_rows) if processed_rows else None, | |
None, None, | |
f"Обработка остановлена. Обработано {idx} из {total} строк" | |
) | |
return | |
try: | |
text = str(row.get('Выдержки из текста', '')) | |
if not text.strip(): | |
continue | |
entity = str(row.get('Объект', '')) | |
if not entity.strip(): | |
continue | |
event_type, event_summary = detector.detect_events(text, entity) | |
sentiment = detector.analyze_sentiment(text) | |
processed_rows.append({ | |
'Объект': entity, | |
'Заголовок': str(row.get('Заголовок', '')), | |
'Sentiment': sentiment, | |
'Event_Type': event_type, | |
'Event_Summary': event_summary, | |
'Текст': text[:1000] | |
}) | |
if idx % 5 == 0: | |
progress_status = f"Обработано {idx + 1}/{total} строк" | |
yield None, None, None, progress_status | |
except Exception as e: | |
logger.error(f"Error processing row {idx}: {str(e)}") | |
continue | |
result_df = pd.DataFrame(processed_rows) | |
fig_sentiment, fig_events = create_visualizations(result_df) | |
return ( | |
result_df, | |
fig_sentiment, | |
fig_events, | |
f"Обработка завершена успешно! Обработано {len(result_df)} строк" | |
) | |
except Exception as e: | |
error_msg = f"Ошибка анализа: {str(e)}" | |
logger.error(error_msg) | |
gr.Error(error_msg) | |
return None, None, None, error_msg | |
stop_btn.click(fn=stop_processing, outputs=[progress]) | |
analyze_btn.click( | |
fn=analyze, | |
inputs=[file_input], | |
outputs=[stats, sentiment_plot, events_plot, progress] | |
) | |
return app | |
if __name__ == "__main__": | |
app = create_interface() | |
app.launch(share=True) |