UN-2793 [FEAT] Add retry logic with exponential backoff to SDK1 (#1564)
* UN-2793 [FEAT] Add retry logic with exponential backoff to SDK1 Implemented automatic retry logic for platform and prompt service calls with configurable exponential backoff, comprehensive test coverage, and CI integration. Features: - Exponential backoff with jitter for transient failures - Configurable via environment variables (MAX_RETRIES, MAX_TIME, BASE_DELAY, etc.) - Retries ConnectionError, Timeout, HTTPError (502/503/504), OSError - 67 tests with 100% pass rate - CI integration with test reporting 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * [SECURITY] Use full commit SHA for sticky-pull-request-comment action Replace tag reference with full commit SHA for better security: - marocchino/sticky-pull-request-comment@v2 → @7737449 (v2.9.4) This prevents potential supply chain attacks where tags could be moved to point to malicious code. Commit SHAs are immutable. Fixes SonarQube security hotspot for external GitHub action. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * [FIX] Allow retryable HTTP errors (502/503/504) to propagate for retry Fixed HTTPError handling in _get_adapter_configuration to check status codes and re-raise retryable errors (502, 503, 504) so the retry decorator can handle them. Non-retryable errors are still converted to SdkError as before. Changes: - Check HTTPError status code before converting to SdkError - Re-raise HTTPError for 502/503/504 to allow retry decorator to retry - Added parametrized test for all retryable status codes (502, 503, 504) - All 12 platform tests pass This fixes a bug where 502/503/504 errors were not being retried because they were converted to SdkError before the retry decorator could see them. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * [FIX] Use pytest.approx() for floating point comparisons in tests Replaced direct equality comparisons (==) with pytest.approx() for floating point values to avoid precision issues and satisfy SonarQube code quality check (python:S1244). Changes in test_retry_utils.py: - test_exponential_backoff_without_jitter: Use pytest.approx() for 1.0, 2.0, 4.0, 8.0 - test_max_delay_cap: Use pytest.approx() for 5.0 This is the proper way to compare floating point values in tests, accounting for floating point precision limitations. All 4 TestCalculateDelay tests pass. Fixes SonarQube: python:S1244 - Do not perform equality checks with floating point values. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * minor: Addressed code smells, ruff fixes * misc: Fixed tox config for sdk1 tests * misc: Ruff issues fixed * misc: tox tests fixed * prompt service lock file for venv * updated lock files for backend and prompt-service * UN-2793 [FEAT] Update to unstract-sdk v0.78.0 with retry logic support (#1567) [FEAT] Update unstract-sdk to v0.78.0 across all services and tools - Updated unstract-sdk dependency from v0.77.3 to v0.78.0 in all pyproject.toml files - Main repository, backend, workers, platform-service, prompt-service - filesystem and tool-registry modules - Updated tool requirements.txt files (structure, classifier, text_extractor) - Bumped tool versions in properties.json: - Structure tool: 0.0.88 → 0.0.89 - Classifier tool: 0.0.68 → 0.0.69 - Text extractor tool: 0.0.64 → 0.0.65 - Updated tool versions in backend/sample.env and public_tools.json - Regenerated all uv.lock files with new SDK version This update brings in the retry logic with exponential backoff from unstract-sdk v0.78.0 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude <noreply@anthropic.com> --------- Signed-off-by: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
9997ee10e5
commit
0c0c8c1034
20
.github/workflows/ci-test.yaml
vendored
20
.github/workflows/ci-test.yaml
vendored
@@ -48,13 +48,22 @@ jobs:
|
||||
run: |
|
||||
tox
|
||||
|
||||
- name: Render the report to the PR
|
||||
uses: marocchino/sticky-pull-request-comment@v2
|
||||
- name: Render the Runner report to the PR
|
||||
uses: marocchino/sticky-pull-request-comment@773744901bac0e8cbb5a0dc842800d45e9b2b405 # v2.9.4
|
||||
if: always() && hashFiles('runner-report.md') != ''
|
||||
with:
|
||||
header: runner-test-report
|
||||
recreate: true
|
||||
path: runner-report.md
|
||||
|
||||
- name: Render the SDK1 report to the PR
|
||||
uses: marocchino/sticky-pull-request-comment@773744901bac0e8cbb5a0dc842800d45e9b2b405 # v2.9.4
|
||||
if: always() && hashFiles('sdk1-report.md') != ''
|
||||
with:
|
||||
header: sdk1-test-report
|
||||
recreate: true
|
||||
path: sdk1-report.md
|
||||
|
||||
- name: Output reports to the job summary when tests fail
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -65,3 +74,10 @@ jobs:
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "</details>" >> $GITHUB_STEP_SUMMARY
|
||||
fi
|
||||
if [ -f "sdk1-report.md" ]; then
|
||||
echo "<details><summary>SDK1 Test Report</summary>" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
cat "sdk1-report.md" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "</details>" >> $GITHUB_STEP_SUMMARY
|
||||
fi
|
||||
|
||||
@@ -35,9 +35,7 @@ dependencies = [
|
||||
"python-socketio==5.9.0", # For log_events
|
||||
"social-auth-app-django==5.3.0", # For OAuth
|
||||
"social-auth-core==4.4.2", # For OAuth
|
||||
# TODO: Temporarily removing the extra dependencies of aws and gcs from unstract-sdk
|
||||
# to resolve lock file. Will have to be re-looked into
|
||||
"unstract-sdk[azure]~=0.77.3",
|
||||
"unstract-sdk[aws,gcs,azure]~=0.78.0",
|
||||
"azure-identity==1.16.0",
|
||||
"azure-mgmt-apimanagement==3.0.0",
|
||||
"croniter>=3.0.3",
|
||||
|
||||
@@ -78,9 +78,9 @@ PROMPT_STUDIO_FILE_PATH=/app/prompt-studio-data
|
||||
|
||||
# Structure Tool Image (Runs prompt studio exported tools)
|
||||
# https://hub.docker.com/r/unstract/tool-structure
|
||||
STRUCTURE_TOOL_IMAGE_URL="docker:unstract/tool-structure:0.0.88"
|
||||
STRUCTURE_TOOL_IMAGE_URL="docker:unstract/tool-structure:0.0.89"
|
||||
STRUCTURE_TOOL_IMAGE_NAME="unstract/tool-structure"
|
||||
STRUCTURE_TOOL_IMAGE_TAG="0.0.88"
|
||||
STRUCTURE_TOOL_IMAGE_TAG="0.0.89"
|
||||
|
||||
# Feature Flags
|
||||
EVALUATION_SERVER_IP=unstract-flipt
|
||||
|
||||
2137
backend/uv.lock
generated
2137
backend/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -15,7 +15,7 @@ dependencies = [
|
||||
"redis~=5.2.1",
|
||||
"cryptography>=41.0.7",
|
||||
"requests>=2.31.0",
|
||||
"unstract-sdk[gcs, azure, aws]~=0.77.3", # Add version
|
||||
"unstract-sdk[gcs, azure, aws]~=0.78.0", # Add version
|
||||
"unstract-flags",
|
||||
"unstract-core[flask]",
|
||||
"unstract-sdk1[gcs, azure, aws]"
|
||||
|
||||
10
platform-service/uv.lock
generated
10
platform-service/uv.lock
generated
@@ -3184,7 +3184,7 @@ requires-dist = [
|
||||
{ name = "requests", specifier = ">=2.31.0" },
|
||||
{ name = "unstract-core", extras = ["flask"], editable = "../unstract/core" },
|
||||
{ name = "unstract-flags", editable = "../unstract/flags" },
|
||||
{ name = "unstract-sdk", extras = ["gcs", "azure", "aws"], specifier = "~=0.77.3" },
|
||||
{ name = "unstract-sdk", extras = ["gcs", "azure", "aws"], specifier = "~=0.78.0" },
|
||||
{ name = "unstract-sdk1", extras = ["gcs", "azure", "aws"], editable = "../unstract/sdk1" },
|
||||
]
|
||||
|
||||
@@ -3202,7 +3202,7 @@ test = [{ name = "pytest", specifier = ">=8.0.1" }]
|
||||
|
||||
[[package]]
|
||||
name = "unstract-sdk"
|
||||
version = "0.77.3"
|
||||
version = "0.78.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filetype" },
|
||||
@@ -3241,9 +3241,9 @@ dependencies = [
|
||||
{ name = "tiktoken" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c5/e8/f7e1045fee076c75c42bb27b4fa5a077836cfa8fd3fd580f6c6193fa19dc/unstract_sdk-0.77.3.tar.gz", hash = "sha256:378c19129a91e861b7235e92411ef72792c6c851320cb7897b380d08d3489c9d", size = 2375843 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7b/b9/ca49124e2548e7baa57c00dab2300d59a1dac13ccc6b0d08e5c81f3f7b94/unstract_sdk-0.78.0.tar.gz", hash = "sha256:35cf15acd946996be3871b35aec099b9fb03d30e96893dd5f34991a8aa2535a0", size = 2364734 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d5/84/d596295fce3a713b1953a4fd7be3be1d788920e5633e28fbea7c7331b68a/unstract_sdk-0.77.3-py3-none-any.whl", hash = "sha256:74e83cbb68eef98fbaecfc83e0cfd2228e5958506e1b51a65fae0d891027aaee", size = 266640 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/23/2f0f9cf0f4e489a4be8133622e968f8a7dd639f0f582252625ca3a22a749/unstract_sdk-0.78.0-py3-none-any.whl", hash = "sha256:ecaf24386c1eaefa47e6c06d1f978ca26406cc8cfffcbe9d96197fa1b2fba826", size = 270348 },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
@@ -3335,6 +3335,8 @@ docs = [{ name = "lazydocs", specifier = "~=0.4.8" }]
|
||||
test = [
|
||||
{ name = "parameterized", specifier = "==0.9.0" },
|
||||
{ name = "pytest", specifier = "==8.3.3" },
|
||||
{ name = "pytest-cov", specifier = ">=6.0.0" },
|
||||
{ name = "pytest-md-report", specifier = ">=0.6.2" },
|
||||
{ name = "pytest-mock", specifier = "==3.14.0" },
|
||||
]
|
||||
|
||||
|
||||
@@ -15,9 +15,7 @@ dependencies = [
|
||||
"python-dotenv==1.0.1",
|
||||
"json-repair~=0.42.0",
|
||||
"requests>=2.28,<3.0",
|
||||
# TODO: Temporarily removing the extra dependencies of aws and gcs from unstract-sdk
|
||||
# to resolve lock file. Will have to be re-looked into
|
||||
"unstract-sdk[azure]~=0.77.3",
|
||||
"unstract-sdk[aws,gcs,azure]~=0.78.0",
|
||||
"redis>=5.0.3,<5.3",
|
||||
"unstract-core",
|
||||
"unstract-flags",
|
||||
|
||||
1721
prompt-service/uv.lock
generated
1721
prompt-service/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -41,7 +41,7 @@ hook-check-django-migrations = [
|
||||
"psycopg2-binary==2.9.9",
|
||||
"python-dotenv==1.0.1",
|
||||
"python-magic==0.4.27",
|
||||
"unstract-sdk~=0.77.3",
|
||||
"unstract-sdk~=0.78.0",
|
||||
"unstract-connectors",
|
||||
"unstract-core",
|
||||
"unstract-flags",
|
||||
|
||||
@@ -4,6 +4,6 @@
|
||||
# aws alone is needed here
|
||||
# because tools use transient temporary storage.
|
||||
|
||||
unstract-sdk[aws]~=0.77.3
|
||||
unstract-sdk[aws]~=0.78.0
|
||||
-e file:/unstract/sdk1
|
||||
-e file:/unstract/flags
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"schemaVersion": "0.0.1",
|
||||
"displayName": "File Classifier",
|
||||
"functionName": "classify",
|
||||
"toolVersion": "0.0.68",
|
||||
"toolVersion": "0.0.69",
|
||||
"description": "Classifies a file into a bin based on its contents",
|
||||
"input": {
|
||||
"description": "File to be classified"
|
||||
|
||||
@@ -3,6 +3,6 @@
|
||||
# Required for all unstract tools
|
||||
# aws alone is needed here
|
||||
# because tools use transient temporary storage.
|
||||
unstract-sdk[aws]~=0.77.3
|
||||
unstract-sdk[aws]~=0.78.0
|
||||
-e file:/unstract/sdk1
|
||||
-e file:/unstract/flags
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"schemaVersion": "0.0.1",
|
||||
"displayName": "Structure Tool",
|
||||
"functionName": "structure_tool",
|
||||
"toolVersion": "0.0.88",
|
||||
"toolVersion": "0.0.89",
|
||||
"description": "This is a template tool which can answer set of input prompts designed in the Prompt Studio",
|
||||
"input": {
|
||||
"description": "File that needs to be indexed and parsed for answers"
|
||||
|
||||
@@ -4,6 +4,6 @@
|
||||
# aws alone is needed here
|
||||
# because tools use transient temporary storage.
|
||||
|
||||
unstract-sdk[aws]~=0.77.3
|
||||
unstract-sdk[aws]~=0.78.0
|
||||
-e file:/unstract/sdk1
|
||||
-e file:/unstract/flags
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"schemaVersion": "0.0.1",
|
||||
"displayName": "Text Extractor",
|
||||
"functionName": "text_extractor",
|
||||
"toolVersion": "0.0.64",
|
||||
"toolVersion": "0.0.65",
|
||||
"description": "The Text Extractor is a powerful tool designed to convert documents to its text form or Extract texts from documents",
|
||||
"input": {
|
||||
"description": "Document"
|
||||
|
||||
15
tox.ini
15
tox.ini
@@ -1,5 +1,5 @@
|
||||
[tox]
|
||||
env_list = py{312}, runner
|
||||
env_list = py{312}, runner, sdk1
|
||||
requires =
|
||||
tox-uv>=0.2.0
|
||||
|
||||
@@ -30,3 +30,16 @@ commands_pre =
|
||||
uv pip install pytest pytest-cov pytest-md-report pytest-mock
|
||||
commands =
|
||||
pytest -v --md-report-verbose=1 --md-report --md-report-flavor gfm --md-report-output ../runner-report.md
|
||||
|
||||
[testenv:sdk1]
|
||||
changedir = unstract/sdk1
|
||||
deps = uv
|
||||
allowlist_externals=
|
||||
sh
|
||||
uv
|
||||
pytest
|
||||
commands_pre =
|
||||
# Install dependencies with test group
|
||||
uv sync --group test
|
||||
commands =
|
||||
uv run pytest -v -m "not slow" --cov=src/unstract/sdk1 --cov-report=term --cov-report=html --md-report-verbose=1 --md-report --md-report-flavor gfm --md-report-output ../../sdk1-report.md
|
||||
|
||||
@@ -6,7 +6,7 @@ authors = [{ name = "Zipstack Inc.", email = "devsupport@zipstack.com" }]
|
||||
requires-python = ">=3.12,<3.13"
|
||||
readme = "README.md"
|
||||
|
||||
dependencies = ["unstract-sdk~=0.77.1"]
|
||||
dependencies = ["unstract-sdk~=0.78.0"]
|
||||
|
||||
# [tool.uv.sources]
|
||||
# unstract-sdk = { git = "https://github.com/Zipstack/unstract-sdk", branch = "python-upgrade" }
|
||||
|
||||
1436
unstract/filesystem/uv.lock
generated
1436
unstract/filesystem/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -6,3 +6,36 @@
|
||||
The `unstract-sdk1` package helps with developing tools that are meant to be run on the Unstract platform. This includes
|
||||
modules to help with tool development and execution, caching, making calls to LLMs / vectorDBs / embeddings .etc.
|
||||
They also contain helper methods/classes to aid with other tasks such as indexing and auditing the LLM calls.
|
||||
|
||||
## Features
|
||||
|
||||
### Retry Configuration
|
||||
|
||||
The SDK automatically retries platform and prompt service calls on transient failures. Configure via environment variables (prefix: `PLATFORM_SERVICE_` or `PROMPT_SERVICE_`):
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `MAX_RETRIES` | 3 | Maximum retry attempts |
|
||||
| `MAX_TIME` | 60 | Maximum total time (seconds) |
|
||||
| `BASE_DELAY` | 1.0 | Initial delay (seconds) |
|
||||
| `MULTIPLIER` | 2.0 | Backoff multiplier |
|
||||
| `JITTER` | true | Add random jitter (0-25%) |
|
||||
|
||||
**Retryable errors**: ConnectionError, Timeout, HTTPError (502/503/504), OSError (connection failures)
|
||||
|
||||
## Development
|
||||
|
||||
### Running Tests
|
||||
|
||||
Install test dependencies and run tests:
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
uv sync --group test
|
||||
|
||||
# Run all tests
|
||||
uv run pytest
|
||||
|
||||
# Run with coverage
|
||||
uv run pytest --cov=src/unstract/sdk1 --cov-report=html
|
||||
```
|
||||
|
||||
@@ -67,6 +67,8 @@ test = [
|
||||
"parameterized==0.9.0",
|
||||
"pytest==8.3.3",
|
||||
"pytest-mock==3.14.0",
|
||||
"pytest-cov>=6.0.0",
|
||||
"pytest-md-report>=0.6.2",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Self
|
||||
|
||||
import requests
|
||||
from requests import RequestException, Response
|
||||
@@ -17,6 +17,7 @@ from unstract.sdk1.constants import (
|
||||
from unstract.sdk1.exceptions import SdkError
|
||||
from unstract.sdk1.tool.base import BaseTool
|
||||
from unstract.sdk1.utils.common import Utils
|
||||
from unstract.sdk1.utils.retry_utils import retry_platform_service_call
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,7 +30,7 @@ class PlatformHelper:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
self: Self,
|
||||
tool: BaseTool,
|
||||
platform_host: str,
|
||||
platform_port: str,
|
||||
@@ -50,7 +51,11 @@ class PlatformHelper:
|
||||
self.request_id = request_id
|
||||
|
||||
@classmethod
|
||||
def get_platform_base_url(cls, platform_host: str, platform_port: str) -> str:
|
||||
def get_platform_base_url(
|
||||
cls: type[Self],
|
||||
platform_host: str,
|
||||
platform_port: str,
|
||||
) -> str:
|
||||
"""Make base url from host and port.
|
||||
|
||||
Args:
|
||||
@@ -65,7 +70,7 @@ class PlatformHelper:
|
||||
return f"{platform_host}:{platform_port}"
|
||||
|
||||
@classmethod
|
||||
def is_public_adapter(cls, adapter_id: str) -> bool:
|
||||
def is_public_adapter(cls: type[Self], adapter_id: str) -> bool:
|
||||
"""Check if the given adapter_id is one of the public adapter keys.
|
||||
|
||||
This method iterates over the attributes of the PublicAdapterKeys class
|
||||
@@ -91,14 +96,24 @@ class PlatformHelper:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@retry_platform_service_call
|
||||
def _get_adapter_configuration(
|
||||
cls,
|
||||
cls: type[Self],
|
||||
tool: BaseTool,
|
||||
adapter_instance_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Get adapter configuration from platform service.
|
||||
"""Get Adapter.
|
||||
|
||||
1. Get the adapter config from platform service using the adapter_instance_id.
|
||||
Get the adapter config from platform service
|
||||
using the adapter_instance_id. This method automatically
|
||||
retries on connection errors with exponential backoff.
|
||||
|
||||
Retry behavior is configurable via environment variables:
|
||||
- PLATFORM_SERVICE_MAX_RETRIES (default: 3)
|
||||
- PLATFORM_SERVICE_MAX_TIME (default: 60s)
|
||||
- PLATFORM_SERVICE_BASE_DELAY (default: 1.0s)
|
||||
- PLATFORM_SERVICE_MULTIPLIER (default: 2.0)
|
||||
- PLATFORM_SERVICE_JITTER (default: true)
|
||||
|
||||
Args:
|
||||
adapter_instance_id (str): Adapter instance ID
|
||||
@@ -130,11 +145,14 @@ class PlatformHelper:
|
||||
f"'{adapter_type}', provider: '{provider}', name: '{adapter_name}'",
|
||||
level=LogLevel.DEBUG,
|
||||
)
|
||||
except ConnectionError as e:
|
||||
raise SdkError(
|
||||
"Unable to connect to platform service, please contact the admin."
|
||||
) from e
|
||||
except HTTPError as e:
|
||||
# Check if this is a retryable HTTP error (502, 503, 504)
|
||||
# If so, re-raise it so the retry decorator can handle it
|
||||
if hasattr(e, "response") and e.response is not None:
|
||||
if e.response.status_code in (502, 503, 504):
|
||||
raise # Re-raise to allow retry decorator to retry
|
||||
|
||||
# Non-retryable error - convert to SdkError
|
||||
default_err = (
|
||||
"Error while calling the platform service, please contact the admin."
|
||||
)
|
||||
@@ -146,7 +164,7 @@ class PlatformHelper:
|
||||
|
||||
@classmethod
|
||||
def get_adapter_config(
|
||||
cls, tool: BaseTool, adapter_instance_id: str
|
||||
cls: type[Self], tool: BaseTool, adapter_instance_id: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get adapter spec by the help of unstract DB tool.
|
||||
|
||||
@@ -174,9 +192,15 @@ class PlatformHelper:
|
||||
f"Retrieving config from DB for '{adapter_instance_id}'",
|
||||
level=LogLevel.DEBUG,
|
||||
)
|
||||
return cls._get_adapter_configuration(tool, adapter_instance_id)
|
||||
|
||||
def _get_headers(self, headers: dict[str, str] | None = None) -> dict[str, str]:
|
||||
try:
|
||||
return cls._get_adapter_configuration(tool, adapter_instance_id)
|
||||
except ConnectionError as e:
|
||||
raise SdkError(
|
||||
"Unable to connect to platform service, please contact the admin."
|
||||
) from e
|
||||
|
||||
def _get_headers(self: Self, headers: dict[str, str] | None = None) -> dict[str, str]:
|
||||
"""Get default headers for requests.
|
||||
|
||||
Returns:
|
||||
@@ -190,8 +214,9 @@ class PlatformHelper:
|
||||
request_headers.update(headers)
|
||||
return request_headers
|
||||
|
||||
@retry_platform_service_call
|
||||
def _call_service(
|
||||
self,
|
||||
self: Self,
|
||||
url_path: str,
|
||||
payload: dict[str, Any] | None = None,
|
||||
params: dict[str, str] | None = None,
|
||||
@@ -201,6 +226,14 @@ class PlatformHelper:
|
||||
"""Talks to platform-service to make GET / POST calls.
|
||||
|
||||
Only GET calls are made to platform-service though functionality exists.
|
||||
This method automatically retries on connection errors with exponential backoff.
|
||||
|
||||
Retry behavior is configurable via environment variables:
|
||||
- PLATFORM_SERVICE_MAX_RETRIES (default: 3)
|
||||
- PLATFORM_SERVICE_MAX_TIME (default: 60s)
|
||||
- PLATFORM_SERVICE_BASE_DELAY (default: 1.0s)
|
||||
- PLATFORM_SERVICE_MULTIPLIER (default: 2.0)
|
||||
- PLATFORM_SERVICE_JITTER (default: true)
|
||||
|
||||
Args:
|
||||
url_path (str): URL path to the service endpoint
|
||||
@@ -234,9 +267,13 @@ class PlatformHelper:
|
||||
|
||||
response.raise_for_status()
|
||||
except ConnectionError as connect_err:
|
||||
msg = "Unable to connect to platform service. Please contact admin."
|
||||
msg += " \n" + str(connect_err)
|
||||
self.tool.stream_error_and_exit(msg)
|
||||
logger.exception("Connection error to platform service")
|
||||
msg = (
|
||||
"Unable to connect to platform service. Retrying with backoff, "
|
||||
"please contact admin if retries ultimately fail."
|
||||
)
|
||||
self.tool.stream_log(msg, level=LogLevel.ERROR)
|
||||
raise ConnectionError(msg) from connect_err
|
||||
except RequestException as e:
|
||||
# Extract error information from the response if available
|
||||
error_message = str(e)
|
||||
@@ -252,7 +289,7 @@ class PlatformHelper:
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def get_platform_details(self) -> dict[str, Any] | None:
|
||||
def get_platform_details(self: Self) -> dict[str, Any] | None:
|
||||
"""Obtains platform details associated with the platform key.
|
||||
|
||||
Currently helps fetch organization ID related to the key.
|
||||
@@ -269,7 +306,7 @@ class PlatformHelper:
|
||||
)
|
||||
return response.get("details")
|
||||
|
||||
def get_prompt_studio_tool(self, prompt_registry_id: str) -> dict[str, Any]:
|
||||
def get_prompt_studio_tool(self: Self, prompt_registry_id: str) -> dict[str, Any]:
|
||||
"""Get exported custom tool by the help of unstract DB tool.
|
||||
|
||||
Args:
|
||||
@@ -287,7 +324,7 @@ class PlatformHelper:
|
||||
method="GET",
|
||||
)
|
||||
|
||||
def get_llm_profile(self, llm_profile_id: str) -> dict[str, Any]:
|
||||
def get_llm_profile(self: Self, llm_profile_id: str) -> dict[str, Any]:
|
||||
"""Get llm profile by the help of unstract DB tool.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -9,6 +9,7 @@ from unstract.sdk1.constants import MimeType, RequestHeader, ToolEnv
|
||||
from unstract.sdk1.platform import PlatformHelper
|
||||
from unstract.sdk1.tool.base import BaseTool
|
||||
from unstract.sdk1.utils.common import log_elapsed
|
||||
from unstract.sdk1.utils.retry_utils import retry_prompt_service_call
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -187,6 +188,7 @@ class PromptTool:
|
||||
request_headers.update(headers)
|
||||
return request_headers
|
||||
|
||||
@retry_prompt_service_call
|
||||
def _call_service(
|
||||
self,
|
||||
url_path: str,
|
||||
@@ -198,6 +200,14 @@ class PromptTool:
|
||||
"""Communicates to prompt service to fetch response for the prompt.
|
||||
|
||||
Only POST calls are made to prompt-service though functionality exists.
|
||||
This method automatically retries on connection errors with exponential backoff.
|
||||
|
||||
Retry behavior is configurable via environment variables:
|
||||
- PROMPT_SERVICE_MAX_RETRIES (default: 3)
|
||||
- PROMPT_SERVICE_MAX_TIME (default: 60s)
|
||||
- PROMPT_SERVICE_BASE_DELAY (default: 1.0s)
|
||||
- PROMPT_SERVICE_MULTIPLIER (default: 2.0)
|
||||
- PROMPT_SERVICE_JITTER (default: true)
|
||||
|
||||
Args:
|
||||
url_path (str): URL path to the service endpoint
|
||||
|
||||
305
unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py
Normal file
305
unstract/sdk1/src/unstract/sdk1/utils/retry_utils.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""Generic retry utilities with custom exponential backoff implementation."""
|
||||
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from requests.exceptions import ConnectionError, HTTPError, Timeout
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_retryable_error(error: Exception) -> bool:
|
||||
"""Check if an error is retryable.
|
||||
|
||||
Handles:
|
||||
- ConnectionError and Timeout from requests
|
||||
- HTTPError with status codes 502, 503, 504
|
||||
- OSError with specific errno codes (ECONNREFUSED, ECONNRESET, etc.)
|
||||
|
||||
Args:
|
||||
error: The exception to check
|
||||
|
||||
Returns:
|
||||
True if the error should trigger a retry
|
||||
"""
|
||||
# Requests connection and timeout errors
|
||||
if isinstance(error, ConnectionError | Timeout):
|
||||
return True
|
||||
|
||||
# HTTP errors with specific status codes
|
||||
if isinstance(error, HTTPError):
|
||||
if hasattr(error, "response") and error.response is not None:
|
||||
status_code = error.response.status_code
|
||||
# Retry on server errors and bad gateway
|
||||
if status_code in [502, 503, 504]:
|
||||
return True
|
||||
|
||||
# OS-level connection failures (preserving existing errno checks)
|
||||
if isinstance(error, OSError) and error.errno in {
|
||||
errno.ECONNREFUSED, # Connection refused
|
||||
getattr(errno, "ECONNRESET", 104), # Connection reset by peer
|
||||
getattr(errno, "ETIMEDOUT", 110), # Connection timed out
|
||||
getattr(errno, "EHOSTUNREACH", 113), # No route to host
|
||||
getattr(errno, "ENETUNREACH", 101), # Network is unreachable
|
||||
}:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def calculate_delay(
|
||||
attempt: int,
|
||||
base_delay: float,
|
||||
multiplier: float,
|
||||
max_delay: float,
|
||||
jitter: bool = True,
|
||||
) -> float:
|
||||
"""Calculate delay for the next retry attempt with exponential backoff.
|
||||
|
||||
Args:
|
||||
attempt: Current attempt number (0-indexed)
|
||||
base_delay: Base delay in seconds
|
||||
multiplier: Backoff multiplier
|
||||
max_delay: Maximum delay in seconds
|
||||
jitter: Whether to add random jitter
|
||||
|
||||
Returns:
|
||||
Delay in seconds before the next retry
|
||||
"""
|
||||
# Calculate exponential backoff
|
||||
base = base_delay * (multiplier**attempt)
|
||||
|
||||
# Add jitter if enabled (0-25% of base)
|
||||
if jitter:
|
||||
delay = base + (base * random.uniform(0, 0.25))
|
||||
else:
|
||||
delay = base
|
||||
|
||||
# Enforce cap after jitter
|
||||
return min(delay, max_delay)
|
||||
|
||||
|
||||
def retry_with_exponential_backoff( # noqa: C901
|
||||
max_retries: int,
|
||||
max_time: float,
|
||||
base_delay: float,
|
||||
multiplier: float,
|
||||
jitter: bool,
|
||||
exceptions: tuple[type[Exception], ...],
|
||||
logger_instance: logging.Logger,
|
||||
prefix: str,
|
||||
retry_predicate: Callable[[Exception], bool] | None = None,
|
||||
) -> Callable:
|
||||
"""Create retry decorator with exponential backoff.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts
|
||||
max_time: Maximum total time in seconds
|
||||
base_delay: Initial delay in seconds
|
||||
multiplier: Backoff multiplier
|
||||
jitter: Whether to add jitter
|
||||
exceptions: Exception types to catch for retries
|
||||
logger_instance: Logger instance for retry messages
|
||||
prefix: Service prefix for logging
|
||||
retry_predicate: Optional callable to determine if exception should trigger retry
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable: # noqa: C901
|
||||
@wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: C901, ANN401
|
||||
start_time = time.time()
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(max_retries + 1): # +1 for initial attempt
|
||||
try:
|
||||
# Try to execute the function
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# If successful and we had retried, log success
|
||||
if attempt > 0:
|
||||
logger_instance.info(
|
||||
"Successfully completed '%s' after %d retry attempt(s)",
|
||||
func.__name__,
|
||||
attempt,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except exceptions as e:
|
||||
last_exception = e
|
||||
|
||||
# Check if the error should trigger a retry
|
||||
# First check if it's in the allowed exception types (already caught)
|
||||
# Then check using the predicate if provided
|
||||
should_retry = True
|
||||
if retry_predicate is not None:
|
||||
should_retry = retry_predicate(e)
|
||||
|
||||
# Check if we've exceeded max time
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time >= max_time:
|
||||
logger_instance.exception(
|
||||
"Giving up '%s' after %.1fs (max time exceeded): %s",
|
||||
func.__name__,
|
||||
elapsed_time,
|
||||
e,
|
||||
)
|
||||
raise
|
||||
|
||||
# If not retryable or last attempt, raise the error
|
||||
if not should_retry or attempt == max_retries:
|
||||
if attempt > 0:
|
||||
logger_instance.exception(
|
||||
"Giving up '%s' after %d attempt(s) for %s",
|
||||
func.__name__,
|
||||
attempt + 1,
|
||||
prefix,
|
||||
)
|
||||
raise
|
||||
|
||||
# Calculate delay for next retry
|
||||
delay = calculate_delay(
|
||||
attempt, base_delay, multiplier, max_time, jitter
|
||||
)
|
||||
|
||||
# Ensure we don't exceed max_time with the delay
|
||||
remaining_time = max_time - elapsed_time
|
||||
if delay >= remaining_time:
|
||||
logger_instance.exception(
|
||||
"Giving up '%s' - next delay %.1fs would exceed "
|
||||
"max time %.1fs",
|
||||
func.__name__,
|
||||
delay,
|
||||
max_time,
|
||||
)
|
||||
raise
|
||||
|
||||
# Log retry attempt
|
||||
logger_instance.warning(
|
||||
"Retry %d/%d for %s: %s (waiting %.1fs)",
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
prefix,
|
||||
e,
|
||||
delay,
|
||||
)
|
||||
|
||||
# Wait before retrying
|
||||
time.sleep(delay)
|
||||
|
||||
except Exception as e:
|
||||
# Exception not in the exceptions tuple - don't retry
|
||||
last_exception = e
|
||||
raise
|
||||
|
||||
# This should never be reached, but just in case
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def create_retry_decorator(
|
||||
prefix: str,
|
||||
exceptions: tuple[type[Exception], ...] | None = None,
|
||||
retry_predicate: Callable[[Exception], bool] | None = None,
|
||||
logger_instance: logging.Logger | None = None,
|
||||
) -> Callable:
|
||||
"""Create a configured retry decorator for a specific service.
|
||||
|
||||
Args:
|
||||
prefix: Environment variable prefix for configuration
|
||||
exceptions: Tuple of exception types to retry on.
|
||||
Defaults to (ConnectionError, HTTPError, Timeout, OSError)
|
||||
retry_predicate: Optional callable to determine if exception should trigger retry.
|
||||
If only exceptions list provided, retry on those exceptions.
|
||||
If only predicate provided, use predicate (catch all exceptions).
|
||||
If both provided, first filter by exceptions then check predicate.
|
||||
logger_instance: Optional logger for retry events
|
||||
|
||||
Environment variables (using prefix):
|
||||
{prefix}_MAX_RETRIES: Maximum retry attempts (default: 3)
|
||||
{prefix}_MAX_TIME: Maximum total time in seconds (default: 60)
|
||||
{prefix}_BASE_DELAY: Initial delay in seconds (default: 1.0)
|
||||
{prefix}_MULTIPLIER: Backoff multiplier (default: 2.0)
|
||||
{prefix}_JITTER: Enable jitter true/false (default: true)
|
||||
|
||||
Returns:
|
||||
Configured retry decorator
|
||||
"""
|
||||
# Handle different combinations of exceptions and predicate
|
||||
if exceptions is None and retry_predicate is None:
|
||||
# Default case: use specific exceptions with is_retryable_error predicate
|
||||
exceptions = (ConnectionError, HTTPError, Timeout, OSError)
|
||||
retry_predicate = is_retryable_error
|
||||
elif exceptions is None and retry_predicate is not None:
|
||||
# Only predicate provided: catch all exceptions and use predicate
|
||||
exceptions = (Exception,)
|
||||
elif exceptions is not None and retry_predicate is None:
|
||||
# Only exceptions provided: retry on those exceptions
|
||||
pass # exceptions already set, no predicate needed
|
||||
# Both provided: use both (exceptions filter first, then predicate)
|
||||
|
||||
# Load configuration from environment
|
||||
max_retries = int(os.getenv(f"{prefix}_MAX_RETRIES", "3"))
|
||||
max_time = float(os.getenv(f"{prefix}_MAX_TIME", "60"))
|
||||
base_delay = float(os.getenv(f"{prefix}_BASE_DELAY", "1.0"))
|
||||
multiplier = float(os.getenv(f"{prefix}_MULTIPLIER", "2.0"))
|
||||
use_jitter = os.getenv(f"{prefix}_JITTER", "true").strip().lower() in {
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
"on",
|
||||
}
|
||||
|
||||
if max_retries < 0:
|
||||
raise ValueError(f"{prefix}_MAX_RETRIES must be >= 0")
|
||||
if max_time <= 0:
|
||||
raise ValueError(f"{prefix}_MAX_TIME must be > 0")
|
||||
if base_delay <= 0:
|
||||
raise ValueError(f"{prefix}_BASE_DELAY must be > 0")
|
||||
if multiplier <= 0:
|
||||
raise ValueError(f"{prefix}_MULTIPLIER must be > 0")
|
||||
|
||||
if logger_instance is None:
|
||||
logger_instance = logger
|
||||
|
||||
return retry_with_exponential_backoff(
|
||||
max_retries=max_retries,
|
||||
max_time=max_time,
|
||||
base_delay=base_delay,
|
||||
multiplier=multiplier,
|
||||
jitter=use_jitter,
|
||||
exceptions=exceptions,
|
||||
logger_instance=logger_instance,
|
||||
prefix=prefix,
|
||||
retry_predicate=retry_predicate,
|
||||
)
|
||||
|
||||
|
||||
# Retry configured through below envs.
|
||||
# - PLATFORM_SERVICE_MAX_RETRIES (default: 3)
|
||||
# - PLATFORM_SERVICE_MAX_TIME (default: 60s)
|
||||
# - PLATFORM_SERVICE_BASE_DELAY (default: 1.0s)
|
||||
# - PLATFORM_SERVICE_MULTIPLIER (default: 2.0)
|
||||
# - PLATFORM_SERVICE_JITTER (default: true)
|
||||
retry_platform_service_call = create_retry_decorator("PLATFORM_SERVICE")
|
||||
|
||||
# Retry configured through below envs.
|
||||
# - PROMPT_SERVICE_MAX_RETRIES (default: 3)
|
||||
# - PROMPT_SERVICE_MAX_TIME (default: 60s)
|
||||
# - PROMPT_SERVICE_BASE_DELAY (default: 1.0s)
|
||||
# - PROMPT_SERVICE_MULTIPLIER (default: 2.0)
|
||||
# - PROMPT_SERVICE_JITTER (default: true)
|
||||
retry_prompt_service_call = create_retry_decorator("PROMPT_SERVICE")
|
||||
1
unstract/sdk1/tests/__init__.py
Normal file
1
unstract/sdk1/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for unstract-sdk1."""
|
||||
55
unstract/sdk1/tests/conftest.py
Normal file
55
unstract/sdk1/tests/conftest.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Pytest configuration and fixtures for unstract-sdk1 tests."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger() -> MagicMock:
|
||||
"""Create a mock logger for testing."""
|
||||
logger = MagicMock(spec=logging.Logger)
|
||||
return logger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clean_env(monkeypatch: MonkeyPatch) -> MonkeyPatch:
|
||||
"""Clean environment variables before each test."""
|
||||
# Remove any retry-related environment variables
|
||||
env_vars = [
|
||||
"PLATFORM_SERVICE_MAX_RETRIES",
|
||||
"PLATFORM_SERVICE_MAX_TIME",
|
||||
"PLATFORM_SERVICE_BASE_DELAY",
|
||||
"PLATFORM_SERVICE_MULTIPLIER",
|
||||
"PLATFORM_SERVICE_JITTER",
|
||||
"PROMPT_SERVICE_MAX_RETRIES",
|
||||
"PROMPT_SERVICE_MAX_TIME",
|
||||
"PROMPT_SERVICE_BASE_DELAY",
|
||||
"PROMPT_SERVICE_MULTIPLIER",
|
||||
"PROMPT_SERVICE_JITTER",
|
||||
]
|
||||
for var in env_vars:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
return monkeypatch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_env(monkeypatch: MonkeyPatch) -> Callable[..., None]:
|
||||
"""Helper fixture to set environment variables."""
|
||||
|
||||
def _set_env(prefix: str, **kwargs: Any) -> None: # noqa: ANN401
|
||||
"""Set environment variables with given prefix.
|
||||
|
||||
Args:
|
||||
prefix: Environment variable prefix (e.g., 'PLATFORM_SERVICE')
|
||||
**kwargs: Key-value pairs to set (e.g., max_retries=5)
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
env_key = f"{prefix}_{key.upper()}"
|
||||
monkeypatch.setenv(env_key, str(value))
|
||||
|
||||
return _set_env
|
||||
233
unstract/sdk1/tests/test_platform.py
Normal file
233
unstract/sdk1/tests/test_platform.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Integration tests for platform module with retry logic."""
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from requests.exceptions import ConnectionError, HTTPError
|
||||
from unstract.sdk1.exceptions import SdkError
|
||||
from unstract.sdk1.platform import PlatformHelper
|
||||
|
||||
|
||||
class TestPlatformHelperRetry:
|
||||
"""Tests for PlatformHelper retry functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool(self) -> MagicMock:
|
||||
"""Create a mock tool for testing."""
|
||||
tool = MagicMock()
|
||||
tool.get_env_or_die.side_effect = lambda key: {
|
||||
"PLATFORM_HOST": "http://localhost",
|
||||
"PLATFORM_PORT": "3001",
|
||||
"PLATFORM_API_KEY": "test-api-key",
|
||||
}.get(key, "mock-value")
|
||||
tool.stream_log = MagicMock()
|
||||
tool.stream_error_and_exit = MagicMock()
|
||||
return tool
|
||||
|
||||
@pytest.fixture
|
||||
def platform_helper(self, mock_tool: MagicMock) -> PlatformHelper:
|
||||
"""Create a PlatformHelper instance."""
|
||||
return PlatformHelper(
|
||||
tool=mock_tool,
|
||||
platform_host="http://localhost",
|
||||
platform_port="3001",
|
||||
request_id="test-request-id",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method_name,method_args,http_method",
|
||||
[
|
||||
("_get_adapter_configuration", ("test-adapter-id",), "GET"),
|
||||
("_call_service", ("test-endpoint",), "GET"),
|
||||
],
|
||||
)
|
||||
def test_success_on_first_attempt(
|
||||
self,
|
||||
mock_tool: MagicMock,
|
||||
platform_helper: PlatformHelper,
|
||||
method_name: str,
|
||||
method_args: tuple[str, ...],
|
||||
http_method: str,
|
||||
clean_env: MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test successful calls on first attempt for various methods."""
|
||||
expected_data = {"adapter_id": "test", "config": {}}
|
||||
|
||||
patch_target = f"requests.{http_method.lower()}"
|
||||
with patch(patch_target) as mock_request:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = expected_data
|
||||
mock_response.raise_for_status = Mock()
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
if method_name == "_get_adapter_configuration":
|
||||
PlatformHelper._get_adapter_configuration(mock_tool, *method_args)
|
||||
else:
|
||||
getattr(platform_helper, method_name)(*method_args)
|
||||
|
||||
assert mock_request.call_count == 1
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method_name,method_args",
|
||||
[
|
||||
("_get_adapter_configuration", ("test-adapter-id",)),
|
||||
("_call_service", ("test-endpoint",)),
|
||||
],
|
||||
)
|
||||
def test_retry_on_connection_error(
|
||||
self,
|
||||
mock_tool: MagicMock,
|
||||
platform_helper: PlatformHelper,
|
||||
method_name: str,
|
||||
method_args: tuple[str, ...],
|
||||
clean_env: MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test methods retry on ConnectionError."""
|
||||
expected_data = {"result": "success"}
|
||||
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = expected_data
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
mock_get.side_effect = [
|
||||
ConnectionError("Transient failure"),
|
||||
mock_response,
|
||||
]
|
||||
|
||||
if method_name == "_get_adapter_configuration":
|
||||
PlatformHelper._get_adapter_configuration(mock_tool, *method_args)
|
||||
else:
|
||||
getattr(platform_helper, method_name)(*method_args)
|
||||
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_max_retries_exceeded(
|
||||
self, mock_tool: MagicMock, clean_env: MonkeyPatch
|
||||
) -> None:
|
||||
"""Test service call fails after exceeding max retries."""
|
||||
platform_helper = PlatformHelper(
|
||||
tool=mock_tool,
|
||||
platform_host="http://localhost",
|
||||
platform_port="3001",
|
||||
)
|
||||
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.side_effect = ConnectionError("Persistent failure")
|
||||
|
||||
with pytest.raises(ConnectionError):
|
||||
platform_helper._call_service("test-endpoint")
|
||||
|
||||
# Default: 3 retries + 1 initial = 4 attempts
|
||||
assert mock_get.call_count == 4
|
||||
|
||||
def test_non_retryable_http_error(
|
||||
self, mock_tool: MagicMock, clean_env: MonkeyPatch
|
||||
) -> None:
|
||||
"""Test non-retryable HTTP errors (404, 400) don't trigger retry."""
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
mock_response.json.return_value = {"error": "Not found"}
|
||||
|
||||
http_error = HTTPError()
|
||||
http_error.response = mock_response
|
||||
mock_get.side_effect = http_error
|
||||
|
||||
with pytest.raises(SdkError, match="Error retrieving adapter"):
|
||||
PlatformHelper._get_adapter_configuration(mock_tool, "test-adapter-id")
|
||||
|
||||
# Should not retry 404
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
@pytest.mark.parametrize("status_code", [502, 503, 504])
|
||||
def test_retryable_http_errors(
|
||||
self, mock_tool: MagicMock, status_code: int, clean_env: MonkeyPatch
|
||||
) -> None:
|
||||
"""Test retryable HTTP errors (502, 503, 504) trigger retry."""
|
||||
expected_data = {"adapter_id": "test", "config": {}}
|
||||
|
||||
with patch("requests.get") as mock_get:
|
||||
# First attempt: retryable HTTP error
|
||||
http_error = HTTPError()
|
||||
error_response = Mock()
|
||||
error_response.status_code = status_code
|
||||
error_response.json.return_value = {"error": "Service unavailable"}
|
||||
http_error.response = error_response
|
||||
|
||||
# Second attempt: success
|
||||
success_response = Mock()
|
||||
success_response.json.return_value = expected_data
|
||||
success_response.raise_for_status = Mock()
|
||||
|
||||
mock_get.side_effect = [http_error, success_response]
|
||||
|
||||
result = PlatformHelper._get_adapter_configuration(
|
||||
mock_tool, "test-adapter-id"
|
||||
)
|
||||
|
||||
# Should retry and succeed
|
||||
assert mock_get.call_count == 2
|
||||
assert result == expected_data
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_connection_error_converted_to_sdk_error(
|
||||
self, mock_tool: MagicMock, clean_env: MonkeyPatch
|
||||
) -> None:
|
||||
"""Test get_adapter_config wraps ConnectionError as SdkError."""
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.side_effect = ConnectionError("Connection failed")
|
||||
|
||||
with pytest.raises(SdkError, match="Unable to connect to platform service"):
|
||||
PlatformHelper.get_adapter_config(mock_tool, "test-adapter-id")
|
||||
|
||||
def test_post_method_retry(
|
||||
self, platform_helper: PlatformHelper, clean_env: MonkeyPatch
|
||||
) -> None:
|
||||
"""Test POST requests also support retry."""
|
||||
payload = {"key": "value"}
|
||||
expected_response = {"status": "OK"}
|
||||
|
||||
with patch("requests.post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = expected_response
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
mock_post.side_effect = [
|
||||
ConnectionError("Transient failure"),
|
||||
mock_response,
|
||||
]
|
||||
|
||||
result = platform_helper._call_service(
|
||||
"test-endpoint", payload=payload, method="POST"
|
||||
)
|
||||
|
||||
assert result == expected_response
|
||||
assert mock_post.call_count == 2
|
||||
|
||||
def test_retry_logging(self, mock_tool: MagicMock, clean_env: MonkeyPatch) -> None:
|
||||
"""Test that retry attempts are logged."""
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {}
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
mock_get.side_effect = [
|
||||
ConnectionError("Transient failure"),
|
||||
mock_response,
|
||||
]
|
||||
|
||||
helper = PlatformHelper(
|
||||
tool=mock_tool,
|
||||
platform_host="http://localhost",
|
||||
platform_port="3001",
|
||||
)
|
||||
|
||||
helper._call_service("test-endpoint")
|
||||
|
||||
# Verify logging occurred
|
||||
mock_tool.stream_log.assert_called()
|
||||
log_calls = [str(c) for c in mock_tool.stream_log.call_args_list]
|
||||
assert any("retry" in call.lower() for call in log_calls)
|
||||
174
unstract/sdk1/tests/test_prompt.py
Normal file
174
unstract/sdk1/tests/test_prompt.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""Integration tests for prompt module with retry logic."""
|
||||
|
||||
from typing import Any, Self
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from pytest import MonkeyPatch
|
||||
from requests.exceptions import ConnectionError, Timeout
|
||||
from unstract.sdk1.prompt import PromptTool
|
||||
|
||||
|
||||
class TestPromptToolRetry:
|
||||
"""Tests for PromptTool retry functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool(self: Self) -> MagicMock:
|
||||
"""Create a mock tool for testing."""
|
||||
tool = MagicMock()
|
||||
tool.get_env_or_die.side_effect = lambda key: {
|
||||
"PLATFORM_API_KEY": "test-api-key",
|
||||
}.get(key, "mock-value")
|
||||
tool.stream_log = MagicMock()
|
||||
tool.stream_error_and_exit = MagicMock()
|
||||
return tool
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_tool(self: Self, mock_tool: MagicMock) -> PromptTool:
|
||||
"""Create a PromptTool instance."""
|
||||
return PromptTool(
|
||||
tool=mock_tool,
|
||||
prompt_host="http://localhost",
|
||||
prompt_port="3003",
|
||||
is_public_call=False,
|
||||
request_id="test-request-id",
|
||||
)
|
||||
|
||||
def test_success_on_first_attempt(
|
||||
self: Self, prompt_tool: PromptTool, clean_env: MonkeyPatch
|
||||
) -> None:
|
||||
"""Test successful service call on first attempt."""
|
||||
expected_response = {"result": "success"}
|
||||
payload = {"prompt": "test"}
|
||||
|
||||
with patch("requests.post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = expected_response
|
||||
mock_response.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = prompt_tool._call_service("answer-prompt", payload=payload)
|
||||
|
||||
assert result == expected_response
|
||||
assert mock_post.call_count == 1
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_type,error_instance",
|
||||
[
|
||||
("ConnectionError", ConnectionError("Connection failed")),
|
||||
("Timeout", Timeout("Request timed out")),
|
||||
],
|
||||
)
|
||||
def test_retry_on_errors(
|
||||
self: Self,
|
||||
prompt_tool: PromptTool,
|
||||
error_type: str,
|
||||
error_instance: Exception,
|
||||
clean_env: MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test service retries on ConnectionError and Timeout."""
|
||||
expected_response = {"result": "success"}
|
||||
payload = {"prompt": "test"}
|
||||
|
||||
with patch("requests.post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = expected_response
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
mock_post.side_effect = [
|
||||
error_instance,
|
||||
mock_response,
|
||||
]
|
||||
|
||||
result = prompt_tool._call_service("answer-prompt", payload=payload)
|
||||
|
||||
assert result == expected_response
|
||||
assert mock_post.call_count == 2
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_max_retries_exceeded(
|
||||
self: Self, mock_tool: MagicMock, clean_env: MonkeyPatch
|
||||
) -> None:
|
||||
"""Test service call fails after exceeding max retries."""
|
||||
prompt_tool = PromptTool(
|
||||
tool=mock_tool,
|
||||
prompt_host="http://localhost",
|
||||
prompt_port="3003",
|
||||
is_public_call=False,
|
||||
request_id="test-request-id",
|
||||
)
|
||||
|
||||
payload = {"prompt": "test"}
|
||||
|
||||
with patch("requests.post") as mock_post:
|
||||
mock_post.side_effect = ConnectionError("Persistent failure")
|
||||
|
||||
# Exception handled by decorator
|
||||
with pytest.raises(ConnectionError):
|
||||
prompt_tool._call_service("answer-prompt", payload=payload)
|
||||
|
||||
# Default: 3 retries + 1 initial = 4 attempts
|
||||
assert mock_post.call_count == 4
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method_name,payload",
|
||||
[
|
||||
("answer_prompt", {"prompts": ["test"]}),
|
||||
("index", {"document": "test"}),
|
||||
("extract", {"doc_id": "123"}),
|
||||
("summarize", {"text": "test"}),
|
||||
],
|
||||
)
|
||||
def test_wrapper_methods_retry(
|
||||
self: Self,
|
||||
prompt_tool: PromptTool,
|
||||
method_name: str,
|
||||
payload: dict[str, Any],
|
||||
clean_env: MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test that wrapper methods inherit retry behavior."""
|
||||
expected_response = {
|
||||
"answers": ["result"],
|
||||
"doc_id": "doc-123",
|
||||
"extracted_text": "text",
|
||||
"summary": "summary",
|
||||
}
|
||||
|
||||
with patch("requests.post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = expected_response
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
mock_post.side_effect = [
|
||||
ConnectionError("Transient failure"),
|
||||
mock_response,
|
||||
]
|
||||
|
||||
getattr(prompt_tool, method_name)(payload)
|
||||
|
||||
assert mock_post.call_count == 2
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_error_handling_with_retry(
|
||||
self: Self, mock_tool: MagicMock, clean_env: MonkeyPatch
|
||||
) -> None:
|
||||
"""Test error handling decorator works with retry."""
|
||||
prompt_tool = PromptTool(
|
||||
tool=mock_tool,
|
||||
prompt_host="http://localhost",
|
||||
prompt_port="3003",
|
||||
is_public_call=False,
|
||||
request_id="test-request-id",
|
||||
)
|
||||
|
||||
payload = {"prompt": "test"}
|
||||
|
||||
with patch("requests.post") as mock_post:
|
||||
mock_post.side_effect = ConnectionError("Persistent failure")
|
||||
|
||||
# Error handler should catch after all retries
|
||||
result = prompt_tool.answer_prompt(payload)
|
||||
|
||||
# handle_service_exceptions decorator calls stream_error_and_exit
|
||||
assert result is None
|
||||
prompt_tool.tool.stream_error_and_exit.assert_called()
|
||||
1
unstract/sdk1/tests/utils/__init__.py
Normal file
1
unstract/sdk1/tests/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for unstract-sdk1 utilities."""
|
||||
667
unstract/sdk1/tests/utils/test_retry_utils.py
Normal file
667
unstract/sdk1/tests/utils/test_retry_utils.py
Normal file
@@ -0,0 +1,667 @@
|
||||
"""Unit tests for retry_utils module."""
|
||||
|
||||
import errno
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from requests.exceptions import ConnectionError, HTTPError, Timeout
|
||||
from unstract.sdk1.utils.retry_utils import (
|
||||
calculate_delay,
|
||||
create_retry_decorator,
|
||||
is_retryable_error,
|
||||
retry_platform_service_call,
|
||||
retry_prompt_service_call,
|
||||
retry_with_exponential_backoff,
|
||||
)
|
||||
|
||||
|
||||
class TestIsRetryableError:
|
||||
"""Tests for is_retryable_error function."""
|
||||
|
||||
def test_connection_error_is_retryable(self) -> None:
|
||||
"""ConnectionError should be retryable."""
|
||||
error = ConnectionError("Connection failed")
|
||||
assert is_retryable_error(error) is True
|
||||
|
||||
def test_timeout_is_retryable(self) -> None:
|
||||
"""Timeout error should be retryable."""
|
||||
error = Timeout("Request timed out")
|
||||
assert is_retryable_error(error) is True
|
||||
|
||||
@pytest.mark.parametrize("status_code", [502, 503, 504])
|
||||
def test_http_error_retryable_status_codes(self, status_code: int) -> None:
|
||||
"""HTTPError with 502, 503, 504 should be retryable."""
|
||||
response = Mock()
|
||||
response.status_code = status_code
|
||||
error = HTTPError()
|
||||
error.response = response
|
||||
assert is_retryable_error(error) is True
|
||||
|
||||
@pytest.mark.parametrize("status_code", [400, 401, 403, 404, 500])
|
||||
def test_http_error_non_retryable_status_codes(self, status_code: int) -> None:
|
||||
"""HTTPError with other status codes should not be retryable."""
|
||||
response = Mock()
|
||||
response.status_code = status_code
|
||||
error = HTTPError()
|
||||
error.response = response
|
||||
assert is_retryable_error(error) is False
|
||||
|
||||
def test_http_error_without_response(self) -> None:
|
||||
"""HTTPError without response should not be retryable."""
|
||||
error = HTTPError()
|
||||
error.response = None
|
||||
assert is_retryable_error(error) is False
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"errno_code",
|
||||
[
|
||||
errno.ECONNREFUSED,
|
||||
getattr(errno, "ECONNRESET", 104),
|
||||
getattr(errno, "ETIMEDOUT", 110),
|
||||
getattr(errno, "EHOSTUNREACH", 113),
|
||||
getattr(errno, "ENETUNREACH", 101),
|
||||
],
|
||||
)
|
||||
def test_os_error_retryable_errno(self, errno_code: int) -> None:
|
||||
"""OSError with specific errno codes should be retryable."""
|
||||
error = OSError()
|
||||
error.errno = errno_code
|
||||
assert is_retryable_error(error) is True
|
||||
|
||||
def test_os_error_non_retryable_errno(self) -> None:
|
||||
"""OSError with other errno codes should not be retryable."""
|
||||
error = OSError()
|
||||
error.errno = errno.ENOENT # File not found
|
||||
assert is_retryable_error(error) is False
|
||||
|
||||
def test_other_exception_not_retryable(self) -> None:
|
||||
"""Other exceptions should not be retryable."""
|
||||
error = ValueError("Invalid value")
|
||||
assert is_retryable_error(error) is False
|
||||
|
||||
|
||||
class TestCalculateDelay:
|
||||
"""Tests for calculate_delay function."""
|
||||
|
||||
def test_exponential_backoff_without_jitter(self) -> None:
|
||||
"""Test exponential backoff calculation without jitter."""
|
||||
base_delay = 1.0
|
||||
multiplier = 2.0
|
||||
max_delay = 60.0
|
||||
|
||||
# Attempt 0: 1.0 * (2.0^0) = 1.0
|
||||
assert calculate_delay(
|
||||
0, base_delay, multiplier, max_delay, jitter=False
|
||||
) == pytest.approx(1.0)
|
||||
|
||||
# Attempt 1: 1.0 * (2.0^1) = 2.0
|
||||
assert calculate_delay(
|
||||
1, base_delay, multiplier, max_delay, jitter=False
|
||||
) == pytest.approx(2.0)
|
||||
|
||||
# Attempt 2: 1.0 * (2.0^2) = 4.0
|
||||
assert calculate_delay(
|
||||
2, base_delay, multiplier, max_delay, jitter=False
|
||||
) == pytest.approx(4.0)
|
||||
|
||||
# Attempt 3: 1.0 * (2.0^3) = 8.0
|
||||
assert calculate_delay(
|
||||
3, base_delay, multiplier, max_delay, jitter=False
|
||||
) == pytest.approx(8.0)
|
||||
|
||||
def test_exponential_backoff_with_jitter(self) -> None:
|
||||
"""Test exponential backoff calculation with jitter."""
|
||||
base_delay = 1.0
|
||||
multiplier = 2.0
|
||||
max_delay = 60.0
|
||||
|
||||
# With jitter, delay should be in range [base, base * 1.25]
|
||||
for attempt in range(4):
|
||||
base = base_delay * (multiplier**attempt)
|
||||
delay = calculate_delay(
|
||||
attempt, base_delay, multiplier, max_delay, jitter=True
|
||||
)
|
||||
assert base <= delay <= base * 1.25
|
||||
|
||||
def test_max_delay_cap(self) -> None:
|
||||
"""Test that max_delay caps the calculated delay."""
|
||||
base_delay = 1.0
|
||||
multiplier = 2.0
|
||||
max_delay = 5.0
|
||||
|
||||
# Attempt 10: 1.0 * (2.0^10) = 1024.0, but capped at 5.0
|
||||
delay = calculate_delay(10, base_delay, multiplier, max_delay, jitter=False)
|
||||
assert delay == pytest.approx(5.0)
|
||||
|
||||
def test_max_delay_cap_with_jitter(self) -> None:
|
||||
"""Test that max_delay caps the delay even with jitter."""
|
||||
base_delay = 1.0
|
||||
multiplier = 2.0
|
||||
max_delay = 5.0
|
||||
|
||||
# Even with jitter, should not exceed max_delay
|
||||
delay = calculate_delay(10, base_delay, multiplier, max_delay, jitter=True)
|
||||
assert delay <= max_delay
|
||||
|
||||
|
||||
class TestRetryWithExponentialBackoff:
|
||||
"""Tests for retry_with_exponential_backoff decorator."""
|
||||
|
||||
def test_successful_call_first_attempt(self, mock_logger: MagicMock) -> None:
|
||||
"""Test successful call on first attempt."""
|
||||
mock_func = Mock(return_value="success")
|
||||
|
||||
decorator = retry_with_exponential_backoff(
|
||||
max_retries=3,
|
||||
max_time=60.0,
|
||||
base_delay=1.0,
|
||||
multiplier=2.0,
|
||||
jitter=False,
|
||||
exceptions=(Exception,),
|
||||
logger_instance=mock_logger,
|
||||
prefix="TEST",
|
||||
)
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
result = decorated_func()
|
||||
|
||||
assert result == "success"
|
||||
assert mock_func.call_count == 1
|
||||
# Should not log retry success message for first attempt
|
||||
mock_logger.info.assert_not_called()
|
||||
|
||||
def test_retry_after_transient_failure(self, mock_logger: MagicMock) -> None:
|
||||
"""Test retry after transient failure."""
|
||||
mock_func = Mock(
|
||||
side_effect=[ConnectionError("Failed"), "success"], __name__="test_func"
|
||||
)
|
||||
|
||||
decorator = retry_with_exponential_backoff(
|
||||
max_retries=3,
|
||||
max_time=60.0,
|
||||
base_delay=0.1, # Short delay for testing
|
||||
multiplier=2.0,
|
||||
jitter=False,
|
||||
exceptions=(ConnectionError,),
|
||||
logger_instance=mock_logger,
|
||||
prefix="TEST",
|
||||
)
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
result = decorated_func()
|
||||
|
||||
assert result == "success"
|
||||
assert mock_func.call_count == 2
|
||||
# Should log success after retry
|
||||
mock_logger.info.assert_called_once()
|
||||
assert "Successfully completed" in str(mock_logger.info.call_args)
|
||||
|
||||
def test_max_retries_exceeded(self, mock_logger: MagicMock) -> None:
|
||||
"""Test that max retries causes failure."""
|
||||
mock_func = Mock(
|
||||
side_effect=ConnectionError("Always fails"), __name__="test_func"
|
||||
)
|
||||
|
||||
decorator = retry_with_exponential_backoff(
|
||||
max_retries=2,
|
||||
max_time=60.0,
|
||||
base_delay=0.1,
|
||||
multiplier=2.0,
|
||||
jitter=False,
|
||||
exceptions=(ConnectionError,),
|
||||
logger_instance=mock_logger,
|
||||
prefix="TEST",
|
||||
)
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
with pytest.raises(ConnectionError, match="Always fails"):
|
||||
decorated_func()
|
||||
|
||||
# Should attempt 3 times (initial + 2 retries)
|
||||
assert mock_func.call_count == 3
|
||||
# Should log giving up
|
||||
mock_logger.exception.assert_called()
|
||||
|
||||
def test_max_time_exceeded(self, mock_logger: MagicMock) -> None:
|
||||
"""Test that max time causes failure."""
|
||||
mock_func = Mock(
|
||||
side_effect=ConnectionError("Always fails"), __name__="test_func"
|
||||
)
|
||||
|
||||
decorator = retry_with_exponential_backoff(
|
||||
max_retries=10,
|
||||
max_time=0.5, # Very short max time
|
||||
base_delay=0.2,
|
||||
multiplier=2.0,
|
||||
jitter=False,
|
||||
exceptions=(ConnectionError,),
|
||||
logger_instance=mock_logger,
|
||||
prefix="TEST",
|
||||
)
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
with pytest.raises(ConnectionError):
|
||||
decorated_func()
|
||||
|
||||
# Should fail before reaching max retries due to time limit
|
||||
assert mock_func.call_count < 10
|
||||
|
||||
def test_retry_with_custom_predicate(self, mock_logger: MagicMock) -> None:
|
||||
"""Test retry with custom predicate."""
|
||||
|
||||
def custom_predicate(e: Exception) -> bool:
|
||||
# Only retry if message contains "retry"
|
||||
return "retry" in str(e)
|
||||
|
||||
mock_func = Mock(
|
||||
side_effect=[Exception("retry please"), "success"], __name__="test_func"
|
||||
)
|
||||
|
||||
decorator = retry_with_exponential_backoff(
|
||||
max_retries=3,
|
||||
max_time=60.0,
|
||||
base_delay=0.1,
|
||||
multiplier=2.0,
|
||||
jitter=False,
|
||||
exceptions=(Exception,),
|
||||
logger_instance=mock_logger,
|
||||
prefix="TEST",
|
||||
retry_predicate=custom_predicate,
|
||||
)
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
result = decorated_func()
|
||||
|
||||
assert result == "success"
|
||||
assert mock_func.call_count == 2
|
||||
|
||||
def test_no_retry_with_predicate_false(self, mock_logger: MagicMock) -> None:
|
||||
"""Test no retry when predicate returns False."""
|
||||
|
||||
def custom_predicate(e: Exception) -> bool:
|
||||
return False
|
||||
|
||||
mock_func = Mock(__name__="test_func", side_effect=Exception("Error"))
|
||||
|
||||
decorator = retry_with_exponential_backoff(
|
||||
max_retries=3,
|
||||
max_time=60.0,
|
||||
base_delay=0.1,
|
||||
multiplier=2.0,
|
||||
jitter=False,
|
||||
exceptions=(Exception,),
|
||||
logger_instance=mock_logger,
|
||||
prefix="TEST",
|
||||
retry_predicate=custom_predicate,
|
||||
)
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
with pytest.raises(Exception, match="Error"):
|
||||
decorated_func()
|
||||
|
||||
# Should not retry
|
||||
assert mock_func.call_count == 1
|
||||
|
||||
def test_exception_not_in_tuple_not_retried(self, mock_logger: MagicMock) -> None:
|
||||
"""Test that exceptions not in the tuple are not retried."""
|
||||
mock_func = Mock(__name__="test_func", side_effect=ValueError("Not retryable"))
|
||||
|
||||
decorator = retry_with_exponential_backoff(
|
||||
max_retries=3,
|
||||
max_time=60.0,
|
||||
base_delay=0.1,
|
||||
multiplier=2.0,
|
||||
jitter=False,
|
||||
exceptions=(ConnectionError,), # Only ConnectionError
|
||||
logger_instance=mock_logger,
|
||||
prefix="TEST",
|
||||
)
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
with pytest.raises(ValueError, match="Not retryable"):
|
||||
decorated_func()
|
||||
|
||||
# Should not retry
|
||||
assert mock_func.call_count == 1
|
||||
|
||||
def test_delay_would_exceed_max_time(self, mock_logger: MagicMock) -> None:
|
||||
"""Test that delay exceeding max time causes immediate failure."""
|
||||
mock_func = Mock(
|
||||
__name__="test_func", side_effect=ConnectionError("Always fails")
|
||||
)
|
||||
|
||||
decorator = retry_with_exponential_backoff(
|
||||
max_retries=10,
|
||||
max_time=0.3, # Very short max time
|
||||
base_delay=1.0, # Large delay
|
||||
multiplier=2.0,
|
||||
jitter=False,
|
||||
exceptions=(ConnectionError,),
|
||||
logger_instance=mock_logger,
|
||||
prefix="TEST",
|
||||
)
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
with pytest.raises(ConnectionError):
|
||||
decorated_func()
|
||||
|
||||
# Should fail quickly due to delay exceeding remaining time
|
||||
mock_logger.exception.assert_called()
|
||||
exception_calls = [str(c) for c in mock_logger.exception.call_args_list]
|
||||
assert any("would exceed max time" in c for c in exception_calls)
|
||||
|
||||
|
||||
class TestCreateRetryDecorator:
|
||||
"""Tests for create_retry_decorator function."""
|
||||
|
||||
def test_default_configuration(
|
||||
self, clean_env: MonkeyPatch, mock_logger: MagicMock
|
||||
) -> None:
|
||||
"""Test decorator with default configuration."""
|
||||
decorator = create_retry_decorator("TEST_SERVICE", logger_instance=mock_logger)
|
||||
|
||||
mock_func = Mock(return_value="success")
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
result = decorated_func()
|
||||
|
||||
assert result == "success"
|
||||
|
||||
def test_environment_variable_configuration(
|
||||
self, clean_env: MonkeyPatch, set_env: Callable[..., None], mock_logger: MagicMock
|
||||
) -> None:
|
||||
"""Test decorator reads configuration from environment."""
|
||||
set_env(
|
||||
"TEST_SERVICE",
|
||||
max_retries=5,
|
||||
max_time=120,
|
||||
base_delay=2.0,
|
||||
multiplier=3.0,
|
||||
jitter="false",
|
||||
)
|
||||
|
||||
mock_func = Mock(
|
||||
__name__="test_func", side_effect=[ConnectionError("Failed"), "success"]
|
||||
)
|
||||
|
||||
decorator = create_retry_decorator("TEST_SERVICE", logger_instance=mock_logger)
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
result = decorated_func()
|
||||
|
||||
assert result == "success"
|
||||
assert mock_func.call_count == 2
|
||||
|
||||
def test_invalid_max_retries(
|
||||
self, clean_env: MonkeyPatch, set_env: Callable[..., None]
|
||||
) -> None:
|
||||
"""Test that negative max_retries raises error."""
|
||||
set_env("TEST_SERVICE", max_retries=-1)
|
||||
|
||||
with pytest.raises(ValueError, match="MAX_RETRIES must be >= 0"):
|
||||
create_retry_decorator("TEST_SERVICE")
|
||||
|
||||
def test_invalid_max_time(
|
||||
self, clean_env: MonkeyPatch, set_env: Callable[..., None]
|
||||
) -> None:
|
||||
"""Test that non-positive max_time raises error."""
|
||||
set_env("TEST_SERVICE", max_time=0)
|
||||
|
||||
with pytest.raises(ValueError, match="MAX_TIME must be > 0"):
|
||||
create_retry_decorator("TEST_SERVICE")
|
||||
|
||||
def test_invalid_base_delay(
|
||||
self, clean_env: MonkeyPatch, set_env: Callable[..., None]
|
||||
) -> None:
|
||||
"""Test that non-positive base_delay raises error."""
|
||||
set_env("TEST_SERVICE", base_delay=-0.5)
|
||||
|
||||
with pytest.raises(ValueError, match="BASE_DELAY must be > 0"):
|
||||
create_retry_decorator("TEST_SERVICE")
|
||||
|
||||
def test_invalid_multiplier(
|
||||
self, clean_env: MonkeyPatch, set_env: Callable[..., None]
|
||||
) -> None:
|
||||
"""Test that non-positive multiplier raises error."""
|
||||
set_env("TEST_SERVICE", multiplier=0)
|
||||
|
||||
with pytest.raises(ValueError, match="MULTIPLIER must be > 0"):
|
||||
create_retry_decorator("TEST_SERVICE")
|
||||
|
||||
@pytest.mark.parametrize("jitter_value", ["true", "false"])
|
||||
def test_jitter_values(
|
||||
self,
|
||||
jitter_value: str,
|
||||
clean_env: MonkeyPatch,
|
||||
set_env: Callable[..., None],
|
||||
mock_logger: MagicMock,
|
||||
) -> None:
|
||||
"""Test jitter configuration values."""
|
||||
set_env("TEST_SERVICE", jitter=jitter_value)
|
||||
|
||||
decorator = create_retry_decorator("TEST_SERVICE", logger_instance=mock_logger)
|
||||
|
||||
# Should not raise error
|
||||
assert decorator is not None
|
||||
|
||||
def test_custom_exceptions_only(
|
||||
self, clean_env: MonkeyPatch, mock_logger: MagicMock
|
||||
) -> None:
|
||||
"""Test decorator with custom exceptions and no predicate."""
|
||||
decorator = create_retry_decorator(
|
||||
"TEST_SERVICE",
|
||||
exceptions=(ValueError, TypeError),
|
||||
logger_instance=mock_logger,
|
||||
)
|
||||
|
||||
mock_func = Mock(
|
||||
__name__="test_func", side_effect=[ValueError("Error"), "success"]
|
||||
)
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
result = decorated_func()
|
||||
|
||||
assert result == "success"
|
||||
assert mock_func.call_count == 2
|
||||
|
||||
def test_custom_predicate_only(
|
||||
self, clean_env: MonkeyPatch, mock_logger: MagicMock
|
||||
) -> None:
|
||||
"""Test decorator with custom predicate and no exceptions."""
|
||||
|
||||
def custom_predicate(e: Exception) -> bool:
|
||||
return isinstance(e, ValueError)
|
||||
|
||||
decorator = create_retry_decorator(
|
||||
"TEST_SERVICE",
|
||||
retry_predicate=custom_predicate,
|
||||
logger_instance=mock_logger,
|
||||
)
|
||||
|
||||
mock_func = Mock(
|
||||
__name__="test_func", side_effect=[ValueError("Error"), "success"]
|
||||
)
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
result = decorated_func()
|
||||
|
||||
assert result == "success"
|
||||
assert mock_func.call_count == 2
|
||||
|
||||
def test_both_exceptions_and_predicate(
|
||||
self, clean_env: MonkeyPatch, mock_logger: MagicMock
|
||||
) -> None:
|
||||
"""Test decorator with both exceptions and predicate."""
|
||||
|
||||
def custom_predicate(e: Exception) -> bool:
|
||||
return "retry" in str(e)
|
||||
|
||||
decorator = create_retry_decorator(
|
||||
"TEST_SERVICE",
|
||||
exceptions=(ValueError,),
|
||||
retry_predicate=custom_predicate,
|
||||
logger_instance=mock_logger,
|
||||
)
|
||||
|
||||
# Should retry - ValueError with "retry" in message
|
||||
mock_func = Mock(
|
||||
__name__="test_func", side_effect=[ValueError("retry please"), "success"]
|
||||
)
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
result = decorated_func()
|
||||
|
||||
assert result == "success"
|
||||
assert mock_func.call_count == 2
|
||||
|
||||
def test_exceptions_match_but_predicate_false(
|
||||
self, clean_env: MonkeyPatch, mock_logger: MagicMock
|
||||
) -> None:
|
||||
"""Test that predicate can prevent retry even if exception matches."""
|
||||
|
||||
def custom_predicate(e: Exception) -> bool:
|
||||
return "retry" in str(e)
|
||||
|
||||
decorator = create_retry_decorator(
|
||||
"TEST_SERVICE",
|
||||
exceptions=(ValueError,),
|
||||
retry_predicate=custom_predicate,
|
||||
logger_instance=mock_logger,
|
||||
)
|
||||
|
||||
# ValueError matches but predicate returns False (no "retry" substring)
|
||||
mock_func = Mock(
|
||||
__name__="test_func", side_effect=ValueError("do not attempt again")
|
||||
)
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
with pytest.raises(ValueError, match="do not attempt again"):
|
||||
decorated_func()
|
||||
|
||||
# Should not retry
|
||||
assert mock_func.call_count == 1
|
||||
|
||||
|
||||
class TestPreconfiguredDecorators:
|
||||
"""Tests for pre-configured decorators."""
|
||||
|
||||
def test_retry_platform_service_call_exists(self) -> None:
|
||||
"""Test that retry_platform_service_call decorator exists."""
|
||||
assert retry_platform_service_call is not None
|
||||
|
||||
def test_retry_prompt_service_call_exists(self) -> None:
|
||||
"""Test that retry_prompt_service_call decorator exists."""
|
||||
assert retry_prompt_service_call is not None
|
||||
|
||||
def test_platform_service_decorator_retries_on_connection_error(
|
||||
self, clean_env: MonkeyPatch
|
||||
) -> None:
|
||||
"""Test platform service decorator retries ConnectionError."""
|
||||
mock_func = Mock(
|
||||
__name__="test_func", side_effect=[ConnectionError("Failed"), "success"]
|
||||
)
|
||||
|
||||
decorated_func = retry_platform_service_call(mock_func)
|
||||
|
||||
result = decorated_func()
|
||||
|
||||
assert result == "success"
|
||||
assert mock_func.call_count == 2
|
||||
|
||||
def test_prompt_service_decorator_retries_on_timeout(
|
||||
self, clean_env: MonkeyPatch
|
||||
) -> None:
|
||||
"""Test prompt service decorator retries Timeout."""
|
||||
mock_func = Mock(
|
||||
__name__="test_func", side_effect=[Timeout("Timed out"), "success"]
|
||||
)
|
||||
|
||||
decorated_func = retry_prompt_service_call(mock_func)
|
||||
|
||||
result = decorated_func()
|
||||
|
||||
assert result == "success"
|
||||
assert mock_func.call_count == 2
|
||||
|
||||
|
||||
class TestRetryLogging:
|
||||
"""Tests for retry logging behavior."""
|
||||
|
||||
def test_warning_logged_on_retry(self, mock_logger: MagicMock) -> None:
|
||||
"""Test that warning is logged on retry attempt."""
|
||||
mock_func = Mock(
|
||||
__name__="test_func", side_effect=[ConnectionError("Failed"), "success"]
|
||||
)
|
||||
|
||||
decorator = retry_with_exponential_backoff(
|
||||
max_retries=3,
|
||||
max_time=60.0,
|
||||
base_delay=0.1,
|
||||
multiplier=2.0,
|
||||
jitter=False,
|
||||
exceptions=(ConnectionError,),
|
||||
logger_instance=mock_logger,
|
||||
prefix="TEST",
|
||||
)
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
decorated_func()
|
||||
|
||||
# Should log warning about retry
|
||||
mock_logger.warning.assert_called_once()
|
||||
warning_msg = str(mock_logger.warning.call_args)
|
||||
assert "Retry" in warning_msg
|
||||
assert "TEST" in warning_msg
|
||||
|
||||
def test_info_logged_on_success_after_retry(self, mock_logger: MagicMock) -> None:
|
||||
"""Test that info is logged when successful after retry."""
|
||||
mock_func = Mock(
|
||||
__name__="test_func", side_effect=[ConnectionError("Failed"), "success"]
|
||||
)
|
||||
|
||||
decorator = retry_with_exponential_backoff(
|
||||
max_retries=3,
|
||||
max_time=60.0,
|
||||
base_delay=0.1,
|
||||
multiplier=2.0,
|
||||
jitter=False,
|
||||
exceptions=(ConnectionError,),
|
||||
logger_instance=mock_logger,
|
||||
prefix="TEST",
|
||||
)
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
decorated_func()
|
||||
|
||||
# Should log success after retry
|
||||
mock_logger.info.assert_called_once()
|
||||
info_msg = str(mock_logger.info.call_args)
|
||||
assert "Successfully completed" in info_msg
|
||||
|
||||
def test_exception_logged_on_giving_up(self, mock_logger: MagicMock) -> None:
|
||||
"""Test that exception is logged when giving up."""
|
||||
mock_func = Mock(
|
||||
__name__="test_func", side_effect=ConnectionError("Always fails")
|
||||
)
|
||||
|
||||
decorator = retry_with_exponential_backoff(
|
||||
max_retries=1,
|
||||
max_time=60.0,
|
||||
base_delay=0.1,
|
||||
multiplier=2.0,
|
||||
jitter=False,
|
||||
exceptions=(ConnectionError,),
|
||||
logger_instance=mock_logger,
|
||||
prefix="TEST",
|
||||
)
|
||||
decorated_func = decorator(mock_func)
|
||||
|
||||
with pytest.raises(ConnectionError):
|
||||
decorated_func()
|
||||
|
||||
# Should log exception when giving up
|
||||
mock_logger.exception.assert_called()
|
||||
exception_msg = str(mock_logger.exception.call_args)
|
||||
assert "Giving up" in exception_msg or "exceeded" in exception_msg
|
||||
1651
unstract/sdk1/uv.lock
generated
1651
unstract/sdk1/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -11,7 +11,7 @@ dependencies = [
|
||||
"docker~=6.1.3",
|
||||
"jsonschema>=4.18.6,<5.0",
|
||||
"PyYAML~=6.0.1",
|
||||
"unstract-sdk~=0.77.1",
|
||||
"unstract-sdk~=0.78.0",
|
||||
"unstract-tool-sandbox",
|
||||
"unstract-flags",
|
||||
]
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
"schemaVersion": "0.0.1",
|
||||
"displayName": "File Classifier",
|
||||
"functionName": "classify",
|
||||
"toolVersion": "0.0.68",
|
||||
"toolVersion": "0.0.69",
|
||||
"description": "Classifies a file into a bin based on its contents",
|
||||
"input": {
|
||||
"description": "File to be classified"
|
||||
@@ -106,9 +106,9 @@
|
||||
"properties": {}
|
||||
},
|
||||
"icon": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<svg\n enable-background=\"new 0 0 20 20\"\n height=\"48\"\n viewBox=\"0 0 20 20\"\n width=\"48\"\n fill=\"#000000\"\n version=\"1.1\"\n id=\"svg8109\"\n sodipodi:docname=\"folder_copy_black_48dp.svg\"\n xmlns:inkscape=\"http://www.inkscape.org/namespaces/inkscape\"\n xmlns:sodipodi=\"http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd\"\n xmlns=\"http://www.w3.org/2000/svg\"\n xmlns:svg=\"http://www.w3.org/2000/svg\">\n <defs\n id=\"defs8113\" />\n <sodipodi:namedview\n id=\"namedview8111\"\n pagecolor=\"#ffffff\"\n bordercolor=\"#000000\"\n borderopacity=\"0.25\"\n inkscape:showpageshadow=\"2\"\n inkscape:pageopacity=\"0.0\"\n inkscape:pagecheckerboard=\"0\"\n inkscape:deskcolor=\"#d1d1d1\"\n showgrid=\"false\" />\n <g\n id=\"g8099\">\n <rect\n fill=\"none\"\n height=\"20\"\n width=\"20\"\n x=\"0\"\n id=\"rect8097\"\n y=\"0\" />\n </g>\n <g\n id=\"g8107\"\n style=\"fill:#ff4d6d;fill-opacity:1\">\n <g\n id=\"g8105\"\n style=\"fill:#ff4d6d;fill-opacity:1\">\n <path\n d=\"M 2.5,5 H 1 V 15.5 C 1,16.33 1.67,17 2.5,17 H 15.68 V 15.5 H 2.5 Z\"\n id=\"path8101\"\n style=\"fill:#ff4d6d;fill-opacity:1\" />\n <path\n d=\"M 16.5,4 H 11 L 9,2 H 5.5 C 4.67,2 4,2.67 4,3.5 v 9 C 4,13.33 4.67,14 5.5,14 h 11 c 0.83,0 1.5,-0.67 1.5,-1.5 v -7 C 18,4.67 17.33,4 16.5,4 Z m 0,8.5 h -11 v -9 h 2.88 l 2,2 h 6.12 z\"\n id=\"path8103\"\n style=\"fill:#ff4d6d;fill-opacity:1\" />\n </g>\n </g>\n</svg>\n",
|
||||
"image_url": "docker:unstract/tool-classifier:0.0.68",
|
||||
"image_url": "docker:unstract/tool-classifier:0.0.69",
|
||||
"image_name": "unstract/tool-classifier",
|
||||
"image_tag": "0.0.68"
|
||||
"image_tag": "0.0.69"
|
||||
},
|
||||
"text_extractor": {
|
||||
"tool_uid": "text_extractor",
|
||||
@@ -116,7 +116,7 @@
|
||||
"schemaVersion": "0.0.1",
|
||||
"displayName": "Text Extractor",
|
||||
"functionName": "text_extractor",
|
||||
"toolVersion": "0.0.64",
|
||||
"toolVersion": "0.0.65",
|
||||
"description": "The Text Extractor is a powerful tool designed to convert documents to its text form or Extract texts from documents",
|
||||
"input": {
|
||||
"description": "Document"
|
||||
@@ -191,8 +191,8 @@
|
||||
}
|
||||
},
|
||||
"icon": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<svg\n enable-background=\"new 0 0 20 20\"\n height=\"48\"\n viewBox=\"0 0 20 20\"\n width=\"48\"\n fill=\"#000000\"\n version=\"1.1\"\n id=\"svg8109\"\n sodipodi:docname=\"folder_copy_black_48dp.svg\"\n xmlns:inkscape=\"http://www.inkscape.org/namespaces/inkscape\"\n xmlns:sodipodi=\"http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd\"\n xmlns=\"http://www.w3.org/2000/svg\"\n xmlns:svg=\"http://www.w3.org/2000/svg\">\n <defs\n id=\"defs8113\" />\n <sodipodi:namedview\n id=\"namedview8111\"\n pagecolor=\"#ffffff\"\n bordercolor=\"#000000\"\n borderopacity=\"0.25\"\n inkscape:showpageshadow=\"2\"\n inkscape:pageopacity=\"0.0\"\n inkscape:pagecheckerboard=\"0\"\n inkscape:deskcolor=\"#d1d1d1\"\n showgrid=\"false\" />\n <g\n id=\"g8099\">\n <rect\n fill=\"none\"\n height=\"20\"\n width=\"20\"\n x=\"0\"\n id=\"rect8097\"\n y=\"0\" />\n </g>\n <g\n id=\"g8107\"\n style=\"fill:#ff4d6d;fill-opacity:1\">\n <g\n id=\"g8105\"\n style=\"fill:#ff4d6d;fill-opacity:1\">\n <path\n d=\"M 2.5,5 H 1 V 15.5 C 1,16.33 1.67,17 2.5,17 H 15.68 V 15.5 H 2.5 Z\"\n id=\"path8101\"\n style=\"fill:#ff4d6d;fill-opacity:1\" />\n <path\n d=\"M 16.5,4 H 11 L 9,2 H 5.5 C 4.67,2 4,2.67 4,3.5 v 9 C 4,13.33 4.67,14 5.5,14 h 11 c 0.83,0 1.5,-0.67 1.5,-1.5 v -7 C 18,4.67 17.33,4 16.5,4 Z m 0,8.5 h -11 v -9 h 2.88 l 2,2 h 6.12 z\"\n id=\"path8103\"\n style=\"fill:#ff4d6d;fill-opacity:1\" />\n </g>\n </g>\n</svg>\n",
|
||||
"image_url": "docker:unstract/tool-text-extractor:0.0.64",
|
||||
"image_url": "docker:unstract/tool-text-extractor:0.0.65",
|
||||
"image_name": "unstract/tool-text-extractor",
|
||||
"image_tag": "0.0.64"
|
||||
"image_tag": "0.0.65"
|
||||
}
|
||||
}
|
||||
|
||||
1452
unstract/tool-registry/uv.lock
generated
1452
unstract/tool-registry/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -22,7 +22,7 @@ dependencies = [
|
||||
"prometheus-client>=0.17.0,<1.0.0", # Metrics collection
|
||||
"psutil>=5.9.0,<6.0.0", # System resource monitoring
|
||||
# Essential Unstract packages - with Azure support for connectors
|
||||
"unstract-sdk[azure]~=0.77.3", # Core SDK with Azure connector support
|
||||
"unstract-sdk[azure]~=0.78.0", # Core SDK with Azure connector support
|
||||
"unstract-connectors",
|
||||
"unstract-core",
|
||||
"unstract-flags",
|
||||
|
||||
3352
workers/uv.lock
generated
3352
workers/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user