UN-2470 [FEAT] Remove Django dependency from Celery workers with internal APIs (#1494)

* UN-2470 [MISC] Remove Django dependency from Celery workers

This commit introduces a new worker architecture that decouples
Celery workers from Django where possible, enabling support for
gevent/eventlet pool types and reducing worker startup overhead.

Key changes:
- Created separate worker modules (api-deployment, callback, file_processing, general)
- Added internal API endpoints for worker communication
- Implemented Django-free task execution where appropriate
- Added shared utilities and client facades
- Updated container configurations for new worker architecture

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix pre-commit issues: file permissions and ruff errors

Setup the docker for new workers

- Add executable permissions to worker entrypoint files
- Fix import order in namespace package __init__.py
- Remove unused variable api_status in general worker
- Address ruff E402 and F841 errors

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* refactoreed, Dockerfiles,fixes

* flexibility on celery run commands

* added debug logs

* handled filehistory for API

* cleanup

* cleanup

* cloud plugin structure

* minor changes in import plugin

* added notification and logger workers under new worker module

* add docker compatibility for new workers

* handled docker issues

* log consumer worker fixes

* added scheduler worker

* minor env changes

* cleanup the logs

* minor changes in logs

* resolved scheduler worker issues

* cleanup and refactor

* ensuring backward compatibbility to existing wokers

* added configuration internal apis and cache utils

* optimization

* Fix API client singleton pattern to share HTTP sessions

- Fix flawed singleton implementation that was trying to share BaseAPIClient instances
- Now properly shares HTTP sessions between specialized clients
- Eliminates 6x BaseAPIClient initialization by reusing the same underlying session
- Should reduce API deployment orchestration time by ~135ms (from 6 clients to 1 session)
- Added debug logging to verify singleton pattern activation

* cleanup and structuring

* cleanup in callback

* file system connectors  issue

* celery env values changes

* optional gossip

* variables for sync, mingle and gossip

* Fix for file type check

* Task pipeline issue resolving

* api deployement failed response handled

* Task pipline fixes

* updated file history cleanup with active file execution

* pipline status update and workflow ui page execution

* cleanup and resolvinf conflicts

* remove unstract-core from conenctoprs

* Commit uv.lock changes

* uv locks updates

* resolve migration issues

* defer connector-metadtda

* Fix connector migration for production scale

- Add encryption key handling with defer() to prevent decryption failures
- Add final cleanup step to fix duplicate connector names
- Optimize for large datasets with batch processing and bulk operations
- Ensure unique constraint in migration 0004 can be created successfully

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* hitl fixes

* minor fixes on hitl

* api_hub related changes

* dockerfile fixes

* api client cache fixes with actual response class

* fix: tags and llm_profile_id

* optimized clear cache

* cleanup

* enhanced logs

* added more handling on is file dir and added loggers

* cleanup the runplatform script

* internal apis are excempting from csrf

* sonal cloud issues

* sona-cloud issues

* resolving sonar cloud issues

* resolving sonar cloud issues

* Delta: added Batch size fix in workers

* comments addressed

* celery configurational changes for new workers

* fiixes in callback regaurding the pipline type check

* change internal url registry logic

* gitignore changes

* gitignore changes

* addressng pr cmmnets and cleanup the codes

* adding missed profiles for v2

* sonal cloud blocker issues resolved

* imlement otel

* Commit uv.lock changes

* handle execution time and some cleanup

* adding user_data in metadata Pr: https://github.com/Zipstack/unstract/pull/1544

* scheduler backward compatibitlity

* replace user_data with custom_data

* Commit uv.lock changes

* celery worker command issue resolved

* enhance package imports in connectors by changing to lazy imports

* Update runner.py by removing the otel from it

Update runner.py by removing the otel from it

Signed-off-by: ali <117142933+muhammad-ali-e@users.noreply.github.com>

* added delta changes

* handle erro to destination db

* resolve tool instances id validation and hitl queu name in API

* handled direct execution from workflow page to worker and logs

* handle cost logs

* Update health.py

Signed-off-by: Ritwik G <100672805+ritwik-g@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor log changes

* introducing log consumer scheduler to bulk create, and socket .emit from worker for ws

* Commit uv.lock changes

* time limit or timeout celery config cleanup

* implemented redis client class in worker

* pipline status enum mismatch

* notification worker fixes

* resolve uv lock conflicts

* workflow log fixes

* ws channel name issue resolved. and handling redis down in status tracker, and removing redis keys

* default TTL changed for unified logs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: ali <117142933+muhammad-ali-e@users.noreply.github.com>
Signed-off-by: Ritwik G <100672805+ritwik-g@users.noreply.github.com>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Ritwik G <100672805+ritwik-g@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
ali
2025-10-03 11:24:07 +05:30
committed by GitHub
parent e2b72589bc
commit 0c5997f9a9
312 changed files with 68877 additions and 387 deletions

10
.gitignore vendored
View File

@@ -1,6 +1,14 @@
# Created by https://www.toptal.com/developers/gitignore/api/windows,macos,linux,pycharm,pycharm+all,pycharm+iml,python,visualstudiocode,react,django
# Edit at https://www.toptal.com/developers/gitignore?templates=windows,macos,linux,pycharm,pycharm+all,pycharm+iml,python,visualstudiocode,react,django
# Development helper scripts
*.sh
# list Exceptional files with ! like !fix-and-test.sh
!run-platform.sh
!workers/run-worker.sh
!workers/run-worker-docker.sh
!workers/log_consumer/scheduler.sh
### Django ###
*.log
*.pot
@@ -622,6 +630,7 @@ backend/plugins/processor/*
# Subscription Plugins
backend/plugins/subscription/*
# API Deployment Plugins
backend/plugins/api/**
@@ -685,6 +694,7 @@ backend/requirements.txt
backend/backend/*_urls.py
!backend/backend/base_urls.py
!backend/backend/public_urls.py
!backend/backend/internal_base_urls.py
# TODO: Remove after v2 migration is completed
backend/backend/*_urls_v2.py
!backend/backend/public_urls_v2.py

View File

@@ -8,6 +8,7 @@ from account_v2.authentication_plugin_registry import AuthenticationPluginRegist
from account_v2.authentication_service import AuthenticationService
from account_v2.constants import Common
from backend.constants import RequestHeader
from backend.internal_api_constants import INTERNAL_API_PREFIX
class CustomAuthMiddleware:
@@ -22,6 +23,10 @@ class CustomAuthMiddleware:
if any(request.path.startswith(path) for path in settings.WHITELISTED_PATHS):
return self.get_response(request)
# Skip internal API paths - they are handled by InternalAPIAuthMiddleware
if request.path.startswith(f"{INTERNAL_API_PREFIX}/"):
return self.get_response(request)
# Authenticating With API_KEY
x_api_key = request.headers.get(RequestHeader.X_API_KEY)
if (

View File

@@ -0,0 +1,15 @@
"""Account Internal API Serializers
Handles serialization for organization context related endpoints.
"""
from rest_framework import serializers
class OrganizationContextSerializer(serializers.Serializer):
"""Serializer for organization context information."""
organization_id = serializers.CharField()
organization_name = serializers.CharField()
organization_slug = serializers.CharField(required=False, allow_blank=True)
created_at = serializers.CharField(required=False, allow_blank=True)
settings = serializers.DictField(required=False)

View File

@@ -0,0 +1,20 @@
"""Internal API URLs for Organization Context
URL patterns for organization-related internal APIs.
"""
from django.urls import path
from .internal_views import OrganizationContextAPIView
urlpatterns = [
# Organization context endpoint (backward compatibility)
path(
"<str:org_id>/", OrganizationContextAPIView.as_view(), name="organization-context"
),
# Organization context endpoint (explicit path)
path(
"<str:org_id>/context/",
OrganizationContextAPIView.as_view(),
name="organization-context-explicit",
),
]

View File

@@ -0,0 +1,40 @@
"""Account Internal API Views
Handles organization context related endpoints for internal services.
"""
import logging
from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import APIView
from utils.organization_utils import get_organization_context, resolve_organization
from .internal_serializers import OrganizationContextSerializer
logger = logging.getLogger(__name__)
class OrganizationContextAPIView(APIView):
"""Internal API endpoint for getting organization context."""
def get(self, request, org_id):
"""Get organization context information."""
try:
# Use shared utility to resolve organization
organization = resolve_organization(org_id, raise_on_not_found=True)
# Use shared utility to get context data
context_data = get_organization_context(organization)
serializer = OrganizationContextSerializer(context_data)
logger.info(f"Retrieved organization context for {org_id}")
return Response(serializer.data)
except Exception as e:
logger.error(f"Failed to get organization context for {org_id}: {str(e)}")
return Response(
{"error": "Failed to get organization context", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

View File

@@ -0,0 +1,16 @@
"""Account Internal API URLs
Defines internal API endpoints for organization operations.
"""
from django.urls import path
from .internal_views import OrganizationContextAPIView
urlpatterns = [
# Organization context API
path(
"<str:org_id>/context/",
OrganizationContextAPIView.as_view(),
name="organization-context",
),
]

View File

@@ -21,8 +21,19 @@ class SubscriptionConfig:
METADATA_IS_ACTIVE = "is_active"
# Cache for loaded plugins to avoid repeated loading
_subscription_plugins_cache: list[Any] = []
_plugins_loaded = False
def load_plugins() -> list[Any]:
"""Iterate through the subscription plugins and register them."""
global _subscription_plugins_cache, _plugins_loaded
# Return cached plugins if already loaded
if _plugins_loaded:
return _subscription_plugins_cache
plugins_app = apps.get_app_config(SubscriptionConfig.PLUGINS_APP)
package_path = plugins_app.module.__package__
subscription_dir = os.path.join(plugins_app.path, SubscriptionConfig.PLUGIN_DIR)
@@ -30,6 +41,8 @@ def load_plugins() -> list[Any]:
subscription_plugins: list[Any] = []
if not os.path.exists(subscription_dir):
_subscription_plugins_cache = subscription_plugins
_plugins_loaded = True
return subscription_plugins
for item in os.listdir(subscription_dir):
@@ -56,10 +69,13 @@ def load_plugins() -> list[Any]:
SubscriptionConfig.METADATA: module.metadata,
}
)
name = metadata.get(
SubscriptionConfig.METADATA_NAME,
getattr(module, "__name__", "unknown"),
)
is_active = metadata.get(SubscriptionConfig.METADATA_IS_ACTIVE, False)
logger.info(
"Loaded subscription plugin: %s, is_active: %s",
module.metadata[SubscriptionConfig.METADATA_NAME],
module.metadata[SubscriptionConfig.METADATA_IS_ACTIVE],
"Loaded subscription plugin: %s, is_active: %s", name, is_active
)
else:
logger.info(
@@ -75,6 +91,10 @@ def load_plugins() -> list[Any]:
if len(subscription_plugins) == 0:
logger.info("No subscription plugins found.")
# Cache the results for future requests
_subscription_plugins_cache = subscription_plugins
_plugins_loaded = True
return subscription_plugins

View File

@@ -0,0 +1,74 @@
"""Internal API Views for API v2
This module provides internal API endpoints for worker communication,
specifically optimized for type-aware pipeline data fetching.
Since we know the context from worker function calls:
- process_batch_callback_api -> APIDeployment model
- process_batch_callback -> Pipeline model (handled in workflow_manager)
This provides direct access to APIDeployment model data without
the overhead of checking both Pipeline and APIDeployment models.
"""
import logging
from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import APIView
from api_v2.models import APIDeployment
from api_v2.serializers import APIDeploymentSerializer
logger = logging.getLogger(__name__)
class APIDeploymentDataView(APIView):
"""Internal API endpoint for fetching APIDeployment data.
This endpoint is optimized for callback workers that know they're dealing
with API deployments. It directly queries the APIDeployment model without
checking the Pipeline model, improving performance.
Endpoint: GET /v2/api-deployments/{api_id}/data/
"""
def get(self, request, api_id):
"""Get APIDeployment model data by API ID.
Args:
request: HTTP request object
api_id: APIDeployment UUID
Returns:
Response with APIDeployment model data
"""
try:
logger.debug(f"Fetching APIDeployment data for ID: {api_id}")
# Query APIDeployment model directly (organization-scoped via DefaultOrganizationMixin)
api_deployment = APIDeployment.objects.get(id=api_id)
# Serialize the APIDeployment model
serializer = APIDeploymentSerializer(api_deployment)
# Use consistent response format with pipeline endpoint
response_data = {"status": "success", "pipeline": serializer.data}
logger.info(
f"Found APIDeployment {api_id}: name='{api_deployment.api_name}', display_name='{api_deployment.display_name}'"
)
return Response(response_data, status=status.HTTP_200_OK)
except APIDeployment.DoesNotExist:
logger.warning(f"APIDeployment not found for ID: {api_id}")
return Response(
{"error": f"APIDeployment with ID {api_id} not found"},
status=status.HTTP_404_NOT_FOUND,
)
except Exception as e:
logger.error(f"Error fetching APIDeployment data for {api_id}: {str(e)}")
return Response(
{"error": f"Failed to fetch APIDeployment data: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

View File

@@ -0,0 +1,20 @@
"""Internal API URLs for API v2
Internal endpoints for worker communication, specifically optimized
for type-aware pipeline data fetching.
"""
from django.urls import path
from rest_framework.urlpatterns import format_suffix_patterns
from api_v2.internal_api_views import APIDeploymentDataView
urlpatterns = format_suffix_patterns(
[
path(
"<uuid:api_id>/",
APIDeploymentDataView.as_view(),
name="api_deployment_data_internal",
),
]
)

View File

@@ -23,4 +23,6 @@ urlpatterns = [
include("pipeline_v2.public_api_urls"),
),
path("", include("health.urls")),
# Internal API for worker communication
path("internal/", include("backend.internal_base_urls")),
]

View File

@@ -0,0 +1,100 @@
"""Internal API Constants
Centralized constants for internal API paths, versions, and configuration.
These constants can be overridden via environment variables for flexibility.
"""
import os
# Default constant for SonarCloud compliance
DEFAULT_INTERNAL_PREFIX = "/internal"
# Internal API Configuration
INTERNAL_API_PREFIX = os.getenv("INTERNAL_API_PREFIX", DEFAULT_INTERNAL_PREFIX)
INTERNAL_API_VERSION = os.getenv("INTERNAL_API_VERSION", "v1")
# Computed full prefix
INTERNAL_API_BASE_PATH = f"{INTERNAL_API_PREFIX}/{INTERNAL_API_VERSION}"
def build_internal_endpoint(path: str) -> str:
"""Build a complete internal API endpoint path.
Args:
path: The endpoint path without the internal prefix (e.g., "health/")
Returns:
Complete internal API path (e.g., "/internal/v1/health/")
"""
# Ensure path starts and ends with /
if not path.startswith("/"):
path = f"/{path}"
if not path.endswith("/"):
path = f"{path}/"
return f"{INTERNAL_API_BASE_PATH}{path}"
# Common endpoint builder shortcuts
class InternalEndpoints:
"""Convenience class for building internal API endpoints."""
@staticmethod
def health() -> str:
"""Health check endpoint."""
return build_internal_endpoint("health")
@staticmethod
def workflow(workflow_id: str = "{id}") -> str:
"""Workflow endpoint."""
return build_internal_endpoint(f"workflow/{workflow_id}")
@staticmethod
def workflow_status(workflow_id: str = "{id}") -> str:
"""Workflow status endpoint."""
return build_internal_endpoint(f"workflow/{workflow_id}/status")
@staticmethod
def file_execution(file_execution_id: str = "{id}") -> str:
"""File execution endpoint."""
return build_internal_endpoint(f"file-execution/{file_execution_id}")
@staticmethod
def file_execution_status(file_execution_id: str = "{id}") -> str:
"""File execution status endpoint."""
return build_internal_endpoint(f"file-execution/{file_execution_id}/status")
@staticmethod
def webhook_send() -> str:
"""Webhook send endpoint."""
return build_internal_endpoint("webhook/send")
@staticmethod
def organization(org_id: str = "{org_id}") -> str:
"""Organization endpoint."""
return build_internal_endpoint(f"organization/{org_id}")
# Environment variable documentation
ENVIRONMENT_VARIABLES = {
"INTERNAL_API_PREFIX": {
"description": "Base prefix for internal API endpoints",
"default": DEFAULT_INTERNAL_PREFIX,
"example": DEFAULT_INTERNAL_PREFIX,
},
"INTERNAL_API_VERSION": {
"description": "API version for internal endpoints",
"default": "v1",
"example": "v1",
},
}
def get_api_info() -> dict:
"""Get current internal API configuration info."""
return {
"prefix": INTERNAL_API_PREFIX,
"version": INTERNAL_API_VERSION,
"base_path": INTERNAL_API_BASE_PATH,
"environment_variables": ENVIRONMENT_VARIABLES,
}

View File

@@ -0,0 +1,266 @@
"""Internal API URL Configuration - OSS Base.
Base internal URL patterns for OSS deployment. This file contains
the foundational internal APIs available in all deployments.
Cloud deployments extend this via cloud_internal_urls.py following
the same pattern as base_urls.py / cloud_base_urls.py.
"""
import logging
import secrets
from django.conf import settings
from django.http import Http404, JsonResponse
from django.urls import include, path
from django.views.decorators.http import require_http_methods
from utils.websocket_views import emit_websocket
logger = logging.getLogger(__name__)
@require_http_methods(["GET"])
def internal_api_root(request):
"""Internal API root endpoint with comprehensive documentation."""
return JsonResponse(
{
"message": "Unstract Internal API",
"version": "1.0.0",
"description": "Internal service-to-service API for Celery workers",
"documentation": "https://docs.unstract.com/internal-api",
"endpoints": {
"description": "Various v1 endpoints for workflow execution, pipeline, organization, and other services",
"base_path": "/internal/v1/",
},
"authentication": {
"type": "Bearer Token",
"header": "Authorization: Bearer <internal_service_api_key>",
"organization": "X-Organization-ID header (optional for scoped requests)",
"requirements": [
"All requests must include Authorization header",
"API key must match INTERNAL_SERVICE_API_KEY setting",
"Organization ID header required for org-scoped operations",
],
},
"response_format": {
"success": {"status": "success", "data": "..."},
"error": {"error": "Error message", "detail": "Additional details"},
},
"rate_limits": {
"default": "No rate limits for internal services",
"note": "Monitor usage through application logs",
},
}
)
@require_http_methods(["GET"])
def internal_health_check(request):
"""Health check endpoint for internal API."""
try:
# Debug information (sanitized for security)
debug_info = {
"has_internal_service": hasattr(request, "internal_service"),
"internal_service_value": getattr(request, "internal_service", None),
"auth_header_present": bool(request.META.get("HTTP_AUTHORIZATION")),
"auth_scheme": (
request.META.get("HTTP_AUTHORIZATION", "").split()[0]
if request.META.get("HTTP_AUTHORIZATION", "").strip()
else "None"
),
"path": request.path,
"method": request.method,
}
# Check authentication - first check middleware, then fallback to direct key check
authenticated = False
if hasattr(request, "internal_service") and request.internal_service:
authenticated = True
else:
# Fallback: check API key directly if middleware didn't run
auth_header = request.META.get("HTTP_AUTHORIZATION", "")
if auth_header.startswith("Bearer "):
api_key = auth_header[7:] # Remove 'Bearer ' prefix
internal_api_key = getattr(settings, "INTERNAL_SERVICE_API_KEY", None)
if internal_api_key and secrets.compare_digest(api_key, internal_api_key):
authenticated = True
# Set the flag manually since middleware didn't run
request.internal_service = True
elif internal_api_key:
# Log authentication failure (without exposing the key)
logger.warning(
"Internal API authentication failed",
extra={
"path": request.path,
"method": request.method,
"remote_addr": request.META.get("REMOTE_ADDR"),
},
)
if not authenticated:
return JsonResponse(
{
"status": "error",
"message": "Not authenticated as internal service",
"debug": debug_info,
},
status=401,
)
# Basic health checks
health_data = {
"status": "healthy",
"service": "internal_api",
"version": "1.0.0",
"timestamp": request.META.get("HTTP_DATE"),
"authenticated": True,
"organization_id": getattr(request, "organization_id", None),
"debug": debug_info,
}
return JsonResponse(health_data)
except Exception as e:
logger.exception("internal_health_check failed")
return JsonResponse(
{
"status": "error",
"message": "Health check failed",
"error": str(e),
"debug": {
"has_internal_service": hasattr(request, "internal_service"),
"auth_header_present": bool(request.META.get("HTTP_AUTHORIZATION")),
"auth_scheme": (
request.META.get("HTTP_AUTHORIZATION", "").split()[0]
if request.META.get("HTTP_AUTHORIZATION", "").strip()
else "None"
),
"path": request.path,
},
},
status=500,
)
# Test endpoint to debug middleware (only available in DEBUG mode)
@require_http_methods(["GET"])
def test_middleware_debug(request):
"""Debug endpoint to check middleware execution - only in DEBUG mode."""
# Only available in DEBUG mode or with explicit flag
if not (settings.DEBUG or getattr(settings, "INTERNAL_API_DEBUG", False)):
raise Http404("Debug endpoint not available")
return JsonResponse(
{
"middleware_debug": {
"path": request.path,
"method": request.method,
"auth_header_present": bool(request.META.get("HTTP_AUTHORIZATION")),
"auth_scheme": (
request.META.get("HTTP_AUTHORIZATION", "").split()[0]
if request.META.get("HTTP_AUTHORIZATION", "").strip()
else "None"
),
"has_internal_service": hasattr(request, "internal_service"),
"internal_service_value": getattr(request, "internal_service", None),
"authenticated_via": getattr(request, "authenticated_via", None),
"organization_id": getattr(request, "organization_id", None),
"internal_api_key_configured": bool(
getattr(settings, "INTERNAL_SERVICE_API_KEY", None)
),
}
}
)
# Internal API URL patterns - OSS Base
urlpatterns = [
# Internal API root and utilities
path("", internal_api_root, name="internal_api_root"),
path("debug/", test_middleware_debug, name="test_middleware_debug"),
path("v1/health/", internal_health_check, name="internal_health"),
# WebSocket emission endpoint for workers
path("emit-websocket/", emit_websocket, name="emit_websocket"),
# ========================================
# CORE OSS INTERNAL API MODULES
# ========================================
# Workflow execution management APIs
path(
"v1/workflow-execution/",
include("workflow_manager.workflow_execution_internal_urls"),
name="workflow_execution_internal",
),
# Workflow management and pipeline APIs
path(
"v1/workflow-manager/",
include("workflow_manager.internal_urls"),
name="workflow_manager_internal",
),
# Pipeline APIs
path(
"v1/pipeline/",
include("pipeline_v2.internal_urls"),
name="pipeline_internal",
),
# Organization context and management APIs
path(
"v1/organization/",
include("account_v2.organization_internal_urls"),
name="organization_internal",
),
# File execution and batch processing APIs
path(
"v1/file-execution/",
include("workflow_manager.file_execution.internal_urls"),
name="file_execution_internal",
),
# Tool instance execution APIs
path(
"v1/tool-execution/",
include("tool_instance_v2.internal_urls"),
name="tool_execution_internal",
),
# File processing history and caching APIs
path(
"v1/file-history/",
include("workflow_manager.workflow_v2.file_history_internal_urls"),
name="file_history_internal",
),
# Webhook notification APIs
path(
"v1/webhook/",
include("notification_v2.internal_urls"),
name="webhook_internal",
),
# API deployment data APIs for type-aware worker optimization
path(
"v1/api-deployments/",
include("api_v2.internal_urls"),
name="api_deployments_internal",
),
# Platform configuration and settings APIs
path(
"v1/platform-settings/",
include("platform_settings_v2.internal_urls"),
name="platform_settings_internal",
),
# Execution log management and cache operations APIs
path(
"v1/execution-logs/",
include("workflow_manager.workflow_v2.execution_log_internal_urls"),
name="execution_logs_internal",
),
# Organization configuration management APIs
path(
"v1/configuration/",
include("configuration.internal_urls"),
name="configuration_internal",
),
# Usage data and token count APIs
path(
"v1/usage/",
include("usage_v2.internal_urls"),
name="usage_internal",
),
]

View File

@@ -562,7 +562,6 @@ SOCIAL_AUTH_GOOGLE_OAUTH2_AUTH_EXTRA_ARGUMENTS = {
}
SOCIAL_AUTH_GOOGLE_OAUTH2_USE_UNIQUE_USER_ID = True
# Always keep this line at the bottom of the file.
if missing_settings:
ERROR_MESSAGE = "Below required settings are missing.\n" + ",\n".join(

View File

@@ -61,4 +61,5 @@ urlpatterns = [
include("prompt_studio.prompt_studio_index_manager.urls"),
),
path("notifications/", include("notification.urls")),
path("internal/", include("backend.internal_base_urls")),
]

View File

@@ -0,0 +1,15 @@
"""Internal API URLs for Configuration access by workers."""
from django.urls import path
from . import internal_views
app_name = "configuration_internal"
urlpatterns = [
path(
"<str:config_key>/",
internal_views.ConfigurationInternalView.as_view(),
name="configuration-detail",
),
]

View File

@@ -0,0 +1,122 @@
"""Internal API views for Configuration access by workers."""
import logging
from account_v2.models import Organization
from django.http import JsonResponse
from rest_framework import status
from rest_framework.request import Request
from rest_framework.views import APIView
from .models import Configuration
logger = logging.getLogger(__name__)
class ConfigurationInternalView(APIView):
"""Internal API view for workers to access organization configurations.
This endpoint allows workers to get organization-specific configuration
values without direct database access, maintaining the same logic as
Configuration.get_value_by_organization() but over HTTP.
Workers can call this to get configs like MAX_PARALLEL_FILE_BATCHES
with proper organization-specific overrides and fallbacks.
"""
def get(self, request: Request, config_key: str) -> JsonResponse:
"""Get configuration value for an organization.
Args:
request: HTTP request with organization_id parameter
config_key: Configuration key name (e.g., "MAX_PARALLEL_FILE_BATCHES")
Returns:
JSON response with configuration value and metadata
"""
try:
organization_id = request.query_params.get("organization_id")
if not organization_id:
return JsonResponse(
{
"success": False,
"error": "organization_id parameter is required",
"config_key": config_key,
},
status=status.HTTP_400_BAD_REQUEST,
)
# Get the organization - handle both ID (int) and organization_id (string)
try:
# Try to get organization by primary key ID first (for backward compatibility)
if organization_id.isdigit():
organization = Organization.objects.get(id=int(organization_id))
else:
# Otherwise, lookup by organization_id field (string identifier)
organization = Organization.objects.get(
organization_id=organization_id
)
except (Organization.DoesNotExist, ValueError):
return JsonResponse(
{
"success": False,
"error": f"Organization {organization_id} not found",
"config_key": config_key,
},
status=status.HTTP_404_NOT_FOUND,
)
# Get the configuration value using the same logic as the backend
try:
config_value = Configuration.get_value_by_organization(
config_key=config_key, organization=organization
)
# Check if we found an organization-specific override
has_override = False
try:
Configuration.objects.get(
organization=organization, key=config_key, enabled=True
)
has_override = True
except Configuration.DoesNotExist:
has_override = False
return JsonResponse(
{
"success": True,
"data": {
"config_key": config_key,
"value": config_value,
"organization_id": organization_id,
"has_organization_override": has_override,
},
}
)
except ValueError as e:
# Configuration key not found in registry
return JsonResponse(
{
"success": False,
"error": str(e),
"config_key": config_key,
"organization_id": organization_id,
},
status=status.HTTP_400_BAD_REQUEST,
)
except Exception as e:
logger.error(
f"Error getting configuration {config_key} for organization {organization_id}: {e}",
exc_info=True,
)
return JsonResponse(
{
"success": False,
"error": "Internal server error",
"config_key": config_key,
},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

View File

@@ -42,10 +42,22 @@ def _group_connectors(
) -> dict[tuple[Any, str, str | None], list[Any]]:
"""Group connectors by organization, connector type, and metadata hash."""
connector_groups = {}
skipped_connectors = 0
for connector in connector_instances:
try:
metadata_hash = _compute_metadata_hash(connector.connector_metadata)
# Try to access connector_metadata - this may fail due to encryption key mismatch
try:
metadata_hash = _compute_metadata_hash(connector.connector_metadata)
except Exception as decrypt_error:
# Log the encryption error and skip this connector
logger.warning(
f"Skipping connector {connector.id} due to encryption error: {str(decrypt_error)}. "
f"This is likely due to a changed ENCRYPTION_KEY."
)
skipped_connectors += 1
continue
connector_sys_name = _extract_connector_sys_name(connector.connector_id)
group_key = (
@@ -62,6 +74,11 @@ def _group_connectors(
logger.error(f"Error processing connector {connector.id}: {str(e)}")
raise
if skipped_connectors > 0:
logger.warning(
f"Skipped {skipped_connectors} connectors due to encryption key issues"
)
return connector_groups
@@ -70,9 +87,16 @@ def _process_single_connector(
processed_groups: int,
total_groups: int,
short_group_key: tuple[Any, str, str],
connector_instance_model: Any,
) -> None:
"""Process a group with only one connector."""
connector.connector_name = f"{connector.connector_name}-{uuid.uuid4().hex[:8]}"
base_name = connector.connector_name
new_name = f"{base_name}-{uuid.uuid4().hex[:8]}"
# For performance with large datasets, UUID collisions are extremely rare
# If uniqueness becomes critical, we can add collision detection later
connector.connector_name = new_name
logger.info(
f"[Group {processed_groups}/{total_groups}] {short_group_key}: "
f"Only 1 connector present, renaming to '{connector.connector_name}'"
@@ -85,6 +109,7 @@ def _centralize_connector_group(
processed_groups: int,
total_groups: int,
short_group_key: tuple[Any, str, str],
connector_instance_model: Any,
) -> tuple[Any, dict[Any, Any], set[Any]]:
"""Centralize a group of multiple connectors."""
logger.info(
@@ -95,7 +120,12 @@ def _centralize_connector_group(
# First connector becomes the centralized one
centralized_connector = connectors[0]
original_name = centralized_connector.connector_name
centralized_connector.connector_name = f"{original_name}-{uuid.uuid4().hex[:8]}"
new_name = f"{original_name}-{uuid.uuid4().hex[:8]}"
# For performance with large datasets, UUID collisions are extremely rare
# If uniqueness becomes critical, we can add collision detection later
centralized_connector.connector_name = new_name
logger.info(
f"[Group {processed_groups}/{total_groups}] {short_group_key}: "
@@ -164,6 +194,88 @@ def _delete_redundant_connectors(
raise
def _fix_remaining_duplicate_names(connector_instance_model: Any) -> int:
"""Fix any remaining duplicate connector names within organizations."""
from django.db.models import Count
# Find all organizations with duplicate connector names (optimized query)
duplicates = list(
connector_instance_model.objects.values("connector_name", "organization_id")
.annotate(count=Count("id"))
.filter(count__gt=1)
.order_by("organization_id", "connector_name")
)
total_duplicates = len(duplicates)
if total_duplicates == 0:
logger.info("No duplicate connector names found after migration")
return 0
logger.info(
f"Found {total_duplicates} groups with duplicate connector names - fixing"
)
fixed_count = 0
# Process in batches to avoid memory issues
batch_size = 20
for i in range(0, len(duplicates), batch_size):
batch = duplicates[i : i + batch_size]
logger.info(
f"Processing batch {i//batch_size + 1}/{(len(duplicates)-1)//batch_size + 1}"
)
for dup_info in batch:
connector_name = dup_info["connector_name"]
org_id = dup_info["organization_id"]
# Get all connectors with this name in this organization (select only needed fields)
duplicate_connectors = list(
connector_instance_model.objects.filter(
connector_name=connector_name, organization_id=org_id
)
.only("id", "connector_name", "organization_id")
.order_by("id")
)
if len(duplicate_connectors) <= 1:
continue # Skip if no longer duplicates
# Prepare batch updates (keep first, rename others)
updates = []
existing_names_in_org = set(
connector_instance_model.objects.filter(
organization_id=org_id
).values_list("connector_name", flat=True)
)
for j, connector in enumerate(duplicate_connectors[1:], 1): # Skip first
base_name = connector_name
new_name = f"{base_name}-{uuid.uuid4().hex[:8]}"
# Simple collision check against existing names in this org
attempt = 0
while new_name in existing_names_in_org and attempt < 5:
new_name = f"{base_name}-{uuid.uuid4().hex[:8]}"
attempt += 1
existing_names_in_org.add(new_name) # Track new names
connector.connector_name = new_name
updates.append(connector)
fixed_count += 1
# Bulk update for better performance
if updates:
connector_instance_model.objects.bulk_update(
updates, ["connector_name"], batch_size=100
)
logger.info(
f" Fixed {len(updates)} duplicates of '{connector_name}' in org {org_id}"
)
logger.info(f"Fixed {fixed_count} duplicate connector names")
return fixed_count
def migrate_to_centralized_connectors(apps, schema_editor): # noqa: ARG001
"""Migrate existing workflow-specific connectors to centralized connectors.
@@ -176,10 +288,15 @@ def migrate_to_centralized_connectors(apps, schema_editor): # noqa: ARG001
ConnectorInstance = apps.get_model("connector_v2", "ConnectorInstance") # NOSONAR
WorkflowEndpoint = apps.get_model("endpoint_v2", "WorkflowEndpoint") # NOSONAR
# Get all connector instances with select_related for performance
connector_instances = ConnectorInstance.objects.select_related(
"organization", "created_by", "modified_by"
).all()
# Get all connector instances, but defer the encrypted metadata field to avoid
# automatic decryption failures when the encryption key has changed
connector_instances = (
ConnectorInstance.objects.select_related(
"organization", "created_by", "modified_by"
)
.defer("connector_metadata")
.all()
)
total_connectors = connector_instances.count()
logger.info(f"Processing {total_connectors} connector instances for centralization")
@@ -187,6 +304,17 @@ def migrate_to_centralized_connectors(apps, schema_editor): # noqa: ARG001
# Group connectors by organization and unique credential fingerprint
connector_groups = _group_connectors(connector_instances)
# Safety check: If we have connectors but all were skipped, this indicates a serious issue
if total_connectors > 0 and len(connector_groups) == 0:
error_msg = (
f"CRITICAL: All {total_connectors} connectors were skipped due to encryption errors. "
f"This likely means the ENCRYPTION_KEY has changed. Please restore the correct "
f"ENCRYPTION_KEY and retry the migration. The migration has been aborted to prevent "
f"data loss."
)
logger.error(error_msg)
raise RuntimeError(error_msg)
# Process each group and centralize connectors
processed_groups = 0
centralized_count = 0
@@ -202,13 +330,21 @@ def migrate_to_centralized_connectors(apps, schema_editor): # noqa: ARG001
# Process single connector groups differently
if len(connectors) == 1:
_process_single_connector(
connectors[0], processed_groups, total_groups, short_group_key
connectors[0],
processed_groups,
total_groups,
short_group_key,
ConnectorInstance,
)
continue
# Centralize multiple connectors
_, connector_mapping, connectors_to_delete = _centralize_connector_group(
connectors, processed_groups, total_groups, short_group_key
connectors,
processed_groups,
total_groups,
short_group_key,
ConnectorInstance,
)
centralized_count += 1
@@ -232,6 +368,9 @@ def migrate_to_centralized_connectors(apps, schema_editor): # noqa: ARG001
# Delete redundant connectors
_delete_redundant_connectors(all_connectors_to_delete, ConnectorInstance)
# Final cleanup: Fix any remaining duplicate names within organizations
_fix_remaining_duplicate_names(ConnectorInstance)
logger.info(
f"Migration completed: {centralized_count} centralized connectors created"
)
@@ -273,19 +412,28 @@ def _create_workflow_specific_connector(
connector_instance_model: Any,
) -> Any:
"""Create a new workflow-specific connector from a centralized one."""
return connector_instance_model.objects.create(
connector_name=centralized_connector.connector_name,
connector_id=centralized_connector.connector_id,
connector_metadata=centralized_connector.connector_metadata,
connector_version=centralized_connector.connector_version,
connector_type=connector_type,
connector_auth=centralized_connector.connector_auth,
connector_mode=centralized_connector.connector_mode,
workflow=workflow,
organization=centralized_connector.organization,
created_by=centralized_connector.created_by,
modified_by=centralized_connector.modified_by,
)
try:
# Try to access connector_metadata to ensure it's readable
metadata = centralized_connector.connector_metadata
return connector_instance_model.objects.create(
connector_name=centralized_connector.connector_name,
connector_id=centralized_connector.connector_id,
connector_metadata=metadata,
connector_version=centralized_connector.connector_version,
connector_type=connector_type,
connector_auth=centralized_connector.connector_auth,
connector_mode=centralized_connector.connector_mode,
workflow=workflow,
organization=centralized_connector.organization,
created_by=centralized_connector.created_by,
modified_by=centralized_connector.modified_by,
)
except Exception as e:
logger.warning(
f"Skipping creation of workflow-specific connector from {centralized_connector.id} "
f"due to encryption error: {str(e)}"
)
raise
def _process_connector_endpoints(
@@ -359,10 +507,13 @@ def reverse_centralized_connectors(apps, schema_editor): # noqa: ARG001
ConnectorInstance = apps.get_model("connector_v2", "ConnectorInstance") # NOSONAR
WorkflowEndpoint = apps.get_model("endpoint_v2", "WorkflowEndpoint") # NOSONAR
# Get all centralized connectors with prefetch for better performance
centralized_connectors = ConnectorInstance.objects.prefetch_related(
"workflow_endpoints"
).all()
# Get all centralized connectors, but defer the encrypted metadata field to avoid
# automatic decryption failures when the encryption key has changed
centralized_connectors = (
ConnectorInstance.objects.prefetch_related("workflow_endpoints")
.defer("connector_metadata")
.all()
)
total_connectors = centralized_connectors.count()
logger.info(f"Processing {total_connectors} centralized connectors for reversal")
@@ -375,6 +526,7 @@ def reverse_centralized_connectors(apps, schema_editor): # noqa: ARG001
# Process connectors with endpoints to create workflow-specific copies
added_connector_count = 0
processed_connectors = 0
skipped_reverse_connectors = 0
for centralized_connector in centralized_connectors:
processed_connectors += 1
@@ -384,6 +536,17 @@ def reverse_centralized_connectors(apps, schema_editor): # noqa: ARG001
continue
try:
# Test if we can access encrypted fields before processing
try:
_ = centralized_connector.connector_metadata
except Exception as decrypt_error:
logger.warning(
f"Skipping reverse migration for connector {centralized_connector.id} "
f"due to encryption error: {str(decrypt_error)}"
)
skipped_reverse_connectors += 1
continue
endpoints = WorkflowEndpoint.objects.filter(
connector_instance=centralized_connector
)
@@ -404,6 +567,22 @@ def reverse_centralized_connectors(apps, schema_editor): # noqa: ARG001
)
raise
if skipped_reverse_connectors > 0:
logger.warning(
f"Skipped {skipped_reverse_connectors} connectors during reverse migration due to encryption issues"
)
# Safety check for reverse migration: if we skipped everything, abort
if skipped_reverse_connectors == total_connectors and total_connectors > 0:
error_msg = (
f"CRITICAL: All {total_connectors} connectors were skipped during reverse migration "
f"due to encryption errors. This likely means the ENCRYPTION_KEY has changed. "
f"Please restore the correct ENCRYPTION_KEY and retry the reverse migration. "
f"The reverse migration has been aborted to prevent data loss."
)
logger.error(error_msg)
raise RuntimeError(error_msg)
# Delete unused centralized connectors
_delete_unused_centralized_connectors(unused_connectors, ConnectorInstance)

View File

@@ -0,0 +1,241 @@
"""Internal API Service Authentication Middleware
Handles service-to-service authentication for internal APIs.
"""
import logging
from typing import Any
from django.conf import settings
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.utils.deprecation import MiddlewareMixin
from utils.constants import Account
from utils.local_context import StateStore
logger = logging.getLogger(__name__)
class InternalAPIAuthMiddleware(MiddlewareMixin):
"""Middleware for authenticating internal service API requests.
This middleware:
1. Checks for internal service API key in Authorization header
2. Validates the key against INTERNAL_SERVICE_API_KEY setting
3. Sets up organization context for requests
4. Bypasses normal user authentication for internal services
"""
def process_request(self, request: HttpRequest) -> HttpResponse | None:
"""Enhanced request processing with improved debugging and organization context handling."""
# Enhanced request logging with more context
request_info = {
"path": request.path,
"method": request.method,
"content_type": request.META.get("CONTENT_TYPE", "unknown"),
"user_agent": request.META.get("HTTP_USER_AGENT", "unknown")[:100],
"remote_addr": request.META.get("REMOTE_ADDR", "unknown"),
"auth_header_present": bool(request.META.get("HTTP_AUTHORIZATION")),
"org_header_present": bool(request.headers.get("X-Organization-ID")),
}
logger.debug(f"InternalAPIAuthMiddleware processing request: {request_info}")
# Only apply to internal API endpoints
if not request.path.startswith("/internal/"):
logger.debug(f"Skipping middleware for non-internal path: {request.path}")
return None
logger.info(f"Processing internal API request: {request.method} {request.path}")
# Enhanced authentication handling
auth_result = self._authenticate_request(request)
if auth_result["error"]:
logger.warning(
f"Authentication failed for {request.path}: {auth_result['message']}"
)
return JsonResponse(
{
"error": auth_result["message"],
"detail": auth_result["detail"],
"debug_info": auth_result.get("debug_info", {})
if settings.DEBUG
else {},
},
status=auth_result["status"],
)
# Enhanced organization context handling
org_result = self._setup_organization_context(request)
if org_result["warning"]:
logger.warning(
f"Organization context issue for {request.path}: {org_result['warning']}"
)
# Mark request as authenticated
request.internal_service = True
request.authenticated_via = "internal_service_api_key"
# Enhanced organization context logging
final_context = {
"path": request.path,
"request_org_id": getattr(request, "organization_id", "None"),
"statestore_org_id": StateStore.get(Account.ORGANIZATION_ID),
"org_context_set": org_result["context_set"],
"org_validated": org_result.get("organization_validated", False),
}
logger.info(f"Internal API request authenticated successfully: {final_context}")
return None # Continue with request processing
def _authenticate_request(self, request: HttpRequest) -> dict[str, Any]:
"""Enhanced authentication with detailed error reporting."""
auth_header = request.META.get("HTTP_AUTHORIZATION", "")
if not auth_header:
return {
"error": True,
"status": 401,
"message": "Authorization header required for internal APIs",
"detail": "Missing Authorization header",
"debug_info": {
"headers_present": list(request.META.keys()),
"expected_format": "Authorization: Bearer <api_key>",
},
}
if not auth_header.startswith("Bearer "):
return {
"error": True,
"status": 401,
"message": "Bearer token required for internal APIs",
"detail": f"Invalid authorization format: {auth_header[:20]}...",
"debug_info": {
"provided_format": auth_header.split(" ")[0]
if " " in auth_header
else auth_header[:10],
"expected_format": "Bearer <api_key>",
},
}
# Extract and validate API key
api_key = auth_header[7:] # Remove 'Bearer ' prefix
internal_api_key = getattr(settings, "INTERNAL_SERVICE_API_KEY", None)
if not internal_api_key:
logger.error("INTERNAL_SERVICE_API_KEY not configured in Django settings")
return {
"error": True,
"status": 500,
"message": "Internal API authentication not configured",
"detail": "INTERNAL_SERVICE_API_KEY setting missing",
}
if api_key != internal_api_key:
# Enhanced logging for key mismatch debugging
key_comparison = {
"provided_key_length": len(api_key),
"expected_key_length": len(internal_api_key),
"keys_match": api_key == internal_api_key,
"provided_key_prefix": api_key[:8] + "..."
if len(api_key) > 8
else api_key,
"expected_key_prefix": internal_api_key[:8] + "..."
if len(internal_api_key) > 8
else internal_api_key,
}
logger.warning(f"API key validation failed: {key_comparison}")
return {
"error": True,
"status": 401,
"message": "Invalid internal service API key",
"detail": "API key does not match configured value",
"debug_info": key_comparison if settings.DEBUG else {},
}
return {"error": False, "message": "Authentication successful"}
def _setup_organization_context(self, request: HttpRequest) -> dict[str, Any]:
"""Enhanced organization context setup with validation."""
org_id = request.headers.get("X-Organization-ID")
if not org_id:
return {
"warning": "No organization ID provided in X-Organization-ID header",
"context_set": False,
}
try:
# Validate organization ID format
if not org_id.strip():
return {"warning": "Empty organization ID provided", "context_set": False}
# Enhanced organization context validation
from utils.organization_utils import resolve_organization
try:
organization = resolve_organization(org_id, raise_on_not_found=False)
if organization:
# Use organization.organization_id (string field) for StateStore consistency
# This ensures UserContext.get_organization() can properly retrieve the organization
request.organization_id = organization.organization_id
request.organization_context = {
"id": str(organization.id),
"organization_id": organization.organization_id,
"name": organization.display_name,
"validated": True,
}
# Store the organization_id string field in StateStore for UserContext compatibility
StateStore.set(Account.ORGANIZATION_ID, organization.organization_id)
logger.debug(
f"Organization context validated and set: {organization.display_name} (org_id: {organization.organization_id}, pk: {organization.id})"
)
return {
"warning": None,
"context_set": True,
"organization_validated": True,
}
else:
logger.warning(f"Organization {org_id} not found in database")
# Still set the context for backward compatibility
request.organization_id = org_id
StateStore.set(Account.ORGANIZATION_ID, org_id)
return {
"warning": f"Organization {org_id} not found in database, using raw value",
"context_set": True,
"organization_validated": False,
}
except Exception as e:
logger.warning(f"Failed to validate organization {org_id}: {str(e)}")
# Fallback to raw organization ID
request.organization_id = org_id
StateStore.set(Account.ORGANIZATION_ID, org_id)
return {
"warning": f"Organization validation failed: {str(e)}, using raw value",
"context_set": True,
"organization_validated": False,
}
except Exception as e:
logger.error(f"Unexpected error setting organization context: {str(e)}")
return {
"warning": f"Failed to set organization context: {str(e)}",
"context_set": False,
}
def process_response(
self, request: HttpRequest, response: HttpResponse
) -> HttpResponse:
# Clean up organization context if we set it
if hasattr(request, "internal_service") and request.internal_service:
try:
org_id_before_clear = StateStore.get(Account.ORGANIZATION_ID)
if org_id_before_clear is not None:
StateStore.clear(Account.ORGANIZATION_ID)
logger.debug(
f"Cleaned up organization context for {request.path}: {org_id_before_clear}"
)
except AttributeError:
# StateStore key doesn't exist, which is fine
logger.debug(f"No organization context to clean up for {request.path}")
return response

View File

@@ -0,0 +1,252 @@
"""Internal API views for notification data access by workers.
These endpoints provide notification configuration data to workers
without exposing full Django models or requiring Django dependencies.
Security Note:
- CSRF protection is disabled for internal service-to-service communication
- Authentication is handled by InternalAPIAuthMiddleware using Bearer tokens
- These endpoints are not accessible from browsers and don't use session cookies
"""
import logging
from api_v2.models import APIDeployment
from django.http import JsonResponse
from django.shortcuts import get_object_or_404
from django.views.decorators.csrf import csrf_exempt
from django.views.decorators.http import require_http_methods
from pipeline_v2.models import Pipeline
from utils.organization_utils import filter_queryset_by_organization
from notification_v2.models import Notification
logger = logging.getLogger(__name__)
# Constants for error messages
INTERNAL_SERVER_ERROR_MSG = "Internal server error"
@csrf_exempt # Safe: Internal API with Bearer token auth, service-to-service only
@require_http_methods(["GET"])
def get_pipeline_notifications(request, pipeline_id):
"""Get active notifications for a pipeline or API deployment.
Used by callback worker to fetch notification configuration.
"""
try:
# Try to find the pipeline ID in Pipeline model first
pipeline_queryset = Pipeline.objects.filter(id=pipeline_id)
pipeline_queryset = filter_queryset_by_organization(
pipeline_queryset, request, "organization"
)
if pipeline_queryset.exists():
pipeline = pipeline_queryset.first()
# Get active notifications for this pipeline
notifications = Notification.objects.filter(pipeline=pipeline, is_active=True)
notifications_data = []
for notification in notifications:
notifications_data.append(
{
"id": str(notification.id),
"notification_type": notification.notification_type,
"platform": notification.platform,
"url": notification.url,
"authorization_type": notification.authorization_type,
"authorization_key": notification.authorization_key,
"authorization_header": notification.authorization_header,
"max_retries": notification.max_retries,
"is_active": notification.is_active,
}
)
return JsonResponse(
{
"status": "success",
"pipeline_id": str(pipeline.id),
"pipeline_name": pipeline.pipeline_name,
"pipeline_type": pipeline.pipeline_type,
"notifications": notifications_data,
}
)
else:
# If not found in Pipeline, try APIDeployment model
api_queryset = APIDeployment.objects.filter(id=pipeline_id)
api_queryset = filter_queryset_by_organization(
api_queryset, request, "organization"
)
if api_queryset.exists():
api = api_queryset.first()
# Get active notifications for this API deployment
notifications = Notification.objects.filter(api=api, is_active=True)
notifications_data = []
for notification in notifications:
notifications_data.append(
{
"id": str(notification.id),
"notification_type": notification.notification_type,
"platform": notification.platform,
"url": notification.url,
"authorization_type": notification.authorization_type,
"authorization_key": notification.authorization_key,
"authorization_header": notification.authorization_header,
"max_retries": notification.max_retries,
"is_active": notification.is_active,
}
)
return JsonResponse(
{
"status": "success",
"pipeline_id": str(api.id),
"pipeline_name": api.api_name,
"pipeline_type": "API",
"notifications": notifications_data,
}
)
else:
return JsonResponse(
{
"status": "error",
"message": "Pipeline or API deployment not found",
},
status=404,
)
except Exception as e:
logger.error(f"Error getting pipeline notifications for {pipeline_id}: {e}")
return JsonResponse(
{"status": "error", "message": INTERNAL_SERVER_ERROR_MSG}, status=500
)
@csrf_exempt # Safe: Internal API with Bearer token auth, service-to-service only
@require_http_methods(["GET"])
def get_api_notifications(request, api_id):
"""Get active notifications for an API deployment.
Used by callback worker to fetch notification configuration.
"""
try:
# Get API deployment with organization filtering
api_queryset = APIDeployment.objects.filter(id=api_id)
api_queryset = filter_queryset_by_organization(
api_queryset, request, "organization"
)
api = get_object_or_404(api_queryset)
# Get active notifications for this API
notifications = Notification.objects.filter(api=api, is_active=True)
notifications_data = []
for notification in notifications:
notifications_data.append(
{
"id": str(notification.id),
"notification_type": notification.notification_type,
"platform": notification.platform,
"url": notification.url,
"authorization_type": notification.authorization_type,
"authorization_key": notification.authorization_key,
"authorization_header": notification.authorization_header,
"max_retries": notification.max_retries,
"is_active": notification.is_active,
}
)
return JsonResponse(
{
"status": "success",
"api_id": str(api.id),
"api_name": api.api_name,
"display_name": api.display_name,
"notifications": notifications_data,
}
)
except APIDeployment.DoesNotExist:
return JsonResponse(
{"status": "error", "message": "API deployment not found"}, status=404
)
except Exception as e:
logger.error(f"Error getting API notifications for {api_id}: {e}")
return JsonResponse(
{"status": "error", "message": INTERNAL_SERVER_ERROR_MSG}, status=500
)
@csrf_exempt # Safe: Internal API with Bearer token auth, service-to-service only
@require_http_methods(["GET"])
def get_pipeline_data(request, pipeline_id):
"""Get basic pipeline data for notification purposes.
Used by callback worker to determine pipeline type and name.
"""
try:
# Get pipeline with organization filtering
pipeline_queryset = Pipeline.objects.filter(id=pipeline_id)
pipeline_queryset = filter_queryset_by_organization(
pipeline_queryset, request, "organization"
)
pipeline = get_object_or_404(pipeline_queryset)
return JsonResponse(
{
"status": "success",
"pipeline_id": str(pipeline.id),
"pipeline_name": pipeline.pipeline_name,
"pipeline_type": pipeline.pipeline_type,
"last_run_status": pipeline.last_run_status,
}
)
except Pipeline.DoesNotExist:
return JsonResponse(
{"status": "error", "message": "Pipeline not found"}, status=404
)
except Exception as e:
logger.error(f"Error getting pipeline data for {pipeline_id}: {e}")
return JsonResponse(
{"status": "error", "message": INTERNAL_SERVER_ERROR_MSG}, status=500
)
@csrf_exempt # Safe: Internal API with Bearer token auth, service-to-service only
@require_http_methods(["GET"])
def get_api_data(request, api_id):
"""Get basic API deployment data for notification purposes.
Used by callback worker to determine API name and details.
"""
try:
# Get API deployment with organization filtering
api_queryset = APIDeployment.objects.filter(id=api_id)
api_queryset = filter_queryset_by_organization(
api_queryset, request, "organization"
)
api = get_object_or_404(api_queryset)
return JsonResponse(
{
"status": "success",
"api_id": str(api.id),
"api_name": api.api_name,
"display_name": api.display_name,
"is_active": api.is_active,
}
)
except APIDeployment.DoesNotExist:
return JsonResponse(
{"status": "error", "message": "API deployment not found"}, status=404
)
except Exception as e:
logger.error(f"Error getting API data for {api_id}: {e}")
return JsonResponse(
{"status": "error", "message": INTERNAL_SERVER_ERROR_MSG}, status=500
)

View File

@@ -0,0 +1,128 @@
"""Internal API Serializers for Notification/Webhook Operations
Used by Celery workers for service-to-service communication.
"""
from rest_framework import serializers
from notification_v2.enums import AuthorizationType, NotificationType, PlatformType
from notification_v2.models import Notification
class NotificationSerializer(serializers.ModelSerializer):
"""Serializer for Notification model."""
class Meta:
model = Notification
fields = [
"id",
"url",
"authorization_type",
"authorization_key",
"authorization_header",
"notification_type",
"platform",
"max_retries",
"is_active",
"created_at",
"modified_at",
"pipeline",
"api",
]
class WebhookNotificationRequestSerializer(serializers.Serializer):
"""Serializer for webhook notification requests."""
notification_id = serializers.UUIDField(required=False)
url = serializers.URLField(required=True)
payload = serializers.JSONField(required=True)
authorization_type = serializers.ChoiceField(
choices=AuthorizationType.choices(), default=AuthorizationType.NONE.value
)
authorization_key = serializers.CharField(required=False, allow_blank=True)
authorization_header = serializers.CharField(required=False, allow_blank=True)
headers = serializers.DictField(required=False, default=dict)
timeout = serializers.IntegerField(default=30, min_value=1, max_value=300)
max_retries = serializers.IntegerField(default=3, min_value=0, max_value=10)
retry_delay = serializers.IntegerField(default=60, min_value=1, max_value=3600)
class WebhookNotificationResponseSerializer(serializers.Serializer):
"""Serializer for webhook notification responses."""
task_id = serializers.CharField()
notification_id = serializers.UUIDField(required=False)
url = serializers.URLField()
status = serializers.CharField()
queued_at = serializers.DateTimeField()
class WebhookStatusSerializer(serializers.Serializer):
"""Serializer for webhook delivery status."""
task_id = serializers.CharField()
status = serializers.CharField()
url = serializers.CharField()
attempts = serializers.IntegerField()
success = serializers.BooleanField()
error_message = serializers.CharField(required=False, allow_null=True)
class WebhookBatchRequestSerializer(serializers.Serializer):
"""Serializer for batch webhook requests."""
batch_name = serializers.CharField(required=False, max_length=255)
webhooks = serializers.ListField(
child=WebhookNotificationRequestSerializer(), min_length=1, max_length=100
)
delay_between_requests = serializers.IntegerField(
default=0, min_value=0, max_value=60
)
class WebhookBatchResponseSerializer(serializers.Serializer):
"""Serializer for batch webhook responses."""
batch_id = serializers.CharField()
batch_name = serializers.CharField()
total_webhooks = serializers.IntegerField()
queued_webhooks = serializers.ListField(child=WebhookNotificationResponseSerializer())
failed_webhooks = serializers.ListField(child=serializers.DictField())
class WebhookConfigurationSerializer(serializers.Serializer):
"""Serializer for webhook configuration."""
notification_id = serializers.UUIDField()
url = serializers.URLField()
authorization_type = serializers.ChoiceField(choices=AuthorizationType.choices())
authorization_key = serializers.CharField(required=False, allow_blank=True)
authorization_header = serializers.CharField(required=False, allow_blank=True)
max_retries = serializers.IntegerField()
is_active = serializers.BooleanField()
class NotificationListSerializer(serializers.Serializer):
"""Serializer for notification list filters."""
pipeline_id = serializers.UUIDField(required=False)
api_deployment_id = serializers.UUIDField(required=False)
notification_type = serializers.ChoiceField(
choices=NotificationType.choices(), required=False
)
platform = serializers.ChoiceField(choices=PlatformType.choices(), required=False)
is_active = serializers.BooleanField(required=False)
class WebhookTestSerializer(serializers.Serializer):
"""Serializer for webhook testing."""
url = serializers.URLField(required=True)
payload = serializers.JSONField(required=True)
authorization_type = serializers.ChoiceField(
choices=AuthorizationType.choices(), default=AuthorizationType.NONE.value
)
authorization_key = serializers.CharField(required=False, allow_blank=True)
authorization_header = serializers.CharField(required=False, allow_blank=True)
headers = serializers.DictField(required=False, default=dict)
timeout = serializers.IntegerField(default=30, min_value=1, max_value=300)

View File

@@ -0,0 +1,56 @@
"""Internal API URLs for Notification/Webhook Operations
URL patterns for webhook notification internal APIs.
"""
from django.urls import include, path
from rest_framework.routers import DefaultRouter
from . import internal_api_views
from .internal_views import (
WebhookBatchAPIView,
WebhookBatchStatusAPIView,
WebhookInternalViewSet,
WebhookMetricsAPIView,
WebhookSendAPIView,
WebhookStatusAPIView,
WebhookTestAPIView,
)
# Create router for webhook viewsets
router = DefaultRouter()
router.register(r"", WebhookInternalViewSet, basename="webhook-internal")
urlpatterns = [
# Notification data endpoints for workers
path(
"pipeline/<str:pipeline_id>/notifications/",
internal_api_views.get_pipeline_notifications,
name="get_pipeline_notifications",
),
path(
"pipeline/<str:pipeline_id>/",
internal_api_views.get_pipeline_data,
name="get_pipeline_data",
),
path(
"api/<str:api_id>/notifications/",
internal_api_views.get_api_notifications,
name="get_api_notifications",
),
path(
"api/<str:api_id>/",
internal_api_views.get_api_data,
name="get_api_data",
),
# Webhook operation endpoints
path("send/", WebhookSendAPIView.as_view(), name="webhook-send"),
path("batch/", WebhookBatchAPIView.as_view(), name="webhook-batch"),
path("test/", WebhookTestAPIView.as_view(), name="webhook-test"),
path("status/<str:task_id>/", WebhookStatusAPIView.as_view(), name="webhook-status"),
path(
"batch-status/", WebhookBatchStatusAPIView.as_view(), name="webhook-batch-status"
),
path("metrics/", WebhookMetricsAPIView.as_view(), name="webhook-metrics"),
# Webhook configuration CRUD (via router)
path("", include(router.urls)),
]

View File

@@ -0,0 +1,559 @@
"""Internal API Views for Webhook Operations
Handles webhook notification related endpoints for internal services.
"""
import logging
import uuid
from typing import Any
from celery import current_app as celery_app
from celery.result import AsyncResult
from django.utils import timezone
from rest_framework import status, viewsets
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.views import APIView
from utils.organization_utils import filter_queryset_by_organization
from notification_v2.enums import AuthorizationType, NotificationType, PlatformType
# Import serializers from notification_v2 internal API
from notification_v2.internal_serializers import (
NotificationListSerializer,
NotificationSerializer,
WebhookBatchRequestSerializer,
WebhookBatchResponseSerializer,
WebhookConfigurationSerializer,
WebhookNotificationRequestSerializer,
WebhookNotificationResponseSerializer,
WebhookStatusSerializer,
WebhookTestSerializer,
)
from notification_v2.models import Notification
from notification_v2.provider.webhook.webhook import send_webhook_notification
logger = logging.getLogger(__name__)
# Constants
APPLICATION_JSON = "application/json"
class WebhookInternalViewSet(viewsets.ReadOnlyModelViewSet):
"""Internal API ViewSet for Webhook/Notification operations."""
serializer_class = NotificationSerializer
lookup_field = "id"
def get_queryset(self):
"""Get notifications filtered by organization context."""
queryset = Notification.objects.all()
return filter_queryset_by_organization(queryset, self.request)
def list(self, request, *args, **kwargs):
"""List notifications with filtering options."""
try:
serializer = NotificationListSerializer(data=request.query_params)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
filters = serializer.validated_data
queryset = self.get_queryset()
# Apply filters
if filters.get("pipeline_id"):
queryset = queryset.filter(pipeline_id=filters["pipeline_id"])
if filters.get("api_deployment_id"):
queryset = queryset.filter(api_id=filters["api_deployment_id"])
if filters.get("notification_type"):
queryset = queryset.filter(notification_type=filters["notification_type"])
if filters.get("platform"):
queryset = queryset.filter(platform=filters["platform"])
if filters.get("is_active") is not None:
queryset = queryset.filter(is_active=filters["is_active"])
notifications = NotificationSerializer(queryset, many=True).data
return Response({"count": len(notifications), "notifications": notifications})
except Exception as e:
logger.error(f"Failed to list notifications: {str(e)}")
return Response(
{"error": "Failed to list notifications", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@action(detail=True, methods=["get"])
def configuration(self, request, id=None):
"""Get webhook configuration for a notification."""
try:
notification = self.get_object()
config_data = {
"notification_id": notification.id,
"url": notification.url,
"authorization_type": notification.authorization_type,
"authorization_key": notification.authorization_key,
"authorization_header": notification.authorization_header,
"max_retries": notification.max_retries,
"is_active": notification.is_active,
}
serializer = WebhookConfigurationSerializer(config_data)
return Response(serializer.data)
except Exception as e:
logger.error(f"Failed to get webhook configuration {id}: {str(e)}")
return Response(
{"error": "Failed to get webhook configuration", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
class WebhookSendAPIView(APIView):
"""Internal API endpoint for sending webhook notifications."""
def post(self, request):
"""Send a webhook notification."""
try:
serializer = WebhookNotificationRequestSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
validated_data = serializer.validated_data
# Build headers based on authorization type
headers = self._build_headers(validated_data)
# Send webhook notification task
task = send_webhook_notification.delay(
url=validated_data["url"],
payload=validated_data["payload"],
headers=headers,
timeout=validated_data["timeout"],
max_retries=validated_data["max_retries"],
retry_delay=validated_data["retry_delay"],
)
# Prepare response
response_data = {
"task_id": task.id,
"notification_id": validated_data.get("notification_id"),
"url": validated_data["url"],
"status": "queued",
"queued_at": timezone.now(),
}
response_serializer = WebhookNotificationResponseSerializer(response_data)
logger.info(
f"Queued webhook notification task {task.id} for URL {validated_data['url']}"
)
return Response(response_serializer.data, status=status.HTTP_202_ACCEPTED)
except Exception as e:
logger.error(f"Failed to send webhook notification: {str(e)}")
return Response(
{"error": "Failed to send webhook notification", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
def _build_headers(self, validated_data: dict[str, Any]) -> dict[str, str]:
"""Build headers based on authorization configuration."""
headers = {"Content-Type": APPLICATION_JSON}
auth_type = validated_data.get("authorization_type", AuthorizationType.NONE.value)
auth_key = validated_data.get("authorization_key")
auth_header = validated_data.get("authorization_header")
if validated_data.get("headers"):
headers.update(validated_data["headers"])
if auth_type == AuthorizationType.BEARER.value and auth_key:
headers["Authorization"] = f"Bearer {auth_key}"
elif auth_type == AuthorizationType.API_KEY.value and auth_key:
headers["Authorization"] = auth_key
elif (
auth_type == AuthorizationType.CUSTOM_HEADER.value
and auth_header
and auth_key
):
headers[auth_header] = auth_key
return headers
class WebhookStatusAPIView(APIView):
"""Internal API endpoint for checking webhook delivery status."""
def get(self, request, task_id):
"""Get webhook delivery status by task ID."""
try:
task_result = AsyncResult(task_id, app=celery_app)
status_data = {
"task_id": task_id,
"status": task_result.status,
"url": "unknown",
"attempts": 0,
"success": task_result.successful(),
"error_message": None,
}
if task_result.failed():
status_data["error_message"] = str(task_result.result)
elif task_result.successful():
status_data["attempts"] = getattr(task_result.result, "attempts", 1)
serializer = WebhookStatusSerializer(status_data)
return Response(serializer.data)
except Exception as e:
logger.error(f"Failed to get webhook status for task {task_id}: {str(e)}")
return Response(
{"error": "Failed to get webhook status", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
class WebhookBatchAPIView(APIView):
"""Internal API endpoint for sending batch webhook notifications."""
def post(self, request):
"""Send multiple webhook notifications in batch."""
try:
serializer = WebhookBatchRequestSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
validated_data = serializer.validated_data
webhooks = validated_data["webhooks"]
delay_between = validated_data.get("delay_between_requests", 0)
batch_id = str(uuid.uuid4())
queued_webhooks = []
failed_webhooks = []
for i, webhook_data in enumerate(webhooks):
try:
headers = self._build_headers(webhook_data)
countdown = i * delay_between if delay_between > 0 else 0
task = send_webhook_notification.apply_async(
args=[
webhook_data["url"],
webhook_data["payload"],
headers,
webhook_data["timeout"],
],
kwargs={
"max_retries": webhook_data["max_retries"],
"retry_delay": webhook_data["retry_delay"],
},
countdown=countdown,
)
queued_webhooks.append(
{
"task_id": task.id,
"notification_id": webhook_data.get("notification_id"),
"url": webhook_data["url"],
"status": "queued",
"queued_at": timezone.now(),
}
)
except Exception as e:
failed_webhooks.append({"url": webhook_data["url"], "error": str(e)})
response_data = {
"batch_id": batch_id,
"batch_name": validated_data.get("batch_name", f"Batch-{batch_id[:8]}"),
"total_webhooks": len(webhooks),
"queued_webhooks": queued_webhooks,
"failed_webhooks": failed_webhooks,
}
response_serializer = WebhookBatchResponseSerializer(response_data)
logger.info(
f"Queued batch {batch_id} with {len(queued_webhooks)} webhooks, {len(failed_webhooks)} failed"
)
return Response(response_serializer.data, status=status.HTTP_202_ACCEPTED)
except Exception as e:
logger.error(f"Failed to send webhook batch: {str(e)}")
return Response(
{"error": "Failed to send webhook batch", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
def _build_headers(self, webhook_data: dict[str, Any]) -> dict[str, str]:
"""Build headers for webhook request."""
headers = {"Content-Type": APPLICATION_JSON}
auth_type = webhook_data.get("authorization_type", AuthorizationType.NONE.value)
auth_key = webhook_data.get("authorization_key")
auth_header = webhook_data.get("authorization_header")
if webhook_data.get("headers"):
headers.update(webhook_data["headers"])
if auth_type == AuthorizationType.BEARER.value and auth_key:
headers["Authorization"] = f"Bearer {auth_key}"
elif auth_type == AuthorizationType.API_KEY.value and auth_key:
headers["Authorization"] = auth_key
elif (
auth_type == AuthorizationType.CUSTOM_HEADER.value
and auth_header
and auth_key
):
headers[auth_header] = auth_key
return headers
class WebhookTestAPIView(APIView):
"""Internal API endpoint for testing webhook configurations."""
def post(self, request):
"""Test a webhook configuration without queuing."""
try:
serializer = WebhookTestSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
validated_data = serializer.validated_data
headers = self._build_headers(validated_data)
import requests
try:
response = requests.post(
url=validated_data["url"],
json=validated_data["payload"],
headers=headers,
timeout=validated_data["timeout"],
)
test_result = {
"success": response.status_code < 400,
"status_code": response.status_code,
"response_headers": dict(response.headers),
"response_body": response.text[:1000],
"url": validated_data["url"],
"request_headers": headers,
"request_payload": validated_data["payload"],
}
logger.info(
f"Webhook test to {validated_data['url']} completed with status {response.status_code}"
)
return Response(test_result)
except requests.exceptions.RequestException as e:
test_result = {
"success": False,
"error": str(e),
"url": validated_data["url"],
"request_headers": headers,
"request_payload": validated_data["payload"],
}
return Response(test_result, status=status.HTTP_400_BAD_REQUEST)
except Exception as e:
logger.error(f"Failed to test webhook: {str(e)}")
return Response(
{"error": "Failed to test webhook", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
def _build_headers(self, validated_data: dict[str, Any]) -> dict[str, str]:
"""Build headers for webhook test."""
headers = {"Content-Type": APPLICATION_JSON}
auth_type = validated_data.get("authorization_type", AuthorizationType.NONE.value)
auth_key = validated_data.get("authorization_key")
auth_header = validated_data.get("authorization_header")
if validated_data.get("headers"):
headers.update(validated_data["headers"])
if auth_type == AuthorizationType.BEARER.value and auth_key:
headers["Authorization"] = f"Bearer {auth_key}"
elif auth_type == AuthorizationType.API_KEY.value and auth_key:
headers["Authorization"] = auth_key
elif (
auth_type == AuthorizationType.CUSTOM_HEADER.value
and auth_header
and auth_key
):
headers[auth_header] = auth_key
return headers
class WebhookBatchStatusAPIView(APIView):
"""Internal API endpoint for checking batch webhook delivery status."""
def get(self, request):
"""Get batch webhook delivery status."""
try:
batch_id = request.query_params.get("batch_id")
task_ids = request.query_params.get("task_ids", "").split(",")
if not batch_id and not task_ids:
return Response(
{"error": "Either batch_id or task_ids parameter is required"},
status=status.HTTP_400_BAD_REQUEST,
)
batch_results = []
if task_ids and task_ids[0]: # task_ids is not empty
for task_id in task_ids:
if task_id.strip():
try:
task_result = AsyncResult(task_id.strip(), app=celery_app)
batch_results.append(
{
"task_id": task_id.strip(),
"status": task_result.status,
"success": task_result.successful(),
"error_message": str(task_result.result)
if task_result.failed()
else None,
}
)
except Exception as e:
batch_results.append(
{
"task_id": task_id.strip(),
"status": "ERROR",
"success": False,
"error_message": f"Failed to get task status: {str(e)}",
}
)
response_data = {
"batch_id": batch_id,
"total_tasks": len(batch_results),
"results": batch_results,
"summary": {
"completed": sum(
1 for r in batch_results if r["status"] == "SUCCESS"
),
"failed": sum(1 for r in batch_results if r["status"] == "FAILURE"),
"pending": sum(1 for r in batch_results if r["status"] == "PENDING"),
"running": sum(1 for r in batch_results if r["status"] == "STARTED"),
},
}
return Response(response_data)
except Exception as e:
logger.error(f"Failed to get batch webhook status: {str(e)}")
return Response(
{"error": "Failed to get batch webhook status", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
class WebhookMetricsAPIView(APIView):
"""Internal API endpoint for webhook delivery metrics."""
def get(self, request):
"""Get webhook delivery metrics."""
try:
# Get query parameters
organization_id = request.query_params.get("organization_id")
start_date = request.query_params.get("start_date")
end_date = request.query_params.get("end_date")
# Get base queryset
queryset = Notification.objects.all()
queryset = filter_queryset_by_organization(queryset, request)
# Apply filters
if organization_id:
queryset = queryset.filter(organization_id=organization_id)
if start_date:
from datetime import datetime
try:
start_dt = datetime.fromisoformat(start_date.replace("Z", "+00:00"))
queryset = queryset.filter(created_at__gte=start_dt)
except ValueError:
return Response(
{"error": "Invalid start_date format. Use ISO format."},
status=status.HTTP_400_BAD_REQUEST,
)
if end_date:
from datetime import datetime
try:
end_dt = datetime.fromisoformat(end_date.replace("Z", "+00:00"))
queryset = queryset.filter(created_at__lte=end_dt)
except ValueError:
return Response(
{"error": "Invalid end_date format. Use ISO format."},
status=status.HTTP_400_BAD_REQUEST,
)
# Calculate metrics
total_webhooks = queryset.count()
active_webhooks = queryset.filter(is_active=True).count()
inactive_webhooks = queryset.filter(is_active=False).count()
# Group by notification type
type_breakdown = {}
for notification_type in NotificationType:
count = queryset.filter(notification_type=notification_type.value).count()
if count > 0:
type_breakdown[notification_type.value] = count
# Group by platform
platform_breakdown = {}
for platform_type in PlatformType:
count = queryset.filter(platform=platform_type.value).count()
if count > 0:
platform_breakdown[platform_type.value] = count
# Group by authorization type
auth_breakdown = {}
for auth_type in AuthorizationType:
count = queryset.filter(authorization_type=auth_type.value).count()
if count > 0:
auth_breakdown[auth_type.value] = count
metrics = {
"total_webhooks": total_webhooks,
"active_webhooks": active_webhooks,
"inactive_webhooks": inactive_webhooks,
"type_breakdown": type_breakdown,
"platform_breakdown": platform_breakdown,
"authorization_breakdown": auth_breakdown,
"filters_applied": {
"organization_id": organization_id,
"start_date": start_date,
"end_date": end_date,
},
}
return Response(metrics)
except Exception as e:
logger.error(f"Failed to get webhook metrics: {str(e)}")
return Response(
{"error": "Failed to get webhook metrics", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

View File

@@ -0,0 +1,167 @@
import logging
from api_v2.models import APIDeployment
from rest_framework.response import Response
from rest_framework.viewsets import ViewSet
from utils.organization_utils import filter_queryset_by_organization
from pipeline_v2.models import Pipeline
from .serializers.internal import APIDeploymentSerializer, PipelineSerializer
logger = logging.getLogger(__name__)
class PipelineInternalViewSet(ViewSet):
def retrieve(self, request, pk=None):
logger.info(f"[PipelineInternalViewSet] Retrieving data for ID: {pk}")
try:
# 1⃣ Try in Pipeline
pipeline_data = self._fetch_single_record(
pk,
request,
Pipeline.objects.filter(id=pk),
PipelineSerializer,
"Pipeline",
)
if isinstance(pipeline_data, dict): # Found successfully
return Response({"status": "success", "pipeline": pipeline_data})
elif isinstance(pipeline_data, Response): # Integrity error
return pipeline_data
# 2⃣ Try in APIDeployment
api_data = self._fetch_single_record(
pk,
request,
APIDeployment.objects.filter(id=pk),
APIDeploymentSerializer,
"APIDeployment",
)
if isinstance(api_data, dict):
return Response({"status": "success", "pipeline": api_data})
elif isinstance(api_data, Response):
return api_data
# 3⃣ Not found anywhere
logger.warning(f"⚠️ No Pipeline or APIDeployment found for {pk}")
return Response(
{"status": "error", "message": "Pipeline not found"}, status=404
)
except Exception:
logger.exception(f"💥 Error retrieving pipeline or deployment for {pk}")
return Response(
{"status": "error", "message": "Internal server error"}, status=500
)
# Helper function for DRY logic
def _fetch_single_record(self, pk, request, qs, serializer_cls, model_name):
qs = filter_queryset_by_organization(qs, request, "organization")
count = qs.count()
if count == 1:
obj = qs.first()
logger.info(f"✅ Found {model_name} entry: {obj}")
return serializer_cls(obj).data
elif count > 1:
logger.error(f"❌ Multiple {model_name} entries found for {pk}")
return Response(
{
"status": "error",
"message": f"Data integrity error: multiple {model_name} entries found",
},
status=500,
)
return None # Not found in this model
def update(self, request, pk=None):
"""Update pipeline status with support for completion states."""
try:
new_status = request.data.get("status")
if not new_status:
return Response(
{"status": "error", "message": "Status is required"}, status=400
)
# Extract additional parameters for completion states
is_end = request.data.get("is_end", False)
# Import here to avoid circular imports
from pipeline_v2.pipeline_processor import PipelineProcessor
# Try to update pipeline first
try:
pipeline_qs = Pipeline.objects.filter(id=pk)
pipeline_qs = filter_queryset_by_organization(
pipeline_qs, request, "organization"
)
pipeline = pipeline_qs.first()
if pipeline:
# Use PipelineProcessor.update_pipeline() without execution_id and error_message
# This will update status but skip notifications (since execution_id=None)
PipelineProcessor.update_pipeline(
pipeline_guid=pk,
status=new_status,
is_end=is_end,
)
return Response(
{
"status": "success",
"pipeline_id": pk,
"new_status": new_status,
"is_end": is_end,
"message": "Pipeline status updated successfully",
}
)
except Exception as e:
logger.error(f"Error updating pipeline status: {e}")
return Response(
{"status": "error", "message": f"Failed to update pipeline: {e}"},
status=500,
)
# Try API deployment if pipeline not found
try:
api_qs = APIDeployment.objects.filter(id=pk)
api_qs = filter_queryset_by_organization(api_qs, request, "organization")
api_deployment = api_qs.first()
if api_deployment:
# For API deployments, log the status update
logger.info(f"Updated API deployment {pk} status to {new_status}")
return Response(
{
"status": "success",
"pipeline_id": pk,
"new_status": new_status,
"message": "API deployment status updated successfully",
}
)
except Exception as e:
logger.error(f"Error updating API deployment status: {e}")
return Response(
{
"status": "error",
"message": f"Failed to update API deployment: {e}",
},
status=500,
)
# Not found in either model
return Response(
{"status": "error", "message": "Pipeline or API deployment not found"},
status=404,
)
except Exception as e:
logger.error(f"Error updating pipeline/API deployment status for {pk}: {e}")
return Response(
{"status": "error", "message": "Internal server error"}, status=500
)

View File

@@ -0,0 +1,17 @@
"""Internal API URLs for Pipeline Operations"""
from django.urls import include, path
from rest_framework.routers import DefaultRouter
from .internal_api_views import (
PipelineInternalViewSet,
)
# Create router for pipeline viewsets
router = DefaultRouter()
router.register(r"", PipelineInternalViewSet, basename="pipeline-internal")
urlpatterns = [
# Pipeline internal APIs
path("", include(router.urls)),
]

View File

@@ -113,7 +113,12 @@ class PipelineProcessor:
pipeline = PipelineProcessor._update_pipeline_status(
pipeline=pipeline, is_end=is_end, status=status, is_active=is_active
)
PipelineProcessor._send_notification(
pipeline=pipeline, execution_id=execution_id, error_message=error_message
)
# Only send notifications if execution_id is provided
# This avoids duplicate notifications when called from workers (who handle notifications separately)
if execution_id:
PipelineProcessor._send_notification(
pipeline=pipeline, execution_id=execution_id, error_message=error_message
)
logger.info(f"Updated pipeline {pipeline_guid} status: {status}")

View File

@@ -0,0 +1,59 @@
from api_v2.models import APIDeployment
from pipeline_v2.models import Pipeline
from rest_framework import serializers
class PipelineSerializer(serializers.ModelSerializer):
# Add computed fields for callback worker
is_api = serializers.SerializerMethodField()
resolved_pipeline_type = serializers.SerializerMethodField()
resolved_pipeline_name = serializers.SerializerMethodField()
pipeline_name = serializers.SerializerMethodField()
class Meta:
model = Pipeline
fields = "__all__"
def get_is_api(self, obj):
"""Returns False for Pipeline model entries."""
return False
def get_resolved_pipeline_type(self, obj):
"""Returns the pipeline type from the Pipeline model."""
return obj.pipeline_type
def get_resolved_pipeline_name(self, obj):
"""Returns the pipeline name from the Pipeline model."""
return obj.pipeline_name
def get_pipeline_name(self, obj):
"""Returns the pipeline name for callback worker compatibility."""
return obj.pipeline_name
class APIDeploymentSerializer(serializers.ModelSerializer):
# Add computed fields for callback worker
is_api = serializers.SerializerMethodField()
resolved_pipeline_type = serializers.SerializerMethodField()
resolved_pipeline_name = serializers.SerializerMethodField()
pipeline_name = serializers.SerializerMethodField()
class Meta:
model = APIDeployment
fields = "__all__"
def get_is_api(self, obj):
"""Returns True for APIDeployment model entries."""
return True
def get_resolved_pipeline_type(self, obj):
"""Returns 'API' for APIDeployment model entries."""
return "API"
def get_resolved_pipeline_name(self, obj):
"""Returns the api_name from the APIDeployment model."""
return obj.api_name
def get_pipeline_name(self, obj):
"""Returns the api_name for callback worker compatibility."""
return obj.api_name

View File

@@ -0,0 +1,18 @@
"""Internal URLs for platform settings
Routes for internal API endpoints used by workers.
"""
from django.urls import path
from .internal_views import InternalPlatformKeyView
app_name = "platform_settings_internal"
urlpatterns = [
path(
"platform-key/",
InternalPlatformKeyView.as_view(),
name="platform_key",
),
]

View File

@@ -0,0 +1,76 @@
"""Internal API views for platform settings
Provides internal endpoints for workers to access platform settings
without direct database access.
"""
import logging
from account_v2.models import PlatformKey
from account_v2.organization import OrganizationService
from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import APIView
from platform_settings_v2.platform_auth_service import PlatformAuthenticationService
logger = logging.getLogger(__name__)
class InternalPlatformKeyView(APIView):
"""Internal API to get active platform key for an organization."""
def get(self, request):
"""Get active platform key for organization.
Uses X-Organization-ID header to identify the organization.
Args:
request: HTTP request with X-Organization-ID header
Returns:
Response with platform key
"""
try:
# Get organization ID from header
org_id = request.headers.get("X-Organization-ID")
if not org_id:
return Response(
{"error": "X-Organization-ID header is required"},
status=status.HTTP_400_BAD_REQUEST,
)
# Get organization
organization = OrganizationService.get_organization_by_org_id(org_id=org_id)
if not organization:
return Response(
{"error": f"Organization {org_id} not found"},
status=status.HTTP_404_NOT_FOUND,
)
# Get active platform key
platform_key = PlatformAuthenticationService.get_active_platform_key(
organization_id=org_id
)
return Response(
{
"platform_key": str(platform_key.key),
"key_name": platform_key.key_name,
"organization_id": org_id,
},
status=status.HTTP_200_OK,
)
except PlatformKey.DoesNotExist:
return Response(
{"error": f"No active platform key found for organization {org_id}"},
status=status.HTTP_404_NOT_FOUND,
)
except Exception as e:
logger.error(f"Error getting platform key for org {org_id}: {str(e)}")
return Response(
{"error": "Internal server error"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

View File

@@ -20,8 +20,19 @@ class ModifierConfig:
METADATA_IS_ACTIVE = "is_active"
# Cache for loaded plugins to avoid repeated loading
_modifier_plugins_cache: list[Any] = []
_plugins_loaded = False
def load_plugins() -> list[Any]:
"""Iterate through the extraction plugins and register them."""
global _modifier_plugins_cache, _plugins_loaded
# Return cached plugins if already loaded
if _plugins_loaded:
return _modifier_plugins_cache
plugins_app = apps.get_app_config(ModifierConfig.PLUGINS_APP)
package_path = plugins_app.module.__package__
modifier_dir = os.path.join(plugins_app.path, ModifierConfig.PLUGIN_DIR)
@@ -29,6 +40,8 @@ def load_plugins() -> list[Any]:
modifier_plugins: list[Any] = []
if not os.path.exists(modifier_dir):
_modifier_plugins_cache = modifier_plugins
_plugins_loaded = True
return modifier_plugins
for item in os.listdir(modifier_dir):
@@ -69,4 +82,8 @@ def load_plugins() -> list[Any]:
if len(modifier_plugins) == 0:
logger.info("No modifier plugins found.")
# Cache the results for future requests
_modifier_plugins_cache = modifier_plugins
_plugins_loaded = True
return modifier_plugins

View File

@@ -20,14 +20,29 @@ class ProcessorConfig:
METADATA_IS_ACTIVE = "is_active"
# Cache for loaded plugins to avoid repeated loading
_processor_plugins_cache: list[Any] = []
_plugins_loaded = False
def load_plugins() -> list[Any]:
"""Iterate through the processor plugins and register them."""
global _processor_plugins_cache, _plugins_loaded
# Return cached plugins if already loaded
if _plugins_loaded:
return _processor_plugins_cache
plugins_app = apps.get_app_config(ProcessorConfig.PLUGINS_APP)
package_path = plugins_app.module.__package__
processor_dir = os.path.join(plugins_app.path, ProcessorConfig.PLUGIN_DIR)
processor_package_path = f"{package_path}.{ProcessorConfig.PLUGIN_DIR}"
processor_plugins: list[Any] = []
if not os.path.exists(processor_dir):
logger.info("No processor directory found at %s.", processor_dir)
return []
for item in os.listdir(processor_dir):
# Loads a plugin if it is in a directory.
if os.path.isdir(os.path.join(processor_dir, item)):
@@ -71,6 +86,10 @@ def load_plugins() -> list[Any]:
if len(processor_plugins) == 0:
logger.info("No processor plugins found.")
# Cache the results for future requests
_processor_plugins_cache = processor_plugins
_plugins_loaded = True
return processor_plugins

View File

@@ -72,7 +72,10 @@ dev = [
# For file watching
"inotify>=0.2.10",
"poethepoet>=0.33.1",
"debugpy>=1.8.14"
"debugpy>=1.8.14",
"pytest>=8.3.5",
"responses>=0.25.7",
"psutil>=7.0.0",
]
test = ["pytest>=8.0.1", "pytest-dotenv==0.5.2"]
deploy = [

View File

@@ -65,7 +65,7 @@ PLATFORM_SERVICE_PORT=3001
# Tool Runner
UNSTRACT_RUNNER_HOST=http://unstract-runner
UNSTRACT_RUNNER_PORT=5002
UNSTRACT_RUNNER_API_TIMEOUT=120 # (in seconds) 2 mins
UNSTRACT_RUNNER_API_TIMEOUT=240 # (in seconds) 2 mins
UNSTRACT_RUNNER_API_RETRY_COUNT=5 # Number of retries for failed requests
UNSTRACT_RUNNER_API_BACKOFF_FACTOR=3 # Exponential backoff factor for retries

View File

@@ -0,0 +1,16 @@
"""Internal API URLs for tool instance operations."""
from django.urls import path
from .internal_views import tool_by_id_internal, validate_tool_instances_internal
urlpatterns = [
# Tool by ID endpoint - critical for worker functionality
path("tool/<str:tool_id>/", tool_by_id_internal, name="tool-by-id-internal"),
# Tool instance validation endpoint - used by workers before execution
path(
"validate/",
validate_tool_instances_internal,
name="validate-tool-instances-internal",
),
]

View File

@@ -0,0 +1,403 @@
"""Internal API Views for Tool Instance Operations
This module contains internal API endpoints used by workers for tool execution.
"""
import logging
from django.views.decorators.csrf import csrf_exempt
from rest_framework import status, viewsets
from rest_framework.decorators import api_view
from rest_framework.response import Response
from utils.organization_utils import filter_queryset_by_organization
from tool_instance_v2.models import ToolInstance
from tool_instance_v2.serializers import ToolInstanceSerializer
from tool_instance_v2.tool_instance_helper import ToolInstanceHelper
from tool_instance_v2.tool_processor import ToolProcessor
logger = logging.getLogger(__name__)
class ToolExecutionInternalViewSet(viewsets.ModelViewSet):
"""Internal API for tool execution operations used by lightweight workers."""
serializer_class = ToolInstanceSerializer
def get_queryset(self):
# Filter by organization context set by internal API middleware
# Use relationship path: ToolInstance -> Workflow -> Organization
queryset = ToolInstance.objects.all()
return filter_queryset_by_organization(
queryset, self.request, "workflow__organization"
)
def execute_tool(self, request, pk=None):
"""Execute a specific tool with provided input data.
This replaces the direct tool execution that was previously done
in the heavy Django workers.
"""
try:
tool_instance = self.get_object()
# Extract execution parameters from request
input_data = request.data.get("input_data", {})
file_data = request.data.get("file_data", {})
execution_context = request.data.get("execution_context", {})
# Execute tool using existing tool processor
execution_result = ToolProcessor.execute_tool(
tool_instance=tool_instance,
input_data=input_data,
file_data=file_data,
context=execution_context,
user=request.user,
)
return Response(
{
"status": "success",
"tool_instance_id": str(tool_instance.id),
"execution_result": execution_result,
"tool_function": tool_instance.tool_function,
"step": tool_instance.step,
},
status=status.HTTP_200_OK,
)
except Exception as e:
logger.error(f"Tool execution failed for tool {pk}: {e}")
return Response(
{
"status": "error",
"error_message": str(e),
"tool_instance_id": str(pk) if pk else None,
},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
@api_view(["GET"])
def tool_execution_status_internal(request, execution_id):
"""Get tool execution status for internal API calls."""
try:
# This would track tool execution status
# For now, return a basic status structure
return Response(
{
"execution_id": execution_id,
"status": "completed", # Could be: pending, running, completed, failed
"progress": 100,
"results": [],
"error_message": None,
},
status=status.HTTP_200_OK,
)
except Exception as e:
logger.error(f"Failed to get tool execution status for {execution_id}: {e}")
return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
@api_view(["GET"])
def tool_by_id_internal(request, tool_id):
"""Get tool information by tool ID for internal API calls."""
try:
logger.info(f"Getting tool information for tool ID: {tool_id}")
# Get tool from registry using ToolProcessor
try:
tool = ToolProcessor.get_tool_by_uid(tool_id)
logger.info(f"Successfully retrieved tool from ToolProcessor: {tool_id}")
except Exception as tool_fetch_error:
logger.error(
f"Failed to fetch tool {tool_id} from ToolProcessor: {tool_fetch_error}"
)
# Return fallback using Structure Tool image (which actually exists)
from django.conf import settings
return Response(
{
"tool": {
"tool_id": tool_id,
"properties": {
"displayName": f"Missing Tool ({tool_id[:8]}...)",
"functionName": tool_id,
"description": "Tool not found in registry or Prompt Studio",
"toolVersion": "unknown",
},
"image_name": settings.STRUCTURE_TOOL_IMAGE_NAME,
"image_tag": settings.STRUCTURE_TOOL_IMAGE_TAG,
"name": f"Missing Tool ({tool_id[:8]}...)",
"description": "Tool not found in registry or Prompt Studio",
"version": "unknown",
"note": "Fallback data for missing tool",
}
},
status=status.HTTP_200_OK,
)
# Convert Properties object to dict for JSON serialization
properties_dict = {}
try:
if hasattr(tool.properties, "to_dict"):
# Use the to_dict method if available (which handles Adapter serialization)
properties_dict = tool.properties.to_dict()
logger.info(f"Properties serialized using to_dict() for tool {tool_id}")
elif hasattr(tool.properties, "dict"):
properties_dict = tool.properties.dict()
logger.info(f"Properties serialized using dict() for tool {tool_id}")
elif hasattr(tool.properties, "__dict__"):
properties_dict = tool.properties.__dict__
logger.info(f"Properties serialized using __dict__ for tool {tool_id}")
else:
# Try to convert to dict if it's iterable
try:
properties_dict = dict(tool.properties)
logger.info(
f"Properties serialized using dict conversion for tool {tool_id}"
)
except (TypeError, ValueError):
properties_dict = {"default": "true"} # Fallback
logger.warning(f"Using fallback properties for tool {tool_id}")
except Exception as props_error:
logger.error(
f"Failed to serialize properties for tool {tool_id}: {props_error}"
)
properties_dict = {"error": "serialization_failed"}
# Handle spec serialization if needed
if hasattr(tool, "spec") and tool.spec:
if hasattr(tool.spec, "to_dict"):
tool.spec.to_dict()
elif hasattr(tool.spec, "__dict__"):
pass
# Return tool information with essential fields only to avoid serialization issues
return Response(
{
"tool": {
"tool_id": tool_id,
"properties": properties_dict,
"image_name": str(tool.image_name)
if tool.image_name
else "default-tool",
"image_tag": str(tool.image_tag) if tool.image_tag else "latest",
"name": getattr(tool, "name", tool_id),
"description": getattr(tool, "description", ""),
"version": getattr(tool, "version", "latest"),
}
},
status=status.HTTP_200_OK,
)
except Exception as e:
logger.error(f"Failed to get tool information for {tool_id}: {e}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
# Always return fallback data instead of error to allow workflow to continue
from django.conf import settings
return Response(
{
"tool": {
"tool_id": tool_id,
"properties": {
"displayName": f"Error Tool ({tool_id[:8]}...)",
"functionName": tool_id,
"description": f"Error processing tool: {str(e)[:100]}",
"toolVersion": "error",
},
"image_name": settings.STRUCTURE_TOOL_IMAGE_NAME,
"image_tag": settings.STRUCTURE_TOOL_IMAGE_TAG,
"name": f"Error Tool ({tool_id[:8]}...)",
"description": f"Error: {str(e)[:100]}",
"version": "error",
"error": str(e),
"note": "Fallback data for tool processing error",
}
},
status=status.HTTP_200_OK, # Return 200 to allow workflow to continue
)
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
@api_view(["GET"])
def tool_instances_by_workflow_internal(request, workflow_id):
"""Get tool instances for a workflow for internal API calls."""
try:
from workflow_manager.workflow_v2.models.workflow import Workflow
logger.info(f"Getting tool instances for workflow: {workflow_id}")
# Get workflow with organization filtering first (via DefaultOrganizationManagerMixin)
try:
workflow = Workflow.objects.get(id=workflow_id)
logger.info(f"Found workflow: {workflow.id}")
except Workflow.DoesNotExist:
logger.error(f"Workflow not found: {workflow_id}")
return Response(
{"error": "Workflow not found or access denied"},
status=status.HTTP_404_NOT_FOUND,
)
# Get tool instances for the workflow with organization filtering
# Filter through the relationship: ToolInstance -> Workflow -> Organization
tool_instances_queryset = ToolInstance.objects.filter(workflow=workflow)
tool_instances_queryset = filter_queryset_by_organization(
tool_instances_queryset, request, "workflow__organization"
)
tool_instances = tool_instances_queryset.order_by("step")
logger.info(f"Found {len(tool_instances)} tool instances")
# Serialize the tool instances
try:
logger.info("Starting serialization of tool instances")
serializer = ToolInstanceSerializer(tool_instances, many=True)
logger.info("Accessing serializer.data")
serializer_data = serializer.data
logger.info(f"Serialization completed, got {len(serializer_data)} items")
except Exception as serializer_error:
logger.error(f"Serialization error: {serializer_error}")
# Try to return basic data without enhanced tool information
basic_data = []
for instance in tool_instances:
basic_data.append(
{
"id": str(instance.id),
"tool_id": instance.tool_id,
"step": instance.step,
"metadata": instance.metadata,
}
)
logger.info(f"Returning {len(basic_data)} basic tool instances")
return Response(
{
"workflow_id": workflow_id,
"tool_instances": basic_data,
"total_count": len(tool_instances),
"note": "Basic data returned due to serialization error",
},
status=status.HTTP_200_OK,
)
return Response(
{
"workflow_id": workflow_id,
"tool_instances": serializer_data,
"total_count": len(tool_instances),
},
status=status.HTTP_200_OK,
)
except Exception as e:
logger.exception(f"Failed to get tool instances for workflow {workflow_id}: {e}")
return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
@api_view(["POST"])
def validate_tool_instances_internal(request):
"""Validate tool instances and ensure adapter IDs are migrated.
This internal endpoint validates tool instances for a workflow, ensuring:
1. Adapter names are migrated to IDs
2. User has permissions to access tools and adapters
3. Tool settings match JSON schema requirements
Used by workers to validate tools before execution.
Args:
request: Request containing:
- workflow_id: ID of the workflow
- tool_instances: List of tool instance IDs
Returns:
Response with validation results and migrated metadata
"""
workflow_id = request.data.get("workflow_id")
tool_instance_ids = request.data.get("tool_instances", [])
if not workflow_id:
return Response(
{"error": "workflow_id is required"}, status=status.HTTP_400_BAD_REQUEST
)
validated_instances = []
validation_errors = []
try:
# Get tool instances from database with organization filtering
tool_instances_queryset = ToolInstance.objects.filter(
workflow_id=workflow_id, id__in=tool_instance_ids
).select_related("workflow", "workflow__created_by")
# Apply organization filtering
tool_instances_queryset = filter_queryset_by_organization(
tool_instances_queryset, request, "workflow__organization"
)
tool_instances = list(tool_instances_queryset)
# Validate each tool instance
for tool in tool_instances:
try:
# Get the user who created the workflow
user = tool.workflow.created_by
# Ensure adapter IDs are migrated from names to IDs
migrated_metadata = ToolInstanceHelper.ensure_adapter_ids_in_metadata(
tool, user=user
)
# Validate tool settings
ToolInstanceHelper.validate_tool_settings(
user=user,
tool_uid=tool.tool_id,
tool_meta=migrated_metadata,
)
# Add to validated list with migrated metadata
validated_instances.append(
{
"id": str(tool.id),
"tool_id": tool.tool_id,
"metadata": migrated_metadata,
"step": tool.step,
"status": "valid",
}
)
except Exception as e:
validation_errors.append(
{
"tool_id": tool.tool_id,
"tool_instance_id": str(tool.id),
"error": str(e),
}
)
logger.error(f"Tool validation failed for {tool.tool_id}: {e}")
# Return validation results
response_data = {
"success": len(validation_errors) == 0,
"validated_instances": validated_instances,
"errors": validation_errors,
"workflow_id": workflow_id,
}
if validation_errors:
return Response(response_data, status=status.HTTP_422_UNPROCESSABLE_ENTITY)
return Response(response_data, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"Tool validation failed: {e}", exc_info=True)
return Response(
{"error": f"Tool validation failed: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

View File

@@ -63,7 +63,6 @@ class ToolInstanceSerializer(AuditSerializer):
rep[TIKey.METADATA] = self._transform_adapter_ids_to_names_for_display(
metadata, tool_function
)
return rep
def _transform_adapter_ids_to_names_for_display(

View File

@@ -0,0 +1,15 @@
"""Internal API URLs for Usage access by workers."""
from django.urls import path
from . import internal_views
app_name = "usage_internal"
urlpatterns = [
path(
"aggregated-token-count/<str:file_execution_id>/",
internal_views.UsageInternalView.as_view(),
name="aggregated-token-count",
),
]

View File

@@ -0,0 +1,79 @@
"""Internal API views for Usage access by workers."""
import logging
from django.http import JsonResponse
from rest_framework import status
from rest_framework.request import Request
from rest_framework.views import APIView
from unstract.core.data_models import UsageResponseData
from .helper import UsageHelper
logger = logging.getLogger(__name__)
class UsageInternalView(APIView):
"""Internal API view for workers to access usage data.
This endpoint allows workers to get aggregated token usage data
for a specific file execution without direct database access.
"""
def get(self, request: Request, file_execution_id: str) -> JsonResponse:
"""Get aggregated token usage for a file execution.
Args:
request: HTTP request (no additional parameters needed)
file_execution_id: File execution ID to get usage data for
Returns:
JSON response with aggregated usage data using core data models
"""
try:
if not file_execution_id:
return JsonResponse(
{
"success": False,
"error": "file_execution_id parameter is required",
},
status=status.HTTP_400_BAD_REQUEST,
)
# Get aggregated token count using the existing helper
result = UsageHelper.get_aggregated_token_count(run_id=file_execution_id)
# Create UsageResponseData for type safety and consistency
usage_data = UsageResponseData(
file_execution_id=file_execution_id,
embedding_tokens=result.get("embedding_tokens"),
prompt_tokens=result.get("prompt_tokens"),
completion_tokens=result.get("completion_tokens"),
total_tokens=result.get("total_tokens"),
cost_in_dollars=result.get("cost_in_dollars"),
)
return JsonResponse(
{
"success": True,
"data": {
"file_execution_id": file_execution_id,
"usage": usage_data.to_dict(),
},
}
)
except Exception as e:
logger.error(
f"Error getting usage data for file_execution_id {file_execution_id}: {e}",
exc_info=True,
)
return JsonResponse(
{
"success": False,
"error": "Internal server error",
"file_execution_id": file_execution_id,
},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

View File

@@ -1,4 +1,6 @@
import json
import logging
import time
from typing import Any
from django.conf import settings
@@ -7,6 +9,8 @@ from django_redis import get_redis_connection
redis_cache = get_redis_connection("default")
logger = logging.getLogger(__name__)
class CacheService:
@staticmethod
@@ -38,6 +42,72 @@ class CacheService:
"""Delete keys in bulk based on the key pattern."""
cache.delete_pattern(key_pattern)
@staticmethod
def clear_cache_optimized(key_pattern: str) -> Any:
"""Delete keys in bulk using optimized SCAN approach for large datasets.
Uses Redis SCAN instead of KEYS to avoid blocking Redis during deletion.
Safe for production with large key sets. Use this for heavy operations
like workflow history clearing.
"""
TIMEOUT_SECONDS = 90 # Generous but bounded timeout
BATCH_SIZE = 1000
start_time = time.time()
deleted_count = 0
cursor = 0
completed_naturally = False
try:
while True:
# Check timeout first
if time.time() - start_time > TIMEOUT_SECONDS:
logger.warning(
f"Cache clearing timed out after {TIMEOUT_SECONDS}s, "
f"deleted {deleted_count} keys matching '{key_pattern}'"
)
break
# SCAN returns (next_cursor, keys_list)
cursor, keys = redis_cache.scan(
cursor=cursor, match=key_pattern, count=BATCH_SIZE
)
if keys:
# Delete keys in pipeline for efficiency
pipe = redis_cache.pipeline()
for key in keys:
pipe.delete(key)
pipe.execute()
deleted_count += len(keys)
# SCAN is complete when cursor returns to 0
if cursor == 0:
completed_naturally = True
break
# Log completion status
if completed_naturally:
logger.info(
f"Cache clearing completed: deleted {deleted_count} keys matching '{key_pattern}'"
)
else:
logger.warning(
f"Cache clearing incomplete: deleted {deleted_count} keys before timeout"
)
except (ConnectionError, TimeoutError, OSError) as e:
logger.error(f"Failed to clear cache pattern '{key_pattern}': {str(e)}")
# Fallback to old method for backward compatibility
try:
cache.delete_pattern(key_pattern)
logger.warning(f"Used fallback delete_pattern for '{key_pattern}'")
except (ConnectionError, TimeoutError, OSError) as fallback_error:
logger.error(
f"Fallback cache clearing also failed: {str(fallback_error)}"
)
raise e
@staticmethod
def check_a_key_exist(key: str, version: Any = None) -> bool:
data: bool = cache.has_key(key, version)
@@ -70,6 +140,10 @@ class CacheService:
def lpop(key: str) -> Any:
return redis_cache.lpop(key)
@staticmethod
def llen(key: str) -> int:
return redis_cache.llen(key)
@staticmethod
def lrem(key: str, value: str) -> None:
redis_cache.lrem(key, value)

View File

@@ -1,5 +1,4 @@
import http
import json
import logging
import os
from typing import Any
@@ -9,10 +8,9 @@ import socketio
from django.conf import settings
from django.core.wsgi import WSGIHandler
from unstract.core.constants import LogFieldName
from unstract.workflow_execution.enums import LogType
from unstract.core.data_models import LogDataDTO
from unstract.core.log_utils import get_validated_log_data, store_execution_log
from utils.constants import ExecutionLogConstants
from utils.dto import LogDataDTO
logger = logging.getLogger(__name__)
@@ -79,71 +77,23 @@ def _get_user_session_id_from_cookies(sid: str, environ: Any) -> str | None:
return session_id.value
# Functions moved to unstract.core.log_utils for sharing with workers
# Keep these as wrapper functions for backward compatibility
def _get_validated_log_data(json_data: Any) -> LogDataDTO | None:
"""Validate log data to persist history. This function takes log data in
JSON format, validates it, and returns a `LogDataDTO` object if the data is
valid. The validation process includes decoding bytes to string, parsing
the string as JSON, and checking for required fields and log type.
Args:
json_data (Any): Log data in JSON format
Returns:
Optional[LogDataDTO]: Log data DTO object
"""
if isinstance(json_data, bytes):
json_data = json_data.decode("utf-8")
if isinstance(json_data, str):
try:
# Parse the string as JSON
json_data = json.loads(json_data)
except json.JSONDecodeError:
logger.error(f"Error decoding JSON data while validating {json_data}")
return
if not isinstance(json_data, dict):
logger.warning(f"Getting invalid data type while validating {json_data}")
return
# Extract required fields from the JSON data
execution_id = json_data.get(LogFieldName.EXECUTION_ID)
organization_id = json_data.get(LogFieldName.ORGANIZATION_ID)
timestamp = json_data.get(LogFieldName.TIMESTAMP)
log_type = json_data.get(LogFieldName.TYPE)
file_execution_id = json_data.get(LogFieldName.FILE_EXECUTION_ID)
# Ensure the log type is LogType.LOG
if log_type != LogType.LOG.value:
return
# Check if all required fields are present
if not all((execution_id, organization_id, timestamp)):
logger.debug(f"Missing required fields while validating {json_data}")
return
return LogDataDTO(
execution_id=execution_id,
file_execution_id=file_execution_id,
organization_id=organization_id,
timestamp=timestamp,
log_type=log_type,
data=json_data,
)
"""Validate log data to persist history (backward compatibility wrapper)."""
return get_validated_log_data(json_data)
def _store_execution_log(data: dict[str, Any]) -> None:
"""Store execution log in database
Args:
data (dict[str, Any]): Execution log data
"""
if not ExecutionLogConstants.IS_ENABLED:
return
try:
log_data = _get_validated_log_data(json_data=data)
if log_data:
redis_conn.rpush(ExecutionLogConstants.LOG_QUEUE_NAME, log_data.to_json())
except Exception as e:
logger.error(f"Error storing execution log: {e}")
"""Store execution log in database (backward compatibility wrapper)."""
store_execution_log(
data=data,
redis_client=redis_conn,
log_queue_name=ExecutionLogConstants.LOG_QUEUE_NAME,
is_enabled=ExecutionLogConstants.IS_ENABLED,
)
def _emit_websocket_event(room: str, event: str, data: dict[str, Any]) -> None:

View File

@@ -0,0 +1,95 @@
"""Organization utilities for internal APIs.
Provides shared functions for organization context resolution.
"""
import logging
from typing import Any
from account_v2.models import Organization
from django.shortcuts import get_object_or_404
logger = logging.getLogger(__name__)
def resolve_organization(
org_id: str, raise_on_not_found: bool = False
) -> Organization | None:
"""Resolve organization by either organization.id (int) or organization.organization_id (string).
Args:
org_id: Organization identifier - can be either the primary key (numeric string)
or the organization_id field (string)
raise_on_not_found: If True, raises Http404 on not found. If False, returns None.
Returns:
Organization instance if found, None if not found and raise_on_not_found=False
Raises:
Http404: If organization not found and raise_on_not_found=True
"""
try:
if org_id.isdigit():
# If it's numeric, treat as primary key
if raise_on_not_found:
return get_object_or_404(Organization, id=org_id)
else:
return Organization.objects.get(id=org_id)
else:
# If it's string, treat as organization_id field
if raise_on_not_found:
return get_object_or_404(Organization, organization_id=org_id)
else:
return Organization.objects.get(organization_id=org_id)
except Organization.DoesNotExist:
if raise_on_not_found:
raise
logger.warning(f"Organization {org_id} not found")
return None
def get_organization_context(organization: Organization) -> dict[str, Any]:
"""Get standardized organization context data.
Args:
organization: Organization instance
Returns:
Dictionary with organization context information
"""
return {
"organization_id": str(organization.id),
"organization_name": organization.display_name,
"organization_slug": getattr(organization, "slug", ""),
"created_at": organization.created_at.isoformat()
if hasattr(organization, "created_at")
else None,
"settings": {
# Add organization-specific settings here
"subscription_active": True, # This would come from subscription model
"features_enabled": [], # This would come from feature flags
},
}
def filter_queryset_by_organization(queryset, request, organization_field="organization"):
"""Filter a Django queryset by organization context from request.
Args:
queryset: Django QuerySet to filter
request: HTTP request object with organization_id attribute
organization_field: Field name for organization relationship (default: 'organization')
Returns:
Filtered queryset or empty queryset if organization not found
"""
org_id = getattr(request, "organization_id", None)
if org_id:
organization = resolve_organization(org_id, raise_on_not_found=False)
if organization:
# Use dynamic field lookup
filter_kwargs = {organization_field: organization}
return queryset.filter(**filter_kwargs)
else:
# Return empty queryset if organization not found
return queryset.none()
return queryset

View File

@@ -0,0 +1,87 @@
"""WebSocket emission views for internal API.
This module provides endpoints for workers to trigger WebSocket events
through the backend's SocketIO server.
Security Note:
- CSRF protection is disabled for internal service-to-service communication
- Authentication is handled by InternalAPIAuthMiddleware using Bearer tokens
- This endpoint is for worker → backend WebSocket event triggering only
"""
import json
import logging
from django.http import JsonResponse
from django.views.decorators.csrf import csrf_exempt
from django.views.decorators.http import require_http_methods
from utils.log_events import _emit_websocket_event
logger = logging.getLogger(__name__)
# CSRF exemption is safe here because:
# 1. Internal service-to-service communication (workers → backend)
# 2. Protected by InternalAPIAuthMiddleware Bearer token authentication
# 3. No browser sessions or cookies involved
# 4. Used for WebSocket event triggering, not state modification
@csrf_exempt
@require_http_methods(["POST"])
def emit_websocket(request):
"""Internal API endpoint for workers to emit WebSocket events.
Expected payload:
{
"room": "session_id",
"event": "logs:session_id",
"data": {...}
}
Returns:
JSON response with success/error status
"""
try:
# Parse request data (standard Django view)
data = json.loads(request.body.decode("utf-8"))
# Extract required fields
room = data.get("room")
event = data.get("event")
message_data = data.get("data", {})
# Validate required fields
if not room or not event:
return JsonResponse(
{
"status": "error",
"message": "Missing required fields: room and event are required",
},
status=400,
)
# Emit the WebSocket event
_emit_websocket_event(room=room, event=event, data=message_data)
logger.debug(f"WebSocket event emitted: room={room}, event={event}")
return JsonResponse(
{
"status": "success",
"message": "WebSocket event emitted successfully",
"room": room,
"event": event,
}
)
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in WebSocket emission request: {e}")
return JsonResponse(
{"status": "error", "message": "Invalid JSON payload"}, status=400
)
except Exception as e:
logger.error(f"Error emitting WebSocket event: {e}")
return JsonResponse(
{"status": "error", "message": f"Failed to emit WebSocket event: {str(e)}"},
status=500,
)

35
backend/uv.lock generated
View File

@@ -2975,6 +2975,21 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/0c/c1/6aece0ab5209981a70cd186f164c133fdba2f51e124ff92b73de7fd24d78/protobuf-4.25.8-py3-none-any.whl", hash = "sha256:15a0af558aa3b13efef102ae6e4f3efac06f1eea11afb3a57db2901447d9fb59", size = 156757, upload-time = "2025-05-28T14:22:24.135Z" },
]
[[package]]
name = "psutil"
version = "7.0.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/2a/80/336820c1ad9286a4ded7e845b2eccfcb27851ab8ac6abece774a6ff4d3de/psutil-7.0.0.tar.gz", hash = "sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456", size = 497003 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ed/e6/2d26234410f8b8abdbf891c9da62bee396583f713fb9f3325a4760875d22/psutil-7.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25", size = 238051 },
{ url = "https://files.pythonhosted.org/packages/04/8b/30f930733afe425e3cbfc0e1468a30a18942350c1a8816acfade80c005c4/psutil-7.0.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da", size = 239535 },
{ url = "https://files.pythonhosted.org/packages/2a/ed/d362e84620dd22876b55389248e522338ed1bf134a5edd3b8231d7207f6d/psutil-7.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91", size = 275004 },
{ url = "https://files.pythonhosted.org/packages/bf/b9/b0eb3f3cbcb734d930fdf839431606844a825b23eaf9a6ab371edac8162c/psutil-7.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34", size = 277986 },
{ url = "https://files.pythonhosted.org/packages/eb/a2/709e0fe2f093556c17fbafda93ac032257242cabcc7ff3369e2cb76a97aa/psutil-7.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993", size = 279544 },
{ url = "https://files.pythonhosted.org/packages/50/e6/eecf58810b9d12e6427369784efe814a1eec0f492084ce8eb8f4d89d6d61/psutil-7.0.0-cp37-abi3-win32.whl", hash = "sha256:ba3fcef7523064a6c9da440fc4d6bd07da93ac726b5733c29027d7dc95b39d99", size = 241053 },
{ url = "https://files.pythonhosted.org/packages/50/1b/6921afe68c74868b4c9fa424dad3be35b095e16687989ebbb50ce4fceb7c/psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553", size = 244885 },
]
[[package]]
name = "psycopg2-binary"
version = "2.9.9"
@@ -3502,6 +3517,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06", size = 54481, upload-time = "2023-05-01T04:11:28.427Z" },
]
[[package]]
name = "responses"
version = "0.25.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pyyaml" },
{ name = "requests" },
{ name = "urllib3" },
]
sdist = { url = "https://files.pythonhosted.org/packages/81/7e/2345ac3299bd62bd7163216702bbc88976c099cfceba5b889f2a457727a1/responses-0.25.7.tar.gz", hash = "sha256:8ebae11405d7a5df79ab6fd54277f6f2bc29b2d002d0dd2d5c632594d1ddcedb", size = 79203 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e4/fc/1d20b64fa90e81e4fa0a34c9b0240a6cfb1326b7e06d18a5432a9917c316/responses-0.25.7-py3-none-any.whl", hash = "sha256:92ca17416c90fe6b35921f52179bff29332076bb32694c0df02dcac2c6bc043c", size = 34732 },
]
[[package]]
name = "rpds-py"
version = "0.27.1"
@@ -4042,6 +4071,9 @@ dev = [
{ name = "debugpy" },
{ name = "inotify" },
{ name = "poethepoet" },
{ name = "psutil" },
{ name = "pytest" },
{ name = "responses" },
{ name = "unstract-connectors" },
{ name = "unstract-core" },
{ name = "unstract-filesystem" },
@@ -4109,6 +4141,9 @@ dev = [
{ name = "debugpy", specifier = ">=1.8.14" },
{ name = "inotify", specifier = ">=0.2.10" },
{ name = "poethepoet", specifier = ">=0.33.1" },
{ name = "psutil", specifier = ">=7.0.0" },
{ name = "pytest", specifier = ">=8.3.5" },
{ name = "responses", specifier = ">=0.25.7" },
{ name = "unstract-connectors", editable = "../unstract/connectors" },
{ name = "unstract-core", editable = "../unstract/core" },
{ name = "unstract-filesystem", editable = "../unstract/filesystem" },

View File

@@ -871,7 +871,6 @@ class DestinationConnector(BaseConnector):
).to_dict()
queue_result_json = json.dumps(queue_result)
conn = QueueUtils.get_queue_inst()
conn.enqueue(queue_name=q_name, message=queue_result_json)
logger.info(f"Pushed {file_name} to queue {q_name} with file content")
@@ -891,11 +890,13 @@ class DestinationConnector(BaseConnector):
q_name = self._get_review_queue_name()
if meta_data:
whisper_hash = meta_data.get("whisper-hash")
extracted_text = meta_data.get("extracted_text")
else:
whisper_hash = None
extracted_text = None
# Get extracted text from metadata (added by structure tool)
extracted_text = meta_data.get("extracted_text") if meta_data else None
# Get TTL from workflow settings
ttl_seconds = WorkflowUtil.get_hitl_ttl_seconds(workflow)
# Create QueueResult with TTL metadata
queue_result_obj = QueueResult(
@@ -907,8 +908,8 @@ class DestinationConnector(BaseConnector):
whisper_hash=whisper_hash,
file_execution_id=file_execution_id,
extracted_text=extracted_text,
ttl_seconds=ttl_seconds,
)
# Add TTL metadata based on HITLSettings
queue_result_obj.ttl_seconds = WorkflowUtil.get_hitl_ttl_seconds(workflow)

View File

@@ -28,3 +28,27 @@ class WorkflowEndpointUtils:
workflow=workflow
)
return endpoints
@staticmethod
def get_endpoint_for_workflow_by_type(
workflow_id: str, endpoint_type: WorkflowEndpoint.EndpointType
) -> WorkflowEndpoint:
"""Get endpoint for a given workflow by type.
Args:
workflow_id (str): The ID of the workflow.
endpoint_type (WorkflowEndpoint.EndpointType): The type of the endpoint.
Returns:
WorkflowEndpoint: The endpoint for the given workflow and type.
"""
workflow = WorkflowHelper.get_workflow_by_id(workflow_id)
endpoint: WorkflowEndpoint = WorkflowEndpoint.objects.get(
workflow=workflow,
endpoint_type=endpoint_type,
)
if endpoint.connector_instance:
endpoint.connector_instance.connector_metadata = (
endpoint.connector_instance.metadata
)
return endpoint

View File

@@ -578,7 +578,7 @@ class SourceConnector(BaseConnector):
return WorkflowExecution.objects.filter(
workflow=self.workflow,
workflow__organization_id=organization.id, # Security: Organization isolation
status__in=[ExecutionStatus.EXECUTING, ExecutionStatus.PENDING],
status__in=[ExecutionStatus.EXECUTING.value, ExecutionStatus.PENDING.value],
)
def _has_blocking_file_execution(self, execution, file_hash: FileHash) -> bool:
@@ -616,7 +616,10 @@ class SourceConnector(BaseConnector):
workflow_execution=execution,
file_hash=file_hash.file_hash,
file_path=file_hash.file_path,
status__in=ExecutionStatus.get_skip_processing_statuses(),
status__in=[
status.value
for status in ExecutionStatus.get_skip_processing_statuses()
],
)
except WorkflowFileExecution.DoesNotExist:
return None
@@ -633,7 +636,10 @@ class SourceConnector(BaseConnector):
workflow_execution=execution,
provider_file_uuid=file_hash.provider_file_uuid,
file_path=file_hash.file_path,
status__in=ExecutionStatus.get_skip_processing_statuses(),
status__in=[
status.value
for status in ExecutionStatus.get_skip_processing_statuses()
],
)
except WorkflowFileExecution.DoesNotExist:
return None

View File

@@ -26,8 +26,8 @@ class ExecutionSerializer(serializers.ModelSerializer):
def get_successful_files(self, obj: WorkflowExecution) -> int:
"""Return the count of successfully executed files"""
return obj.file_executions.filter(status=ExecutionStatus.COMPLETED).count()
return obj.file_executions.filter(status=ExecutionStatus.COMPLETED.value).count()
def get_failed_files(self, obj: WorkflowExecution) -> int:
"""Return the count of failed executed files"""
return obj.file_executions.filter(status=ExecutionStatus.ERROR).count()
return obj.file_executions.filter(status=ExecutionStatus.ERROR.value).count()

View File

@@ -0,0 +1,42 @@
"""Internal API URLs for File Execution
URL patterns for file execution internal APIs.
"""
from django.urls import include, path
from rest_framework.routers import DefaultRouter
from .internal_views import (
FileExecutionBatchCreateAPIView,
FileExecutionBatchHashUpdateAPIView,
FileExecutionBatchStatusUpdateAPIView,
FileExecutionInternalViewSet,
FileExecutionMetricsAPIView,
)
# Create router for file execution viewsets
router = DefaultRouter()
router.register(r"", FileExecutionInternalViewSet, basename="file-execution-internal")
urlpatterns = [
# Batch operations
path(
"batch-create/",
FileExecutionBatchCreateAPIView.as_view(),
name="file-execution-batch-create",
),
path(
"batch-status-update/",
FileExecutionBatchStatusUpdateAPIView.as_view(),
name="file-execution-batch-status-update",
),
path(
"batch-hash-update/",
FileExecutionBatchHashUpdateAPIView.as_view(),
name="file-execution-batch-hash-update",
),
path(
"metrics/", FileExecutionMetricsAPIView.as_view(), name="file-execution-metrics"
),
# File execution CRUD (via router)
path("", include(router.urls)),
]

View File

@@ -0,0 +1,777 @@
"""Internal API Views for File Execution
Handles file execution related endpoints for internal services.
"""
import logging
from django.db import transaction
from rest_framework import status, viewsets
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.views import APIView
from utils.organization_utils import filter_queryset_by_organization
from workflow_manager.endpoint_v2.dto import FileHash
from workflow_manager.file_execution.models import WorkflowFileExecution
# Import serializers from workflow_manager internal API
from workflow_manager.internal_serializers import (
FileExecutionStatusUpdateSerializer,
WorkflowFileExecutionSerializer,
)
logger = logging.getLogger(__name__)
class FileExecutionInternalViewSet(viewsets.ModelViewSet):
"""Internal API ViewSet for File Execution operations."""
serializer_class = WorkflowFileExecutionSerializer
lookup_field = "id"
def get_queryset(self):
"""Get file executions filtered by organization context and query parameters."""
queryset = WorkflowFileExecution.objects.all()
# Filter through the relationship: WorkflowFileExecution -> WorkflowExecution -> Workflow -> Organization
queryset = filter_queryset_by_organization(
queryset, self.request, "workflow_execution__workflow__organization"
)
# Debug: Log initial queryset count after organization filtering
org_filtered_count = queryset.count()
logger.debug(
f"After organization filtering: {org_filtered_count} file executions"
)
# Support filtering by query parameters for get-or-create operations
execution_id = self.request.query_params.get("execution_id")
file_hash = self.request.query_params.get("file_hash")
provider_file_uuid = self.request.query_params.get("provider_file_uuid")
workflow_id = self.request.query_params.get("workflow_id")
file_path = self.request.query_params.get(
"file_path"
) # CRITICAL: Add file_path parameter
logger.debug(
f"Query parameters: execution_id={execution_id}, file_hash={file_hash}, provider_file_uuid={provider_file_uuid}, workflow_id={workflow_id}, file_path={file_path}"
)
# Apply filters step by step with debugging
if execution_id:
queryset = queryset.filter(workflow_execution_id=execution_id)
logger.info(
f"DEBUG: After execution_id filter: {queryset.count()} file executions"
)
# CRITICAL FIX: Include file_path filter to match unique constraints
if file_path:
queryset = queryset.filter(file_path=file_path)
logger.debug(f"After file_path filter: {queryset.count()} file executions")
# CRITICAL FIX: Match backend manager logic - use file_hash OR provider_file_uuid (not both)
if file_hash:
queryset = queryset.filter(file_hash=file_hash)
logger.info(
f"DEBUG: After file_hash filter: {queryset.count()} file executions"
)
elif provider_file_uuid:
queryset = queryset.filter(provider_file_uuid=provider_file_uuid)
logger.info(
f"DEBUG: After provider_file_uuid filter: {queryset.count()} file executions"
)
if workflow_id:
queryset = queryset.filter(workflow_execution__workflow_id=workflow_id)
logger.info(
f"DEBUG: After workflow_id filter: {queryset.count()} file executions"
)
final_count = queryset.count()
logger.info(
f"Final queryset count: {final_count} file executions for params: execution_id={execution_id}, file_hash={file_hash}, provider_file_uuid={provider_file_uuid}, workflow_id={workflow_id}, file_path={file_path}"
)
# If we still have too many results, something is wrong with the filtering
if final_count > 10: # Reasonable threshold
logger.warning(
f"Query returned {final_count} file executions - filtering may not be working correctly"
)
logger.warning(
f"Query params: execution_id={execution_id}, file_hash={file_hash}, workflow_id={workflow_id}"
)
return queryset
def list(self, request, *args, **kwargs):
"""List file executions with enhanced filtering validation."""
queryset = self.get_queryset()
count = queryset.count()
# If we get too many results, it means the filtering failed
if count > 50: # Conservative threshold
logger.error(
f"GET request returned {count} file executions - this suggests broken query parameter filtering"
)
logger.error(f"Request query params: {dict(request.query_params)}")
# For debugging, show a sample of what we're returning
sample_ids = list(queryset.values_list("id", flat=True)[:5])
logger.error(f"Sample file execution IDs: {sample_ids}")
# Return error response instead of broken list
return Response(
{
"error": "Query returned too many results",
"detail": f"Expected 0-1 file executions but got {count}. Check query parameters.",
"count": count,
"query_params": dict(request.query_params),
},
status=status.HTTP_400_BAD_REQUEST,
)
# Continue with normal list behavior for reasonable result counts
logger.info(f"GET request successfully filtered to {count} file executions")
return super().list(request, *args, **kwargs)
@action(detail=True, methods=["post"])
def status(self, request, id=None):
"""Update file execution status."""
try:
# Get file execution by ID with organization filtering
# Don't use self.get_object() as it applies query parameter filtering
base_queryset = WorkflowFileExecution.objects.all()
base_queryset = filter_queryset_by_organization(
base_queryset, request, "workflow_execution__workflow__organization"
)
try:
file_execution = base_queryset.get(id=id)
except WorkflowFileExecution.DoesNotExist:
logger.warning(f"WorkflowFileExecution {id} not found for status update")
return Response(
{
"error": "WorkflowFileExecution not found",
"detail": f"No file execution record found with ID {id}",
},
status=status.HTTP_404_NOT_FOUND,
)
serializer = FileExecutionStatusUpdateSerializer(data=request.data)
if serializer.is_valid():
validated_data = serializer.validated_data
# Update file execution using the model's update_status method
file_execution.update_status(
status=validated_data["status"],
execution_error=validated_data.get("error_message"),
execution_time=validated_data.get("execution_time"),
)
logger.info(
f"Updated file execution {id} status to {validated_data['status']}"
)
# Return consistent dataclass response
from unstract.core.data_models import FileExecutionStatusUpdateRequest
response_data = FileExecutionStatusUpdateRequest(
status=file_execution.status,
error_message=file_execution.execution_error,
result=getattr(file_execution, "result", None),
)
return Response(
{
"status": "updated",
"file_execution_id": str(file_execution.id),
"data": response_data.to_dict(),
},
status=status.HTTP_200_OK,
)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
except Exception as e:
logger.error(f"Failed to update file execution status {id}: {str(e)}")
return Response(
{"error": "Failed to update file execution status", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
def create(self, request, *args, **kwargs):
"""Create or get existing workflow file execution using existing manager method."""
try:
from workflow_manager.workflow_v2.models.execution import WorkflowExecution
data = request.data
execution_id = data.get("execution_id")
file_hash_data = data.get("file_hash", {})
workflow_id = data.get("workflow_id")
if not execution_id:
return Response(
{"error": "execution_id is required"},
status=status.HTTP_400_BAD_REQUEST,
)
# Get workflow execution with organization filtering
try:
workflow_execution = WorkflowExecution.objects.get(id=execution_id)
# Verify organization access
filter_queryset_by_organization(
WorkflowExecution.objects.filter(id=execution_id),
request,
"workflow__organization",
).get()
except WorkflowExecution.DoesNotExist:
return Response(
{"error": "WorkflowExecution not found or access denied"},
status=status.HTTP_404_NOT_FOUND,
)
# Convert request data to FileHash object that the manager expects
file_hash = FileHash(
file_path=file_hash_data.get("file_path", ""),
file_name=file_hash_data.get("file_name", ""),
source_connection_type=file_hash_data.get("source_connection_type", ""),
file_hash=file_hash_data.get("file_hash"),
file_size=file_hash_data.get("file_size"),
provider_file_uuid=file_hash_data.get("provider_file_uuid"),
mime_type=file_hash_data.get("mime_type"),
fs_metadata=file_hash_data.get("fs_metadata"),
file_destination=file_hash_data.get("file_destination"),
is_executed=file_hash_data.get("is_executed", False),
file_number=file_hash_data.get("file_number"),
)
# Determine if this is an API request (affects file_path handling in manager)
is_api = file_hash_data.get("source_connection_type", "") == "API"
# Use existing manager method - this handles get_or_create logic properly
file_execution = WorkflowFileExecution.objects.get_or_create_file_execution(
workflow_execution=workflow_execution, file_hash=file_hash, is_api=is_api
)
# Return single object (not list!) using serializer
serializer = self.get_serializer(file_execution)
response_data = serializer.data
# ROOT CAUSE FIX: Ensure file_path is always present in API response
# The backend model sets file_path to None for API files, but workers require it
if not response_data.get("file_path") and file_hash.file_path:
logger.info(
f"Backend stored null file_path for API file, including original: {file_hash.file_path}"
)
response_data["file_path"] = file_hash.file_path
logger.info(
f"Retrieved/created file execution {file_execution.id} for workflow {workflow_id}"
)
logger.debug(f"Response data: {response_data}")
# Determine status code based on whether it was created or retrieved
# Note: We can't easily tell if it was created or retrieved from the manager,
# but 201 is fine for both cases in this context
return Response(response_data, status=status.HTTP_201_CREATED)
except Exception as e:
logger.error(f"Failed to get/create file execution: {str(e)}")
return Response(
{"error": "Failed to get/create file execution", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@action(detail=True, methods=["patch"])
def update_hash(self, request, id=None):
"""Update file execution with computed file hash."""
try:
# Get file execution by ID with organization filtering
base_queryset = WorkflowFileExecution.objects.all()
base_queryset = filter_queryset_by_organization(
base_queryset, request, "workflow_execution__workflow__organization"
)
try:
file_execution = base_queryset.get(id=id)
except WorkflowFileExecution.DoesNotExist:
logger.warning(f"WorkflowFileExecution {id} not found for hash update")
return Response(
{
"error": "WorkflowFileExecution not found",
"detail": f"No file execution record found with ID {id}",
},
status=status.HTTP_404_NOT_FOUND,
)
# Extract update data
file_hash = request.data.get("file_hash")
fs_metadata = request.data.get("fs_metadata")
mime_type = request.data.get("mime_type")
if not file_hash and not fs_metadata and not mime_type:
return Response(
{"error": "file_hash, fs_metadata, or mime_type is required"},
status=status.HTTP_400_BAD_REQUEST,
)
# Use the model's update method for efficient field-specific updates
file_execution.update(
file_hash=file_hash, fs_metadata=fs_metadata, mime_type=mime_type
)
logger.info(
f"Updated file execution {id} with file_hash: {file_hash[:16] if file_hash else 'none'}..."
)
# Return updated record
serializer = self.get_serializer(file_execution)
return Response(
{
"status": "updated",
"file_execution_id": str(file_execution.id),
"data": serializer.data,
},
status=status.HTTP_200_OK,
)
except Exception as e:
logger.error(f"Failed to update file execution hash {id}: {str(e)}")
return Response(
{"error": "Failed to update file execution hash", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
class FileExecutionBatchCreateAPIView(APIView):
"""Internal API endpoint for creating multiple file executions in a single batch."""
def post(self, request):
"""Create multiple file executions in a single batch request."""
try:
file_executions = request.data.get("file_executions", [])
if not file_executions:
return Response(
{"error": "file_executions list is required"},
status=status.HTTP_400_BAD_REQUEST,
)
successful_creations = []
failed_creations = []
with transaction.atomic():
for file_execution_data in file_executions:
try:
from workflow_manager.workflow_v2.models.execution import (
WorkflowExecution,
)
execution_id = file_execution_data.get("execution_id")
file_hash_data = file_execution_data.get("file_hash", {})
if not execution_id:
failed_creations.append(
{
"file_name": file_hash_data.get(
"file_name", "unknown"
),
"error": "execution_id is required",
}
)
continue
# Get workflow execution with organization filtering
try:
workflow_execution = WorkflowExecution.objects.get(
id=execution_id
)
# Verify organization access
filter_queryset_by_organization(
WorkflowExecution.objects.filter(id=execution_id),
request,
"workflow__organization",
).get()
except WorkflowExecution.DoesNotExist:
failed_creations.append(
{
"file_name": file_hash_data.get(
"file_name", "unknown"
),
"error": "WorkflowExecution not found or access denied",
}
)
continue
# Convert request data to FileHash object
file_hash = FileHash(
file_path=file_hash_data.get("file_path", ""),
file_name=file_hash_data.get("file_name", ""),
source_connection_type=file_hash_data.get(
"source_connection_type", ""
),
file_hash=file_hash_data.get("file_hash"),
file_size=file_hash_data.get("file_size"),
provider_file_uuid=file_hash_data.get("provider_file_uuid"),
mime_type=file_hash_data.get("mime_type"),
fs_metadata=file_hash_data.get("fs_metadata"),
file_destination=file_hash_data.get("file_destination"),
is_executed=file_hash_data.get("is_executed", False),
file_number=file_hash_data.get("file_number"),
)
# Determine if this is an API request
is_api = file_hash_data.get("source_connection_type", "") == "API"
# Use existing manager method
file_execution = (
WorkflowFileExecution.objects.get_or_create_file_execution(
workflow_execution=workflow_execution,
file_hash=file_hash,
is_api=is_api,
)
)
# ROOT CAUSE FIX: Ensure file_path is always present in batch response
# The backend model sets file_path to None for API files, but workers require it
response_file_path = file_execution.file_path
if not response_file_path and file_hash.file_path:
response_file_path = file_hash.file_path
successful_creations.append(
{
"id": str(file_execution.id),
"file_name": file_execution.file_name,
"file_path": response_file_path,
"status": file_execution.status,
}
)
except Exception as e:
failed_creations.append(
{
"file_name": file_execution_data.get("file_hash", {}).get(
"file_name", "unknown"
),
"error": str(e),
}
)
logger.info(
f"Batch file execution creation: {len(successful_creations)} successful, {len(failed_creations)} failed"
)
return Response(
{
"successful_creations": successful_creations,
"failed_creations": failed_creations,
"total_processed": len(file_executions),
},
status=status.HTTP_201_CREATED,
)
except Exception as e:
logger.error(f"Failed to process batch file execution creation: {str(e)}")
return Response(
{
"error": "Failed to process batch file execution creation",
"detail": str(e),
},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
class FileExecutionBatchStatusUpdateAPIView(APIView):
"""Internal API endpoint for updating multiple file execution statuses in a single batch."""
def post(self, request):
"""Update multiple file execution statuses in a single batch request."""
try:
status_updates = request.data.get("status_updates", [])
if not status_updates:
return Response(
{"error": "status_updates list is required"},
status=status.HTTP_400_BAD_REQUEST,
)
successful_updates = []
failed_updates = []
with transaction.atomic():
for update_data in status_updates:
try:
file_execution_id = update_data.get("file_execution_id")
status_value = update_data.get("status")
if not file_execution_id or not status_value:
failed_updates.append(
{
"file_execution_id": file_execution_id,
"error": "file_execution_id and status are required",
}
)
continue
# Get file execution with organization filtering
base_queryset = WorkflowFileExecution.objects.all()
base_queryset = filter_queryset_by_organization(
base_queryset,
request,
"workflow_execution__workflow__organization",
)
try:
file_execution = base_queryset.get(id=file_execution_id)
except WorkflowFileExecution.DoesNotExist:
failed_updates.append(
{
"file_execution_id": file_execution_id,
"error": "WorkflowFileExecution not found",
}
)
continue
# Update file execution using the model's update_status method
file_execution.update_status(
status=status_value,
execution_error=update_data.get("error_message"),
execution_time=update_data.get("execution_time"),
)
successful_updates.append(
{
"file_execution_id": str(file_execution.id),
"status": file_execution.status,
"file_name": file_execution.file_name,
}
)
except Exception as e:
failed_updates.append(
{"file_execution_id": file_execution_id, "error": str(e)}
)
logger.info(
f"Batch file execution status update: {len(successful_updates)} successful, {len(failed_updates)} failed"
)
return Response(
{
"successful_updates": successful_updates,
"failed_updates": failed_updates,
"total_processed": len(status_updates),
}
)
except Exception as e:
logger.error(
f"Failed to process batch file execution status update: {str(e)}"
)
return Response(
{
"error": "Failed to process batch file execution status update",
"detail": str(e),
},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
class FileExecutionBatchHashUpdateAPIView(APIView):
"""Internal API endpoint for updating multiple file execution hashes in a single batch."""
def post(self, request):
"""Update multiple file execution hashes in a single batch request."""
try:
hash_updates = request.data.get("hash_updates", [])
if not hash_updates:
return Response(
{"error": "hash_updates list is required"},
status=status.HTTP_400_BAD_REQUEST,
)
successful_updates = []
failed_updates = []
with transaction.atomic():
for update_data in hash_updates:
try:
file_execution_id = update_data.get("file_execution_id")
file_hash = update_data.get("file_hash")
if not file_execution_id or not file_hash:
failed_updates.append(
{
"file_execution_id": file_execution_id,
"error": "file_execution_id and file_hash are required",
}
)
continue
# Get file execution with organization filtering
base_queryset = WorkflowFileExecution.objects.all()
base_queryset = filter_queryset_by_organization(
base_queryset,
request,
"workflow_execution__workflow__organization",
)
try:
file_execution = base_queryset.get(id=file_execution_id)
except WorkflowFileExecution.DoesNotExist:
failed_updates.append(
{
"file_execution_id": file_execution_id,
"error": "WorkflowFileExecution not found",
}
)
continue
# Update file execution hash using the model's update method
file_execution.update(
file_hash=file_hash,
fs_metadata=update_data.get("fs_metadata"),
)
successful_updates.append(
{
"file_execution_id": str(file_execution.id),
"file_hash": file_hash[:16] + "..."
if file_hash
else None,
"file_name": file_execution.file_name,
}
)
except Exception as e:
failed_updates.append(
{"file_execution_id": file_execution_id, "error": str(e)}
)
logger.info(
f"Batch file execution hash update: {len(successful_updates)} successful, {len(failed_updates)} failed"
)
return Response(
{
"successful_updates": successful_updates,
"failed_updates": failed_updates,
"total_processed": len(hash_updates),
}
)
except Exception as e:
logger.error(f"Failed to process batch file execution hash update: {str(e)}")
return Response(
{
"error": "Failed to process batch file execution hash update",
"detail": str(e),
},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
class FileExecutionMetricsAPIView(APIView):
"""Internal API endpoint for getting file execution metrics."""
def get(self, request):
"""Get file execution metrics with optional filtering."""
try:
# Get query parameters
start_date = request.query_params.get("start_date")
end_date = request.query_params.get("end_date")
workflow_id = request.query_params.get("workflow_id")
execution_id = request.query_params.get("execution_id")
status = request.query_params.get("status")
# Build base queryset with organization filtering
file_executions = WorkflowFileExecution.objects.all()
file_executions = filter_queryset_by_organization(
file_executions, request, "workflow_execution__workflow__organization"
)
# Apply filters
if start_date:
from datetime import datetime
file_executions = file_executions.filter(
created_at__gte=datetime.fromisoformat(start_date)
)
if end_date:
from datetime import datetime
file_executions = file_executions.filter(
created_at__lte=datetime.fromisoformat(end_date)
)
if workflow_id:
file_executions = file_executions.filter(
workflow_execution__workflow_id=workflow_id
)
if execution_id:
file_executions = file_executions.filter(
workflow_execution_id=execution_id
)
if status:
file_executions = file_executions.filter(status=status)
# Calculate metrics
from django.db.models import Avg, Count, Sum
total_file_executions = file_executions.count()
# Status breakdown
status_counts = file_executions.values("status").annotate(count=Count("id"))
status_breakdown = {item["status"]: item["count"] for item in status_counts}
# Success rate
completed_count = status_breakdown.get("COMPLETED", 0)
success_rate = (
(completed_count / total_file_executions)
if total_file_executions > 0
else 0
)
# Average execution time
avg_execution_time = (
file_executions.aggregate(avg_time=Avg("execution_time"))["avg_time"] or 0
)
# File size statistics
total_file_size = (
file_executions.aggregate(total_size=Sum("file_size"))["total_size"] or 0
)
avg_file_size = (
file_executions.aggregate(avg_size=Avg("file_size"))["avg_size"] or 0
)
metrics = {
"total_file_executions": total_file_executions,
"status_breakdown": status_breakdown,
"success_rate": success_rate,
"average_execution_time": avg_execution_time,
"total_file_size": total_file_size,
"average_file_size": avg_file_size,
"filters_applied": {
"start_date": start_date,
"end_date": end_date,
"workflow_id": workflow_id,
"execution_id": execution_id,
"status": status,
},
}
logger.info(
f"Generated file execution metrics: {total_file_executions} executions, {success_rate:.2%} success rate"
)
return Response(metrics)
except Exception as e:
logger.error(f"Failed to get file execution metrics: {str(e)}")
return Response(
{"error": "Failed to get file execution metrics", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

View File

@@ -120,32 +120,30 @@ class WorkflowFileExecution(BaseModel):
def update_status(
self,
status: ExecutionStatus,
status: ExecutionStatus | str,
execution_error: str = None,
execution_time: float = None,
) -> None:
"""Updates the status and execution details of an input file.
Args:
execution_file: The `WorkflowExecutionFile` object to update
status: The new status of the file
execution_time: The execution time for processing the file
status: The new status of the file (ExecutionStatus enum or string)
execution_time: The execution time for processing the file (optional)
execution_error: (Optional) Error message if processing failed
Return:
The updated `WorkflowExecutionInputFile` object
"""
self.status = status
if (
status
in [
ExecutionStatus.COMPLETED,
ExecutionStatus.ERROR,
ExecutionStatus.STOPPED,
]
and not self.execution_time
):
self.execution_time = CommonUtils.time_since(self.created_at)
# Set execution_time if provided, otherwise calculate it for final states
status = ExecutionStatus(status)
self.status = status.value
if status in [
ExecutionStatus.COMPLETED,
ExecutionStatus.ERROR,
ExecutionStatus.STOPPED,
]:
self.execution_time = CommonUtils.time_since(self.created_at, 3)
self.execution_error = execution_error
self.save()
@@ -210,6 +208,13 @@ class WorkflowFileExecution(BaseModel):
fields=["workflow_execution", "provider_file_uuid", "file_path"],
name="unique_workflow_provider_uuid_path",
),
# CRITICAL FIX: Add constraint for API files where file_path is None
# This prevents duplicate entries for same file_hash
models.UniqueConstraint(
fields=["workflow_execution", "file_hash"],
condition=models.Q(file_path__isnull=True),
name="unique_workflow_api_file_hash",
),
]
@property
@@ -219,17 +224,20 @@ class WorkflowFileExecution(BaseModel):
Returns:
bool: True if the execution status is completed, False otherwise.
"""
return self.status is not None and self.status == ExecutionStatus.COMPLETED
return self.status is not None and self.status == ExecutionStatus.COMPLETED.value
def update(
self,
file_hash: str = None,
fs_metadata: dict[str, Any] = None,
mime_type: str = None,
) -> None:
"""Updates the file execution details.
Args:
file_hash: (Optional) Hash of the file content
fs_metadata: (Optional) File system metadata
mime_type: (Optional) MIME type of the file
Returns:
None
@@ -242,5 +250,8 @@ class WorkflowFileExecution(BaseModel):
if fs_metadata is not None:
self.fs_metadata = fs_metadata
update_fields.append("fs_metadata")
if mime_type is not None:
self.mime_type = mime_type
update_fields.append("mime_type")
if update_fields: # Save only if there's an actual update
self.save(update_fields=update_fields)

View File

@@ -30,9 +30,9 @@ class FileCentricExecutionSerializer(serializers.ModelSerializer):
exclude = ["file_hash"]
def get_status_msg(self, obj: FileExecution) -> dict[str, any] | None:
if obj.status in [ExecutionStatus.PENDING]:
if obj.status in [ExecutionStatus.PENDING.value]:
return self.INIT_STATUS_MSG
elif obj.status == ExecutionStatus.ERROR:
elif obj.status == ExecutionStatus.ERROR.value:
return obj.execution_error
latest_log = (

View File

@@ -0,0 +1,447 @@
"""Internal API Views for Worker Communication
This module provides internal API endpoints that workers use to communicate
with Django backend for database operations only. All business logic has been
moved to workers.
NOTE: Many sophisticated endpoints are now implemented in internal_views.py
using class-based views. This file contains simpler function-based views
for basic operations.
"""
import logging
from account_v2.models import Organization
from django.views.decorators.csrf import csrf_exempt
from rest_framework import status
from rest_framework.decorators import api_view
from rest_framework.response import Response
from tool_instance_v2.models import ToolInstance
from workflow_manager.workflow_v2.enums import ExecutionStatus
from workflow_manager.workflow_v2.models import Workflow, WorkflowExecution
logger = logging.getLogger(__name__)
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
@api_view(["GET"])
def get_workflow_execution_data(request, execution_id: str):
"""Get workflow execution data for workers.
Args:
execution_id: Workflow execution ID
Returns:
JSON response with workflow and execution data
"""
try:
# Get organization from header
org_id = request.headers.get("X-Organization-ID")
if not org_id:
return Response(
{"error": "X-Organization-ID header is required"},
status=status.HTTP_400_BAD_REQUEST,
)
# Get execution with organization filtering
execution = WorkflowExecution.objects.select_related("workflow").get(
id=execution_id, workflow__organization_id=org_id
)
workflow = execution.workflow
# Prepare workflow data
workflow_data = {
"id": str(workflow.id),
"workflow_name": workflow.workflow_name,
"execution_details": workflow.execution_details,
"organization_id": workflow.organization_id,
}
# Prepare execution data
execution_data = {
"id": str(execution.id),
"status": execution.status,
"execution_mode": execution.execution_mode,
"execution_method": execution.execution_method,
"execution_type": execution.execution_type,
"pipeline_id": execution.pipeline_id,
"total_files": execution.total_files,
"completed_files": execution.completed_files,
"failed_files": execution.failed_files,
"execution_log_id": execution.execution_log_id, # Include for WebSocket messaging
}
return Response(
{
"workflow": workflow_data,
"execution": execution_data,
}
)
except WorkflowExecution.DoesNotExist:
return Response(
{"error": f"Workflow execution {execution_id} not found"},
status=status.HTTP_404_NOT_FOUND,
)
except Exception as e:
logger.error(f"Error getting workflow execution data: {e}")
return Response(
{"error": "Internal server error"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
@api_view(["GET"])
def get_tool_instances_by_workflow(request, workflow_id: str):
"""Get tool instances for a workflow.
Args:
workflow_id: Workflow ID
Returns:
JSON response with tool instances data
"""
try:
# Get organization from header
org_id = request.headers.get("X-Organization-ID")
if not org_id:
logger.error(f"Missing X-Organization-ID header for workflow {workflow_id}")
return Response(
{"error": "X-Organization-ID header is required"},
status=status.HTTP_400_BAD_REQUEST,
)
logger.info(f"Getting tool instances for workflow {workflow_id}, org {org_id}")
# Get tool instances with organization filtering
# First check if workflow exists and belongs to organization
try:
# Get organization object first (org_id is the organization_id string field)
logger.info(f"Looking up organization with organization_id: {org_id}")
organization = Organization.objects.get(organization_id=org_id)
logger.info(
f"Found organization: {organization.id} - {organization.display_name}"
)
logger.info(
f"Looking up workflow {workflow_id} for organization {organization.id}"
)
workflow = Workflow.objects.get(id=workflow_id, organization=organization)
logger.info(f"Found workflow: {workflow.workflow_name}")
except Organization.DoesNotExist:
logger.error(f"Organization not found: {org_id}")
return Response(
{"error": f"Organization {org_id} not found"},
status=status.HTTP_404_NOT_FOUND,
)
except Workflow.DoesNotExist:
logger.error(f"Workflow {workflow_id} not found for organization {org_id}")
return Response(
{"error": f"Workflow {workflow_id} not found for organization {org_id}"},
status=status.HTTP_404_NOT_FOUND,
)
except Exception as e:
logger.error(
f"Unexpected error during organization/workflow lookup: {e}",
exc_info=True,
)
return Response(
{"error": "Database lookup error", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
# Get tool instances for the workflow
tool_instances = ToolInstance.objects.filter(workflow=workflow).order_by("step")
# Prepare tool instances data
instances_data = []
for instance in tool_instances:
instance_data = {
"id": str(instance.id),
"tool_id": instance.tool_id,
"step": instance.step,
"status": instance.status,
"version": instance.version,
"metadata": instance.metadata,
"input": instance.input,
"output": instance.output,
}
instances_data.append(instance_data)
return Response(
{
"tool_instances": instances_data,
}
)
except Exception as e:
logger.error(
f"Error getting tool instances for workflow {workflow_id}: {e}", exc_info=True
)
return Response(
{"error": "Internal server error", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
@api_view(["POST"])
def create_file_execution_batch(request):
"""Create a batch of file executions for workers.
Returns:
JSON response with batch creation result
"""
try:
# Get organization from header
org_id = request.headers.get("X-Organization-ID")
if not org_id:
logger.error(
"Missing X-Organization-ID header for file execution batch creation"
)
return Response(
{"error": "X-Organization-ID header is required"},
status=status.HTTP_400_BAD_REQUEST,
)
# For now, return a simple response indicating batch creation
# This would be expanded based on actual requirements
return Response(
{
"batch_id": "temp-batch-id",
"status": "created",
"organization_id": org_id,
}
)
except Exception as e:
logger.error(f"Error creating file execution batch: {e}", exc_info=True)
return Response(
{"error": "Internal server error", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
@api_view(["POST"])
def update_file_execution_batch_status(request):
"""Update file execution batch status for workers.
Returns:
JSON response with batch status update result
"""
try:
# Get organization from header
org_id = request.headers.get("X-Organization-ID")
if not org_id:
logger.error(
"Missing X-Organization-ID header for file execution batch status update"
)
return Response(
{"error": "X-Organization-ID header is required"},
status=status.HTTP_400_BAD_REQUEST,
)
# For now, return a simple response indicating status update
# This would be expanded based on actual requirements
return Response(
{
"status": "updated",
"organization_id": org_id,
}
)
except Exception as e:
logger.error(f"Error updating file execution batch status: {e}", exc_info=True)
return Response(
{"error": "Internal server error", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
@api_view(["POST"])
def create_workflow_execution(request):
"""Create a new workflow execution.
Returns:
JSON response with execution ID
"""
try:
data = request.data
# Get organization from header
org_id = request.headers.get("X-Organization-ID")
if not org_id:
return Response(
{"error": "X-Organization-ID header is required"},
status=status.HTTP_400_BAD_REQUEST,
)
# Get workflow with organization filtering
# First get organization object, then lookup workflow
try:
organization = Organization.objects.get(organization_id=org_id)
workflow = Workflow.objects.get(
id=data["workflow_id"], organization=organization
)
except Organization.DoesNotExist:
return Response(
{"error": f"Organization {org_id} not found"},
status=status.HTTP_404_NOT_FOUND,
)
# Create execution with log_events_id for WebSocket messaging
log_events_id = data.get("log_events_id")
# If log_events_id not provided, fall back to pipeline_id for backward compatibility
execution_log_id = log_events_id if log_events_id else data.get("pipeline_id")
execution = WorkflowExecution.objects.create(
workflow=workflow,
pipeline_id=data.get("pipeline_id"),
execution_mode=data.get("mode", WorkflowExecution.Mode.INSTANT),
execution_method=WorkflowExecution.Method.SCHEDULED
if data.get("scheduled")
else WorkflowExecution.Method.DIRECT,
execution_type=WorkflowExecution.Type.STEP
if data.get("single_step")
else WorkflowExecution.Type.COMPLETE,
status=ExecutionStatus.PENDING.value,
total_files=data.get("total_files", 0),
execution_log_id=execution_log_id, # Set execution_log_id for WebSocket messaging
)
# Set tags if provided
if data.get("tags"):
# Handle tags logic if needed
pass
return Response(
{
"execution_id": str(execution.id),
"status": execution.status,
"execution_log_id": execution.execution_log_id, # Return for workers to use
}
)
except Workflow.DoesNotExist:
return Response({"error": "Workflow not found"}, status=status.HTTP_404_NOT_FOUND)
except Exception as e:
logger.error(f"Error creating workflow execution: {e}")
return Response(
{"error": "Internal server error"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
@api_view(["POST"])
def compile_workflow(request):
"""Compile workflow for workers.
This is a database-only operation that workers need.
Returns:
JSON response with compilation result
"""
try:
data = request.data
workflow_id = data.get("workflow_id")
# Get organization from header
org_id = request.headers.get("X-Organization-ID")
if not org_id:
return Response(
{"error": "X-Organization-ID header is required"},
status=status.HTTP_400_BAD_REQUEST,
)
# For now, return success since compilation logic needs to be migrated
# TODO: Implement actual compilation logic in workers
return Response(
{
"success": True,
"workflow_id": workflow_id,
}
)
except Exception as e:
logger.error(f"Error compiling workflow: {e}")
return Response(
{"error": "Internal server error"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
@api_view(["POST"])
def submit_file_batch_for_processing(request):
"""Submit file batch for processing by workers.
This endpoint receives batch data and returns immediately,
as actual processing is handled by Celery workers.
Returns:
JSON response with batch submission status
"""
try:
batch_data = request.data
# Get organization from header
org_id = request.headers.get("X-Organization-ID")
if not org_id:
return Response(
{"error": "X-Organization-ID header is required"},
status=status.HTTP_400_BAD_REQUEST,
)
# Add organization ID to file_data where WorkerFileData expects it
if "file_data" in batch_data:
batch_data["file_data"]["organization_id"] = org_id
else:
# Fallback: add at top level for backward compatibility
batch_data["organization_id"] = org_id
# Submit to file processing worker queue using Celery
try:
from backend.celery_service import app as celery_app
# Submit the batch data to the file processing worker using send_task
# This calls the task by name without needing to import it
task_result = celery_app.send_task(
"process_file_batch", # Task name as defined in workers/file_processing/tasks.py
args=[batch_data], # Pass batch_data as first argument
queue="file_processing", # Send to file processing queue
)
logger.info(
f"Successfully submitted file batch {batch_data.get('batch_id')} to worker queue (task: {task_result.id})"
)
return Response(
{
"success": True,
"batch_id": batch_data.get("batch_id"),
"celery_task_id": task_result.id,
"message": "Batch submitted for processing",
}
)
except Exception as e:
logger.error(f"Failed to submit batch to worker queue: {e}")
return Response(
{"error": f"Failed to submit batch for processing: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
except Exception as e:
logger.error(f"Error submitting file batch: {e}")
return Response(
{"error": "Internal server error"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

View File

@@ -0,0 +1,220 @@
"""Workflow Manager Internal API Serializers
Handles serialization for workflow execution related internal endpoints.
"""
import logging
from pipeline_v2.models import Pipeline
from rest_framework import serializers
# Import shared dataclasses for type safety and consistency
from unstract.core.data_models import (
FileExecutionStatusUpdateRequest,
WorkflowFileExecutionData,
)
from workflow_manager.file_execution.models import WorkflowFileExecution
from workflow_manager.workflow_v2.enums import ExecutionStatus
from workflow_manager.workflow_v2.models.execution import WorkflowExecution
logger = logging.getLogger(__name__)
class WorkflowExecutionSerializer(serializers.ModelSerializer):
"""Serializer for WorkflowExecution model for internal API."""
workflow_id = serializers.CharField(source="workflow.id", read_only=True)
workflow_name = serializers.CharField(source="workflow.workflow_name", read_only=True)
pipeline_id = serializers.SerializerMethodField()
tags = serializers.SerializerMethodField()
def get_pipeline_id(self, obj):
"""ROOT CAUSE FIX: Return None for pipeline_id if the referenced pipeline doesn't exist.
This prevents callback workers from attempting to update deleted pipelines.
PERFORMANCE: Cache pipeline existence to avoid repeated DB queries.
"""
if not obj.pipeline_id:
return None
# Use instance-level cache to avoid repeated DB queries within same request
cache_key = f"_pipeline_exists_{obj.pipeline_id}"
if hasattr(self, cache_key):
return getattr(self, cache_key)
# Import here to avoid circular imports
from api_v2.models import APIDeployment
try:
# First check if it's a Pipeline
Pipeline.objects.get(id=obj.pipeline_id)
result = str(obj.pipeline_id)
setattr(self, cache_key, result)
return result
except Pipeline.DoesNotExist:
# Not a Pipeline, check if it's an APIDeployment
try:
APIDeployment.objects.get(id=obj.pipeline_id)
result = str(obj.pipeline_id)
setattr(self, cache_key, result)
return result
except APIDeployment.DoesNotExist:
# Neither Pipeline nor APIDeployment exists - return None to prevent stale reference usage
setattr(self, cache_key, None)
return None
def get_tags(self, obj):
"""Serialize tags as full objects with id, name, and description.
This method ensures tags are serialized as:
[{"id": "uuid", "name": "tag_name", "description": "..."}, ...]
instead of just ["uuid1", "uuid2", ...]
"""
try:
return [
{
"id": str(tag.id),
"name": tag.name,
"description": tag.description or "",
}
for tag in obj.tags.all()
]
except Exception as e:
logger.warning(f"Failed to serialize tags for execution {obj.id}: {str(e)}")
return []
class Meta:
model = WorkflowExecution
fields = [
"id",
"workflow_id",
"workflow_name",
"pipeline_id",
"task_id",
"execution_mode",
"execution_method",
"execution_type",
"execution_log_id",
"status",
"result_acknowledged",
"total_files",
"error_message",
"attempts",
"execution_time",
"created_at",
"modified_at",
"tags",
]
read_only_fields = ["id", "created_at", "modified_at"]
class WorkflowFileExecutionSerializer(serializers.ModelSerializer):
"""Serializer for WorkflowFileExecution model for internal API.
Enhanced with shared dataclass integration for type safety.
"""
workflow_execution_id = serializers.CharField(
source="workflow_execution.id", read_only=True
)
class Meta:
model = WorkflowFileExecution
fields = [
"id",
"workflow_execution_id",
"file_name",
"file_path",
"file_size",
"file_hash",
"provider_file_uuid",
"mime_type",
"fs_metadata",
"status",
"execution_error",
"created_at",
"modified_at",
]
read_only_fields = ["id", "created_at", "modified_at"]
def to_dataclass(self, instance=None) -> WorkflowFileExecutionData:
"""Convert serialized data to shared dataclass."""
if instance is None:
instance = self.instance
return WorkflowFileExecutionData.from_dict(self.to_representation(instance))
@classmethod
def from_dataclass(cls, data: WorkflowFileExecutionData) -> dict:
"""Convert shared dataclass to serializer-compatible dict."""
return data.to_dict()
class FileExecutionStatusUpdateSerializer(serializers.Serializer):
"""Serializer for updating file execution status.
Enhanced with shared dataclass integration for type safety.
"""
status = serializers.ChoiceField(choices=ExecutionStatus.choices)
error_message = serializers.CharField(required=False, allow_blank=True)
result = serializers.CharField(required=False, allow_blank=True)
execution_time = serializers.FloatField(required=False, min_value=0)
def to_dataclass(self) -> FileExecutionStatusUpdateRequest:
"""Convert validated data to shared dataclass."""
return FileExecutionStatusUpdateRequest(
status=self.validated_data["status"],
error_message=self.validated_data.get("error_message"),
result=self.validated_data.get("result"),
)
@classmethod
def from_dataclass(cls, data: FileExecutionStatusUpdateRequest):
"""Create serializer from shared dataclass."""
return cls(data=data.to_dict())
class WorkflowExecutionStatusUpdateSerializer(serializers.Serializer):
"""Serializer for updating workflow execution status."""
status = serializers.ChoiceField(choices=ExecutionStatus.choices)
error_message = serializers.CharField(required=False, allow_blank=True)
total_files = serializers.IntegerField(
required=False, min_value=0
) # Allow 0 but backend will only update if > 0
attempts = serializers.IntegerField(required=False, min_value=0)
execution_time = serializers.FloatField(required=False, min_value=0)
class OrganizationContextSerializer(serializers.Serializer):
"""Serializer for organization context information."""
organization_id = serializers.CharField(allow_null=True, required=False)
organization_name = serializers.CharField(required=False, allow_blank=True)
settings = serializers.DictField(required=False)
class WorkflowExecutionContextSerializer(serializers.Serializer):
"""Serializer for complete workflow execution context."""
execution = WorkflowExecutionSerializer()
workflow_definition = serializers.DictField()
source_config = serializers.DictField()
destination_config = serializers.DictField(required=False)
organization_context = OrganizationContextSerializer()
file_executions = serializers.ListField(required=False)
aggregated_usage_cost = serializers.FloatField(required=False, allow_null=True)
class FileBatchCreateSerializer(serializers.Serializer):
"""Serializer for creating file execution batches."""
workflow_execution_id = serializers.UUIDField()
files = serializers.ListField(child=serializers.DictField(), allow_empty=False)
is_api = serializers.BooleanField(default=False)
class FileBatchResponseSerializer(serializers.Serializer):
"""Serializer for file batch creation response."""
batch_id = serializers.CharField()
workflow_execution_id = serializers.CharField()
total_files = serializers.IntegerField()
created_file_executions = serializers.ListField()
skipped_files = serializers.ListField(required=False)

View File

@@ -0,0 +1,139 @@
"""Internal API URLs for Workflow Manager
URLs for internal APIs that workers use to communicate with Django backend.
These handle only database operations while business logic remains in workers.
"""
from django.urls import path
from . import internal_api_views, internal_views
app_name = "workflow_manager_internal"
urlpatterns = [
# Workflow execution endpoints - specific paths first
path(
"execution/create/",
internal_api_views.create_workflow_execution,
name="create_workflow_execution",
),
path(
"execution/<str:execution_id>/",
internal_api_views.get_workflow_execution_data,
name="get_workflow_execution_data",
),
# Tool instance endpoints
path(
"workflow/<str:workflow_id>/tool-instances/",
internal_api_views.get_tool_instances_by_workflow,
name="get_tool_instances_by_workflow",
),
# Workflow compilation
path(
"workflow/compile/",
internal_api_views.compile_workflow,
name="compile_workflow",
),
# File batch processing
path(
"file-batch/submit/",
internal_api_views.submit_file_batch_for_processing,
name="submit_file_batch_for_processing",
),
# Workflow definition and type detection (using sophisticated class-based views)
path(
"workflow/<str:workflow_id>/",
internal_views.WorkflowDefinitionAPIView.as_view(),
name="get_workflow_definition",
),
path(
"<str:workflow_id>/endpoint/",
internal_views.WorkflowEndpointAPIView.as_view(),
name="get_workflow_endpoints",
),
path(
"pipeline-type/<str:pipeline_id>/",
internal_views.PipelineTypeAPIView.as_view(),
name="get_pipeline_type",
),
path(
"pipeline-name/<str:pipeline_id>/",
internal_views.PipelineNameAPIView.as_view(),
name="get_pipeline_name",
),
# Batch operations (using sophisticated class-based views)
path(
"batch-status-update/",
internal_views.BatchStatusUpdateAPIView.as_view(),
name="batch_update_execution_status",
),
path(
"file-batch/",
internal_views.FileBatchCreateAPIView.as_view(),
name="create_file_batch",
),
# File management (using sophisticated class-based views)
path(
"increment-files/",
internal_views.FileCountIncrementAPIView.as_view(),
name="increment_files",
),
path(
"file-history/create/",
internal_views.FileHistoryCreateView.as_view(),
name="create_file_history_entry",
),
path(
"file-history/check-batch/",
internal_views.FileHistoryBatchCheckView.as_view(),
name="check_file_history_batch",
),
# Additional endpoints available in internal_views.py
path(
"source-files/<str:workflow_id>/",
internal_views.WorkflowSourceFilesAPIView.as_view(),
name="get_workflow_source_files",
),
# path("execution/finalize/<str:execution_id>/", removed - ExecutionFinalizationAPIView was unused dead code
path(
"execution/cleanup/",
internal_views.WorkflowExecutionCleanupAPIView.as_view(),
name="cleanup_executions",
),
path(
"execution/metrics/",
internal_views.WorkflowExecutionMetricsAPIView.as_view(),
name="get_execution_metrics",
),
path(
"file-execution/",
internal_views.WorkflowFileExecutionAPIView.as_view(),
name="workflow_file_execution",
),
path(
"file-execution/check-active",
internal_views.WorkflowFileExecutionCheckActiveAPIView.as_view(),
name="workflow_file_execution_check_active",
),
path(
"execute-file/",
internal_views.WorkflowExecuteFileAPIView.as_view(),
name="execute_workflow_file",
),
path(
"pipeline/<str:pipeline_id>/status/",
internal_views.PipelineStatusUpdateAPIView.as_view(),
name="update_pipeline_status",
),
# File execution batch operations (using simple function views for now)
path(
"file-execution/batch-create/",
internal_api_views.create_file_execution_batch,
name="file_execution_batch_create",
),
path(
"file-execution/batch-status-update/",
internal_api_views.update_file_execution_batch_status,
name="file_execution_batch_status_update",
),
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,27 @@
"""Workflow Manager Internal API URLs
Defines internal API endpoints for workflow execution operations.
"""
from django.urls import include, path
from rest_framework.routers import DefaultRouter
from .internal_views import FileBatchCreateAPIView, WorkflowExecutionInternalViewSet
# Create router for internal API viewsets
router = DefaultRouter()
router.register(
r"",
WorkflowExecutionInternalViewSet,
basename="workflow-execution-internal",
)
urlpatterns = [
# Workflow execution internal APIs
path(
"create-file-batch/",
FileBatchCreateAPIView.as_view(),
name="create-file-batch",
),
# Include router URLs for viewsets (this creates the CRUD endpoints)
path("", include(router.urls)),
]

View File

@@ -2,7 +2,10 @@ from __future__ import annotations
from enum import Enum
from django.db.models import TextChoices
from unstract.core.data_models import ExecutionStatus as SharedExecutionStatus
# Alias shared ExecutionStatus to ensure consistency between backend and workers
ExecutionStatus = SharedExecutionStatus
class WorkflowExecutionMethod(Enum):
@@ -10,105 +13,6 @@ class WorkflowExecutionMethod(Enum):
QUEUED = "QUEUED"
class ExecutionStatus(TextChoices):
"""An enumeration representing the various statuses of an execution
process.
Statuses:
PENDING: The execution's entry has been created in the database.
EXECUTING: The execution is currently in progress.
COMPLETED: The execution has been successfully completed.
STOPPED: The execution was stopped by the user
(applicable to step executions).
ERROR: An error occurred during the execution process.
Note:
Intermediate statuses might not be experienced due to
Django's query triggering once all processes are completed.
"""
PENDING = "PENDING"
EXECUTING = "EXECUTING"
COMPLETED = "COMPLETED"
STOPPED = "STOPPED"
ERROR = "ERROR"
@classmethod
def is_completed(cls, status: str | ExecutionStatus) -> bool:
"""Check if the execution status is completed."""
try:
status_enum = cls(status)
except ValueError:
raise ValueError(
f"Invalid status: {status}. Must be a valid ExecutionStatus."
)
return status_enum in [cls.COMPLETED, cls.STOPPED, cls.ERROR]
@classmethod
def is_active(cls, status: str | ExecutionStatus) -> bool:
"""Check if the workflow execution status is active (in progress)."""
try:
status_enum = cls(status)
except ValueError:
raise ValueError(
f"Invalid status: {status}. Must be a valid ExecutionStatus."
)
return status_enum in [cls.PENDING, cls.EXECUTING]
@classmethod
def get_skip_processing_statuses(cls) -> list[ExecutionStatus]:
"""Get list of statuses that should skip file processing.
Skip processing if:
- EXECUTING: File is currently being processed
- PENDING: File is queued to be processed
- COMPLETED: File has already been successfully processed
Returns:
list[ExecutionStatus]: List of statuses where file processing should be skipped
"""
return [cls.EXECUTING, cls.PENDING, cls.COMPLETED]
@classmethod
def should_skip_file_processing(cls, status: str | ExecutionStatus) -> bool:
"""Check if file processing should be skipped based on status.
Allow processing (retry) if:
- STOPPED: Processing was stopped, can retry
- ERROR: Processing failed, can retry
"""
try:
status_enum = cls(status)
except ValueError:
raise ValueError(
f"Invalid status: {status}. Must be a valid ExecutionStatus."
)
return status_enum in cls.get_skip_processing_statuses()
@classmethod
def can_update_to_pending(cls, status: str | ExecutionStatus) -> bool:
"""Check if a status can be updated to PENDING.
Allow updating to PENDING if:
- Status is STOPPED or ERROR (can retry)
- Status is None (new record)
Don't allow updating to PENDING if:
- Status is EXECUTING (currently processing)
- Status is COMPLETED (already done)
- Status is already PENDING (no change needed)
"""
if status is None:
return True
try:
status_enum = cls(status)
except ValueError:
return True # Invalid status, allow update
return status_enum in [cls.STOPPED, cls.ERROR]
class SchemaType(Enum):
"""Possible types for workflow module's JSON schema.

View File

@@ -78,12 +78,12 @@ class WorkflowExecutionServiceHelper(WorkflowExecutionService):
log_events_id = StateStore.get(Common.LOG_EVENTS_ID)
self.execution_log_id = log_events_id if log_events_id else pipeline_id
self.execution_mode = mode
self.execution_method: tuple[str, str] = (
self.execution_method = (
WorkflowExecution.Method.SCHEDULED
if scheduled
else WorkflowExecution.Method.DIRECT
)
self.execution_type: tuple[str, str] = (
self.execution_type = (
WorkflowExecution.Type.STEP
if single_step
else WorkflowExecution.Type.COMPLETE
@@ -94,7 +94,7 @@ class WorkflowExecutionServiceHelper(WorkflowExecutionService):
execution_mode=mode,
execution_method=self.execution_method,
execution_type=self.execution_type,
status=ExecutionStatus.EXECUTING,
status=ExecutionStatus.EXECUTING.value,
execution_log_id=self.execution_log_id,
)
workflow_execution.save()
@@ -140,12 +140,12 @@ class WorkflowExecutionServiceHelper(WorkflowExecutionService):
if existing_execution:
return existing_execution
execution_method: tuple[str, str] = (
execution_method = (
WorkflowExecution.Method.SCHEDULED
if scheduled
else WorkflowExecution.Method.DIRECT
)
execution_type: tuple[str, str] = (
execution_type = (
WorkflowExecution.Type.STEP
if single_step
else WorkflowExecution.Type.COMPLETE
@@ -159,7 +159,7 @@ class WorkflowExecutionServiceHelper(WorkflowExecutionService):
execution_mode=mode,
execution_method=execution_method,
execution_type=execution_type,
status=ExecutionStatus.PENDING,
status=ExecutionStatus.PENDING.value,
execution_log_id=execution_log_id,
total_files=total_files,
)
@@ -396,7 +396,7 @@ class WorkflowExecutionServiceHelper(WorkflowExecutionService):
def update_execution_err(execution_id: str, err_msg: str = "") -> WorkflowExecution:
try:
execution = WorkflowExecution.objects.get(pk=execution_id)
execution.status = ExecutionStatus.ERROR
execution.status = ExecutionStatus.ERROR.value
execution.error_message = err_msg[:EXECUTION_ERROR_LENGTH]
execution.save()
return execution

View File

@@ -0,0 +1,36 @@
"""Internal API URLs for Execution Log Operations
URLs for internal APIs that workers use to communicate with Django backend
for execution log operations. These handle database operations while business
logic remains in workers.
"""
from django.urls import path
from . import execution_log_internal_views
app_name = "execution_log_internal"
urlpatterns = [
# Execution log management endpoints
path(
"workflow-executions/by-ids/",
execution_log_internal_views.GetWorkflowExecutionsByIdsAPIView.as_view(),
name="get_workflow_executions_by_ids",
),
path(
"file-executions/by-ids/",
execution_log_internal_views.GetFileExecutionsByIdsAPIView.as_view(),
name="get_file_executions_by_ids",
),
path(
"executions/validate/",
execution_log_internal_views.ValidateExecutionReferencesAPIView.as_view(),
name="validate_execution_references",
),
path(
"process-log-history/",
execution_log_internal_views.ProcessLogHistoryAPIView.as_view(),
name="process_log_history",
),
]

View File

@@ -0,0 +1,171 @@
"""Internal API Views for Execution Log Operations
These views handle internal API requests from workers for execution log operations.
They provide database access while keeping business logic in workers.
"""
import logging
from rest_framework import status
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
from workflow_manager.file_execution.models import WorkflowFileExecution
from workflow_manager.workflow_v2.execution_log_utils import (
process_log_history_from_cache,
)
from workflow_manager.workflow_v2.models import WorkflowExecution
logger = logging.getLogger(__name__)
class GetWorkflowExecutionsByIdsAPIView(APIView):
"""API view for getting workflow executions by IDs."""
def post(self, request: Request) -> Response:
"""Get workflow execution data for given IDs.
Args:
request: HTTP request containing execution IDs
Returns:
JSON response with execution data
"""
try:
execution_ids = request.data.get("execution_ids", [])
executions = WorkflowExecution.objects.filter(id__in=execution_ids)
execution_data = {}
for execution in executions:
execution_data[str(execution.id)] = {
"id": str(execution.id),
"workflow_id": str(execution.workflow.id)
if execution.workflow
else None,
"status": execution.status,
"created_at": execution.created_at.isoformat()
if execution.created_at
else None,
}
return Response({"executions": execution_data})
except Exception as e:
logger.error(f"Error getting workflow executions: {e}", exc_info=True)
return Response(
{"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
class GetFileExecutionsByIdsAPIView(APIView):
"""API view for getting file executions by IDs."""
def post(self, request: Request) -> Response:
"""Get file execution data for given IDs.
Args:
request: HTTP request containing file execution IDs
Returns:
JSON response with file execution data
"""
try:
file_execution_ids = request.data.get("file_execution_ids", [])
file_executions = WorkflowFileExecution.objects.filter(
id__in=file_execution_ids
)
file_execution_data = {}
for file_execution in file_executions:
file_execution_data[str(file_execution.id)] = {
"id": str(file_execution.id),
"workflow_execution_id": str(file_execution.workflow_execution.id)
if file_execution.workflow_execution
else None,
"status": file_execution.status,
"created_at": file_execution.created_at.isoformat()
if file_execution.created_at
else None,
}
return Response({"file_executions": file_execution_data})
except Exception as e:
logger.error(f"Error getting file executions: {e}", exc_info=True)
return Response(
{"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
class ValidateExecutionReferencesAPIView(APIView):
"""API view for validating execution references exist."""
def post(self, request: Request) -> Response:
"""Validate that execution references exist.
Args:
request: HTTP request containing execution and file execution IDs
Returns:
JSON response with validation results
"""
try:
execution_ids = request.data.get("execution_ids", [])
file_execution_ids = request.data.get("file_execution_ids", [])
# Check which executions exist
existing_executions = {
str(obj.id)
for obj in WorkflowExecution.objects.filter(id__in=execution_ids)
}
# Check which file executions exist
existing_file_executions = {
str(obj.id)
for obj in WorkflowFileExecution.objects.filter(id__in=file_execution_ids)
}
return Response(
{
"valid_executions": list(existing_executions),
"valid_file_executions": list(existing_file_executions),
}
)
except Exception as e:
logger.error(f"Error validating execution references: {e}", exc_info=True)
return Response(
{"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
class ProcessLogHistoryAPIView(APIView):
"""API view for processing log history from scheduler.
This endpoint is called by the log history scheduler when logs exist in Redis queue.
It reuses the existing business logic from execution_log_utils.process_log_history_from_cache().
"""
def post(self, request: Request) -> Response:
"""Process log history batch from Redis cache.
Args:
request: HTTP request (no parameters needed)
Returns:
JSON response with processing results
"""
try:
# Reuse existing business logic (uses ExecutionLogConstants for config)
result = process_log_history_from_cache()
return Response(result)
except Exception as e:
logger.error(f"Error processing log history: {e}", exc_info=True)
return Response(
{"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)

View File

@@ -16,15 +16,35 @@ from workflow_manager.workflow_v2.models import ExecutionLog, WorkflowExecution
logger = logging.getLogger(__name__)
@shared_task(name=ExecutionLogConstants.TASK_V2)
def consume_log_history() -> None:
def process_log_history_from_cache(
queue_name: str = ExecutionLogConstants.LOG_QUEUE_NAME,
batch_limit: int = ExecutionLogConstants.LOGS_BATCH_LIMIT,
) -> dict:
"""Process log history from Redis cache.
This function contains the core business logic for processing execution logs
from Redis cache to database. It can be called by both the Celery task and
internal API endpoints.
Args:
queue_name: Redis queue name to process logs from
batch_limit: Maximum number of logs to process in one batch
Returns:
Dictionary with processing results:
- processed_count: Number of logs successfully stored
- skipped_count: Number of logs skipped (invalid references)
- total_logs: Total number of logs retrieved from cache
- organizations_processed: Number of organizations affected
"""
organization_logs = defaultdict(list)
logs_count = 0
logs_to_process = []
skipped_count = 0
# Collect logs from cache (batch retrieval)
while logs_count < ExecutionLogConstants.LOGS_BATCH_LIMIT:
log = CacheService.lpop(ExecutionLogConstants.LOG_QUEUE_NAME)
while logs_count < batch_limit:
log = CacheService.lpop(queue_name)
if not log:
break
@@ -34,9 +54,14 @@ def consume_log_history() -> None:
logs_count += 1
if not logs_to_process:
return # No logs to process
return {
"processed_count": 0,
"skipped_count": 0,
"total_logs": 0,
"organizations_processed": 0,
}
logger.info(f"Logs count: {logs_count}")
logger.info(f"Processing {logs_count} logs from queue '{queue_name}'")
# Preload required WorkflowExecution and WorkflowFileExecution objects
execution_ids = {log.execution_id for log in logs_to_process}
@@ -60,7 +85,8 @@ def consume_log_history() -> None:
f"Execution not found for execution_id: {log_data.execution_id}, "
"skipping log push"
)
continue # Skip logs with missing execution reference
skipped_count += 1
continue
execution_log = ExecutionLog(
wf_execution=execution,
@@ -69,16 +95,42 @@ def consume_log_history() -> None:
)
if log_data.file_execution_id:
execution_log.file_execution = file_execution_map.get(
log_data.file_execution_id
)
file_execution = file_execution_map.get(log_data.file_execution_id)
if file_execution:
execution_log.file_execution = file_execution
else:
logger.warning(
f"File execution not found for file_execution_id: {log_data.file_execution_id}, "
"skipping log push"
)
skipped_count += 1
continue
organization_logs[log_data.organization_id].append(execution_log)
# Bulk insert logs for each organization
processed_count = 0
for organization_id, logs in organization_logs.items():
logger.info(f"Storing '{len(logs)}' logs for org: {organization_id}")
logger.info(f"Storing {len(logs)} logs for org: {organization_id}")
ExecutionLog.objects.bulk_create(objs=logs, ignore_conflicts=True)
processed_count += len(logs)
return {
"processed_count": processed_count,
"skipped_count": skipped_count,
"total_logs": logs_count,
"organizations_processed": len(organization_logs),
}
@shared_task(name=ExecutionLogConstants.TASK_V2)
def consume_log_history() -> None:
"""Celery task to consume log history from Redis cache.
This task is a thin wrapper around process_log_history_from_cache() for
backward compatibility with existing Celery Beat schedules.
"""
process_log_history_from_cache()
def create_log_consumer_scheduler_if_not_exists() -> None:

View File

@@ -5,6 +5,7 @@ from typing import Any
from django.db.models import Q
from django.db.utils import IntegrityError
from django.utils import timezone
from utils.cache_service import CacheService
from workflow_manager.endpoint_v2.dto import FileHash
from workflow_manager.endpoint_v2.models import WorkflowEndpoint
@@ -245,8 +246,8 @@ class FileHistoryHelper:
metadata: str | None,
error: str | None = None,
is_api: bool = False,
) -> None:
"""Create a new file history record.
) -> FileHistory:
"""Create a new file history record or return existing one.
Args:
file_hash (FileHash): The file hash for the file.
@@ -255,35 +256,90 @@ class FileHistoryHelper:
result (Any): The result from the execution.
metadata (str | None): The metadata from the execution.
error (str | None): The error from the execution.
is_api (bool): Whether this is an API call.
"""
try:
file_path = file_hash.file_path if not is_api else None
is_api (bool): Whether this is for API workflow (affects file_path handling).
FileHistory.objects.create(
Returns:
FileHistory: Either newly created or existing file history record.
"""
file_path = file_hash.file_path if not is_api else None
# Prepare data for creation
create_data = {
"workflow": workflow,
"cache_key": file_hash.file_hash,
"provider_file_uuid": file_hash.provider_file_uuid,
"status": status,
"result": str(result),
"metadata": str(metadata) if metadata else "",
"error": str(error) if error else "",
"file_path": file_path,
}
try:
# Try to create the file history record
file_history = FileHistory.objects.create(**create_data)
logger.info(
f"Created new FileHistory record - "
f"file_name='{file_hash.file_name}', file_path='{file_hash.file_path}', "
f"file_hash='{file_hash.file_hash[:16] if file_hash.file_hash else 'None'}', "
f"workflow={workflow}"
)
return file_history
except IntegrityError as e:
# Race condition detected - another worker created the record
# Try to retrieve the existing record
logger.info(
f"FileHistory constraint violation (expected in concurrent environment) - "
f"file_name='{file_hash.file_name}', file_path='{file_hash.file_path}', "
f"file_hash='{file_hash.file_hash[:16] if file_hash.file_hash else 'None'}', "
f"workflow={workflow}. Error: {str(e)}"
)
# Use the existing get_file_history method to retrieve the record
existing_record = FileHistoryHelper.get_file_history(
workflow=workflow,
cache_key=file_hash.file_hash,
provider_file_uuid=file_hash.provider_file_uuid,
status=status,
result=str(result),
metadata=str(metadata) if metadata else "",
error=str(error) if error else "",
file_path=file_path,
)
except IntegrityError as e:
# TODO: Need to find why duplicate insert is coming
logger.warning(
f"Trying to insert duplication data for filename {file_hash.file_name} "
f"for workflow {workflow}. Error: {str(e)} with metadata {metadata}",
)
if existing_record:
logger.info(
f"Retrieved existing FileHistory record after constraint violation - "
f"ID: {existing_record.id}, workflow={workflow}"
)
return existing_record
else:
# This should rarely happen, but if we can't find the existing record,
# log the issue and re-raise the original exception
logger.error(
f"Failed to retrieve existing FileHistory record after constraint violation - "
f"file_name='{file_hash.file_name}', workflow={workflow}"
)
@staticmethod
def clear_history_for_workflow(
workflow: Workflow,
) -> None:
"""Clear all file history records associated with a workflow.
"""Clear all file history records and Redis caches associated with a workflow.
Args:
workflow (Workflow): The workflow to clear the history for.
"""
# Clear database records
FileHistory.objects.filter(workflow=workflow).delete()
logger.info(f"Cleared database records for workflow {workflow.id}")
# Clear Redis caches for file_active entries
pattern = f"file_active:{workflow.id}:*"
try:
CacheService.clear_cache_optimized(pattern)
logger.info(
f"Cleared Redis cache entries for workflow {workflow.id} with pattern: {pattern}"
)
except Exception as e:
logger.warning(
f"Failed to clear Redis caches for workflow {workflow.id}: {str(e)}"
)

View File

@@ -0,0 +1,45 @@
"""Internal API URLs for file history operations."""
from django.urls import path
from .views import (
create_file_history_internal,
file_history_batch_lookup_internal,
file_history_by_cache_key_internal,
file_history_status_internal,
get_file_history_internal,
reserve_file_processing_internal,
)
urlpatterns = [
# File history endpoints
path(
"cache-key/<str:cache_key>/",
file_history_by_cache_key_internal,
name="file-history-by-cache-key-internal",
),
# Flexible lookup endpoint (supports both cache_key and provider_file_uuid)
path(
"lookup/",
file_history_by_cache_key_internal,
name="file-history-lookup-internal",
),
# Batch lookup endpoint for multiple files
path(
"batch-lookup/",
file_history_batch_lookup_internal,
name="file-history-batch-lookup-internal",
),
path("create/", create_file_history_internal, name="create-file-history-internal"),
path(
"status/<str:file_history_id>/",
file_history_status_internal,
name="file-history-status-internal",
),
path(
"reserve/",
reserve_file_processing_internal,
name="reserve-file-processing-internal",
),
path("get/", get_file_history_internal, name="get-file-history-internal"),
]

View File

@@ -6,7 +6,6 @@ from api_v2.models import APIDeployment
from django.core.exceptions import ObjectDoesNotExist
from django.db import models
from django.db.models import QuerySet, Sum
from django.utils import timezone
from pipeline_v2.models import Pipeline
from tags.models import Tag
from usage_v2.constants import UsageKeys
@@ -226,6 +225,17 @@ class WorkflowExecution(BaseModel):
def is_completed(self) -> bool:
return ExecutionStatus.is_completed(self.status)
@property
def organization_id(self) -> str | None:
"""Get the organization ID from the associated workflow."""
if (
self.workflow
and hasattr(self.workflow, "organization")
and self.workflow.organization
):
return str(self.workflow.organization.organization_id)
return None
def __str__(self) -> str:
return (
f"Workflow execution: {self.id} ("
@@ -250,19 +260,15 @@ class WorkflowExecution(BaseModel):
increment_attempt (bool, optional): Whether to increment attempt counter. Defaults to False.
"""
if status is not None:
status = ExecutionStatus(status)
self.status = status.value
if (
status
in [
ExecutionStatus.COMPLETED,
ExecutionStatus.ERROR,
ExecutionStatus.STOPPED,
]
and not self.execution_time
):
self.execution_time = round(
(timezone.now() - self.created_at).total_seconds(), 3
)
if status in [
ExecutionStatus.COMPLETED,
ExecutionStatus.ERROR,
ExecutionStatus.STOPPED,
]:
self.execution_time = CommonUtils.time_since(self.created_at, 3)
if error:
self.error_message = error[:EXECUTION_ERROR_LENGTH]
if increment_attempt:

View File

@@ -18,7 +18,7 @@ class FileHistory(BaseModel):
Returns:
bool: True if the execution status is completed, False otherwise.
"""
return self.status is not None and self.status == ExecutionStatus.COMPLETED
return self.status is not None and self.status == ExecutionStatus.COMPLETED.value
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
cache_key = models.CharField(

View File

@@ -21,6 +21,7 @@ from backend.serializers import AuditSerializer
from workflow_manager.workflow_v2.constants import WorkflowExecutionKey, WorkflowKey
from workflow_manager.workflow_v2.models.execution import WorkflowExecution
from workflow_manager.workflow_v2.models.execution_log import ExecutionLog
from workflow_manager.workflow_v2.models.file_history import FileHistory
from workflow_manager.workflow_v2.models.workflow import Workflow
logger = logging.getLogger(__name__)
@@ -129,6 +130,12 @@ class WorkflowExecutionLogSerializer(ModelSerializer):
fields = "__all__"
class FileHistorySerializer(ModelSerializer):
class Meta:
model = FileHistory
fields = "__all__"
class SharedUserListSerializer(ModelSerializer):
"""Serializer for returning workflow with shared user details."""

File diff suppressed because it is too large Load Diff

View File

@@ -483,18 +483,13 @@ class WorkflowHelper:
},
queue=queue,
)
# Log task_id for debugging
logger.info(
f"[{org_schema}] AsyncResult created with task_id: '{async_execution.id}' "
f"(type: {type(async_execution.id).__name__})"
f"[{org_schema}] Job '{async_execution}' has been enqueued for "
f"execution_id '{execution_id}', '{len(hash_values_of_files)}' files"
)
workflow_execution: WorkflowExecution = WorkflowExecution.objects.get(
id=execution_id
)
# Handle empty task_id gracefully using existing validation logic
if not async_execution.id:
logger.warning(
f"[{org_schema}] Celery returned empty task_id for execution_id '{execution_id}'. "
@@ -509,6 +504,7 @@ class WorkflowHelper:
f"[{org_schema}] Job '{async_execution.id}' has been enqueued for "
f"execution_id '{execution_id}', '{len(hash_values_of_files)}' files"
)
execution_status = workflow_execution.status
if timeout > -1:
while not ExecutionStatus.is_completed(execution_status) and timeout > 0:
@@ -779,17 +775,16 @@ class WorkflowHelper:
# Normal Workflow page execution
workflow_execution = WorkflowExecution.objects.get(pk=execution_id)
if (
workflow_execution.status != ExecutionStatus.PENDING
workflow_execution.status != ExecutionStatus.PENDING.value
or workflow_execution.execution_type != WorkflowExecution.Type.COMPLETE
):
raise InvalidRequest(WorkflowErrors.INVALID_EXECUTION_ID)
organization_identifier = UserContext.get_organization_identifier()
result: ExecutionResponse = WorkflowHelper.run_workflow(
workflow=workflow,
workflow_execution=workflow_execution,
result: ExecutionResponse = WorkflowHelper.execute_workflow_async(
workflow_id=str(workflow.id) if workflow else None,
pipeline_id=str(pipeline_id) if pipeline_id else None,
execution_id=str(execution_id) if execution_id else None,
hash_values_of_files=hash_values_of_files,
use_file_history=use_file_history,
organization_id=organization_identifier,
)
result = WorkflowHelper.wait_for_execution(result, timeout=timeout)
return result
@@ -813,7 +808,8 @@ class WorkflowHelper:
ExecutionResponse: The execution response.
"""
if (
result.execution_status in [ExecutionStatus.COMPLETED, ExecutionStatus.ERROR]
result.execution_status
in [ExecutionStatus.COMPLETED.value, ExecutionStatus.ERROR.value]
or not timeout
):
return result
@@ -879,7 +875,7 @@ class WorkflowHelper:
except WorkflowExecution.DoesNotExist:
raise WorkflowExecutionNotExist(WorkflowErrors.INVALID_EXECUTION_ID)
if (
workflow_execution.status != ExecutionStatus.PENDING
workflow_execution.status != ExecutionStatus.PENDING.value
or workflow_execution.execution_type != WorkflowExecution.Type.STEP
):
raise InvalidRequest(WorkflowErrors.INVALID_EXECUTION_ID)

View File

@@ -30,6 +30,23 @@ VERSION=dev docker compose -f docker-compose.yaml --profile optional up -d
Now access frontend at http://frontend.unstract.localhost
## V2 Workers (Optional)
V2 workers use a unified container architecture and are **disabled by default**.
```bash
# Default: Run with legacy workers only
VERSION=dev docker compose -f docker-compose.yaml up -d
# Enable V2 workers (unified container)
VERSION=dev docker compose -f docker-compose.yaml --profile workers-v2 up -d
# Or use the platform script
./run-platform.sh --workers-v2
```
V2 workers available: `api-deployment`, `callback`, `file-processing`, `general`, `notification`, `log-consumer`, `scheduler`
## Overriding a service's config
By making use of the [merge compose files](https://docs.docker.com/compose/how-tos/multiple-compose-files/merge/) feature its possible to override some configuration that's used by the services.

View File

@@ -50,3 +50,11 @@ services:
build:
dockerfile: docker/dockerfiles/x2text.Dockerfile
context: ..
# Unified worker image (replaces all individual worker images)
worker-unified:
image: unstract/worker-unified:${VERSION}
build:
dockerfile: docker/dockerfiles/worker-unified.Dockerfile
context: ..
args:
MINIMAL_BUILD: ${MINIMAL_BUILD:-0} # Set to 1 for faster dev builds

View File

@@ -239,6 +239,313 @@ services:
labels:
- traefik.enable=false
# ====================================================================
# V2 DEDICATED WORKER SERVICES (opt-in with --workers-v2 flag)
# ====================================================================
worker-api-deployment-v2:
image: unstract/worker-unified:${VERSION}
container_name: unstract-worker-api-deployment-v2
restart: unless-stopped
command: ["api-deployment"]
ports:
- "8085:8090"
env_file:
- ../workers/.env
- ./essentials.env
depends_on:
- db
- redis
- rabbitmq
environment:
- ENVIRONMENT=development
- APPLICATION_NAME=unstract-worker-api-deployment-v2
- WORKER_TYPE=api_deployment
- CELERY_QUEUES_API_DEPLOYMENT=${CELERY_QUEUES_API_DEPLOYMENT:-celery_api_deployments}
- CELERY_POOL=${WORKER_API_DEPLOYMENT_POOL:-threads}
- CELERY_PREFETCH_MULTIPLIER=${WORKER_API_DEPLOYMENT_PREFETCH_MULTIPLIER:-1}
- CELERY_CONCURRENCY=${WORKER_API_DEPLOYMENT_CONCURRENCY:-4}
- CELERY_EXTRA_ARGS=${WORKER_API_DEPLOYMENT_EXTRA_ARGS:-}
- WORKER_NAME=api-deployment-worker-v2
- API_DEPLOYMENT_METRICS_PORT=8090
- HEALTH_PORT=8090
labels:
- traefik.enable=false
volumes:
- ./workflow_data:/data
- ${TOOL_REGISTRY_CONFIG_SRC_PATH}:/data/tool_registry_config
profiles:
- workers-v2
worker-callback-v2:
image: unstract/worker-unified:${VERSION}
container_name: unstract-worker-callback-v2
restart: unless-stopped
command: ["callback"]
ports:
- "8086:8083"
env_file:
- ../workers/.env
- ./essentials.env
depends_on:
- db
- redis
- rabbitmq
environment:
- ENVIRONMENT=development
- APPLICATION_NAME=unstract-worker-callback-v2
- WORKER_TYPE=callback
- WORKER_NAME=callback-worker-v2
- CALLBACK_METRICS_PORT=8083
labels:
- traefik.enable=false
volumes:
- ./workflow_data:/data
- ${TOOL_REGISTRY_CONFIG_SRC_PATH}:/data/tool_registry_config
profiles:
- workers-v2
worker-file-processing-v2:
image: unstract/worker-unified:${VERSION}
container_name: unstract-worker-file-processing-v2
restart: unless-stopped
# command: ["file-processing"]
command: [".venv/bin/celery", "-A", "worker", "worker", "--queues=file_processing,api_file_processing,file_processing_priority", "--loglevel=INFO", "--pool=prefork", "--concurrency=4", "--prefetch-multiplier=1", "--without-gossip", "--without-mingle", "--without-heartbeat"]
ports:
- "8087:8082"
env_file:
- ../workers/.env
- ./essentials.env
depends_on:
- db
- redis
- rabbitmq
environment:
- ENVIRONMENT=development
- APPLICATION_NAME=unstract-worker-file-processing-v2
- WORKER_TYPE=file_processing
- WORKER_MODE=oss
- WORKER_NAME=file-processing-worker-v2
- FILE_PROCESSING_METRICS_PORT=8082
# OSS Configuration - Enterprise features disabled
- MANUAL_REVIEW_ENABLED=false
- ENTERPRISE_FEATURES_ENABLED=false
- PLUGIN_REGISTRY_MODE=oss
# Configurable Celery options
- CELERY_QUEUES_FILE_PROCESSING=${CELERY_QUEUES_FILE_PROCESSING:-file_processing,api_file_processing}
- CELERY_POOL=${WORKER_FILE_PROCESSING_POOL:-threads}
- CELERY_PREFETCH_MULTIPLIER=${WORKER_FILE_PROCESSING_PREFETCH_MULTIPLIER:-1}
- CELERY_CONCURRENCY=${WORKER_FILE_PROCESSING_CONCURRENCY:-4}
- CELERY_EXTRA_ARGS=${WORKER_FILE_PROCESSING_EXTRA_ARGS:-}
labels:
- traefik.enable=false
volumes:
- ./workflow_data:/data
- ${TOOL_REGISTRY_CONFIG_SRC_PATH}:/data/tool_registry_config
profiles:
- workers-v2
worker-general-v2:
image: unstract/worker-unified:${VERSION}
container_name: unstract-worker-general-v2
restart: unless-stopped
command: ["general"]
ports:
- "8088:8082"
env_file:
- ../workers/.env
- ./essentials.env
depends_on:
- db
- redis
- rabbitmq
environment:
- ENVIRONMENT=development
- APPLICATION_NAME=unstract-worker-general-v2
- WORKER_TYPE=general
- WORKER_NAME=general-worker-v2
- GENERAL_METRICS_PORT=8081
- HEALTH_PORT=8082
labels:
- traefik.enable=false
volumes:
- ./workflow_data:/data
- ${TOOL_REGISTRY_CONFIG_SRC_PATH}:/data/tool_registry_config
profiles:
- workers-v2
worker-notification-v2:
image: unstract/worker-unified:${VERSION}
container_name: unstract-worker-notification-v2
restart: unless-stopped
command: ["notification"]
ports:
- "8089:8085"
env_file:
- ../workers/.env
- ./essentials.env
depends_on:
- db
- redis
- rabbitmq
environment:
- ENVIRONMENT=development
- APPLICATION_NAME=unstract-worker-notification-v2
- WORKER_TYPE=notification
- WORKER_NAME=notification-worker-v2
- NOTIFICATION_METRICS_PORT=8085
- HEALTH_PORT=8085
# Notification specific configs
- NOTIFICATION_QUEUE_NAME=notifications
- WEBHOOK_QUEUE_NAME=notifications_webhook
- EMAIL_QUEUE_NAME=notifications_email
- SMS_QUEUE_NAME=notifications_sms
- PRIORITY_QUEUE_NAME=notifications_priority
# Configurable Celery options
- CELERY_QUEUES_NOTIFICATION=${CELERY_QUEUES_NOTIFICATION:-notifications,notifications_webhook,notifications_email,notifications_sms,notifications_priority}
- CELERY_POOL=${WORKER_NOTIFICATION_POOL:-prefork}
- CELERY_PREFETCH_MULTIPLIER=${WORKER_NOTIFICATION_PREFETCH_MULTIPLIER:-1}
- CELERY_CONCURRENCY=${WORKER_NOTIFICATION_CONCURRENCY:-4}
- CELERY_EXTRA_ARGS=${WORKER_NOTIFICATION_EXTRA_ARGS:-}
# Complete command override (if set, ignores all other options)
- CELERY_COMMAND_OVERRIDE=${WORKER_NOTIFICATION_COMMAND_OVERRIDE:-}
# Individual argument overrides
- CELERY_APP_MODULE=${WORKER_NOTIFICATION_APP_MODULE:-worker}
- CELERY_LOG_LEVEL=${WORKER_NOTIFICATION_LOG_LEVEL:-INFO}
- CELERY_HOSTNAME=${WORKER_NOTIFICATION_HOSTNAME:-}
- CELERY_MAX_TASKS_PER_CHILD=${WORKER_NOTIFICATION_MAX_TASKS_PER_CHILD:-}
- CELERY_TIME_LIMIT=${WORKER_NOTIFICATION_TIME_LIMIT:-}
- CELERY_SOFT_TIME_LIMIT=${WORKER_NOTIFICATION_SOFT_TIME_LIMIT:-}
labels:
- traefik.enable=false
volumes:
- ./workflow_data:/data
- ${TOOL_REGISTRY_CONFIG_SRC_PATH}:/data/tool_registry_config
profiles:
- workers-v2
worker-log-consumer-v2:
image: unstract/worker-unified:${VERSION}
container_name: unstract-worker-log-consumer-v2
restart: unless-stopped
command: ["log-consumer"]
ports:
- "8090:8084"
env_file:
- ../workers/.env
- ./essentials.env
depends_on:
- db
- redis
- rabbitmq
environment:
- ENVIRONMENT=development
- APPLICATION_NAME=unstract-worker-log-consumer-v2
- WORKER_TYPE=log_consumer
- WORKER_NAME=log-consumer-worker-v2
- LOG_CONSUMER_METRICS_PORT=8084
- HEALTH_PORT=8084
# Log consumer specific configs
- LOG_CONSUMER_QUEUE_NAME=celery_log_task_queue
# Multiple queue support - supports comma-separated queue names
- CELERY_QUEUES_LOG_CONSUMER=${CELERY_QUEUES_LOG_CONSUMER:-celery_log_task_queue,celery_periodic_logs}
- PERIODIC_LOGS_QUEUE_NAME=${PERIODIC_LOGS_QUEUE_NAME:-celery_periodic_logs}
# Log history configuration
- LOG_HISTORY_QUEUE_NAME=${LOG_HISTORY_QUEUE_NAME:-log_history_queue}
- LOGS_BATCH_LIMIT=${LOGS_BATCH_LIMIT:-100}
- ENABLE_LOG_HISTORY=${ENABLE_LOG_HISTORY:-true}
- CELERY_POOL=${WORKER_LOG_CONSUMER_POOL:-prefork}
- CELERY_PREFETCH_MULTIPLIER=${WORKER_LOG_CONSUMER_PREFETCH_MULTIPLIER:-1}
- CELERY_CONCURRENCY=${WORKER_LOG_CONSUMER_CONCURRENCY:-2}
- CELERY_EXTRA_ARGS=${WORKER_LOG_CONSUMER_EXTRA_ARGS:-}
# Complete command override (if set, ignores all other options)
- CELERY_COMMAND_OVERRIDE=${WORKER_LOG_CONSUMER_COMMAND_OVERRIDE:-}
# Individual argument overrides
- CELERY_APP_MODULE=${WORKER_LOG_CONSUMER_APP_MODULE:-worker}
- CELERY_LOG_LEVEL=${WORKER_LOG_CONSUMER_LOG_LEVEL:-INFO}
- CELERY_HOSTNAME=${WORKER_LOG_CONSUMER_HOSTNAME:-}
- CELERY_MAX_TASKS_PER_CHILD=${WORKER_LOG_CONSUMER_MAX_TASKS_PER_CHILD:-}
- CELERY_TIME_LIMIT=${WORKER_LOG_CONSUMER_TIME_LIMIT:-}
- CELERY_SOFT_TIME_LIMIT=${WORKER_LOG_CONSUMER_SOFT_TIME_LIMIT:-}
labels:
- traefik.enable=false
volumes:
- ./workflow_data:/data
- ${TOOL_REGISTRY_CONFIG_SRC_PATH}:/data/tool_registry_config
profiles:
- workers-v2
worker-log-history-scheduler-v2:
image: unstract/worker-unified:${VERSION}
container_name: unstract-worker-log-history-scheduler-v2
restart: unless-stopped
entrypoint: ["/bin/bash"]
command: ["/app/log_consumer/scheduler.sh"]
env_file:
- ../workers/.env
- ./essentials.env
depends_on:
- db
- redis
- rabbitmq
environment:
- ENVIRONMENT=development
- APPLICATION_NAME=unstract-worker-log-history-scheduler-v2
# Scheduler interval in seconds
- LOG_HISTORY_CONSUMER_INTERVAL=${LOG_HISTORY_CONSUMER_INTERVAL:-5}
# Override example: TASK_TRIGGER_COMMAND=/custom/trigger/script.sh
- TASK_TRIGGER_COMMAND=${TASK_TRIGGER_COMMAND:-}
labels:
- traefik.enable=false
profiles:
- workers-v2
worker-scheduler-v2:
image: unstract/worker-unified:${VERSION}
container_name: unstract-worker-scheduler-v2
restart: unless-stopped
command: ["scheduler"]
ports:
- "8091:8087"
env_file:
- ../workers/.env
- ./essentials.env
depends_on:
- db
- redis
- rabbitmq
environment:
- ENVIRONMENT=development
- APPLICATION_NAME=unstract-worker-scheduler-v2
- WORKER_TYPE=scheduler
- WORKER_NAME=scheduler-worker-v2
- SCHEDULER_METRICS_PORT=8087
- HEALTH_PORT=8087
# Scheduler specific configs
- SCHEDULER_QUEUE_NAME=scheduler
# Configurable Celery options
- CELERY_QUEUES_SCHEDULER=${CELERY_QUEUES_SCHEDULER:-scheduler}
- CELERY_POOL=${WORKER_SCHEDULER_POOL:-prefork}
- CELERY_PREFETCH_MULTIPLIER=${WORKER_SCHEDULER_PREFETCH_MULTIPLIER:-1}
- CELERY_CONCURRENCY=${WORKER_SCHEDULER_CONCURRENCY:-2}
- CELERY_EXTRA_ARGS=${WORKER_SCHEDULER_EXTRA_ARGS:-}
# Complete command override (if set, ignores all other options)
- CELERY_COMMAND_OVERRIDE=${WORKER_SCHEDULER_COMMAND_OVERRIDE:-}
# Individual argument overrides
- CELERY_APP_MODULE=${WORKER_SCHEDULER_APP_MODULE:-worker}
- CELERY_LOG_LEVEL=${WORKER_SCHEDULER_LOG_LEVEL:-INFO}
- CELERY_HOSTNAME=${WORKER_SCHEDULER_HOSTNAME:-}
- CELERY_MAX_TASKS_PER_CHILD=${WORKER_SCHEDULER_MAX_TASKS_PER_CHILD:-}
- CELERY_TIME_LIMIT=${WORKER_SCHEDULER_TIME_LIMIT:-}
- CELERY_SOFT_TIME_LIMIT=${WORKER_SCHEDULER_SOFT_TIME_LIMIT:-}
labels:
- traefik.enable=false
volumes:
- ./workflow_data:/data
- ${TOOL_REGISTRY_CONFIG_SRC_PATH}:/data/tool_registry_config
profiles:
- workers-v2
volumes:
prompt_studio_data:
unstract_data:

View File

@@ -0,0 +1,85 @@
# Unified Worker Dockerfile - Optimized for fast builds
FROM python:3.12.9-slim AS base
ARG VERSION=dev
LABEL maintainer="Zipstack Inc." \
description="Unified Worker Container for All Worker Types" \
version="${VERSION}"
# Set environment variables (CRITICAL: PYTHONPATH makes paths work!)
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PYTHONPATH=/app:/unstract \
BUILD_CONTEXT_PATH=workers \
BUILD_PACKAGES_PATH=unstract \
APP_HOME=/app \
# OpenTelemetry configuration (disabled by default, enable in docker-compose)
OTEL_TRACES_EXPORTER=none \
OTEL_LOGS_EXPORTER=none \
OTEL_SERVICE_NAME=unstract_workers
# Install system dependencies (minimal for workers)
RUN apt-get update \
&& apt-get --no-install-recommends install -y \
curl \
gcc \
libmagic-dev \
libssl-dev \
pkg-config \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
# Install uv package manager
COPY --from=ghcr.io/astral-sh/uv:0.6.14 /uv /uvx /bin/
# Create non-root user early to avoid ownership issues
RUN groupadd -r worker && useradd -r -g worker worker && \
mkdir -p /home/worker && chown -R worker:worker /home/worker
# Create working directory
WORKDIR ${APP_HOME}
# -----------------------------------------------
# EXTERNAL DEPENDENCIES STAGE - This layer gets cached
# -----------------------------------------------
FROM base AS ext-dependencies
# Copy dependency files (including README.md like backend)
COPY ${BUILD_CONTEXT_PATH}/pyproject.toml ${BUILD_CONTEXT_PATH}/uv.lock ./
# Create empty README.md if it doesn't exist in the copy
RUN touch README.md
# Copy local package dependencies to /unstract directory
# This provides the unstract packages for imports
COPY ${BUILD_PACKAGES_PATH}/ /unstract/
# Install external dependencies with --locked for FAST builds
# No symlinks needed - PYTHONPATH handles the paths
RUN uv sync --group deploy --locked --no-install-project --no-dev
# -----------------------------------------------
# FINAL STAGE - Minimal image for production
# -----------------------------------------------
FROM ext-dependencies AS production
# Copy application code (this layer changes most frequently)
COPY ${BUILD_CONTEXT_PATH}/ ./
# Set shell with pipefail for proper error handling in pipes
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
# Install project and OpenTelemetry instrumentation (as root to avoid permission issues)
# No symlinks needed - PYTHONPATH handles the paths correctly
RUN uv sync --group deploy --locked && \
uv run opentelemetry-bootstrap -a requirements | uv pip install --requirement - && \
{ chmod +x ./run-worker.sh ./run-worker-docker.sh 2>/dev/null || true; } && \
touch requirements.txt && \
{ chown -R worker:worker ./run-worker.sh ./run-worker-docker.sh 2>/dev/null || true; }
# Switch to worker user
USER worker
# Default command - runs the Docker-optimized worker script
ENTRYPOINT ["/app/run-worker-docker.sh"]
CMD ["general"]

View File

@@ -0,0 +1,74 @@
# Unified Worker Docker ignore file
# Based on worker-base.dockerignore but unified for all worker types
# Virtual environments
**/.venv/
**/venv/
**/__pycache__/
**/.pytest_cache/
**/.mypy_cache/
# IDE and editor files
**/.vscode/
**/.idea/
**/*.swp
**/*.swo
**/*~
# OS files
.DS_Store
Thumbs.db
# Git
.git/
.gitignore
# Docker files (avoid recursion)
**/Dockerfile*
**/*.dockerignore
# Build artifacts
**/dist/
**/build/
**/*.egg-info/
# Logs
**/*.log
**/logs/
# Test files
**/tests/
**/test_*.py
**/*_test.py
# Development files
**/dev-*
**/sample.*
**/example.*
# Node modules (if any)
**/node_modules/
# Documentation
**/docs/
**/*.md
!README.md
# Configuration that shouldn't be in containers
**/.env*
**/local_settings.py
# Coverage reports
**/htmlcov/
**/.coverage
**/coverage.xml
# Backup files
**/*.bak
**/*.backup
**/*.orig
# Temporary files
**/tmp/
**/temp/
**/.tmp/

View File

@@ -11,3 +11,103 @@ WORKER_LOGGING_AUTOSCALE=4,1
WORKER_AUTOSCALE=4,1
WORKER_FILE_PROCESSING_AUTOSCALE=4,1
WORKER_FILE_PROCESSING_CALLBACK_AUTOSCALE=4,1
# New unified worker autoscaling (matches hierarchical configuration below)
WORKER_API_DEPLOYMENT_AUTOSCALE=4,1 # API deployment worker autoscale
WORKER_CALLBACK_AUTOSCALE=4,1 # Callback worker autoscale
WORKER_GENERAL_AUTOSCALE=6,2 # General worker autoscale (enhanced)
WORKER_FILE_PROCESSING_NEW_AUTOSCALE=8,2 # File processing unified worker autoscale
WORKER_NOTIFICATION_AUTOSCALE=4,1 # Notification worker autoscale
WORKER_LOG_CONSUMER_AUTOSCALE=2,1 # Log consumer worker autoscale
WORKER_SCHEDULER_AUTOSCALE=2,1 # Scheduler worker autoscale
# Worker-specific configurations
API_DEPLOYMENT_WORKER_NAME=api-deployment-worker
API_DEPLOYMENT_HEALTH_PORT=8080
API_DEPLOYMENT_MAX_CONCURRENT_TASKS=5
CALLBACK_WORKER_NAME=callback-worker
CALLBACK_HEALTH_PORT=8083
CALLBACK_MAX_CONCURRENT_TASKS=3
FILE_PROCESSING_WORKER_NAME=file-processing-worker
FILE_PROCESSING_HEALTH_PORT=8082
FILE_PROCESSING_MAX_CONCURRENT_TASKS=4
GENERAL_WORKER_NAME=general-worker
GENERAL_HEALTH_PORT=8081
GENERAL_MAX_CONCURRENT_TASKS=10
# =============================================================================
# HIERARCHICAL CELERY CONFIGURATION SYSTEM
# =============================================================================
#
# This system uses a 3-tier hierarchy for all Celery settings (most specific wins):
# 1. {WORKER_TYPE}_{SETTING_NAME} - Worker-specific override (highest priority)
# 2. CELERY_{SETTING_NAME} - Global override (medium priority)
# 3. Code default - Celery standard default (lowest priority)
#
# Examples:
# - CALLBACK_TASK_TIME_LIMIT=3600 (callback worker only)
# - CELERY_TASK_TIME_LIMIT=300 (all workers)
# - Code provides default if neither is set
#
# Worker types: API_DEPLOYMENT, GENERAL, FILE_PROCESSING, CALLBACK,
# NOTIFICATION, LOG_CONSUMER, SCHEDULER
# =============================================================================
# Global Celery Configuration (applies to all workers unless overridden)
CELERY_RESULT_CHORD_RETRY_INTERVAL=3 # Global chord unlock retry interval
CELERY_TASK_TIME_LIMIT=7200 # Global task timeout (2 hours)
CELERY_TASK_SOFT_TIME_LIMIT=6300 # Global soft timeout (1h 45m)
CELERY_PREFETCH_MULTIPLIER=1 # Global prefetch multiplier
CELERY_MAX_TASKS_PER_CHILD=1000 # Global max tasks per child process
CELERY_TASK_ACKS_LATE=true # Global acks late setting
CELERY_TASK_DEFAULT_RETRY_DELAY=60 # Global retry delay (1 minute)
CELERY_TASK_MAX_RETRIES=3 # Global max retries
# Worker-Specific Configuration Examples
# Callback Worker - Chord settings and extended timeouts
CALLBACK_RESULT_CHORD_RETRY_INTERVAL=3 # Callback-specific chord retry interval
CALLBACK_TASK_TIME_LIMIT=7200 # Callback tasks need more time (2 hours)
CALLBACK_TASK_SOFT_TIME_LIMIT=6300 # Callback soft timeout (1h 45m)
# File Processing Worker - Thread pool and optimized settings
FILE_PROCESSING_POOL_TYPE=threads # Use threads instead of prefork
FILE_PROCESSING_CONCURRENCY=4 # Fixed concurrency for file processing
FILE_PROCESSING_TASK_TIME_LIMIT=10800 # File processing timeout (3 hours)
# API Deployment Worker - Autoscaling and timeout configuration
API_DEPLOYMENT_AUTOSCALE=4,1 # Max 4, min 1 workers
API_DEPLOYMENT_TASK_TIME_LIMIT=3600 # API deployment timeout (1 hour)
# General Worker - Enhanced scaling for high-throughput tasks
GENERAL_AUTOSCALE=6,2 # Max 6, min 2 workers
# Docker Worker-Specific Concurrency Settings (for docker-compose.yaml)
WORKER_API_DEPLOYMENT_CONCURRENCY=4 # API deployment fixed concurrency
WORKER_FILE_PROCESSING_CONCURRENCY=8 # File processing fixed concurrency
WORKER_NOTIFICATION_CONCURRENCY=4 # Notification worker concurrency
WORKER_LOG_CONSUMER_CONCURRENCY=2 # Log consumer worker concurrency
WORKER_SCHEDULER_CONCURRENCY=2 # Scheduler worker concurrency
# Notification Worker - Optimized for quick message processing
NOTIFICATION_AUTOSCALE=4,1 # Max 4, min 1 workers
NOTIFICATION_TASK_TIME_LIMIT=120 # Quick timeout for notifications
# Scheduler Worker - Conservative settings for scheduled tasks
SCHEDULER_AUTOSCALE=2,1 # Max 2, min 1 workers
SCHEDULER_TASK_TIME_LIMIT=1800 # Scheduler timeout (30 minutes)
# Log Consumer Worker - Optimized for log processing
LOG_CONSUMER_AUTOSCALE=2,1 # Max 2, min 1 workers
LOG_CONSUMER_TASK_TIME_LIMIT=600 # Log processing timeout (10 minutes)
# Worker Circuit Breaker Settings
CIRCUIT_BREAKER_FAILURE_THRESHOLD=5
CIRCUIT_BREAKER_RECOVERY_TIMEOUT=60
# Worker Health Check Settings
HEALTH_CHECK_INTERVAL=30
HEALTH_CHECK_TIMEOUT=10
ENABLE_METRICS=true

View File

@@ -22,6 +22,10 @@ dev = [
"types-tzlocal~=5.1.0.1",
]
workers = [
"unstract-workers",
]
hook-check-django-migrations = [
"celery>=5.3.4",
"cron-descriptor==1.4.0",
@@ -55,6 +59,8 @@ unstract-tool-registry = { path = "./unstract/tool-registry", editable = true }
unstract-flags = { path = "./unstract/flags", editable = true }
unstract-core = { path = "./unstract/core", editable = true }
unstract-connectors = { path = "./unstract/connectors", editable = true }
# Workers
unstract-workers = { path = "./workers", editable = true }
# === Development tool configurations ===
[tool.ruff]

View File

@@ -71,6 +71,7 @@ display_help() {
echo -e " -p, --only-pull Only do docker images pull"
echo -e " -b, --build-local Build docker images locally"
echo -e " -u, --update Update services version"
echo -e " -w, --workers-v2 Use v2 dedicated worker containers"
echo -e " -x, --trace Enables trace mode"
echo -e " -V, --verbose Print verbose logs"
echo -e " -v, --version Docker images version tag (default \"latest\")"
@@ -97,6 +98,9 @@ parse_args() {
-u | --update)
opt_update=true
;;
-w | --workers-v2)
opt_workers_v2=true
;;
-x | --trace)
set -o xtrace # display every line before execution; enables PS4
;;
@@ -128,6 +132,7 @@ parse_args() {
debug "OPTION only_pull: $opt_only_pull"
debug "OPTION build_local: $opt_build_local"
debug "OPTION upgrade: $opt_update"
debug "OPTION workers_v2: $opt_workers_v2"
debug "OPTION verbose: $opt_verbose"
debug "OPTION version: $opt_version"
}
@@ -280,8 +285,13 @@ build_services() {
run_services() {
pushd "$script_dir/docker" 1>/dev/null
echo -e "$blue_text""Starting docker containers in detached mode""$default_text"
VERSION=$opt_version $docker_compose_cmd up -d
if [ "$opt_workers_v2" = true ]; then
echo -e "$blue_text""Starting docker containers with V2 dedicated workers in detached mode""$default_text"
VERSION=$opt_version $docker_compose_cmd --profile workers-v2 up -d
else
echo -e "$blue_text""Starting docker containers with existing backend-based workers in detached mode""$default_text"
VERSION=$opt_version $docker_compose_cmd up -d
fi
if [ "$opt_update" = true ]; then
echo ""
@@ -324,6 +334,7 @@ opt_only_env=false
opt_only_pull=false
opt_build_local=false
opt_update=false
opt_workers_v2=false
opt_verbose=false
opt_version="latest"
@@ -331,6 +342,8 @@ script_dir=$(dirname "$(readlink -f "$BASH_SOURCE")")
first_setup=false
# Extract service names from docker compose file
services=($(VERSION=$opt_version $docker_compose_cmd -f "$script_dir/docker/docker-compose.build.yaml" config --services))
# Add workers manually for env setup
services+=("workers")
spawned_services=("tool-structure" "tool-sidecar")
current_version=""
target_branch=""

View File

@@ -1,6 +1,6 @@
import logging
from flask import Blueprint, Response, jsonify
from flask import Blueprint
logger = logging.getLogger(__name__)
@@ -9,7 +9,6 @@ health_bp = Blueprint("health", __name__)
# Define a route to ping test
@health_bp.route("/ping", methods=["GET"])
def ping() -> Response:
logger.info("Ping request received")
return jsonify({"message": "pong!!!"})
@health_bp.route("/health", methods=["GET"])
def health_check() -> str:
return "OK"

View File

@@ -296,7 +296,7 @@ class UnstractRunner:
settings_json = json.dumps(settings).replace("'", "\\'")
# Prepare the tool execution command
tool_cmd = (
f"opentelemetry-instrument python main.py --command RUN "
f"python main.py --command RUN "
f"--settings '{settings_json}' --log-level DEBUG"
)

View File

@@ -17,7 +17,8 @@ ENV \
OTEL_METRICS_EXPORTER=none \
OTEL_LOGS_EXPORTER=none \
# Enable context propagation
OTEL_PROPAGATORS="tracecontext"
OTEL_PROPAGATORS="tracecontext" \
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
# Install system dependencies in one layer
RUN apt-get update && \

View File

@@ -0,0 +1,2 @@
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
# Unstract namespace package

View File

@@ -1,7 +1,10 @@
import logging
from logging import NullHandler
from typing import Any
from unstract.connectors.connection_types import ConnectionType
logging.getLogger(__name__).addHandler(NullHandler())
ConnectorDict = dict[str, dict[str, Any]]
__all__ = [
"ConnectionType",
]

View File

@@ -0,0 +1,70 @@
"""Unified Connection Types for Unstract Platform
This module provides a centralized definition of connection types used across
the entire Unstract platform to ensure consistency and prevent duplication.
"""
from enum import Enum
class ConnectionType(str, Enum):
"""Core connection types for workflow endpoints and connectors.
This enum provides the fundamental connection types used across:
- workers/shared/enums.py
- workers/shared/workflow/source_connector.py
- workers/shared/workflow/destination_connector.py
- unstract/core/src/unstract/core/data_models.py
- unstract/core/src/unstract/core/workflow_utils.py
"""
FILESYSTEM = "FILESYSTEM"
DATABASE = "DATABASE"
API = "API"
MANUALREVIEW = "MANUALREVIEW"
def __str__(self):
return self.value
@property
def is_filesystem(self) -> bool:
"""Check if this is a filesystem connection type."""
return self == ConnectionType.FILESYSTEM
@property
def is_database(self) -> bool:
"""Check if this is a database connection type."""
return self == ConnectionType.DATABASE
@property
def is_api(self) -> bool:
"""Check if this is an API connection type."""
return self == ConnectionType.API
@property
def is_manual_review(self) -> bool:
"""Check if this is a manual review connection type."""
return self == ConnectionType.MANUALREVIEW
@classmethod
def from_string(cls, connection_type: str) -> "ConnectionType":
"""Create ConnectionType from string, with validation.
Args:
connection_type: Connection type string
Returns:
ConnectionType enum value
Raises:
ValueError: If connection type is not recognized or is empty
"""
if not connection_type:
raise ValueError("Connection type cannot be empty")
connection_type_upper = connection_type.upper()
try:
return cls(connection_type_upper)
except ValueError:
raise ValueError(f"Unknown connection type: {connection_type}")

View File

@@ -3,9 +3,8 @@ from typing import Any
from singleton_decorator import singleton
from unstract.connectors import ConnectorDict # type: ignore
from unstract.connectors.base import UnstractConnector
from unstract.connectors.constants import Common
from unstract.connectors.constants import Common, ConnectorDict
from unstract.connectors.databases import connectors as db_connectors
from unstract.connectors.enums import ConnectorMode
from unstract.connectors.filesystems import connectors as fs_connectors

View File

@@ -1,9 +1,16 @@
from typing import Any
class Common:
METADATA = "metadata"
MODULE = "module"
CONNECTOR = "connector"
# Type definitions
ConnectorDict = dict[str, dict[str, Any]]
class DatabaseTypeConstants:
"""Central location for all database-specific type constants."""

View File

@@ -1,4 +1,4 @@
from unstract.connectors import ConnectorDict # type: ignore
from unstract.connectors.constants import ConnectorDict
from unstract.connectors.databases.register import register_connectors
connectors: ConnectorDict = {}

View File

@@ -7,8 +7,6 @@ from enum import Enum
from typing import Any
import google.api_core.exceptions
from google.cloud import bigquery
from google.cloud.bigquery import Client
from unstract.connectors.constants import DatabaseTypeConstants
from unstract.connectors.databases.exceptions import (
@@ -28,6 +26,9 @@ BIG_QUERY_TABLE_SIZE = 3
class BigQuery(UnstractDB):
def __init__(self, settings: dict[str, Any]):
super().__init__("BigQuery")
from google.cloud import bigquery
self.bigquery = bigquery
self.json_credentials = json.loads(settings.get("json_credentials", "{}"))
self.big_query_table_size = BIG_QUERY_TABLE_SIZE
@@ -62,8 +63,8 @@ class BigQuery(UnstractDB):
def can_read() -> bool:
return True
def get_engine(self) -> Client:
return bigquery.Client.from_service_account_info( # type: ignore
def get_engine(self) -> Any:
return self.bigquery.Client.from_service_account_info( # type: ignore
info=self.json_credentials
)
@@ -208,21 +209,23 @@ class BigQuery(UnstractDB):
f"@`{key}`", f"PARSE_JSON(@`{key}`)"
)
query_parameters.append(
bigquery.ScalarQueryParameter(key, "STRING", json_str)
self.bigquery.ScalarQueryParameter(key, "STRING", json_str)
)
elif isinstance(value, (dict, list)):
# For dict/list values in STRING columns, serialize to JSON string
json_str = json.dumps(value) if value else None
query_parameters.append(
bigquery.ScalarQueryParameter(key, "STRING", json_str)
self.bigquery.ScalarQueryParameter(key, "STRING", json_str)
)
else:
# For other values, use STRING as before
query_parameters.append(
bigquery.ScalarQueryParameter(key, "STRING", value)
self.bigquery.ScalarQueryParameter(key, "STRING", value)
)
query_params = bigquery.QueryJobConfig(query_parameters=query_parameters)
query_params = self.bigquery.QueryJobConfig(
query_parameters=query_parameters
)
query_job = engine.query(modified_sql, job_config=query_params)
else:
query_job = engine.query(sql_query)

View File

@@ -144,8 +144,9 @@ class PostgreSQL(UnstractDB, PsycoPgHandler):
Returns:
str: generates a create sql base query with the constant columns
"""
quoted_table = self._quote_identifier(table)
sql_query = (
f"CREATE TABLE IF NOT EXISTS {table} "
f"CREATE TABLE IF NOT EXISTS {quoted_table} "
f"(id TEXT, "
f"created_by TEXT, created_at TIMESTAMP, "
f"metadata JSONB, "
@@ -158,8 +159,9 @@ class PostgreSQL(UnstractDB, PsycoPgHandler):
return sql_query
def prepare_multi_column_migration(self, table_name: str, column_name: str) -> str:
quoted_table = self._quote_identifier(table_name)
sql_query = (
f"ALTER TABLE {table_name} "
f"ALTER TABLE {quoted_table} "
f"ADD COLUMN {column_name}_v2 JSONB, "
f"ADD COLUMN metadata JSONB, "
f"ADD COLUMN user_field_1 BOOLEAN DEFAULT FALSE, "
@@ -182,3 +184,41 @@ class PostgreSQL(UnstractDB, PsycoPgHandler):
schema=self.schema,
table_name=table_name,
)
@staticmethod
def _quote_identifier(identifier: str) -> str:
"""Quote PostgreSQL identifier to handle special characters like hyphens.
PostgreSQL identifiers with special characters must be enclosed in double quotes.
This method adds proper quoting for table names containing hyphens, spaces,
or other special characters.
Args:
identifier (str): Table name or column name to quote
Returns:
str: Properly quoted identifier safe for PostgreSQL
"""
# Always quote the identifier to handle special characters like hyphens
# This is safe even for valid identifiers and prevents SQL injection
return f'"{identifier}"'
def get_sql_insert_query(
self, table_name: str, sql_keys: list[str], sql_values: list[str] | None = None
) -> str:
"""Override base method to add PostgreSQL-specific table name quoting.
Generates INSERT query with properly quoted table name for PostgreSQL.
Args:
table_name (str): Name of the table
sql_keys (list[str]): List of column names
sql_values (list[str], optional): SQL values for database-specific handling (ignored for PostgreSQL)
Returns:
str: INSERT query with properly quoted table name
"""
quoted_table = self._quote_identifier(table_name)
keys_str = ", ".join(sql_keys)
values_placeholder = ", ".join(["%s"] * len(sql_keys))
return f"INSERT INTO {quoted_table} ({keys_str}) VALUES ({values_placeholder})"

View File

@@ -33,7 +33,10 @@ def register_connectors(connectors: dict[str, Any]) -> None:
Common.METADATA: metadata,
}
except ModuleNotFoundError as exception:
logger.error(f"Error while importing connectors : {exception}")
logger.error(
f"Error while importing connectors {connector} : {exception}",
exc_info=True,
)
if len(connectors) == 0:
logger.warning("No connector found.")

View File

@@ -6,10 +6,6 @@ import uuid
from enum import Enum
from typing import Any
import snowflake.connector
import snowflake.connector.errors as SnowflakeError
from snowflake.connector.connection import SnowflakeConnection
from unstract.connectors.constants import DatabaseTypeConstants
from unstract.connectors.databases.exceptions import SnowflakeProgrammingException
from unstract.connectors.databases.unstract_db import UnstractDB
@@ -88,8 +84,10 @@ class SnowflakeDB(UnstractDB):
}
return str(mapping.get(data_type, DatabaseTypeConstants.SNOWFLAKE_TEXT))
def get_engine(self) -> SnowflakeConnection:
con = snowflake.connector.connect(
def get_engine(self) -> Any:
from snowflake.connector import connect
con = connect(
user=self.user,
password=self.password,
account=self.account,
@@ -134,6 +132,8 @@ class SnowflakeDB(UnstractDB):
def execute_query(
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
) -> None:
import snowflake.connector.errors as SnowflakeError
table_name = kwargs.get("table_name", None)
logger.debug(f"Snowflake execute_query called with sql_query: {sql_query}")
logger.debug(f"sql_values: {sql_values}")
@@ -169,6 +169,8 @@ class SnowflakeDB(UnstractDB):
) from e
def get_information_schema(self, table_name: str) -> dict[str, str]:
import snowflake.connector.errors as SnowflakeError
query = f"describe table {table_name}"
column_types: dict[str, str] = {}
try:

View File

@@ -1,4 +1,4 @@
from unstract.connectors import ConnectorDict # type: ignore
from unstract.connectors.constants import ConnectorDict
from unstract.connectors.filesystems.register import register_connectors
from .local_storage.local_storage import * # noqa: F401, F403

View File

@@ -5,7 +5,7 @@ from email.utils import parsedate_to_datetime
from typing import Any
import azure.core.exceptions as AzureException
from adlfs import AzureBlobFileSystem
from fsspec import AbstractFileSystem
from unstract.connectors.exceptions import AzureHttpError
from unstract.connectors.filesystems.azure_cloud_storage.exceptions import (
@@ -14,7 +14,17 @@ from unstract.connectors.filesystems.azure_cloud_storage.exceptions import (
from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem
from unstract.filesystem import FileStorageType, FileSystem
# Suppress verbose Azure SDK HTTP request/response logging
logging.getLogger("azurefs").setLevel(logging.ERROR)
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(
logging.WARNING
)
logging.getLogger("azure.storage.blob").setLevel(logging.WARNING)
logging.getLogger("azure.storage").setLevel(logging.WARNING)
logging.getLogger("azure.core").setLevel(logging.WARNING)
# Keep ADLFS filesystem errors visible but suppress HTTP noise
logging.getLogger("adlfs").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
@@ -23,6 +33,8 @@ class AzureCloudStorageFS(UnstractFileSystem):
INVALID_PATH = "The specifed resource name contains invalid characters."
def __init__(self, settings: dict[str, Any]):
from adlfs import AzureBlobFileSystem
super().__init__("AzureCloudStorageFS")
account_name = settings.get("account_name", "")
access_key = settings.get("access_key", "")
@@ -70,7 +82,7 @@ class AzureCloudStorageFS(UnstractFileSystem):
def can_read() -> bool:
return True
def get_fsspec_fs(self) -> AzureBlobFileSystem:
def get_fsspec_fs(self) -> AbstractFileSystem:
return self.azure_fs
def extract_metadata_file_hash(self, metadata: dict[str, Any]) -> str | None:

View File

@@ -5,7 +5,7 @@ import os
from datetime import UTC, datetime
from typing import Any
from gcsfs import GCSFileSystem
from fsspec import AbstractFileSystem
from unstract.connectors.exceptions import ConnectorError
from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem
@@ -26,6 +26,8 @@ class GoogleCloudStorageFS(UnstractFileSystem):
project_id = settings.get("project_id", "")
json_credentials_str = settings.get("json_credentials", "{}")
try:
from gcsfs import GCSFileSystem
json_credentials = json.loads(json_credentials_str)
self.gcs_fs = GCSFileSystem(
token=json_credentials,
@@ -81,7 +83,7 @@ class GoogleCloudStorageFS(UnstractFileSystem):
def can_read() -> bool:
return True
def get_fsspec_fs(self) -> GCSFileSystem:
def get_fsspec_fs(self) -> AbstractFileSystem:
return self.gcs_fs
def extract_metadata_file_hash(self, metadata: dict[str, Any]) -> str | None:
@@ -109,8 +111,46 @@ class GoogleCloudStorageFS(UnstractFileSystem):
Returns:
bool: True if the path is a directory, False otherwise.
"""
# Note: Here Metadata type seems to be always "file" even for directories
return metadata.get("type") == "directory"
# Primary check: Standard directory type
if metadata.get("type") == "directory":
return True
# GCS-specific directory detection
# In GCS, folders are represented as objects with specific characteristics
object_name = metadata.get("name", "")
size = metadata.get("size", 0)
content_type = metadata.get("contentType", "")
# GCS folder indicators:
# 1. Object name ends with "/" (most reliable indicator)
if object_name.endswith("/"):
logger.debug(
f"[GCS Directory Check] '{object_name}' identified as directory: name ends with '/'"
)
return True
# 2. Zero-size object with text/plain content type (common for GCS folders)
if size == 0 and content_type == "text/plain":
logger.debug(
f"[GCS Directory Check] '{object_name}' identified as directory: zero-size with text/plain content type"
)
return True
# 3. Check for GCS-specific folder metadata
# Some GCS folder objects have no contentType or have application/x-www-form-urlencoded
if size == 0 and (
not content_type
or content_type
in ["application/x-www-form-urlencoded", "binary/octet-stream"]
):
# Additional validation: check if this looks like a folder path
if "/" in object_name and not object_name.split("/")[-1]: # Path ends with /
logger.debug(
f"[GCS Directory Check] '{object_name}' identified as directory: zero-size folder-like object"
)
return True
return False
def extract_modified_date(self, metadata: dict[str, Any]) -> datetime | None:
"""Extract the last modified date from GCS metadata.

View File

@@ -32,7 +32,10 @@ def register_connectors(connectors: dict[str, Any]) -> None:
Common.METADATA: metadata,
}
except ModuleNotFoundError as exception:
logger.error(f"Error while importing connectors : {exception}")
logger.error(
f"Error while importing connectors {connector} : {exception}",
exc_info=True,
)
if len(connectors) == 0:
logger.warning("No connector found.")

View File

@@ -3,20 +3,18 @@ import os
from datetime import UTC, datetime
from typing import Any
from dropbox.exceptions import ApiError as DropBoxApiError
from dropbox.exceptions import DropboxException
from dropboxdrivefs import DropboxDriveFileSystem
from fsspec import AbstractFileSystem
from unstract.connectors.exceptions import ConnectorError
from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem
from .exceptions import handle_dropbox_exception
logger = logging.getLogger(__name__)
class DropboxFS(UnstractFileSystem):
def __init__(self, settings: dict[str, Any]):
from dropboxdrivefs import DropboxDriveFileSystem
super().__init__("Dropbox")
self.dropbox_fs = DropboxDriveFileSystem(token=settings["token"])
self.path = "///"
@@ -61,7 +59,7 @@ class DropboxFS(UnstractFileSystem):
def can_read() -> bool:
return True
def get_fsspec_fs(self) -> DropboxDriveFileSystem:
def get_fsspec_fs(self) -> AbstractFileSystem:
return self.dropbox_fs
def extract_metadata_file_hash(self, metadata: dict[str, Any]) -> str | None:
@@ -132,10 +130,14 @@ class DropboxFS(UnstractFileSystem):
def test_credentials(self) -> bool:
"""To test credentials for Dropbox."""
from dropbox.exceptions import DropboxException
try:
# self.get_fsspec_fs().connect()
self.get_fsspec_fs().ls("")
except DropboxException as e:
from .exceptions import handle_dropbox_exception
raise handle_dropbox_exception(e) from e
except Exception as e:
raise ConnectorError(f"Error while connecting to Dropbox: {str(e)}") from e
@@ -143,11 +145,23 @@ class DropboxFS(UnstractFileSystem):
@staticmethod
def get_connector_root_dir(input_dir: str, **kwargs: Any) -> str:
"""Get roor dir of zs dropbox."""
return f"/{input_dir.strip('/')}"
"""Get root dir of zs dropbox with backward compatibility.
Dropbox requires leading slashes, so we override the base class behavior.
"""
# Call base class implementation
result = super().get_connector_root_dir(input_dir, **kwargs)
# Dropbox needs leading slash - ensure it's present
if not result.startswith("/"):
result = f"/{result}"
return result
def create_dir_if_not_exists(self, input_dir: str) -> None:
"""Create roor dir of zs dropbox if not exists."""
from dropbox.exceptions import ApiError as DropBoxApiError
fs_fsspec = self.get_fsspec_fs()
try:
fs_fsspec.isdir(input_dir)

View File

@@ -4,8 +4,6 @@ import logging
import os
from typing import Any
from google.cloud import secretmanager
from google.cloud.storage import Client
from google.oauth2 import service_account
from google.oauth2.credentials import Credentials
@@ -20,6 +18,9 @@ logger = logging.getLogger(__name__)
class GCSHelper:
def __init__(self) -> None:
from google.cloud.storage import Client
self.client = Client
self.google_service_json = os.environ.get("GDRIVE_GOOGLE_SERVICE_ACCOUNT")
self.google_project_id = os.environ.get("GDRIVE_GOOGLE_PROJECT_ID")
if self.google_service_json is None:
@@ -39,6 +40,8 @@ class GCSHelper:
return self.google_credentials
def get_secret(self, secret_name: str) -> str:
from google.cloud import secretmanager
google_secrets_client = secretmanager.SecretManagerServiceClient(
credentials=self.google_credentials
)
@@ -50,7 +53,7 @@ class GCSHelper:
return s.payload.data.decode("UTF-8")
def get_object_checksum(self, bucket_name: str, object_name: str) -> str:
client = Client(credentials=self.google_credentials)
client = self.client(credentials=self.google_credentials)
bucket = client.bucket(bucket_name)
md5_hash_hex = ""
try:
@@ -62,26 +65,26 @@ class GCSHelper:
return md5_hash_hex
def upload_file(self, bucket_name: str, object_name: str, file_path: str) -> None:
client = Client(credentials=self.google_credentials)
client = self.client(credentials=self.google_credentials)
bucket = client.bucket(bucket_name)
blob = bucket.blob(object_name)
blob.upload_from_filename(file_path)
def upload_text(self, bucket_name: str, object_name: str, text: str) -> None:
client = Client(credentials=self.google_credentials)
client = self.client(credentials=self.google_credentials)
bucket = client.bucket(bucket_name)
blob = bucket.blob(object_name)
blob.upload_from_string(text)
def upload_object(self, bucket_name: str, object_name: str, object: Any) -> None:
client = Client(credentials=self.google_credentials)
client = self.client(credentials=self.google_credentials)
bucket = client.bucket(bucket_name)
blob = bucket.blob(object_name)
blob.upload_from_string(object, content_type="application/octet-stream")
def read_file(self, bucket_name: str, object_name: str) -> Any:
logger.info(f"Reading file {object_name} from bucket {bucket_name}")
client = Client(credentials=self.google_credentials)
client = self.client(credentials=self.google_credentials)
bucket = client.bucket(bucket_name)
logger.info(f"Reading file {object_name} from bucket {bucket_name}")
try:

View File

@@ -0,0 +1,146 @@
"""Connector Operations for Unstract Platform
This module provides core connector operations for filesystem connectors
and connector health checks.
Used by:
- workers/shared/workflow/connectors/service.py (for worker-native operations)
"""
import logging
from typing import Any
# Import internal connector components (no try/catch needed - proper dependencies)
from unstract.connectors.constants import Common
from unstract.connectors.filesystems import connectors as fs_connectors
from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem
logger = logging.getLogger(__name__)
class ConnectorOperations:
"""Common connector operations shared between backend and workers with strict error handling"""
@staticmethod
def test_connector_connection(
connector_id: str, settings: dict[str, Any]
) -> dict[str, Any]:
"""Test connection to connector before attempting operations.
Args:
connector_id: Connector ID
settings: Connector settings
Returns:
Dictionary with connection test results: {'is_connected': bool, 'error': str}
"""
try:
# Get connector instance
connector = ConnectorOperations.get_fs_connector(connector_id, settings)
# Test basic connectivity by getting fsspec filesystem
fs = connector.get_fsspec_fs()
# For filesystem connectors, try to check if root path exists
test_path = settings.get("path", "/")
try:
fs.exists(test_path)
return {"is_connected": True, "error": None}
except Exception as path_error:
return {
"is_connected": False,
"error": f"Cannot access path '{test_path}': {str(path_error)}",
}
except Exception as e:
return {"is_connected": False, "error": str(e)}
@staticmethod
def get_fs_connector(
connector_id: str, settings: dict[str, Any]
) -> "UnstractFileSystem":
"""Get filesystem connector instance using exact backend BaseConnector logic.
This replicates backend/workflow_manager/endpoint_v2/base_connector.py:get_fs_connector()
Args:
connector_id: Connector ID from the registry
settings: Connector-specific settings
Returns:
UnstractFileSystem instance
Raises:
ImportError: If connector registries not available (critical error)
ValueError: If connector_id is not supported
"""
if not fs_connectors:
raise RuntimeError("Filesystem connectors registry not initialized")
if connector_id not in fs_connectors:
available_ids = list(fs_connectors.keys())
raise ValueError(
f"Connector '{connector_id}' is not supported. "
f"Available connectors: {available_ids}"
)
if not Common:
raise RuntimeError("Common connector constants not initialized")
# Use exact same pattern as backend BaseConnector
connector_class = fs_connectors[connector_id][Common.METADATA][Common.CONNECTOR]
return connector_class(settings)
@staticmethod
def get_connector_health(source_config: dict[str, Any]) -> dict[str, Any]:
"""Get health status of a source connector.
Args:
source_config: Source configuration dictionary
Returns:
Dictionary with health status and metadata
"""
try:
connector_id = source_config.get("connector_id") or source_config.get(
"connection_type"
)
settings = source_config.get("settings", {})
if not connector_id or not settings:
return {
"is_healthy": False,
"connection_type": connector_id,
"errors": ["Missing connector configuration"],
"response_time_ms": None,
}
import time
start_time = time.time()
# Test connection
connection_result = ConnectorOperations.test_connector_connection(
connector_id, settings
)
response_time = int(
(time.time() - start_time) * 1000
) # Convert to milliseconds
return {
"is_healthy": connection_result["is_connected"],
"connection_type": connector_id,
"errors": [connection_result["error"]]
if connection_result["error"]
else [],
"response_time_ms": response_time,
}
except Exception as e:
return {
"is_healthy": False,
"connection_type": source_config.get("connector_id", "unknown"),
"errors": [str(e)],
"response_time_ms": None,
}

View File

@@ -1,4 +1,4 @@
from unstract.connectors import ConnectorDict
from unstract.connectors.constants import ConnectorDict
from unstract.connectors.queues.register import register_connectors
connectors: ConnectorDict = {}

View File

@@ -32,7 +32,10 @@ def register_connectors(connectors: dict[str, Any]) -> None:
Common.METADATA: metadata,
}
except ModuleNotFoundError as exception:
logger.error(f"Error while importing connectors : {exception}")
logger.error(
f"Error while importing connectors {connector} : {exception}",
exc_info=True,
)
if len(connectors) == 0:
logger.warning("No connector found.")

View File

@@ -0,0 +1,2 @@
# Unstract namespace package
__path__ = __import__("pkgutil").extend_path(__path__, __name__)

View File

@@ -0,0 +1,99 @@
"""Unstract Core Library
Core data models, utilities, and base classes for the Unstract platform.
Provides shared functionality between backend and worker services.
"""
# Export core data models and enums
# Export existing utilities and constants
from .constants import LogEventArgument, LogFieldName, LogProcessingTask
from .data_models import (
ConnectionType,
ExecutionStatus,
FileHashData,
SourceConnectionType,
WorkflowExecutionData,
WorkflowFileExecutionData,
WorkflowType,
serialize_dataclass_to_dict,
)
# Export worker base classes
from .worker_base import (
CallbackTaskBase,
FileProcessingTaskBase,
WorkerTaskBase,
circuit_breaker,
create_callback_task,
create_file_processing_task,
create_task_decorator,
monitor_performance,
with_task_context,
)
# Note: Worker constants moved to workers/shared/ to remove Django dependency
# These are now available directly from workers.shared.constants and workers.shared.worker_patterns
# Export worker-specific models and enums
from .worker_models import (
BatchExecutionResult,
CallbackExecutionData,
FileExecutionResult,
NotificationMethod,
NotificationRequest,
PipelineStatus,
PipelineStatusUpdateRequest,
QueueName,
StatusMappings,
TaskError,
TaskExecutionContext,
TaskName,
WebhookResult,
WebhookStatus,
WorkerTaskStatus,
WorkflowExecutionUpdateRequest,
)
__version__ = "1.0.0"
__all__ = [
# Core data models and enums
"ExecutionStatus",
"WorkflowType",
"ConnectionType",
"FileHashData",
"WorkflowFileExecutionData",
"WorkflowExecutionData",
"SourceConnectionType",
"serialize_dataclass_to_dict",
# Worker models and enums
"TaskName",
"QueueName",
"WorkerTaskStatus",
"PipelineStatus",
"WebhookStatus",
"NotificationMethod",
"StatusMappings",
"WebhookResult",
"FileExecutionResult",
"BatchExecutionResult",
"CallbackExecutionData",
"WorkflowExecutionUpdateRequest",
"PipelineStatusUpdateRequest",
"NotificationRequest",
"TaskExecutionContext",
"TaskError",
# Worker base classes
"WorkerTaskBase",
"FileProcessingTaskBase",
"CallbackTaskBase",
"create_task_decorator",
"monitor_performance",
"with_task_context",
"circuit_breaker",
"create_file_processing_task",
"create_callback_task",
# Existing utilities
"LogFieldName",
"LogEventArgument",
"LogProcessingTask",
]

View File

@@ -18,3 +18,114 @@ class LogEventArgument:
class LogProcessingTask:
TASK_NAME = "logs_consumer"
QUEUE_NAME = "celery_log_task_queue"
class FileProcessingConstants:
"""Constants for file processing operations."""
# File chunk size for reading/writing (4MB default)
READ_CHUNK_SIZE = 4194304 # 4MB chunks for file reading
# Log preview size for truncating file content in logs
LOG_PREVIEW_SIZE = 500 # 500 bytes for log preview
# File processing timeout in seconds
DEFAULT_PROCESSING_TIMEOUT = 300 # 5 minutes
# Maximum file size in bytes for validation
MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024 # 100MB
@classmethod
def get_chunk_size(cls) -> int:
"""Get the configured chunk size for file operations."""
return cls.READ_CHUNK_SIZE
@classmethod
def get_log_preview_size(cls) -> int:
"""Get the configured log preview size."""
return cls.LOG_PREVIEW_SIZE
class WorkerConstants:
"""General worker operation constants."""
# Default retry attempts for worker operations
DEFAULT_RETRY_ATTEMPTS = 3
# Default timeout for API calls
API_TIMEOUT = 30
# Health check interval
HEALTH_CHECK_INTERVAL = 30
class FilePatternConstants:
"""Constants for file pattern matching and translation."""
# Display name to file pattern mappings
# Maps UI-friendly display names to actual file matching patterns
DISPLAY_NAME_TO_PATTERNS = {
"pdf documents": ["*.pdf"],
"word documents": ["*.doc", "*.docx"],
"excel documents": ["*.xls", "*.xlsx"],
"powerpoint documents": ["*.ppt", "*.pptx"],
"text files": ["*.txt"],
"image files": ["*.jpg", "*.jpeg", "*.png", "*.gif", "*.bmp", "*.tiff", "*.tif"],
"csv files": ["*.csv"],
"json files": ["*.json"],
"xml files": ["*.xml"],
"all files": ["*"],
"office documents": ["*.doc", "*.docx", "*.xls", "*.xlsx", "*.ppt", "*.pptx"],
"document files": ["*.pdf", "*.doc", "*.docx", "*.txt"],
"spreadsheet files": ["*.xls", "*.xlsx", "*.csv"],
"presentation files": ["*.ppt", "*.pptx"],
"archive files": ["*.zip", "*.rar", "*.7z", "*.tar", "*.gz"],
"video files": ["*.mp4", "*.avi", "*.mov", "*.wmv", "*.flv", "*.mkv"],
"audio files": ["*.mp3", "*.wav", "*.flac", "*.aac", "*.ogg"],
}
# Common file extension categories for inference
EXTENSION_CATEGORIES = {
"pdf": ["*.pdf"],
"doc": ["*.doc", "*.docx"],
"excel": ["*.xls", "*.xlsx"],
"image": ["*.jpg", "*.jpeg", "*.png", "*.gif", "*.bmp", "*.tiff", "*.tif"],
"text": ["*.txt"],
"csv": ["*.csv"],
"json": ["*.json"],
"xml": ["*.xml"],
"office": ["*.doc", "*.docx", "*.xls", "*.xlsx", "*.ppt", "*.pptx"],
"archive": ["*.zip", "*.rar", "*.7z", "*.tar", "*.gz"],
"video": ["*.mp4", "*.avi", "*.mov", "*.wmv", "*.flv", "*.mkv"],
"audio": ["*.mp3", "*.wav", "*.flac", "*.aac", "*.ogg"],
}
@classmethod
def get_patterns_for_display_name(cls, display_name: str) -> list[str] | None:
"""Get file patterns for a given display name.
Args:
display_name: UI display name (e.g., "PDF documents")
Returns:
List of file patterns or None if not found
"""
return cls.DISPLAY_NAME_TO_PATTERNS.get(display_name.strip().lower())
@classmethod
def infer_patterns_from_keyword(cls, keyword: str) -> list[str] | None:
"""Infer file patterns from a keyword.
Args:
keyword: Keyword to search for (e.g., "pdf", "excel")
Returns:
List of file patterns or None if not found
"""
keyword_lower = keyword.strip().lower()
for category, patterns in cls.EXTENSION_CATEGORIES.items():
if category in keyword_lower:
return patterns
return None

Some files were not shown because too many files have changed in this diff Show More