Spaces:
Running
Running
from typing import List, Dict | |
import requests | |
from abc import ABC, abstractmethod | |
import logging | |
import json | |
logging.basicConfig(level=logging.DEBUG) | |
class SearchEngine(ABC): | |
def __init__(self): | |
# 广告域名黑名单 | |
self.ad_domains = { | |
'ads.google.com', | |
'doubleclick.net', | |
'affiliate.', | |
'.ads.', | |
'promotion.', | |
'sponsored.', | |
'partner.', | |
'tracking.', | |
'.shop.', | |
'taobao.com', | |
'tmall.com', | |
'jd.com', | |
'mafengwo.cn', # 蚂蜂窝 | |
'ctrip.com', # 携程 | |
'tour.aoyou.com', # 同程 | |
'wannar.com' # 玩哪儿 | |
} | |
def is_ad_url(self, url: str) -> bool: | |
"""检查URL是否为广告链接""" | |
url_lower = url.lower() | |
return any(ad_domain in url_lower for ad_domain in self.ad_domains) | |
def enhance_query(self, query: str) -> str: | |
"""增强查询词,添加香港旅游关键词""" | |
if "Hong Kong" not in query: | |
query = f"Hong Kong Tourism{query}" | |
return query | |
def search(self, query: str) -> List[Dict]: | |
pass | |
class GoogleSearch(SearchEngine): | |
def __init__(self, api_key: str, cx: str, proxies: Dict[str, str] = None): | |
super().__init__() | |
self.api_key = api_key | |
self.cx = cx | |
self.base_url = "https://www.googleapis.com/customsearch/v1" | |
self.proxies = proxies or {} | |
def filter_results(self, results: List[Dict]) -> List[Dict]: | |
"""过滤搜索结果""" | |
filtered = [] | |
for result in results: | |
url = result['url'].lower() | |
# 只过滤广告域名 | |
if not self.is_ad_url(url): | |
filtered.append(result) | |
return filtered | |
def search(self, query: str) -> List[Dict]: | |
# 增强查询词 | |
enhanced_query = self.enhance_query(query) | |
params = { | |
'key': self.api_key, | |
'cx': self.cx, | |
'q': enhanced_query | |
} | |
response = requests.get(self.base_url, params=params) | |
if response.status_code == 200: | |
results = response.json() | |
return [{ | |
'title': item['title'], | |
'snippet': item['snippet'], | |
'url': item['link'] | |
} for item in results.get('items', [])] | |
return [] | |
class BochaSearch(SearchEngine): | |
def __init__(self, api_key: str, base_url: str, proxies: Dict[str, str] = None): | |
super().__init__() | |
self.api_key = api_key | |
self.base_url = base_url.rstrip('/') # 移除末尾可能的斜杠 | |
self.proxies = proxies or {} | |
def search(self, query: str) -> List[Dict]: | |
try: | |
# 增强查询词 | |
enhanced_query = self.enhance_query(query) | |
headers = { | |
'Authorization': f'Bearer {self.api_key}', | |
'Content-Type': 'application/json', | |
'Connection': 'keep-alive', | |
'Accept': '*/*' | |
} | |
payload = { | |
'query': enhanced_query, | |
'stream': False # 使用非流式返回 | |
} | |
# 使用正确的端点 | |
endpoint = f"{self.base_url}/v1/ai-search" | |
logging.info(f"正在请求博查API...") | |
logging.info(f"增强后的查询词: {enhanced_query}") | |
response = requests.post( | |
endpoint, | |
headers=headers, | |
json=payload, | |
proxies=None | |
) | |
# 详细打印响应信息 | |
logging.info(f"API响应状态码: {response.status_code}") | |
logging.info(f"API响应内容: {response.text[:500]}...") # 只打印前500个字符 | |
if response.status_code != 200: | |
logging.error(f"API请求失败,状态码: {response.status_code}") | |
logging.error(f"错误响应: {response.text}") | |
return [] | |
response_json = response.json() | |
if response_json.get('code') == 200 and 'messages' in response_json: | |
messages = response_json['messages'] | |
if messages and isinstance(messages, list): | |
for msg in messages: | |
if msg.get('type') == 'source' and msg.get('content_type') == 'webpage': | |
try: | |
content = json.loads(msg['content']) | |
if 'value' in content: | |
return content['value'] | |
except json.JSONDecodeError: | |
logging.error(f"无法解析消息内容: {msg['content']}") | |
continue | |
logging.error(f"API返回数据结构异常: {response_json}") | |
return [] | |
except Exception as e: | |
logging.error(f"处理API响应时出错: {str(e)}") | |
return [] | |
def search_images(self, query: str, count: int = 3) -> List[Dict]: | |
"""搜索相关图片""" | |
try: | |
headers = { | |
'Authorization': f'Bearer {self.api_key}', | |
'Content-Type': 'application/json' | |
} | |
# 增强查询词 | |
enhanced_query = self.enhance_query(query) | |
logging.info(f"增强后的图片搜索查询: {enhanced_query}") | |
payload = { | |
'query': enhanced_query, | |
'freshness': 'oneYear', | |
'count': 10, # 搜索更多图片以确保有足够的有效结果 | |
'filter': 'images' | |
} | |
endpoint = f"{self.base_url}/v1/web-search" | |
response = requests.post( | |
endpoint, | |
headers=headers, | |
json=payload, | |
timeout=10 | |
) | |
if response.status_code == 200: | |
try: | |
data = response.json() | |
logging.info(f"API返回数据结构: {data.keys()}") | |
if data.get('code') == 200 and 'data' in data: | |
data_content = data['data'] | |
logging.info(f"data字段内容: {data_content.keys()}") | |
images = [] | |
if 'images' in data_content: | |
image_items = data_content['images'].get('value', []) | |
logging.info(f"找到 {len(image_items)} 张图片") | |
for item in image_items: | |
# 简化过滤条件,只检查基本必要条件 | |
if (item.get('contentUrl') and | |
item.get('width', 0) >= 300 and | |
item.get('height', 0) >= 300): | |
image_info = { | |
'url': item['contentUrl'], | |
'width': item['width'], | |
'height': item['height'] | |
} | |
images.append(image_info) | |
if len(images) >= count: | |
break | |
logging.info(f"最终返回 {len(images)} 张图片") | |
return images[:count] | |
except json.JSONDecodeError as e: | |
logging.error(f"JSON解析错误: {str(e)}") | |
return [] | |
except Exception as e: | |
logging.error(f"处理图片数据时出错: {str(e)}") | |
return [] | |
logging.error(f"API请求失败,状态码: {response.status_code}") | |
return [] | |
except Exception as e: | |
logging.error(f"图片搜索出错: {str(e)}") | |
return [] | |
""" | |
class BingSearch(SearchEngine): | |
def __init__(self, api_key: str): | |
super().__init__() | |
self.api_key = api_key | |
self.base_url = "https://api.bing.microsoft.com/v7.0/search" | |
def search(self, query: str) -> List[Dict]: | |
# 只添加香港旅游关键词 | |
enhanced_query = f"香港旅游 {query}" | |
headers = {'Ocp-Apim-Subscription-Key': self.api_key} | |
params = { | |
'q': enhanced_query | |
} | |
response = requests.get( | |
self.base_url, | |
headers=headers, | |
params=params | |
) | |
results = response.json() | |
filtered_results = [] | |
for item in results.get('webPages', {}).get('value', []): | |
if not self.is_ad_url(item['url']): | |
filtered_results.append({ | |
'title': item['name'], | |
'snippet': item['snippet'], | |
'url': item['url'] | |
}) | |
return filtered_results | |
def is_trusted_domain(self, url: str) -> bool: | |
""检查是否为可信域名"" | |
return any( | |
trusted_domain in url.lower() | |
for trusted_domain in self.config['search_settings']['trusted_domains'] | |
) | |
""" |