Spaces:
Running
Running
import spacy | |
from transformers import pipeline | |
from dateutil.parser import parse | |
import re | |
import pandas as pd | |
from difflib import SequenceMatcher | |
class TextClassifier: | |
def __init__(self): | |
# Use a larger model for better NER (optional) | |
self.nlp = spacy.load("en_core_web_lg") # "en_core_web_lg" | |
try: | |
# Use a smaller, PyTorch-compatible model for zero-shot classification | |
self.classifier = pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli") | |
self.model_available = True | |
print("Successfully loaded zero-shot classification model.") | |
except Exception as e: | |
print(f"Failed to load zero-shot classification model: {e}. Falling back to keyword-based classification.") | |
self.classifier = None | |
self.model_available = False | |
self.intents = [ | |
"get_net_income", | |
"get_revenue", | |
"get_stock_price", | |
"get_profit_margin", | |
"get_company_profile", | |
"get_market_cap", | |
"get_historical_stock_price", | |
"get_dividend_info", | |
"get_balance_sheet", | |
"get_cash_flow", | |
"get_financial_ratios", | |
"get_earnings_per_share", | |
"get_interest", | |
"get_research_info", | |
"get_cost_info", | |
"get_income_tax" | |
] | |
# Mapping of company names to ticker symbols (case-insensitive) | |
self.company_to_ticker = { | |
"apple": "AAPL", | |
"microsoft corporation": "MSFT", | |
"microsoft": "MSFT", | |
"nvidia corporation": "NVDA", | |
"nvidia": "NVDA", | |
"amazon": "AMZN", | |
"alphabet inc": "GOOGL", | |
"google": "GOOGL", | |
"meta platforms": "META", | |
"meta": "META", | |
"facebook": "META", | |
"tesla": "TSLA", | |
"walmart inc": "WMT", | |
"walmart": "WMT", | |
"visa inc": "V", | |
"visa": "V", | |
"coca cola": "KO" | |
} | |
# Mapping of keywords to intents (case-insensitive) | |
self.intent_to_keywords = { | |
"get_net_income": ["net income", "income", "earnings"], | |
"get_revenue": ["revenue", "sales", "turnover", "gross income"], | |
"get_stock_price": ["stock price", "stock", "price", "share price", "current price", "price now", "stock value"], | |
"get_profit_margin": ["profit margin", "margin", "profit percentage", "net margin", "profit"], | |
"get_company_profile": ["who is", "company profile", "about company", "company info"], | |
"get_market_cap": ["market cap", "market capitalization", "company value", "valuation"], | |
"get_historical_stock_price": ["historical stock price", "stock price on", "past stock price", "stock price in", "price on"], | |
"get_dividend_info": ["dividend info", "dividend payout", "payout ratio", "dividend yield", "dividend"], | |
"get_balance_sheet": ["balance sheet", "sheet", "financial position", "assets and liabilities", "balance"], | |
"get_cash_flow": ["cash", "flow", "cash flow", "cashflow", "cash from operations", "operating cash"], | |
"get_financial_ratios": ["financial ratios", "ratios", "current ratio", "liquidity ratio", "debt ratio"], | |
"get_earnings_per_share": ["earnings per share", "eps", "per share earnings"], | |
} | |
def classify_by_keywords(self, text): | |
""" | |
Classify the intent based on keyword mapping. | |
Args: | |
text (str): The input text to classify. | |
Returns: | |
str: The predicted intent, or None if no match is found. | |
""" | |
text_lower = text.lower() | |
for intent, keywords in self.intent_to_keywords.items(): | |
if any(keyword in text_lower for keyword in keywords): | |
print(f"Classified intent: {intent} based on keywords: {keywords}") | |
return intent | |
print("No intent matched based on keywords.") | |
return None # Fallback if no keywords match | |
def classify_with_llm(self, text): | |
if not self.model_available: | |
print("Zero-shot classifier not available. Using keyword-based classification.") | |
return self.classify_by_keywords(text) | |
try: | |
hypothesis_template = "This text is requesting {} information." | |
result = self.classifier(text, candidate_labels=self.intents, hypothesis_template=hypothesis_template, multi_label=False) | |
predicted_intent = result["labels"][0] | |
print(f"Predicted intent: {predicted_intent} with scores: {dict(zip(result['labels'], result['scores']))}") | |
return predicted_intent | |
except Exception as e: | |
print(f"Error classifying intent with model: {e}. Falling back to keyword-based classification.") | |
return self.classify_by_keywords(text) | |
def extract_entities(self, text): | |
doc = self.nlp(text) | |
entities = {"ticker": None, "metric": None, "year": None, "date": None} | |
# Step 1: Extract entities using spaCy NER | |
for ent in doc.ents: | |
if ent.label_ == "ORG": | |
org_name = ent.text.lower() | |
ticker = self.company_to_ticker.get(org_name) | |
if ticker: | |
entities["ticker"] = ticker | |
else: | |
# If not found in the mapping, search in the CSV file | |
try: | |
# Load the CSV file (adjust the path as needed) | |
csv_path = "financial data sp500 companies.csv" # Same path as used in Retriever | |
df = pd.read_csv(csv_path) | |
# Ensure the required columns exist | |
if "firm" not in df.columns or "Ticker" not in df.columns: | |
print("Required columns 'firm' or 'Ticker' not found in CSV. Using fallback ticker.") | |
entities["ticker"] = ent.text.upper() | |
else: | |
# Calculate similarity scores between org_name and each firm name | |
df["similarity"] = df["firm"].apply( | |
lambda x: SequenceMatcher(None, org_name, str(x).lower()).ratio() | |
) | |
# Find rows with similarity >= 80% | |
matches = df[df["similarity"] >= 0.5] | |
if not matches.empty: | |
# Take the first match (highest similarity) | |
best_match = matches.sort_values(by="similarity", ascending=False).iloc[0] | |
ticker = best_match["Ticker"] | |
print(f"Found ticker {ticker} for {org_name} with similarity {best_match['similarity']:.2f}") | |
entities["ticker"] = ticker | |
else: | |
print(f"No match found for {org_name} with >= 50% similarity. Using fallback ticker.") | |
entities["ticker"] = ent.text.upper() | |
except Exception as e: | |
print(f"Error searching CSV for ticker: {e}. Using fallback ticker.") | |
entities["ticker"] = ent.text.upper() | |
elif ent.label_ == "DATE": | |
date_text = ent.text.lower() | |
try: | |
parsed_date = parse(date_text, fuzzy=True, default=parse("2025-01-01")) | |
# If the date is a year (e.g., "2023", "this year") or parsed as January 1 | |
if "year" in date_text or date_text.isdigit() or (parsed_date.day == 1 and parsed_date.month == 1): | |
entities["year"] = parsed_date.strftime("%Y") | |
else: | |
# Otherwise, treat it as a specific date (e.g., "Jan 5") | |
entities["date"] = parsed_date.strftime("%Y-%m-%d") | |
except ValueError: | |
# Fallback if parsing fails | |
if "year" in date_text or date_text.isdigit(): | |
entities["year"] = date_text | |
else: | |
entities["date"] = date_text | |
# Step 2: Fallback ticker extraction if spaCy fails to identify ORG | |
if not entities["ticker"]: | |
text_lower = text.lower() | |
for company_name, ticker in self.company_to_ticker.items(): | |
if company_name in text_lower: | |
entities["ticker"] = ticker | |
break | |
# Step 3: Extract metric using keyword matching with synonyms | |
text_lower = text.lower() | |
if any(keyword in text_lower for keyword in ["net income", "net", "income"]): | |
entities["metric"] = "netIncome" | |
elif "revenue" in text_lower: | |
entities["metric"] = "revenue" | |
elif any(keyword in text_lower for keyword in ["profit margin", "profit", "margin"]): | |
entities["metric"] = "netProfitMargin" | |
elif any(keyword in text_lower for keyword in ["market cap", "market capitalization", "market"]): | |
entities["metric"] = "mktCap" | |
elif any(keyword in text_lower for keyword in ["payout ratio", "dividend payout"]): | |
entities["metric"] = "payoutRatio" | |
elif any(keyword in text_lower for keyword in ["current ratio", "liquidity ratio"]): | |
entities["metric"] = "currentRatio" | |
elif any(keyword in text_lower for keyword in ["eps", "earnings per share", "earnings"]): | |
entities["metric"] = "eps" | |
elif any(keyword in text_lower for keyword in ["stock", "stock price", "current price", "valuation", "price"]): | |
entities["metric"] = "price" | |
elif any(keyword in text_lower for keyword in ["company info", "about company", "who is"]): | |
entities["metric"] = "ceo" | |
elif any(keyword in text_lower for keyword in ["balance sheet", "sheet", "assets"]): | |
entities["metric"] = "Assets&Liabilities" | |
elif any(keyword in text_lower for keyword in ["historical", "earnings per share", "earnings"]): | |
entities["metric"] = "historical" | |
elif any(keyword in text_lower for keyword in ["cash", "flow", "cash flow"]): | |
entities["metric"] = "cashFlowFromOperatingActivities" | |
elif any(keyword in text_lower for keyword in ["tax"]): | |
entities["metric"] = "IncomeTax" | |
elif any(keyword in text_lower for keyword in ["interest", "interest expense", "expense"]): | |
entities["metric"] = "InterestExpense" | |
elif any(keyword in text_lower for keyword in ["research", "research development", "development"]): | |
entities["metric"] = "Research" | |
elif any(keyword in text_lower for keyword in ["cost", "total cost"]): | |
entities["metric"] = "TotalCost" | |
# Step 4: Normalize year (handle "this year", "last year", etc.) | |
if entities["year"]: | |
year_text = entities["year"].lower() | |
current_year = 2025 # Based on the current date (April 16, 2025) | |
if "this year" in year_text: | |
entities["year"] = str(current_year) | |
elif "last year" in year_text: | |
entities["year"] = str(current_year - 1) | |
elif re.match(r"^\d{4}$", year_text): | |
entities["year"] = year_text | |
else: | |
# If year is not a valid format, unset it | |
entities["year"] = None | |
return entities |