|
import asyncio |
|
import os |
|
from playwright.async_api import async_playwright, Page, Request, Response, Download |
|
import re |
|
import logging |
|
from urllib.parse import urlparse |
|
from datetime import datetime, timedelta |
|
import enum |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s [%(levelname)s] %(message)s", |
|
handlers=[logging.StreamHandler()], |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class VoiceType(enum.Enum): |
|
NEUTRAL = "voice1" |
|
HAPPY = "voice2" |
|
SAD = "voice3" |
|
ANGRY = "voice4" |
|
EXCITED = "voice5" |
|
CALM = "voice6" |
|
|
|
|
|
|
|
class PiAIClient: |
|
def __init__(self, headless: bool = False, download_dir: str = "/tmp/Audio"): |
|
self.headless = headless |
|
self.download_dir = "/tmp/Audio" |
|
self.playwright = None |
|
self.browser = None |
|
self.context = None |
|
self.page = None |
|
self.initialized = False |
|
|
|
|
|
self.actions = [ |
|
{ |
|
"selector": 'textarea[placeholder="Talk with Pi"]', |
|
"handler": self.send_chat_message, |
|
"description": "Chat input detected, sending message.", |
|
"break_after": True, |
|
}, |
|
{ |
|
"selector": 'button:has-text("I’ll do it later")', |
|
"handler": self.click_element, |
|
"description": "'I’ll do it later' button found, clicking it.", |
|
}, |
|
{ |
|
"selector": 'button:has-text("Next")', |
|
"handler": self.click_element, |
|
"description": "'Next' button found, clicking it.", |
|
}, |
|
{ |
|
"selector": 'textarea[placeholder="Your first name"]', |
|
"handler": self.fill_name, |
|
"description": "Name input detected, filling it.", |
|
}, |
|
] |
|
|
|
|
|
self.sid_regex = re.compile(r'"sid":"([\w\-]+)"') |
|
|
|
|
|
self.processed_sids = set() |
|
|
|
|
|
self.download_dir = download_dir |
|
self.ensure_download_directory() |
|
|
|
|
|
self.semaphore = asyncio.Semaphore(5) |
|
|
|
|
|
self.rate_limit_until = None |
|
self.rate_limit_lock = asyncio.Lock() |
|
|
|
|
|
self.sid_futures = asyncio.Queue() |
|
|
|
def ensure_download_directory(self): |
|
"""Ensure that the downloads directory exists.""" |
|
if not os.path.exists(self.download_dir): |
|
os.makedirs(self.download_dir) |
|
logger.info( |
|
f"Created directory '{self.download_dir}' for storing downloads." |
|
) |
|
else: |
|
logger.info(f"Directory '{self.download_dir}' already exists.") |
|
|
|
async def setup(self): |
|
"""Initialize Playwright, launch the browser with a persistent context, and create a new page.""" |
|
self.playwright = await async_playwright().start() |
|
|
|
|
|
user_data_dir = os.path.join(os.getcwd(), "user_data") |
|
if not os.path.exists(user_data_dir): |
|
os.makedirs(user_data_dir) |
|
logger.info(f"Created user data directory at '{user_data_dir}'.") |
|
else: |
|
logger.info(f"Using existing user data directory at '{user_data_dir}'.") |
|
|
|
|
|
self.context = await self.playwright.chromium.launch_persistent_context( |
|
user_agent=( |
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) " |
|
"AppleWebKit/537.36 (KHTML, like Gecko) " |
|
"Chrome/114.0.0.0 Safari/537.36" |
|
), |
|
user_data_dir=user_data_dir, |
|
headless=self.headless, |
|
args=["--no-sandbox"], |
|
) |
|
|
|
|
|
self.page = await self.context.new_page() |
|
|
|
|
|
self.page.on("request", self.handle_request) |
|
self.page.on("response", self.handle_response) |
|
await self.navigate("https://pi.ai/talk") |
|
|
|
|
|
asyncio.create_task(self.monitor_page_and_act()) |
|
self.initialized = True |
|
|
|
async def navigate(self, url: str): |
|
"""Navigate to the specified URL and wait for the page to load.""" |
|
await self.page.goto(url) |
|
await self.page.wait_for_load_state("networkidle") |
|
logger.info(f"Navigated to {url}") |
|
|
|
async def monitor_page_and_act(self): |
|
"""Continuously monitor the page and perform actions based on the detected elements.""" |
|
counter = 0 |
|
while True: |
|
try: |
|
|
|
if self.is_rate_limited(): |
|
wait_seconds = ( |
|
self.rate_limit_until - datetime.utcnow() |
|
).total_seconds() |
|
wait_seconds = max(wait_seconds, 0) |
|
logger.warning( |
|
f"Rate limited. Waiting for {wait_seconds:.2f} seconds before retrying." |
|
) |
|
await asyncio.sleep(wait_seconds) |
|
continue |
|
|
|
action_performed = False |
|
for action in self.actions: |
|
|
|
|
|
if await self.page.is_visible(action["selector"]): |
|
logger.info(action["description"]) |
|
await action["handler"](action["selector"]) |
|
action_performed = True |
|
if action.get("break_after"): |
|
action_performed = ( |
|
False |
|
) |
|
break |
|
if not action_performed: |
|
logger.info( |
|
"No matching state detected. Navigating to /talk or /discover route." |
|
) |
|
if counter % 5 == 0: |
|
await self.navigate_to_route("/discover") |
|
logger.info("Navigated to /discover route.") |
|
counter = 0 |
|
else: |
|
await self.navigate_to_route("/talk") |
|
logger.info("Navigated to /talk route.") |
|
counter += 1 |
|
|
|
|
|
await asyncio.sleep(2) |
|
|
|
except Exception as e: |
|
logger.error(f"Error during monitoring: {e}") |
|
await asyncio.sleep( |
|
2 |
|
) |
|
|
|
def is_rate_limited(self): |
|
"""Check if the bot is currently rate-limited.""" |
|
if self.rate_limit_until and datetime.utcnow() < self.rate_limit_until: |
|
return True |
|
return False |
|
|
|
async def navigate_to_route(self, route): |
|
"""Navigate to the specified route.""" |
|
try: |
|
current_url = self.page.url |
|
|
|
if not current_url.endswith(route): |
|
new_url = self.construct_route_url(current_url, route) |
|
await self.navigate(new_url) |
|
else: |
|
logger.info(f"Already on the {route} route.") |
|
except Exception as e: |
|
logger.error(f"Error navigating to {route} route: {e}") |
|
|
|
def construct_route_url(self, current_url, route): |
|
"""Construct the new URL for the specified route.""" |
|
|
|
parsed_url = urlparse(current_url) |
|
|
|
new_url = parsed_url._replace(path=route).geturl() |
|
logger.info(f"Constructed new URL: {new_url}") |
|
return new_url |
|
|
|
async def click_element(self, selector: str): |
|
"""Wait for an element to be visible and click it.""" |
|
try: |
|
await self.page.wait_for_selector(selector, timeout=3000) |
|
await self.page.click(selector) |
|
logger.info(f"Clicked element: {selector}") |
|
except Exception as e: |
|
logger.error(f"Error clicking element {selector}: {e}") |
|
|
|
async def fill_name(self, selector: str): |
|
"""Fill in the name input field and submit.""" |
|
try: |
|
name = "Cassandra" |
|
await self.page.fill(selector, name) |
|
await self.page.click('button[aria-label="Submit text"]') |
|
logger.info(f"Name '{name}' submitted") |
|
except Exception as e: |
|
logger.error(f"Error submitting name: {e}") |
|
await self.handle_send_failure() |
|
|
|
async def send_chat_message(self, selector: str): |
|
"""Send a chat message in the chat input field.""" |
|
try: |
|
await self.page.fill(selector, self.user_input) |
|
await self.page.click('button[aria-label="Submit text"]') |
|
logger.info("Chat message submitted") |
|
except Exception as e: |
|
logger.error(f"Could not send chat message: {e}") |
|
await self.handle_send_failure() |
|
|
|
async def handle_send_failure(self): |
|
"""Handle failure in sending messages by navigating to /talk or /discover.""" |
|
try: |
|
|
|
await self.navigate_to_route("/talk") |
|
logger.info("Navigated to /talk route after failing to send message.") |
|
except Exception: |
|
try: |
|
|
|
await self.navigate_to_route("/discover") |
|
logger.info( |
|
"Navigated to /discover route after failing to send message." |
|
) |
|
except Exception as e2: |
|
logger.error(f"Failed to navigate after send_message failure: {e2}") |
|
|
|
async def handle_request(self, request: Request): |
|
"""Handle and log network requests.""" |
|
|
|
logger.debug(f"Request: {request.method} {request.url}") |
|
|
|
async def handle_response(self, response: Response): |
|
"""Handle and log network responses, extracting 'sid's.""" |
|
url = response.url |
|
if "/api/chat" in url and response.request.method == "POST": |
|
logger.info(f"Handling response for: {url}") |
|
try: |
|
response_status = response.status |
|
response_text = await asyncio.wait_for(response.text(), timeout=5) |
|
logger.info(f"Response received from {url}: {response_text}") |
|
|
|
if response_status == 429: |
|
|
|
logger.warning("Received 429 Too Many Requests.") |
|
retry_after = response.headers.get("Retry-After") |
|
if retry_after: |
|
wait_seconds = int(retry_after) |
|
else: |
|
wait_seconds = ( |
|
60 |
|
) |
|
await self.trigger_rate_limit(wait_seconds) |
|
return |
|
|
|
|
|
if "error" in response_text and "Too Many Requests" in response_text: |
|
logger.warning("Received error response: Too Many Requests.") |
|
await self.trigger_rate_limit(60) |
|
return |
|
|
|
|
|
sids = self.sid_regex.findall(response_text) |
|
|
|
if sids: |
|
logger.info(f"Extracted 'sid's: {sids}") |
|
for sid in sids: |
|
if sid not in self.processed_sids: |
|
self.processed_sids.add(sid) |
|
logger.info(f"Processing sid: {sid}") |
|
|
|
if not self.sid_futures.empty(): |
|
future, voice = await self.sid_futures.get() |
|
|
|
asyncio.create_task( |
|
self.process_sid(sid, voice, future) |
|
) |
|
else: |
|
|
|
asyncio.create_task( |
|
self.process_sid(sid, VoiceType.NEUTRAL.value, None) |
|
) |
|
break |
|
else: |
|
logger.info("No 'sid's found in the response.") |
|
|
|
except asyncio.TimeoutError: |
|
logger.warning( |
|
"Timed out waiting for the response body (possibly streaming)." |
|
) |
|
except Exception as e: |
|
logger.error(f"Error processing response: {e}") |
|
elif "/api/chat/voice" in url: |
|
|
|
pass |
|
|
|
async def trigger_rate_limit(self, wait_seconds: int): |
|
"""Trigger rate limiting by setting the rate_limit_until timestamp.""" |
|
async with self.rate_limit_lock: |
|
if ( |
|
not self.rate_limit_until |
|
or datetime.utcnow() + timedelta(seconds=wait_seconds) |
|
> self.rate_limit_until |
|
): |
|
self.rate_limit_until = datetime.utcnow() + timedelta( |
|
seconds=wait_seconds |
|
) |
|
logger.warning( |
|
f"Rate limited. Will resume after {self.rate_limit_until} UTC." |
|
) |
|
else: |
|
self.rate_limit_until += timedelta(seconds=wait_seconds) |
|
logger.warning("Already rate limited. Extending the wait time.") |
|
|
|
async def process_sid(self, sid: str, voice: str, future: asyncio.Future): |
|
"""Download the TTS audio using the sid and specified voice.""" |
|
async with self.semaphore: |
|
try: |
|
logger.info(f"Processing sid: {sid} with voice: {voice}") |
|
url = f"https://pi.ai/api/chat/voice?mode=eager&voice={voice}&messageSid={sid}" |
|
logger.info(f"Initiating download from URL: {url}") |
|
|
|
|
|
new_page = await self.context.new_page() |
|
|
|
|
|
new_page.on("download", self.handle_download) |
|
|
|
|
|
await new_page.goto(url) |
|
logger.info(f"Opened URL: {url}") |
|
|
|
|
|
await new_page.evaluate( |
|
f""" |
|
(function() {{ |
|
var link = document.createElement('a'); |
|
link.href = "{url}"; |
|
link.download = "{sid}.mp3"; |
|
document.body.appendChild(link); |
|
link.click(); |
|
document.body.removeChild(link); |
|
}})(); |
|
""" |
|
) |
|
logger.info(f"Triggered download for sid: {sid}") |
|
filename = f"{sid}_{voice.lower()}.mp3" |
|
file_path = os.path.join(self.download_dir, filename) |
|
|
|
|
|
|
|
await asyncio.sleep(2) |
|
|
|
await new_page.close() |
|
|
|
|
|
if future: |
|
future.set_result(file_path) |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing sid {sid}: {e}") |
|
if future and not future.done(): |
|
future.set_exception(e) |
|
|
|
async def handle_download(self, download: Download): |
|
"""Handle the download event and save the file.""" |
|
try: |
|
|
|
filename = download.suggested_filename or "audio.mp3" |
|
download_path = os.path.join(self.download_dir, filename) |
|
|
|
|
|
await download.save_as(download_path) |
|
logger.info(f"Downloaded audio to {download_path}") |
|
|
|
except Exception as e: |
|
logger.error(f"Error downloading audio: {e}") |
|
|
|
async def close(self): |
|
"""Close the browser and Playwright.""" |
|
if self.context: |
|
await self.context.close() |
|
if self.playwright: |
|
await self.playwright.stop() |
|
logger.info("Browser closed") |
|
|
|
async def say(self, message: str, voice: str) -> str: |
|
""" |
|
Send a message and retrieve the path to the downloaded TTS audio. |
|
|
|
:param message: The message to send. |
|
:param voice: The emotional voice type to use. |
|
:return: The file path of the downloaded audio. |
|
""" |
|
|
|
future = asyncio.get_event_loop().create_future() |
|
|
|
await self.sid_futures.put((future, voice)) |
|
|
|
self.user_input = message |
|
await self.send_message(message) |
|
|
|
try: |
|
audio_path = await asyncio.wait_for( |
|
future, timeout=60 |
|
) |
|
return audio_path |
|
except asyncio.TimeoutError: |
|
logger.error("Timeout while waiting for audio download.") |
|
return "" |
|
|
|
async def send_message( |
|
self, message: str, retry_count: int = 3, retry_delay: int = 60 |
|
): |
|
""" |
|
Send a message through the chat interface with retry logic. |
|
|
|
:param message: The message to send. |
|
:param retry_count: Number of times to retry on failure. |
|
:param retry_delay: Seconds to wait before retrying. |
|
""" |
|
attempt = 0 |
|
while attempt < retry_count: |
|
try: |
|
|
|
if self.is_rate_limited(): |
|
wait_seconds = ( |
|
self.rate_limit_until - datetime.utcnow() |
|
).total_seconds() |
|
wait_seconds = max(wait_seconds, 0) |
|
logger.warning( |
|
f"Currently rate limited. Waiting for {wait_seconds:.2f} seconds before retrying." |
|
) |
|
await asyncio.sleep(wait_seconds) |
|
|
|
self.user_input = message |
|
await self.page.fill( |
|
'textarea[placeholder="Talk with Pi"]', self.user_input |
|
) |
|
await self.page.click('button[aria-label="Submit text"]') |
|
logger.info("Chat message submitted") |
|
return |
|
|
|
except Exception as e: |
|
logger.error(f"Could not send chat message: {e}") |
|
attempt += 1 |
|
if attempt < retry_count: |
|
logger.info( |
|
f"Retrying to send message in {retry_delay} seconds... (Attempt {attempt}/{retry_count})" |
|
) |
|
await asyncio.sleep(retry_delay) |
|
else: |
|
logger.error( |
|
"Max retry attempts reached. Failed to send the message." |
|
) |
|
await self.handle_send_failure() |
|
|
|
|
|
await self.handle_send_failure() |
|
|
|
|
|
import asyncio |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|