UN-2793 [FEAT] Add retry logic with exponential backoff to SDK1 (#1564)

* UN-2793 [FEAT] Add retry logic with exponential backoff to SDK1

Implemented automatic retry logic for platform and prompt service calls
with configurable exponential backoff, comprehensive test coverage, and
CI integration.

Features:
- Exponential backoff with jitter for transient failures
- Configurable via environment variables (MAX_RETRIES, MAX_TIME, BASE_DELAY, etc.)
- Retries ConnectionError, Timeout, HTTPError (502/503/504), OSError
- 67 tests with 100% pass rate
- CI integration with test reporting

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* [SECURITY] Use full commit SHA for sticky-pull-request-comment action

Replace tag reference with full commit SHA for better security:
- marocchino/sticky-pull-request-comment@v2 → @7737449 (v2.9.4)

This prevents potential supply chain attacks where tags could be moved
to point to malicious code. Commit SHAs are immutable.

Fixes SonarQube security hotspot for external GitHub action.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* [FIX] Allow retryable HTTP errors (502/503/504) to propagate for retry

Fixed HTTPError handling in _get_adapter_configuration to check status
codes and re-raise retryable errors (502, 503, 504) so the retry
decorator can handle them. Non-retryable errors are still converted
to SdkError as before.

Changes:
- Check HTTPError status code before converting to SdkError
- Re-raise HTTPError for 502/503/504 to allow retry decorator to retry
- Added parametrized test for all retryable status codes (502, 503, 504)
- All 12 platform tests pass

This fixes a bug where 502/503/504 errors were not being retried
because they were converted to SdkError before the retry decorator
could see them.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* [FIX] Use pytest.approx() for floating point comparisons in tests

Replaced direct equality comparisons (==) with pytest.approx() for
floating point values to avoid precision issues and satisfy SonarQube
code quality check (python:S1244).

Changes in test_retry_utils.py:
- test_exponential_backoff_without_jitter: Use pytest.approx() for 1.0, 2.0, 4.0, 8.0
- test_max_delay_cap: Use pytest.approx() for 5.0

This is the proper way to compare floating point values in tests,
accounting for floating point precision limitations.

All 4 TestCalculateDelay tests pass.

Fixes SonarQube: python:S1244 - Do not perform equality checks with
floating point values.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* minor: Addressed code smells, ruff fixes

* misc: Fixed tox config for sdk1 tests

* misc: Ruff issues fixed

* misc: tox tests fixed

* prompt service lock file for venv

* updated lock files for backend and prompt-service

* UN-2793 [FEAT] Update to unstract-sdk v0.78.0 with retry logic support (#1567)

[FEAT] Update unstract-sdk to v0.78.0 across all services and tools

- Updated unstract-sdk dependency from v0.77.3 to v0.78.0 in all pyproject.toml files
  - Main repository, backend, workers, platform-service, prompt-service
  - filesystem and tool-registry modules
- Updated tool requirements.txt files (structure, classifier, text_extractor)
- Bumped tool versions in properties.json:
  - Structure tool: 0.0.88 → 0.0.89
  - Classifier tool: 0.0.68 → 0.0.69
  - Text extractor tool: 0.0.64 → 0.0.65
- Updated tool versions in backend/sample.env and public_tools.json
- Regenerated all uv.lock files with new SDK version

This update brings in the retry logic with exponential backoff from unstract-sdk v0.78.0

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>

---------

Signed-off-by: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Chandrasekharan M
2025-10-09 10:48:19 +05:30
committed by GitHub
parent 9997ee10e5
commit 0c0c8c1034
36 changed files with 8616 additions and 6896 deletions

View File

@@ -48,13 +48,22 @@ jobs:
run: |
tox
- name: Render the report to the PR
uses: marocchino/sticky-pull-request-comment@v2
- name: Render the Runner report to the PR
uses: marocchino/sticky-pull-request-comment@773744901bac0e8cbb5a0dc842800d45e9b2b405 # v2.9.4
if: always() && hashFiles('runner-report.md') != ''
with:
header: runner-test-report
recreate: true
path: runner-report.md
- name: Render the SDK1 report to the PR
uses: marocchino/sticky-pull-request-comment@773744901bac0e8cbb5a0dc842800d45e9b2b405 # v2.9.4
if: always() && hashFiles('sdk1-report.md') != ''
with:
header: sdk1-test-report
recreate: true
path: sdk1-report.md
- name: Output reports to the job summary when tests fail
shell: bash
run: |
@@ -65,3 +74,10 @@ jobs:
echo "" >> $GITHUB_STEP_SUMMARY
echo "</details>" >> $GITHUB_STEP_SUMMARY
fi
if [ -f "sdk1-report.md" ]; then
echo "<details><summary>SDK1 Test Report</summary>" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
cat "sdk1-report.md" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "</details>" >> $GITHUB_STEP_SUMMARY
fi

View File

@@ -35,9 +35,7 @@ dependencies = [
"python-socketio==5.9.0", # For log_events
"social-auth-app-django==5.3.0", # For OAuth
"social-auth-core==4.4.2", # For OAuth
# TODO: Temporarily removing the extra dependencies of aws and gcs from unstract-sdk
# to resolve lock file. Will have to be re-looked into
"unstract-sdk[azure]~=0.77.3",
"unstract-sdk[aws,gcs,azure]~=0.78.0",
"azure-identity==1.16.0",
"azure-mgmt-apimanagement==3.0.0",
"croniter>=3.0.3",

View File

@@ -78,9 +78,9 @@ PROMPT_STUDIO_FILE_PATH=/app/prompt-studio-data
# Structure Tool Image (Runs prompt studio exported tools)
# https://hub.docker.com/r/unstract/tool-structure
STRUCTURE_TOOL_IMAGE_URL="docker:unstract/tool-structure:0.0.88"
STRUCTURE_TOOL_IMAGE_URL="docker:unstract/tool-structure:0.0.89"
STRUCTURE_TOOL_IMAGE_NAME="unstract/tool-structure"
STRUCTURE_TOOL_IMAGE_TAG="0.0.88"
STRUCTURE_TOOL_IMAGE_TAG="0.0.89"
# Feature Flags
EVALUATION_SERVER_IP=unstract-flipt

2137
backend/uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -15,7 +15,7 @@ dependencies = [
"redis~=5.2.1",
"cryptography>=41.0.7",
"requests>=2.31.0",
"unstract-sdk[gcs, azure, aws]~=0.77.3", # Add version
"unstract-sdk[gcs, azure, aws]~=0.78.0", # Add version
"unstract-flags",
"unstract-core[flask]",
"unstract-sdk1[gcs, azure, aws]"

View File

@@ -3184,7 +3184,7 @@ requires-dist = [
{ name = "requests", specifier = ">=2.31.0" },
{ name = "unstract-core", extras = ["flask"], editable = "../unstract/core" },
{ name = "unstract-flags", editable = "../unstract/flags" },
{ name = "unstract-sdk", extras = ["gcs", "azure", "aws"], specifier = "~=0.77.3" },
{ name = "unstract-sdk", extras = ["gcs", "azure", "aws"], specifier = "~=0.78.0" },
{ name = "unstract-sdk1", extras = ["gcs", "azure", "aws"], editable = "../unstract/sdk1" },
]
@@ -3202,7 +3202,7 @@ test = [{ name = "pytest", specifier = ">=8.0.1" }]
[[package]]
name = "unstract-sdk"
version = "0.77.3"
version = "0.78.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filetype" },
@@ -3241,9 +3241,9 @@ dependencies = [
{ name = "tiktoken" },
{ name = "transformers" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c5/e8/f7e1045fee076c75c42bb27b4fa5a077836cfa8fd3fd580f6c6193fa19dc/unstract_sdk-0.77.3.tar.gz", hash = "sha256:378c19129a91e861b7235e92411ef72792c6c851320cb7897b380d08d3489c9d", size = 2375843 }
sdist = { url = "https://files.pythonhosted.org/packages/7b/b9/ca49124e2548e7baa57c00dab2300d59a1dac13ccc6b0d08e5c81f3f7b94/unstract_sdk-0.78.0.tar.gz", hash = "sha256:35cf15acd946996be3871b35aec099b9fb03d30e96893dd5f34991a8aa2535a0", size = 2364734 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d5/84/d596295fce3a713b1953a4fd7be3be1d788920e5633e28fbea7c7331b68a/unstract_sdk-0.77.3-py3-none-any.whl", hash = "sha256:74e83cbb68eef98fbaecfc83e0cfd2228e5958506e1b51a65fae0d891027aaee", size = 266640 },
{ url = "https://files.pythonhosted.org/packages/ac/23/2f0f9cf0f4e489a4be8133622e968f8a7dd639f0f582252625ca3a22a749/unstract_sdk-0.78.0-py3-none-any.whl", hash = "sha256:ecaf24386c1eaefa47e6c06d1f978ca26406cc8cfffcbe9d96197fa1b2fba826", size = 270348 },
]
[package.optional-dependencies]
@@ -3335,6 +3335,8 @@ docs = [{ name = "lazydocs", specifier = "~=0.4.8" }]
test = [
{ name = "parameterized", specifier = "==0.9.0" },
{ name = "pytest", specifier = "==8.3.3" },
{ name = "pytest-cov", specifier = ">=6.0.0" },
{ name = "pytest-md-report", specifier = ">=0.6.2" },
{ name = "pytest-mock", specifier = "==3.14.0" },
]

View File

@@ -15,9 +15,7 @@ dependencies = [
"python-dotenv==1.0.1",
"json-repair~=0.42.0",
"requests>=2.28,<3.0",
# TODO: Temporarily removing the extra dependencies of aws and gcs from unstract-sdk
# to resolve lock file. Will have to be re-looked into
"unstract-sdk[azure]~=0.77.3",
"unstract-sdk[aws,gcs,azure]~=0.78.0",
"redis>=5.0.3,<5.3",
"unstract-core",
"unstract-flags",

1721
prompt-service/uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -41,7 +41,7 @@ hook-check-django-migrations = [
"psycopg2-binary==2.9.9",
"python-dotenv==1.0.1",
"python-magic==0.4.27",
"unstract-sdk~=0.77.3",
"unstract-sdk~=0.78.0",
"unstract-connectors",
"unstract-core",
"unstract-flags",

View File

@@ -4,6 +4,6 @@
# aws alone is needed here
# because tools use transient temporary storage.
unstract-sdk[aws]~=0.77.3
unstract-sdk[aws]~=0.78.0
-e file:/unstract/sdk1
-e file:/unstract/flags

View File

@@ -2,7 +2,7 @@
"schemaVersion": "0.0.1",
"displayName": "File Classifier",
"functionName": "classify",
"toolVersion": "0.0.68",
"toolVersion": "0.0.69",
"description": "Classifies a file into a bin based on its contents",
"input": {
"description": "File to be classified"

View File

@@ -3,6 +3,6 @@
# Required for all unstract tools
# aws alone is needed here
# because tools use transient temporary storage.
unstract-sdk[aws]~=0.77.3
unstract-sdk[aws]~=0.78.0
-e file:/unstract/sdk1
-e file:/unstract/flags

View File

@@ -2,7 +2,7 @@
"schemaVersion": "0.0.1",
"displayName": "Structure Tool",
"functionName": "structure_tool",
"toolVersion": "0.0.88",
"toolVersion": "0.0.89",
"description": "This is a template tool which can answer set of input prompts designed in the Prompt Studio",
"input": {
"description": "File that needs to be indexed and parsed for answers"

View File

@@ -4,6 +4,6 @@
# aws alone is needed here
# because tools use transient temporary storage.
unstract-sdk[aws]~=0.77.3
unstract-sdk[aws]~=0.78.0
-e file:/unstract/sdk1
-e file:/unstract/flags

View File

@@ -2,7 +2,7 @@
"schemaVersion": "0.0.1",
"displayName": "Text Extractor",
"functionName": "text_extractor",
"toolVersion": "0.0.64",
"toolVersion": "0.0.65",
"description": "The Text Extractor is a powerful tool designed to convert documents to its text form or Extract texts from documents",
"input": {
"description": "Document"

15
tox.ini
View File

@@ -1,5 +1,5 @@
[tox]
env_list = py{312}, runner
env_list = py{312}, runner, sdk1
requires =
tox-uv>=0.2.0
@@ -30,3 +30,16 @@ commands_pre =
uv pip install pytest pytest-cov pytest-md-report pytest-mock
commands =
pytest -v --md-report-verbose=1 --md-report --md-report-flavor gfm --md-report-output ../runner-report.md
[testenv:sdk1]
changedir = unstract/sdk1
deps = uv
allowlist_externals=
sh
uv
pytest
commands_pre =
# Install dependencies with test group
uv sync --group test
commands =
uv run pytest -v -m "not slow" --cov=src/unstract/sdk1 --cov-report=term --cov-report=html --md-report-verbose=1 --md-report --md-report-flavor gfm --md-report-output ../../sdk1-report.md

View File

@@ -6,7 +6,7 @@ authors = [{ name = "Zipstack Inc.", email = "devsupport@zipstack.com" }]
requires-python = ">=3.12,<3.13"
readme = "README.md"
dependencies = ["unstract-sdk~=0.77.1"]
dependencies = ["unstract-sdk~=0.78.0"]
# [tool.uv.sources]
# unstract-sdk = { git = "https://github.com/Zipstack/unstract-sdk", branch = "python-upgrade" }

File diff suppressed because it is too large Load Diff

View File

@@ -6,3 +6,36 @@
The `unstract-sdk1` package helps with developing tools that are meant to be run on the Unstract platform. This includes
modules to help with tool development and execution, caching, making calls to LLMs / vectorDBs / embeddings .etc.
They also contain helper methods/classes to aid with other tasks such as indexing and auditing the LLM calls.
## Features
### Retry Configuration
The SDK automatically retries platform and prompt service calls on transient failures. Configure via environment variables (prefix: `PLATFORM_SERVICE_` or `PROMPT_SERVICE_`):
| Variable | Default | Description |
|----------|---------|-------------|
| `MAX_RETRIES` | 3 | Maximum retry attempts |
| `MAX_TIME` | 60 | Maximum total time (seconds) |
| `BASE_DELAY` | 1.0 | Initial delay (seconds) |
| `MULTIPLIER` | 2.0 | Backoff multiplier |
| `JITTER` | true | Add random jitter (0-25%) |
**Retryable errors**: ConnectionError, Timeout, HTTPError (502/503/504), OSError (connection failures)
## Development
### Running Tests
Install test dependencies and run tests:
```bash
# Install dependencies
uv sync --group test
# Run all tests
uv run pytest
# Run with coverage
uv run pytest --cov=src/unstract/sdk1 --cov-report=html
```

View File

@@ -67,6 +67,8 @@ test = [
"parameterized==0.9.0",
"pytest==8.3.3",
"pytest-mock==3.14.0",
"pytest-cov>=6.0.0",
"pytest-md-report>=0.6.2",
]
[build-system]

View File

@@ -1,6 +1,6 @@
import json
import logging
from typing import Any
from typing import Any, Self
import requests
from requests import RequestException, Response
@@ -17,6 +17,7 @@ from unstract.sdk1.constants import (
from unstract.sdk1.exceptions import SdkError
from unstract.sdk1.tool.base import BaseTool
from unstract.sdk1.utils.common import Utils
from unstract.sdk1.utils.retry_utils import retry_platform_service_call
logger = logging.getLogger(__name__)
@@ -29,7 +30,7 @@ class PlatformHelper:
"""
def __init__(
self,
self: Self,
tool: BaseTool,
platform_host: str,
platform_port: str,
@@ -50,7 +51,11 @@ class PlatformHelper:
self.request_id = request_id
@classmethod
def get_platform_base_url(cls, platform_host: str, platform_port: str) -> str:
def get_platform_base_url(
cls: type[Self],
platform_host: str,
platform_port: str,
) -> str:
"""Make base url from host and port.
Args:
@@ -65,7 +70,7 @@ class PlatformHelper:
return f"{platform_host}:{platform_port}"
@classmethod
def is_public_adapter(cls, adapter_id: str) -> bool:
def is_public_adapter(cls: type[Self], adapter_id: str) -> bool:
"""Check if the given adapter_id is one of the public adapter keys.
This method iterates over the attributes of the PublicAdapterKeys class
@@ -91,14 +96,24 @@ class PlatformHelper:
return False
@classmethod
@retry_platform_service_call
def _get_adapter_configuration(
cls,
cls: type[Self],
tool: BaseTool,
adapter_instance_id: str,
) -> dict[str, Any]:
"""Get adapter configuration from platform service.
"""Get Adapter.
1. Get the adapter config from platform service using the adapter_instance_id.
Get the adapter config from platform service
using the adapter_instance_id. This method automatically
retries on connection errors with exponential backoff.
Retry behavior is configurable via environment variables:
- PLATFORM_SERVICE_MAX_RETRIES (default: 3)
- PLATFORM_SERVICE_MAX_TIME (default: 60s)
- PLATFORM_SERVICE_BASE_DELAY (default: 1.0s)
- PLATFORM_SERVICE_MULTIPLIER (default: 2.0)
- PLATFORM_SERVICE_JITTER (default: true)
Args:
adapter_instance_id (str): Adapter instance ID
@@ -130,11 +145,14 @@ class PlatformHelper:
f"'{adapter_type}', provider: '{provider}', name: '{adapter_name}'",
level=LogLevel.DEBUG,
)
except ConnectionError as e:
raise SdkError(
"Unable to connect to platform service, please contact the admin."
) from e
except HTTPError as e:
# Check if this is a retryable HTTP error (502, 503, 504)
# If so, re-raise it so the retry decorator can handle it
if hasattr(e, "response") and e.response is not None:
if e.response.status_code in (502, 503, 504):
raise # Re-raise to allow retry decorator to retry
# Non-retryable error - convert to SdkError
default_err = (
"Error while calling the platform service, please contact the admin."
)
@@ -146,7 +164,7 @@ class PlatformHelper:
@classmethod
def get_adapter_config(
cls, tool: BaseTool, adapter_instance_id: str
cls: type[Self], tool: BaseTool, adapter_instance_id: str
) -> dict[str, Any] | None:
"""Get adapter spec by the help of unstract DB tool.
@@ -174,9 +192,15 @@ class PlatformHelper:
f"Retrieving config from DB for '{adapter_instance_id}'",
level=LogLevel.DEBUG,
)
return cls._get_adapter_configuration(tool, adapter_instance_id)
def _get_headers(self, headers: dict[str, str] | None = None) -> dict[str, str]:
try:
return cls._get_adapter_configuration(tool, adapter_instance_id)
except ConnectionError as e:
raise SdkError(
"Unable to connect to platform service, please contact the admin."
) from e
def _get_headers(self: Self, headers: dict[str, str] | None = None) -> dict[str, str]:
"""Get default headers for requests.
Returns:
@@ -190,8 +214,9 @@ class PlatformHelper:
request_headers.update(headers)
return request_headers
@retry_platform_service_call
def _call_service(
self,
self: Self,
url_path: str,
payload: dict[str, Any] | None = None,
params: dict[str, str] | None = None,
@@ -201,6 +226,14 @@ class PlatformHelper:
"""Talks to platform-service to make GET / POST calls.
Only GET calls are made to platform-service though functionality exists.
This method automatically retries on connection errors with exponential backoff.
Retry behavior is configurable via environment variables:
- PLATFORM_SERVICE_MAX_RETRIES (default: 3)
- PLATFORM_SERVICE_MAX_TIME (default: 60s)
- PLATFORM_SERVICE_BASE_DELAY (default: 1.0s)
- PLATFORM_SERVICE_MULTIPLIER (default: 2.0)
- PLATFORM_SERVICE_JITTER (default: true)
Args:
url_path (str): URL path to the service endpoint
@@ -234,9 +267,13 @@ class PlatformHelper:
response.raise_for_status()
except ConnectionError as connect_err:
msg = "Unable to connect to platform service. Please contact admin."
msg += " \n" + str(connect_err)
self.tool.stream_error_and_exit(msg)
logger.exception("Connection error to platform service")
msg = (
"Unable to connect to platform service. Retrying with backoff, "
"please contact admin if retries ultimately fail."
)
self.tool.stream_log(msg, level=LogLevel.ERROR)
raise ConnectionError(msg) from connect_err
except RequestException as e:
# Extract error information from the response if available
error_message = str(e)
@@ -252,7 +289,7 @@ class PlatformHelper:
)
return response.json()
def get_platform_details(self) -> dict[str, Any] | None:
def get_platform_details(self: Self) -> dict[str, Any] | None:
"""Obtains platform details associated with the platform key.
Currently helps fetch organization ID related to the key.
@@ -269,7 +306,7 @@ class PlatformHelper:
)
return response.get("details")
def get_prompt_studio_tool(self, prompt_registry_id: str) -> dict[str, Any]:
def get_prompt_studio_tool(self: Self, prompt_registry_id: str) -> dict[str, Any]:
"""Get exported custom tool by the help of unstract DB tool.
Args:
@@ -287,7 +324,7 @@ class PlatformHelper:
method="GET",
)
def get_llm_profile(self, llm_profile_id: str) -> dict[str, Any]:
def get_llm_profile(self: Self, llm_profile_id: str) -> dict[str, Any]:
"""Get llm profile by the help of unstract DB tool.
Args:

View File

@@ -9,6 +9,7 @@ from unstract.sdk1.constants import MimeType, RequestHeader, ToolEnv
from unstract.sdk1.platform import PlatformHelper
from unstract.sdk1.tool.base import BaseTool
from unstract.sdk1.utils.common import log_elapsed
from unstract.sdk1.utils.retry_utils import retry_prompt_service_call
logger = logging.getLogger(__name__)
@@ -187,6 +188,7 @@ class PromptTool:
request_headers.update(headers)
return request_headers
@retry_prompt_service_call
def _call_service(
self,
url_path: str,
@@ -198,6 +200,14 @@ class PromptTool:
"""Communicates to prompt service to fetch response for the prompt.
Only POST calls are made to prompt-service though functionality exists.
This method automatically retries on connection errors with exponential backoff.
Retry behavior is configurable via environment variables:
- PROMPT_SERVICE_MAX_RETRIES (default: 3)
- PROMPT_SERVICE_MAX_TIME (default: 60s)
- PROMPT_SERVICE_BASE_DELAY (default: 1.0s)
- PROMPT_SERVICE_MULTIPLIER (default: 2.0)
- PROMPT_SERVICE_JITTER (default: true)
Args:
url_path (str): URL path to the service endpoint

View File

@@ -0,0 +1,305 @@
"""Generic retry utilities with custom exponential backoff implementation."""
import errno
import logging
import os
import random
import time
from collections.abc import Callable
from functools import wraps
from typing import Any
from requests.exceptions import ConnectionError, HTTPError, Timeout
logger = logging.getLogger(__name__)
def is_retryable_error(error: Exception) -> bool:
"""Check if an error is retryable.
Handles:
- ConnectionError and Timeout from requests
- HTTPError with status codes 502, 503, 504
- OSError with specific errno codes (ECONNREFUSED, ECONNRESET, etc.)
Args:
error: The exception to check
Returns:
True if the error should trigger a retry
"""
# Requests connection and timeout errors
if isinstance(error, ConnectionError | Timeout):
return True
# HTTP errors with specific status codes
if isinstance(error, HTTPError):
if hasattr(error, "response") and error.response is not None:
status_code = error.response.status_code
# Retry on server errors and bad gateway
if status_code in [502, 503, 504]:
return True
# OS-level connection failures (preserving existing errno checks)
if isinstance(error, OSError) and error.errno in {
errno.ECONNREFUSED, # Connection refused
getattr(errno, "ECONNRESET", 104), # Connection reset by peer
getattr(errno, "ETIMEDOUT", 110), # Connection timed out
getattr(errno, "EHOSTUNREACH", 113), # No route to host
getattr(errno, "ENETUNREACH", 101), # Network is unreachable
}:
return True
return False
def calculate_delay(
attempt: int,
base_delay: float,
multiplier: float,
max_delay: float,
jitter: bool = True,
) -> float:
"""Calculate delay for the next retry attempt with exponential backoff.
Args:
attempt: Current attempt number (0-indexed)
base_delay: Base delay in seconds
multiplier: Backoff multiplier
max_delay: Maximum delay in seconds
jitter: Whether to add random jitter
Returns:
Delay in seconds before the next retry
"""
# Calculate exponential backoff
base = base_delay * (multiplier**attempt)
# Add jitter if enabled (0-25% of base)
if jitter:
delay = base + (base * random.uniform(0, 0.25))
else:
delay = base
# Enforce cap after jitter
return min(delay, max_delay)
def retry_with_exponential_backoff( # noqa: C901
max_retries: int,
max_time: float,
base_delay: float,
multiplier: float,
jitter: bool,
exceptions: tuple[type[Exception], ...],
logger_instance: logging.Logger,
prefix: str,
retry_predicate: Callable[[Exception], bool] | None = None,
) -> Callable:
"""Create retry decorator with exponential backoff.
Args:
max_retries: Maximum number of retry attempts
max_time: Maximum total time in seconds
base_delay: Initial delay in seconds
multiplier: Backoff multiplier
jitter: Whether to add jitter
exceptions: Exception types to catch for retries
logger_instance: Logger instance for retry messages
prefix: Service prefix for logging
retry_predicate: Optional callable to determine if exception should trigger retry
Returns:
Decorator function
"""
def decorator(func: Callable) -> Callable: # noqa: C901
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: C901, ANN401
start_time = time.time()
last_exception = None
for attempt in range(max_retries + 1): # +1 for initial attempt
try:
# Try to execute the function
result = func(*args, **kwargs)
# If successful and we had retried, log success
if attempt > 0:
logger_instance.info(
"Successfully completed '%s' after %d retry attempt(s)",
func.__name__,
attempt,
)
return result
except exceptions as e:
last_exception = e
# Check if the error should trigger a retry
# First check if it's in the allowed exception types (already caught)
# Then check using the predicate if provided
should_retry = True
if retry_predicate is not None:
should_retry = retry_predicate(e)
# Check if we've exceeded max time
elapsed_time = time.time() - start_time
if elapsed_time >= max_time:
logger_instance.exception(
"Giving up '%s' after %.1fs (max time exceeded): %s",
func.__name__,
elapsed_time,
e,
)
raise
# If not retryable or last attempt, raise the error
if not should_retry or attempt == max_retries:
if attempt > 0:
logger_instance.exception(
"Giving up '%s' after %d attempt(s) for %s",
func.__name__,
attempt + 1,
prefix,
)
raise
# Calculate delay for next retry
delay = calculate_delay(
attempt, base_delay, multiplier, max_time, jitter
)
# Ensure we don't exceed max_time with the delay
remaining_time = max_time - elapsed_time
if delay >= remaining_time:
logger_instance.exception(
"Giving up '%s' - next delay %.1fs would exceed "
"max time %.1fs",
func.__name__,
delay,
max_time,
)
raise
# Log retry attempt
logger_instance.warning(
"Retry %d/%d for %s: %s (waiting %.1fs)",
attempt + 1,
max_retries,
prefix,
e,
delay,
)
# Wait before retrying
time.sleep(delay)
except Exception as e:
# Exception not in the exceptions tuple - don't retry
last_exception = e
raise
# This should never be reached, but just in case
if last_exception:
raise last_exception
return wrapper
return decorator
def create_retry_decorator(
prefix: str,
exceptions: tuple[type[Exception], ...] | None = None,
retry_predicate: Callable[[Exception], bool] | None = None,
logger_instance: logging.Logger | None = None,
) -> Callable:
"""Create a configured retry decorator for a specific service.
Args:
prefix: Environment variable prefix for configuration
exceptions: Tuple of exception types to retry on.
Defaults to (ConnectionError, HTTPError, Timeout, OSError)
retry_predicate: Optional callable to determine if exception should trigger retry.
If only exceptions list provided, retry on those exceptions.
If only predicate provided, use predicate (catch all exceptions).
If both provided, first filter by exceptions then check predicate.
logger_instance: Optional logger for retry events
Environment variables (using prefix):
{prefix}_MAX_RETRIES: Maximum retry attempts (default: 3)
{prefix}_MAX_TIME: Maximum total time in seconds (default: 60)
{prefix}_BASE_DELAY: Initial delay in seconds (default: 1.0)
{prefix}_MULTIPLIER: Backoff multiplier (default: 2.0)
{prefix}_JITTER: Enable jitter true/false (default: true)
Returns:
Configured retry decorator
"""
# Handle different combinations of exceptions and predicate
if exceptions is None and retry_predicate is None:
# Default case: use specific exceptions with is_retryable_error predicate
exceptions = (ConnectionError, HTTPError, Timeout, OSError)
retry_predicate = is_retryable_error
elif exceptions is None and retry_predicate is not None:
# Only predicate provided: catch all exceptions and use predicate
exceptions = (Exception,)
elif exceptions is not None and retry_predicate is None:
# Only exceptions provided: retry on those exceptions
pass # exceptions already set, no predicate needed
# Both provided: use both (exceptions filter first, then predicate)
# Load configuration from environment
max_retries = int(os.getenv(f"{prefix}_MAX_RETRIES", "3"))
max_time = float(os.getenv(f"{prefix}_MAX_TIME", "60"))
base_delay = float(os.getenv(f"{prefix}_BASE_DELAY", "1.0"))
multiplier = float(os.getenv(f"{prefix}_MULTIPLIER", "2.0"))
use_jitter = os.getenv(f"{prefix}_JITTER", "true").strip().lower() in {
"true",
"1",
"yes",
"on",
}
if max_retries < 0:
raise ValueError(f"{prefix}_MAX_RETRIES must be >= 0")
if max_time <= 0:
raise ValueError(f"{prefix}_MAX_TIME must be > 0")
if base_delay <= 0:
raise ValueError(f"{prefix}_BASE_DELAY must be > 0")
if multiplier <= 0:
raise ValueError(f"{prefix}_MULTIPLIER must be > 0")
if logger_instance is None:
logger_instance = logger
return retry_with_exponential_backoff(
max_retries=max_retries,
max_time=max_time,
base_delay=base_delay,
multiplier=multiplier,
jitter=use_jitter,
exceptions=exceptions,
logger_instance=logger_instance,
prefix=prefix,
retry_predicate=retry_predicate,
)
# Retry configured through below envs.
# - PLATFORM_SERVICE_MAX_RETRIES (default: 3)
# - PLATFORM_SERVICE_MAX_TIME (default: 60s)
# - PLATFORM_SERVICE_BASE_DELAY (default: 1.0s)
# - PLATFORM_SERVICE_MULTIPLIER (default: 2.0)
# - PLATFORM_SERVICE_JITTER (default: true)
retry_platform_service_call = create_retry_decorator("PLATFORM_SERVICE")
# Retry configured through below envs.
# - PROMPT_SERVICE_MAX_RETRIES (default: 3)
# - PROMPT_SERVICE_MAX_TIME (default: 60s)
# - PROMPT_SERVICE_BASE_DELAY (default: 1.0s)
# - PROMPT_SERVICE_MULTIPLIER (default: 2.0)
# - PROMPT_SERVICE_JITTER (default: true)
retry_prompt_service_call = create_retry_decorator("PROMPT_SERVICE")

View File

@@ -0,0 +1 @@
"""Tests for unstract-sdk1."""

View File

@@ -0,0 +1,55 @@
"""Pytest configuration and fixtures for unstract-sdk1 tests."""
import logging
from collections.abc import Callable
from typing import Any
from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
@pytest.fixture
def mock_logger() -> MagicMock:
"""Create a mock logger for testing."""
logger = MagicMock(spec=logging.Logger)
return logger
@pytest.fixture
def clean_env(monkeypatch: MonkeyPatch) -> MonkeyPatch:
"""Clean environment variables before each test."""
# Remove any retry-related environment variables
env_vars = [
"PLATFORM_SERVICE_MAX_RETRIES",
"PLATFORM_SERVICE_MAX_TIME",
"PLATFORM_SERVICE_BASE_DELAY",
"PLATFORM_SERVICE_MULTIPLIER",
"PLATFORM_SERVICE_JITTER",
"PROMPT_SERVICE_MAX_RETRIES",
"PROMPT_SERVICE_MAX_TIME",
"PROMPT_SERVICE_BASE_DELAY",
"PROMPT_SERVICE_MULTIPLIER",
"PROMPT_SERVICE_JITTER",
]
for var in env_vars:
monkeypatch.delenv(var, raising=False)
return monkeypatch
@pytest.fixture
def set_env(monkeypatch: MonkeyPatch) -> Callable[..., None]:
"""Helper fixture to set environment variables."""
def _set_env(prefix: str, **kwargs: Any) -> None: # noqa: ANN401
"""Set environment variables with given prefix.
Args:
prefix: Environment variable prefix (e.g., 'PLATFORM_SERVICE')
**kwargs: Key-value pairs to set (e.g., max_retries=5)
"""
for key, value in kwargs.items():
env_key = f"{prefix}_{key.upper()}"
monkeypatch.setenv(env_key, str(value))
return _set_env

View File

@@ -0,0 +1,233 @@
"""Integration tests for platform module with retry logic."""
from unittest.mock import MagicMock, Mock, patch
import pytest
from _pytest.monkeypatch import MonkeyPatch
from requests.exceptions import ConnectionError, HTTPError
from unstract.sdk1.exceptions import SdkError
from unstract.sdk1.platform import PlatformHelper
class TestPlatformHelperRetry:
"""Tests for PlatformHelper retry functionality."""
@pytest.fixture
def mock_tool(self) -> MagicMock:
"""Create a mock tool for testing."""
tool = MagicMock()
tool.get_env_or_die.side_effect = lambda key: {
"PLATFORM_HOST": "http://localhost",
"PLATFORM_PORT": "3001",
"PLATFORM_API_KEY": "test-api-key",
}.get(key, "mock-value")
tool.stream_log = MagicMock()
tool.stream_error_and_exit = MagicMock()
return tool
@pytest.fixture
def platform_helper(self, mock_tool: MagicMock) -> PlatformHelper:
"""Create a PlatformHelper instance."""
return PlatformHelper(
tool=mock_tool,
platform_host="http://localhost",
platform_port="3001",
request_id="test-request-id",
)
@pytest.mark.parametrize(
"method_name,method_args,http_method",
[
("_get_adapter_configuration", ("test-adapter-id",), "GET"),
("_call_service", ("test-endpoint",), "GET"),
],
)
def test_success_on_first_attempt(
self,
mock_tool: MagicMock,
platform_helper: PlatformHelper,
method_name: str,
method_args: tuple[str, ...],
http_method: str,
clean_env: MonkeyPatch,
) -> None:
"""Test successful calls on first attempt for various methods."""
expected_data = {"adapter_id": "test", "config": {}}
patch_target = f"requests.{http_method.lower()}"
with patch(patch_target) as mock_request:
mock_response = Mock()
mock_response.json.return_value = expected_data
mock_response.raise_for_status = Mock()
mock_request.return_value = mock_response
if method_name == "_get_adapter_configuration":
PlatformHelper._get_adapter_configuration(mock_tool, *method_args)
else:
getattr(platform_helper, method_name)(*method_args)
assert mock_request.call_count == 1
@pytest.mark.parametrize(
"method_name,method_args",
[
("_get_adapter_configuration", ("test-adapter-id",)),
("_call_service", ("test-endpoint",)),
],
)
def test_retry_on_connection_error(
self,
mock_tool: MagicMock,
platform_helper: PlatformHelper,
method_name: str,
method_args: tuple[str, ...],
clean_env: MonkeyPatch,
) -> None:
"""Test methods retry on ConnectionError."""
expected_data = {"result": "success"}
with patch("requests.get") as mock_get:
mock_response = Mock()
mock_response.json.return_value = expected_data
mock_response.raise_for_status = Mock()
mock_get.side_effect = [
ConnectionError("Transient failure"),
mock_response,
]
if method_name == "_get_adapter_configuration":
PlatformHelper._get_adapter_configuration(mock_tool, *method_args)
else:
getattr(platform_helper, method_name)(*method_args)
assert mock_get.call_count == 2
@pytest.mark.slow
def test_max_retries_exceeded(
self, mock_tool: MagicMock, clean_env: MonkeyPatch
) -> None:
"""Test service call fails after exceeding max retries."""
platform_helper = PlatformHelper(
tool=mock_tool,
platform_host="http://localhost",
platform_port="3001",
)
with patch("requests.get") as mock_get:
mock_get.side_effect = ConnectionError("Persistent failure")
with pytest.raises(ConnectionError):
platform_helper._call_service("test-endpoint")
# Default: 3 retries + 1 initial = 4 attempts
assert mock_get.call_count == 4
def test_non_retryable_http_error(
self, mock_tool: MagicMock, clean_env: MonkeyPatch
) -> None:
"""Test non-retryable HTTP errors (404, 400) don't trigger retry."""
with patch("requests.get") as mock_get:
mock_response = Mock()
mock_response.status_code = 404
mock_response.json.return_value = {"error": "Not found"}
http_error = HTTPError()
http_error.response = mock_response
mock_get.side_effect = http_error
with pytest.raises(SdkError, match="Error retrieving adapter"):
PlatformHelper._get_adapter_configuration(mock_tool, "test-adapter-id")
# Should not retry 404
assert mock_get.call_count == 1
@pytest.mark.parametrize("status_code", [502, 503, 504])
def test_retryable_http_errors(
self, mock_tool: MagicMock, status_code: int, clean_env: MonkeyPatch
) -> None:
"""Test retryable HTTP errors (502, 503, 504) trigger retry."""
expected_data = {"adapter_id": "test", "config": {}}
with patch("requests.get") as mock_get:
# First attempt: retryable HTTP error
http_error = HTTPError()
error_response = Mock()
error_response.status_code = status_code
error_response.json.return_value = {"error": "Service unavailable"}
http_error.response = error_response
# Second attempt: success
success_response = Mock()
success_response.json.return_value = expected_data
success_response.raise_for_status = Mock()
mock_get.side_effect = [http_error, success_response]
result = PlatformHelper._get_adapter_configuration(
mock_tool, "test-adapter-id"
)
# Should retry and succeed
assert mock_get.call_count == 2
assert result == expected_data
@pytest.mark.slow
def test_connection_error_converted_to_sdk_error(
self, mock_tool: MagicMock, clean_env: MonkeyPatch
) -> None:
"""Test get_adapter_config wraps ConnectionError as SdkError."""
with patch("requests.get") as mock_get:
mock_get.side_effect = ConnectionError("Connection failed")
with pytest.raises(SdkError, match="Unable to connect to platform service"):
PlatformHelper.get_adapter_config(mock_tool, "test-adapter-id")
def test_post_method_retry(
self, platform_helper: PlatformHelper, clean_env: MonkeyPatch
) -> None:
"""Test POST requests also support retry."""
payload = {"key": "value"}
expected_response = {"status": "OK"}
with patch("requests.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = expected_response
mock_response.raise_for_status = Mock()
mock_post.side_effect = [
ConnectionError("Transient failure"),
mock_response,
]
result = platform_helper._call_service(
"test-endpoint", payload=payload, method="POST"
)
assert result == expected_response
assert mock_post.call_count == 2
def test_retry_logging(self, mock_tool: MagicMock, clean_env: MonkeyPatch) -> None:
"""Test that retry attempts are logged."""
with patch("requests.get") as mock_get:
mock_response = Mock()
mock_response.json.return_value = {}
mock_response.raise_for_status = Mock()
mock_get.side_effect = [
ConnectionError("Transient failure"),
mock_response,
]
helper = PlatformHelper(
tool=mock_tool,
platform_host="http://localhost",
platform_port="3001",
)
helper._call_service("test-endpoint")
# Verify logging occurred
mock_tool.stream_log.assert_called()
log_calls = [str(c) for c in mock_tool.stream_log.call_args_list]
assert any("retry" in call.lower() for call in log_calls)

View File

@@ -0,0 +1,174 @@
"""Integration tests for prompt module with retry logic."""
from typing import Any, Self
from unittest.mock import MagicMock, Mock, patch
import pytest
from pytest import MonkeyPatch
from requests.exceptions import ConnectionError, Timeout
from unstract.sdk1.prompt import PromptTool
class TestPromptToolRetry:
"""Tests for PromptTool retry functionality."""
@pytest.fixture
def mock_tool(self: Self) -> MagicMock:
"""Create a mock tool for testing."""
tool = MagicMock()
tool.get_env_or_die.side_effect = lambda key: {
"PLATFORM_API_KEY": "test-api-key",
}.get(key, "mock-value")
tool.stream_log = MagicMock()
tool.stream_error_and_exit = MagicMock()
return tool
@pytest.fixture
def prompt_tool(self: Self, mock_tool: MagicMock) -> PromptTool:
"""Create a PromptTool instance."""
return PromptTool(
tool=mock_tool,
prompt_host="http://localhost",
prompt_port="3003",
is_public_call=False,
request_id="test-request-id",
)
def test_success_on_first_attempt(
self: Self, prompt_tool: PromptTool, clean_env: MonkeyPatch
) -> None:
"""Test successful service call on first attempt."""
expected_response = {"result": "success"}
payload = {"prompt": "test"}
with patch("requests.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = expected_response
mock_response.raise_for_status = Mock()
mock_post.return_value = mock_response
result = prompt_tool._call_service("answer-prompt", payload=payload)
assert result == expected_response
assert mock_post.call_count == 1
@pytest.mark.parametrize(
"error_type,error_instance",
[
("ConnectionError", ConnectionError("Connection failed")),
("Timeout", Timeout("Request timed out")),
],
)
def test_retry_on_errors(
self: Self,
prompt_tool: PromptTool,
error_type: str,
error_instance: Exception,
clean_env: MonkeyPatch,
) -> None:
"""Test service retries on ConnectionError and Timeout."""
expected_response = {"result": "success"}
payload = {"prompt": "test"}
with patch("requests.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = expected_response
mock_response.raise_for_status = Mock()
mock_post.side_effect = [
error_instance,
mock_response,
]
result = prompt_tool._call_service("answer-prompt", payload=payload)
assert result == expected_response
assert mock_post.call_count == 2
@pytest.mark.slow
def test_max_retries_exceeded(
self: Self, mock_tool: MagicMock, clean_env: MonkeyPatch
) -> None:
"""Test service call fails after exceeding max retries."""
prompt_tool = PromptTool(
tool=mock_tool,
prompt_host="http://localhost",
prompt_port="3003",
is_public_call=False,
request_id="test-request-id",
)
payload = {"prompt": "test"}
with patch("requests.post") as mock_post:
mock_post.side_effect = ConnectionError("Persistent failure")
# Exception handled by decorator
with pytest.raises(ConnectionError):
prompt_tool._call_service("answer-prompt", payload=payload)
# Default: 3 retries + 1 initial = 4 attempts
assert mock_post.call_count == 4
@pytest.mark.parametrize(
"method_name,payload",
[
("answer_prompt", {"prompts": ["test"]}),
("index", {"document": "test"}),
("extract", {"doc_id": "123"}),
("summarize", {"text": "test"}),
],
)
def test_wrapper_methods_retry(
self: Self,
prompt_tool: PromptTool,
method_name: str,
payload: dict[str, Any],
clean_env: MonkeyPatch,
) -> None:
"""Test that wrapper methods inherit retry behavior."""
expected_response = {
"answers": ["result"],
"doc_id": "doc-123",
"extracted_text": "text",
"summary": "summary",
}
with patch("requests.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = expected_response
mock_response.raise_for_status = Mock()
mock_post.side_effect = [
ConnectionError("Transient failure"),
mock_response,
]
getattr(prompt_tool, method_name)(payload)
assert mock_post.call_count == 2
@pytest.mark.slow
def test_error_handling_with_retry(
self: Self, mock_tool: MagicMock, clean_env: MonkeyPatch
) -> None:
"""Test error handling decorator works with retry."""
prompt_tool = PromptTool(
tool=mock_tool,
prompt_host="http://localhost",
prompt_port="3003",
is_public_call=False,
request_id="test-request-id",
)
payload = {"prompt": "test"}
with patch("requests.post") as mock_post:
mock_post.side_effect = ConnectionError("Persistent failure")
# Error handler should catch after all retries
result = prompt_tool.answer_prompt(payload)
# handle_service_exceptions decorator calls stream_error_and_exit
assert result is None
prompt_tool.tool.stream_error_and_exit.assert_called()

View File

@@ -0,0 +1 @@
"""Tests for unstract-sdk1 utilities."""

View File

@@ -0,0 +1,667 @@
"""Unit tests for retry_utils module."""
import errno
from collections.abc import Callable
from unittest.mock import MagicMock, Mock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from requests.exceptions import ConnectionError, HTTPError, Timeout
from unstract.sdk1.utils.retry_utils import (
calculate_delay,
create_retry_decorator,
is_retryable_error,
retry_platform_service_call,
retry_prompt_service_call,
retry_with_exponential_backoff,
)
class TestIsRetryableError:
"""Tests for is_retryable_error function."""
def test_connection_error_is_retryable(self) -> None:
"""ConnectionError should be retryable."""
error = ConnectionError("Connection failed")
assert is_retryable_error(error) is True
def test_timeout_is_retryable(self) -> None:
"""Timeout error should be retryable."""
error = Timeout("Request timed out")
assert is_retryable_error(error) is True
@pytest.mark.parametrize("status_code", [502, 503, 504])
def test_http_error_retryable_status_codes(self, status_code: int) -> None:
"""HTTPError with 502, 503, 504 should be retryable."""
response = Mock()
response.status_code = status_code
error = HTTPError()
error.response = response
assert is_retryable_error(error) is True
@pytest.mark.parametrize("status_code", [400, 401, 403, 404, 500])
def test_http_error_non_retryable_status_codes(self, status_code: int) -> None:
"""HTTPError with other status codes should not be retryable."""
response = Mock()
response.status_code = status_code
error = HTTPError()
error.response = response
assert is_retryable_error(error) is False
def test_http_error_without_response(self) -> None:
"""HTTPError without response should not be retryable."""
error = HTTPError()
error.response = None
assert is_retryable_error(error) is False
@pytest.mark.parametrize(
"errno_code",
[
errno.ECONNREFUSED,
getattr(errno, "ECONNRESET", 104),
getattr(errno, "ETIMEDOUT", 110),
getattr(errno, "EHOSTUNREACH", 113),
getattr(errno, "ENETUNREACH", 101),
],
)
def test_os_error_retryable_errno(self, errno_code: int) -> None:
"""OSError with specific errno codes should be retryable."""
error = OSError()
error.errno = errno_code
assert is_retryable_error(error) is True
def test_os_error_non_retryable_errno(self) -> None:
"""OSError with other errno codes should not be retryable."""
error = OSError()
error.errno = errno.ENOENT # File not found
assert is_retryable_error(error) is False
def test_other_exception_not_retryable(self) -> None:
"""Other exceptions should not be retryable."""
error = ValueError("Invalid value")
assert is_retryable_error(error) is False
class TestCalculateDelay:
"""Tests for calculate_delay function."""
def test_exponential_backoff_without_jitter(self) -> None:
"""Test exponential backoff calculation without jitter."""
base_delay = 1.0
multiplier = 2.0
max_delay = 60.0
# Attempt 0: 1.0 * (2.0^0) = 1.0
assert calculate_delay(
0, base_delay, multiplier, max_delay, jitter=False
) == pytest.approx(1.0)
# Attempt 1: 1.0 * (2.0^1) = 2.0
assert calculate_delay(
1, base_delay, multiplier, max_delay, jitter=False
) == pytest.approx(2.0)
# Attempt 2: 1.0 * (2.0^2) = 4.0
assert calculate_delay(
2, base_delay, multiplier, max_delay, jitter=False
) == pytest.approx(4.0)
# Attempt 3: 1.0 * (2.0^3) = 8.0
assert calculate_delay(
3, base_delay, multiplier, max_delay, jitter=False
) == pytest.approx(8.0)
def test_exponential_backoff_with_jitter(self) -> None:
"""Test exponential backoff calculation with jitter."""
base_delay = 1.0
multiplier = 2.0
max_delay = 60.0
# With jitter, delay should be in range [base, base * 1.25]
for attempt in range(4):
base = base_delay * (multiplier**attempt)
delay = calculate_delay(
attempt, base_delay, multiplier, max_delay, jitter=True
)
assert base <= delay <= base * 1.25
def test_max_delay_cap(self) -> None:
"""Test that max_delay caps the calculated delay."""
base_delay = 1.0
multiplier = 2.0
max_delay = 5.0
# Attempt 10: 1.0 * (2.0^10) = 1024.0, but capped at 5.0
delay = calculate_delay(10, base_delay, multiplier, max_delay, jitter=False)
assert delay == pytest.approx(5.0)
def test_max_delay_cap_with_jitter(self) -> None:
"""Test that max_delay caps the delay even with jitter."""
base_delay = 1.0
multiplier = 2.0
max_delay = 5.0
# Even with jitter, should not exceed max_delay
delay = calculate_delay(10, base_delay, multiplier, max_delay, jitter=True)
assert delay <= max_delay
class TestRetryWithExponentialBackoff:
"""Tests for retry_with_exponential_backoff decorator."""
def test_successful_call_first_attempt(self, mock_logger: MagicMock) -> None:
"""Test successful call on first attempt."""
mock_func = Mock(return_value="success")
decorator = retry_with_exponential_backoff(
max_retries=3,
max_time=60.0,
base_delay=1.0,
multiplier=2.0,
jitter=False,
exceptions=(Exception,),
logger_instance=mock_logger,
prefix="TEST",
)
decorated_func = decorator(mock_func)
result = decorated_func()
assert result == "success"
assert mock_func.call_count == 1
# Should not log retry success message for first attempt
mock_logger.info.assert_not_called()
def test_retry_after_transient_failure(self, mock_logger: MagicMock) -> None:
"""Test retry after transient failure."""
mock_func = Mock(
side_effect=[ConnectionError("Failed"), "success"], __name__="test_func"
)
decorator = retry_with_exponential_backoff(
max_retries=3,
max_time=60.0,
base_delay=0.1, # Short delay for testing
multiplier=2.0,
jitter=False,
exceptions=(ConnectionError,),
logger_instance=mock_logger,
prefix="TEST",
)
decorated_func = decorator(mock_func)
result = decorated_func()
assert result == "success"
assert mock_func.call_count == 2
# Should log success after retry
mock_logger.info.assert_called_once()
assert "Successfully completed" in str(mock_logger.info.call_args)
def test_max_retries_exceeded(self, mock_logger: MagicMock) -> None:
"""Test that max retries causes failure."""
mock_func = Mock(
side_effect=ConnectionError("Always fails"), __name__="test_func"
)
decorator = retry_with_exponential_backoff(
max_retries=2,
max_time=60.0,
base_delay=0.1,
multiplier=2.0,
jitter=False,
exceptions=(ConnectionError,),
logger_instance=mock_logger,
prefix="TEST",
)
decorated_func = decorator(mock_func)
with pytest.raises(ConnectionError, match="Always fails"):
decorated_func()
# Should attempt 3 times (initial + 2 retries)
assert mock_func.call_count == 3
# Should log giving up
mock_logger.exception.assert_called()
def test_max_time_exceeded(self, mock_logger: MagicMock) -> None:
"""Test that max time causes failure."""
mock_func = Mock(
side_effect=ConnectionError("Always fails"), __name__="test_func"
)
decorator = retry_with_exponential_backoff(
max_retries=10,
max_time=0.5, # Very short max time
base_delay=0.2,
multiplier=2.0,
jitter=False,
exceptions=(ConnectionError,),
logger_instance=mock_logger,
prefix="TEST",
)
decorated_func = decorator(mock_func)
with pytest.raises(ConnectionError):
decorated_func()
# Should fail before reaching max retries due to time limit
assert mock_func.call_count < 10
def test_retry_with_custom_predicate(self, mock_logger: MagicMock) -> None:
"""Test retry with custom predicate."""
def custom_predicate(e: Exception) -> bool:
# Only retry if message contains "retry"
return "retry" in str(e)
mock_func = Mock(
side_effect=[Exception("retry please"), "success"], __name__="test_func"
)
decorator = retry_with_exponential_backoff(
max_retries=3,
max_time=60.0,
base_delay=0.1,
multiplier=2.0,
jitter=False,
exceptions=(Exception,),
logger_instance=mock_logger,
prefix="TEST",
retry_predicate=custom_predicate,
)
decorated_func = decorator(mock_func)
result = decorated_func()
assert result == "success"
assert mock_func.call_count == 2
def test_no_retry_with_predicate_false(self, mock_logger: MagicMock) -> None:
"""Test no retry when predicate returns False."""
def custom_predicate(e: Exception) -> bool:
return False
mock_func = Mock(__name__="test_func", side_effect=Exception("Error"))
decorator = retry_with_exponential_backoff(
max_retries=3,
max_time=60.0,
base_delay=0.1,
multiplier=2.0,
jitter=False,
exceptions=(Exception,),
logger_instance=mock_logger,
prefix="TEST",
retry_predicate=custom_predicate,
)
decorated_func = decorator(mock_func)
with pytest.raises(Exception, match="Error"):
decorated_func()
# Should not retry
assert mock_func.call_count == 1
def test_exception_not_in_tuple_not_retried(self, mock_logger: MagicMock) -> None:
"""Test that exceptions not in the tuple are not retried."""
mock_func = Mock(__name__="test_func", side_effect=ValueError("Not retryable"))
decorator = retry_with_exponential_backoff(
max_retries=3,
max_time=60.0,
base_delay=0.1,
multiplier=2.0,
jitter=False,
exceptions=(ConnectionError,), # Only ConnectionError
logger_instance=mock_logger,
prefix="TEST",
)
decorated_func = decorator(mock_func)
with pytest.raises(ValueError, match="Not retryable"):
decorated_func()
# Should not retry
assert mock_func.call_count == 1
def test_delay_would_exceed_max_time(self, mock_logger: MagicMock) -> None:
"""Test that delay exceeding max time causes immediate failure."""
mock_func = Mock(
__name__="test_func", side_effect=ConnectionError("Always fails")
)
decorator = retry_with_exponential_backoff(
max_retries=10,
max_time=0.3, # Very short max time
base_delay=1.0, # Large delay
multiplier=2.0,
jitter=False,
exceptions=(ConnectionError,),
logger_instance=mock_logger,
prefix="TEST",
)
decorated_func = decorator(mock_func)
with pytest.raises(ConnectionError):
decorated_func()
# Should fail quickly due to delay exceeding remaining time
mock_logger.exception.assert_called()
exception_calls = [str(c) for c in mock_logger.exception.call_args_list]
assert any("would exceed max time" in c for c in exception_calls)
class TestCreateRetryDecorator:
"""Tests for create_retry_decorator function."""
def test_default_configuration(
self, clean_env: MonkeyPatch, mock_logger: MagicMock
) -> None:
"""Test decorator with default configuration."""
decorator = create_retry_decorator("TEST_SERVICE", logger_instance=mock_logger)
mock_func = Mock(return_value="success")
decorated_func = decorator(mock_func)
result = decorated_func()
assert result == "success"
def test_environment_variable_configuration(
self, clean_env: MonkeyPatch, set_env: Callable[..., None], mock_logger: MagicMock
) -> None:
"""Test decorator reads configuration from environment."""
set_env(
"TEST_SERVICE",
max_retries=5,
max_time=120,
base_delay=2.0,
multiplier=3.0,
jitter="false",
)
mock_func = Mock(
__name__="test_func", side_effect=[ConnectionError("Failed"), "success"]
)
decorator = create_retry_decorator("TEST_SERVICE", logger_instance=mock_logger)
decorated_func = decorator(mock_func)
result = decorated_func()
assert result == "success"
assert mock_func.call_count == 2
def test_invalid_max_retries(
self, clean_env: MonkeyPatch, set_env: Callable[..., None]
) -> None:
"""Test that negative max_retries raises error."""
set_env("TEST_SERVICE", max_retries=-1)
with pytest.raises(ValueError, match="MAX_RETRIES must be >= 0"):
create_retry_decorator("TEST_SERVICE")
def test_invalid_max_time(
self, clean_env: MonkeyPatch, set_env: Callable[..., None]
) -> None:
"""Test that non-positive max_time raises error."""
set_env("TEST_SERVICE", max_time=0)
with pytest.raises(ValueError, match="MAX_TIME must be > 0"):
create_retry_decorator("TEST_SERVICE")
def test_invalid_base_delay(
self, clean_env: MonkeyPatch, set_env: Callable[..., None]
) -> None:
"""Test that non-positive base_delay raises error."""
set_env("TEST_SERVICE", base_delay=-0.5)
with pytest.raises(ValueError, match="BASE_DELAY must be > 0"):
create_retry_decorator("TEST_SERVICE")
def test_invalid_multiplier(
self, clean_env: MonkeyPatch, set_env: Callable[..., None]
) -> None:
"""Test that non-positive multiplier raises error."""
set_env("TEST_SERVICE", multiplier=0)
with pytest.raises(ValueError, match="MULTIPLIER must be > 0"):
create_retry_decorator("TEST_SERVICE")
@pytest.mark.parametrize("jitter_value", ["true", "false"])
def test_jitter_values(
self,
jitter_value: str,
clean_env: MonkeyPatch,
set_env: Callable[..., None],
mock_logger: MagicMock,
) -> None:
"""Test jitter configuration values."""
set_env("TEST_SERVICE", jitter=jitter_value)
decorator = create_retry_decorator("TEST_SERVICE", logger_instance=mock_logger)
# Should not raise error
assert decorator is not None
def test_custom_exceptions_only(
self, clean_env: MonkeyPatch, mock_logger: MagicMock
) -> None:
"""Test decorator with custom exceptions and no predicate."""
decorator = create_retry_decorator(
"TEST_SERVICE",
exceptions=(ValueError, TypeError),
logger_instance=mock_logger,
)
mock_func = Mock(
__name__="test_func", side_effect=[ValueError("Error"), "success"]
)
decorated_func = decorator(mock_func)
result = decorated_func()
assert result == "success"
assert mock_func.call_count == 2
def test_custom_predicate_only(
self, clean_env: MonkeyPatch, mock_logger: MagicMock
) -> None:
"""Test decorator with custom predicate and no exceptions."""
def custom_predicate(e: Exception) -> bool:
return isinstance(e, ValueError)
decorator = create_retry_decorator(
"TEST_SERVICE",
retry_predicate=custom_predicate,
logger_instance=mock_logger,
)
mock_func = Mock(
__name__="test_func", side_effect=[ValueError("Error"), "success"]
)
decorated_func = decorator(mock_func)
result = decorated_func()
assert result == "success"
assert mock_func.call_count == 2
def test_both_exceptions_and_predicate(
self, clean_env: MonkeyPatch, mock_logger: MagicMock
) -> None:
"""Test decorator with both exceptions and predicate."""
def custom_predicate(e: Exception) -> bool:
return "retry" in str(e)
decorator = create_retry_decorator(
"TEST_SERVICE",
exceptions=(ValueError,),
retry_predicate=custom_predicate,
logger_instance=mock_logger,
)
# Should retry - ValueError with "retry" in message
mock_func = Mock(
__name__="test_func", side_effect=[ValueError("retry please"), "success"]
)
decorated_func = decorator(mock_func)
result = decorated_func()
assert result == "success"
assert mock_func.call_count == 2
def test_exceptions_match_but_predicate_false(
self, clean_env: MonkeyPatch, mock_logger: MagicMock
) -> None:
"""Test that predicate can prevent retry even if exception matches."""
def custom_predicate(e: Exception) -> bool:
return "retry" in str(e)
decorator = create_retry_decorator(
"TEST_SERVICE",
exceptions=(ValueError,),
retry_predicate=custom_predicate,
logger_instance=mock_logger,
)
# ValueError matches but predicate returns False (no "retry" substring)
mock_func = Mock(
__name__="test_func", side_effect=ValueError("do not attempt again")
)
decorated_func = decorator(mock_func)
with pytest.raises(ValueError, match="do not attempt again"):
decorated_func()
# Should not retry
assert mock_func.call_count == 1
class TestPreconfiguredDecorators:
"""Tests for pre-configured decorators."""
def test_retry_platform_service_call_exists(self) -> None:
"""Test that retry_platform_service_call decorator exists."""
assert retry_platform_service_call is not None
def test_retry_prompt_service_call_exists(self) -> None:
"""Test that retry_prompt_service_call decorator exists."""
assert retry_prompt_service_call is not None
def test_platform_service_decorator_retries_on_connection_error(
self, clean_env: MonkeyPatch
) -> None:
"""Test platform service decorator retries ConnectionError."""
mock_func = Mock(
__name__="test_func", side_effect=[ConnectionError("Failed"), "success"]
)
decorated_func = retry_platform_service_call(mock_func)
result = decorated_func()
assert result == "success"
assert mock_func.call_count == 2
def test_prompt_service_decorator_retries_on_timeout(
self, clean_env: MonkeyPatch
) -> None:
"""Test prompt service decorator retries Timeout."""
mock_func = Mock(
__name__="test_func", side_effect=[Timeout("Timed out"), "success"]
)
decorated_func = retry_prompt_service_call(mock_func)
result = decorated_func()
assert result == "success"
assert mock_func.call_count == 2
class TestRetryLogging:
"""Tests for retry logging behavior."""
def test_warning_logged_on_retry(self, mock_logger: MagicMock) -> None:
"""Test that warning is logged on retry attempt."""
mock_func = Mock(
__name__="test_func", side_effect=[ConnectionError("Failed"), "success"]
)
decorator = retry_with_exponential_backoff(
max_retries=3,
max_time=60.0,
base_delay=0.1,
multiplier=2.0,
jitter=False,
exceptions=(ConnectionError,),
logger_instance=mock_logger,
prefix="TEST",
)
decorated_func = decorator(mock_func)
decorated_func()
# Should log warning about retry
mock_logger.warning.assert_called_once()
warning_msg = str(mock_logger.warning.call_args)
assert "Retry" in warning_msg
assert "TEST" in warning_msg
def test_info_logged_on_success_after_retry(self, mock_logger: MagicMock) -> None:
"""Test that info is logged when successful after retry."""
mock_func = Mock(
__name__="test_func", side_effect=[ConnectionError("Failed"), "success"]
)
decorator = retry_with_exponential_backoff(
max_retries=3,
max_time=60.0,
base_delay=0.1,
multiplier=2.0,
jitter=False,
exceptions=(ConnectionError,),
logger_instance=mock_logger,
prefix="TEST",
)
decorated_func = decorator(mock_func)
decorated_func()
# Should log success after retry
mock_logger.info.assert_called_once()
info_msg = str(mock_logger.info.call_args)
assert "Successfully completed" in info_msg
def test_exception_logged_on_giving_up(self, mock_logger: MagicMock) -> None:
"""Test that exception is logged when giving up."""
mock_func = Mock(
__name__="test_func", side_effect=ConnectionError("Always fails")
)
decorator = retry_with_exponential_backoff(
max_retries=1,
max_time=60.0,
base_delay=0.1,
multiplier=2.0,
jitter=False,
exceptions=(ConnectionError,),
logger_instance=mock_logger,
prefix="TEST",
)
decorated_func = decorator(mock_func)
with pytest.raises(ConnectionError):
decorated_func()
# Should log exception when giving up
mock_logger.exception.assert_called()
exception_msg = str(mock_logger.exception.call_args)
assert "Giving up" in exception_msg or "exceeded" in exception_msg

1651
unstract/sdk1/uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@ dependencies = [
"docker~=6.1.3",
"jsonschema>=4.18.6,<5.0",
"PyYAML~=6.0.1",
"unstract-sdk~=0.77.1",
"unstract-sdk~=0.78.0",
"unstract-tool-sandbox",
"unstract-flags",
]

View File

@@ -5,7 +5,7 @@
"schemaVersion": "0.0.1",
"displayName": "File Classifier",
"functionName": "classify",
"toolVersion": "0.0.68",
"toolVersion": "0.0.69",
"description": "Classifies a file into a bin based on its contents",
"input": {
"description": "File to be classified"
@@ -106,9 +106,9 @@
"properties": {}
},
"icon": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<svg\n enable-background=\"new 0 0 20 20\"\n height=\"48\"\n viewBox=\"0 0 20 20\"\n width=\"48\"\n fill=\"#000000\"\n version=\"1.1\"\n id=\"svg8109\"\n sodipodi:docname=\"folder_copy_black_48dp.svg\"\n xmlns:inkscape=\"http://www.inkscape.org/namespaces/inkscape\"\n xmlns:sodipodi=\"http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd\"\n xmlns=\"http://www.w3.org/2000/svg\"\n xmlns:svg=\"http://www.w3.org/2000/svg\">\n <defs\n id=\"defs8113\" />\n <sodipodi:namedview\n id=\"namedview8111\"\n pagecolor=\"#ffffff\"\n bordercolor=\"#000000\"\n borderopacity=\"0.25\"\n inkscape:showpageshadow=\"2\"\n inkscape:pageopacity=\"0.0\"\n inkscape:pagecheckerboard=\"0\"\n inkscape:deskcolor=\"#d1d1d1\"\n showgrid=\"false\" />\n <g\n id=\"g8099\">\n <rect\n fill=\"none\"\n height=\"20\"\n width=\"20\"\n x=\"0\"\n id=\"rect8097\"\n y=\"0\" />\n </g>\n <g\n id=\"g8107\"\n style=\"fill:#ff4d6d;fill-opacity:1\">\n <g\n id=\"g8105\"\n style=\"fill:#ff4d6d;fill-opacity:1\">\n <path\n d=\"M 2.5,5 H 1 V 15.5 C 1,16.33 1.67,17 2.5,17 H 15.68 V 15.5 H 2.5 Z\"\n id=\"path8101\"\n style=\"fill:#ff4d6d;fill-opacity:1\" />\n <path\n d=\"M 16.5,4 H 11 L 9,2 H 5.5 C 4.67,2 4,2.67 4,3.5 v 9 C 4,13.33 4.67,14 5.5,14 h 11 c 0.83,0 1.5,-0.67 1.5,-1.5 v -7 C 18,4.67 17.33,4 16.5,4 Z m 0,8.5 h -11 v -9 h 2.88 l 2,2 h 6.12 z\"\n id=\"path8103\"\n style=\"fill:#ff4d6d;fill-opacity:1\" />\n </g>\n </g>\n</svg>\n",
"image_url": "docker:unstract/tool-classifier:0.0.68",
"image_url": "docker:unstract/tool-classifier:0.0.69",
"image_name": "unstract/tool-classifier",
"image_tag": "0.0.68"
"image_tag": "0.0.69"
},
"text_extractor": {
"tool_uid": "text_extractor",
@@ -116,7 +116,7 @@
"schemaVersion": "0.0.1",
"displayName": "Text Extractor",
"functionName": "text_extractor",
"toolVersion": "0.0.64",
"toolVersion": "0.0.65",
"description": "The Text Extractor is a powerful tool designed to convert documents to its text form or Extract texts from documents",
"input": {
"description": "Document"
@@ -191,8 +191,8 @@
}
},
"icon": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<svg\n enable-background=\"new 0 0 20 20\"\n height=\"48\"\n viewBox=\"0 0 20 20\"\n width=\"48\"\n fill=\"#000000\"\n version=\"1.1\"\n id=\"svg8109\"\n sodipodi:docname=\"folder_copy_black_48dp.svg\"\n xmlns:inkscape=\"http://www.inkscape.org/namespaces/inkscape\"\n xmlns:sodipodi=\"http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd\"\n xmlns=\"http://www.w3.org/2000/svg\"\n xmlns:svg=\"http://www.w3.org/2000/svg\">\n <defs\n id=\"defs8113\" />\n <sodipodi:namedview\n id=\"namedview8111\"\n pagecolor=\"#ffffff\"\n bordercolor=\"#000000\"\n borderopacity=\"0.25\"\n inkscape:showpageshadow=\"2\"\n inkscape:pageopacity=\"0.0\"\n inkscape:pagecheckerboard=\"0\"\n inkscape:deskcolor=\"#d1d1d1\"\n showgrid=\"false\" />\n <g\n id=\"g8099\">\n <rect\n fill=\"none\"\n height=\"20\"\n width=\"20\"\n x=\"0\"\n id=\"rect8097\"\n y=\"0\" />\n </g>\n <g\n id=\"g8107\"\n style=\"fill:#ff4d6d;fill-opacity:1\">\n <g\n id=\"g8105\"\n style=\"fill:#ff4d6d;fill-opacity:1\">\n <path\n d=\"M 2.5,5 H 1 V 15.5 C 1,16.33 1.67,17 2.5,17 H 15.68 V 15.5 H 2.5 Z\"\n id=\"path8101\"\n style=\"fill:#ff4d6d;fill-opacity:1\" />\n <path\n d=\"M 16.5,4 H 11 L 9,2 H 5.5 C 4.67,2 4,2.67 4,3.5 v 9 C 4,13.33 4.67,14 5.5,14 h 11 c 0.83,0 1.5,-0.67 1.5,-1.5 v -7 C 18,4.67 17.33,4 16.5,4 Z m 0,8.5 h -11 v -9 h 2.88 l 2,2 h 6.12 z\"\n id=\"path8103\"\n style=\"fill:#ff4d6d;fill-opacity:1\" />\n </g>\n </g>\n</svg>\n",
"image_url": "docker:unstract/tool-text-extractor:0.0.64",
"image_url": "docker:unstract/tool-text-extractor:0.0.65",
"image_name": "unstract/tool-text-extractor",
"image_tag": "0.0.64"
"image_tag": "0.0.65"
}
}

File diff suppressed because it is too large Load Diff

2112
uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -22,7 +22,7 @@ dependencies = [
"prometheus-client>=0.17.0,<1.0.0", # Metrics collection
"psutil>=5.9.0,<6.0.0", # System resource monitoring
# Essential Unstract packages - with Azure support for connectors
"unstract-sdk[azure]~=0.77.3", # Core SDK with Azure connector support
"unstract-sdk[azure]~=0.78.0", # Core SDK with Azure connector support
"unstract-connectors",
"unstract-core",
"unstract-flags",

3352
workers/uv.lock generated

File diff suppressed because it is too large Load Diff