Spaces:
Running
Running
import time | |
import pytest | |
from typing import Optional, Union | |
import httpx | |
from backend.db.task.models import TaskStatus, Task | |
from backend.tests.test_backend_config import get_client | |
def fetch_task(identifier: str): | |
"""Get task status""" | |
client = get_client() | |
response = client.get( | |
f"/task/{identifier}" | |
) | |
if response.status_code == 200: | |
return response | |
return None | |
def fetch_file_response(identifier: str): | |
"""Get task status""" | |
client = get_client() | |
response = client.get( | |
f"/task/file/{identifier}" | |
) | |
if response.status_code == 200: | |
return response | |
return None | |
def wait_for_task_completion(identifier: str, | |
max_attempts: int = 20, | |
frequency: int = 3) -> httpx.Response: | |
""" | |
Polls the task status until it is completed, failed, or the maximum attempts are reached. | |
Args: | |
identifier (str): The unique identifier of the task to monitor. | |
max_attempts (int): The maximum number of polling attempts.. | |
frequency (int): The time (in seconds) to wait between polling attempts. | |
Returns: | |
bool: Returns json if the task completes successfully within the allowed attempts. | |
""" | |
attempts = 0 | |
while attempts < max_attempts: | |
task = fetch_task(identifier) | |
status = task.json()["status"] | |
if status == TaskStatus.COMPLETED: | |
return task | |
if status == TaskStatus.FAILED: | |
raise Exception("Task polling failed") | |
time.sleep(frequency) | |
attempts += 1 | |
return None | |