jzou1995's picture
Update app.py
be75490 verified
import os
import re
import json
import requests
from typing import List, Dict, Optional, Tuple
import gradio as gr
from googlesearch import search
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold
from bs4 import BeautifulSoup
def initialize_gemini(api_key: str):
"""Initialize the Google Gemini API with appropriate configurations"""
genai.configure(api_key=api_key)
generation_config = {
"temperature": 0.2,
"top_p": 0.8,
"top_k": 40,
"max_output_tokens": 1024,
}
safety_settings = {
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
model = genai.GenerativeModel(
model_name="gemini-1.5-flash",
generation_config=generation_config,
safety_settings=safety_settings
)
return model
def combined_google_search(company_name: str) -> Tuple[str, List[str]]:
"""
Combined search function that finds both company information and NAICS codes
Returns:
Tuple containing (company_info, naics_code_candidates)
"""
company_info = ""
naics_codes = set()
# Create comprehensive search queries
info_queries = [
f"what is {company_name} company business industry sector",
f"{company_name} company about us business description",
f"{company_name} company profile what they do"
]
naics_queries = [
f"2022 NAICS code for {company_name} company",
f"{company_name} NAICS 2022 classification",
f"what is {company_name} industry NAICS code 2022"
]
all_queries = info_queries + naics_queries
try:
print(f"πŸ” Searching for information about '{company_name}'...")
for query in all_queries:
print(f" Query: {query}")
try:
# Search with each query
search_results = search(query, stop=3, pause=2)
for result_url in search_results:
try:
response = requests.get(result_url, timeout=5)
if response.status_code == 200:
# Extract NAICS codes
found_codes = re.findall(r'\b\d{6}\b', response.text)
if found_codes:
naics_codes.update(found_codes)
print(f" Found codes in {result_url}: {found_codes}")
# Extract company information
if len(company_info) < 1000: # Only if we need more info
soup = BeautifulSoup(response.text, 'html.parser')
paragraphs = soup.find_all('p')
# Get text from paragraphs that mention the company
for p in paragraphs:
text = p.get_text().strip()
if len(text) > 80 and company_name.lower() in text.lower():
company_info += text + "\n\n"
if len(company_info) > 1000:
break
except Exception as e:
print(f" ⚠️ Error fetching {result_url}: {e}")
# If we have enough information, move to the next query
if len(company_info) > 1000 and len(naics_codes) > 0:
break
except Exception as e:
print(f" ⚠️ Error with query '{query}': {e}")
continue
# Return company info and NAICS codes
return company_info.strip(), list(naics_codes)[:10]
except Exception as e:
print(f"❌ Error during Google search: {str(e)}")
return "", []
def analyze_naics_code(model, company_name: str, context: str, company_info: str, naics_candidates: List[str]) -> dict:
"""
Use Gemini AI to determine the most appropriate NAICS code
"""
try:
print("πŸ€– AI is analyzing NAICS classification...")
# Combine provided context with discovered company info
if company_info:
if context:
combined_context = f"{context}\n\nAdditional information found online:\n{company_info}"
else:
combined_context = f"Information found online:\n{company_info}"
else:
combined_context = context
# Create the prompt based on whether we have candidate codes
if naics_candidates:
prompt = f"""
You are a NAICS code classification expert. Based on the company information provided and any NAICS code candidates found from online research, determine the most appropriate NAICS code.
Company Name: {company_name}
Information about the company: {combined_context}
NAICS Code Candidates found in research: {naics_candidates}
First, analyze what these NAICS codes represent and which industry this company belongs to based on the information provided.
Then select the single most appropriate 6-digit NAICS code.
Your response should be in this format:
REASONING: [Your detailed reasoning about why the chosen industry classification is most appropriate for this company, including what business activities it performs]
NAICS_CODE: [6-digit NAICS code]
"""
else:
prompt = f"""
You are a NAICS code classification expert. Based on the company information provided, determine the most appropriate NAICS code.
Company Name: {company_name}
Information about the company: {combined_context}
Analyze what industry this company likely belongs to based on its name and the provided information.
Consider standard business classifications and determine the most appropriate category.
Then provide the single most appropriate 6-digit NAICS code.
Your response should be in this format:
REASONING: [Your detailed reasoning about the company's industry classification, including what business activities it likely performs]
NAICS_CODE: [6-digit NAICS code]
"""
response = model.generate_content(prompt)
response_text = response.text.strip()
# Create result dictionary
result = {}
# Extract reasoning
reasoning_match = re.search(r'REASONING:(.*?)NAICS_CODE:', response_text, re.DOTALL | re.IGNORECASE)
result["reasoning"] = reasoning_match.group(1).strip() if reasoning_match else "No reasoning provided."
# Extract NAICS code
naics_match = re.search(r'NAICS_CODE:(.*?)(\d{6})', response_text, re.DOTALL)
if naics_match:
result["naics_code"] = naics_match.group(2)
else:
# Try to find any 6-digit code in the response
code_match = re.search(r'\b(\d{6})\b', response_text)
result["naics_code"] = code_match.group(1) if code_match else "000000"
return result
except Exception as e:
print(f"❌ Error getting NAICS classification: {str(e)}")
return {
"naics_code": "000000",
"reasoning": f"Error analyzing company: {str(e)}"
}
def find_naics_code(company_name: str, context: str = "", api_key: Optional[str] = None) -> Dict:
"""
Core function to find NAICS code for a company
"""
# Get API key from environment if not provided
if not api_key:
api_key = os.environ.get('GEMINI_API_KEY')
if not api_key:
return {
"error": "No API key provided. Set GEMINI_API_KEY environment variable or pass as parameter.",
"naics_code": "000000",
"reasoning": "Error: API key missing"
}
# Initialize Gemini model
try:
model = initialize_gemini(api_key)
except Exception as e:
return {
"error": f"Failed to initialize Gemini API: {str(e)}",
"naics_code": "000000",
"reasoning": f"Error: {str(e)}"
}
# Run the combined search
company_info, naics_candidates = combined_google_search(company_name)
# Get AI analysis
result = analyze_naics_code(model, company_name, context, company_info, naics_candidates)
# Add metadata
result["company_name"] = company_name
result["context"] = context
result["company_info"] = company_info
result["candidates"] = naics_candidates
return result
# Create the Gradio interface
def create_gradio_interface():
# Check if API key is set in environment
has_api_key = bool(os.environ.get('GEMINI_API_KEY'))
with gr.Blocks(title="NAICS Code Finder") as demo:
gr.Markdown("# NAICS Code Finder")
gr.Markdown("Enter a company name to find its appropriate NAICS code. The tool will search for information about the company and find the most appropriate classification.")
with gr.Row():
with gr.Column():
company_name = gr.Textbox(label="Company Name", placeholder="Enter company name")
company_description = gr.Textbox(label="Additional Context (optional)", placeholder="Any additional information about the company")
# Only show API key input if not set in environment
if not has_api_key:
api_key = gr.Textbox(
label="Gemini API Key (required)",
placeholder="Enter your Google Gemini API key",
type="password"
)
else:
api_key = gr.Textbox(visible=False, value="")
submit_btn = gr.Button("Find NAICS Code", variant="primary")
with gr.Column():
status_output = gr.Markdown(label="Status")
naics_output = gr.Markdown(label="NAICS Code")
with gr.Accordion("Company Information", open=False):
company_info_output = gr.Markdown()
with gr.Accordion("Classification Reasoning", open=True):
reasoning_output = gr.Markdown()
# Functions for the interface
def process_company(company_name, company_description, api_key):
if not company_name:
return "Please enter a company name", "", "", ""
# Use API key from input or environment
key_to_use = api_key if api_key else os.environ.get('GEMINI_API_KEY')
if not key_to_use:
return "No API key provided. Please enter your Gemini API key.", "", "", ""
status_md = "πŸ” Searching for company information and NAICS codes...\n\n"
yield status_md, "", "", ""
# Run the core functionality
result = find_naics_code(company_name, company_description, key_to_use)
# Update status based on results
if "company_info" in result and result["company_info"]:
status_md += "βœ… Found company information\n\n"
company_info_md = f"## Information found about {company_name}\n\n{result['company_info']}"
else:
status_md += "⚠️ Limited company information found\n\n"
company_info_md = f"Limited information found for {company_name}"
if "candidates" in result and result["candidates"]:
status_md += f"βœ… Found {len(result['candidates'])} potential NAICS codes: {', '.join(result['candidates'])}\n\n"
else:
status_md += "⚠️ No specific NAICS codes found in search results\n\n"
status_md += "πŸ€– Analyzing classification...\n\n"
yield status_md, "", company_info_md, ""
# Format the NAICS code output
naics_code_md = f"## NAICS Code: {result['naics_code']}"
# Format the reasoning output
reasoning_md = f"## Analysis\n\n{result['reasoning']}"
status_md += "βœ… Classification complete!"
return status_md, naics_code_md, company_info_md, reasoning_md
submit_btn.click(
process_company,
inputs=[company_name, company_description, api_key],
outputs=[status_output, naics_output, company_info_output, reasoning_output]
)
gr.Examples(
[
["Apple Inc", "Tech company"],
["Walmart", "Retail store"],
["Goldman Sachs", "Investment bank"],
["Ford Motor Company", "Automobile manufacturer"]
],
inputs=[company_name, company_description]
)
return demo
# Create and launch the interface
demo = create_gradio_interface()
# For Spaces deployment
if __name__ == "__main__":
demo.launch()