UN-2470 [FEAT] Remove Django dependency from Celery workers with internal APIs (#1494)
* UN-2470 [MISC] Remove Django dependency from Celery workers This commit introduces a new worker architecture that decouples Celery workers from Django where possible, enabling support for gevent/eventlet pool types and reducing worker startup overhead. Key changes: - Created separate worker modules (api-deployment, callback, file_processing, general) - Added internal API endpoints for worker communication - Implemented Django-free task execution where appropriate - Added shared utilities and client facades - Updated container configurations for new worker architecture 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix pre-commit issues: file permissions and ruff errors Setup the docker for new workers - Add executable permissions to worker entrypoint files - Fix import order in namespace package __init__.py - Remove unused variable api_status in general worker - Address ruff E402 and F841 errors 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * refactoreed, Dockerfiles,fixes * flexibility on celery run commands * added debug logs * handled filehistory for API * cleanup * cleanup * cloud plugin structure * minor changes in import plugin * added notification and logger workers under new worker module * add docker compatibility for new workers * handled docker issues * log consumer worker fixes * added scheduler worker * minor env changes * cleanup the logs * minor changes in logs * resolved scheduler worker issues * cleanup and refactor * ensuring backward compatibbility to existing wokers * added configuration internal apis and cache utils * optimization * Fix API client singleton pattern to share HTTP sessions - Fix flawed singleton implementation that was trying to share BaseAPIClient instances - Now properly shares HTTP sessions between specialized clients - Eliminates 6x BaseAPIClient initialization by reusing the same underlying session - Should reduce API deployment orchestration time by ~135ms (from 6 clients to 1 session) - Added debug logging to verify singleton pattern activation * cleanup and structuring * cleanup in callback * file system connectors issue * celery env values changes * optional gossip * variables for sync, mingle and gossip * Fix for file type check * Task pipeline issue resolving * api deployement failed response handled * Task pipline fixes * updated file history cleanup with active file execution * pipline status update and workflow ui page execution * cleanup and resolvinf conflicts * remove unstract-core from conenctoprs * Commit uv.lock changes * uv locks updates * resolve migration issues * defer connector-metadtda * Fix connector migration for production scale - Add encryption key handling with defer() to prevent decryption failures - Add final cleanup step to fix duplicate connector names - Optimize for large datasets with batch processing and bulk operations - Ensure unique constraint in migration 0004 can be created successfully 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * hitl fixes * minor fixes on hitl * api_hub related changes * dockerfile fixes * api client cache fixes with actual response class * fix: tags and llm_profile_id * optimized clear cache * cleanup * enhanced logs * added more handling on is file dir and added loggers * cleanup the runplatform script * internal apis are excempting from csrf * sonal cloud issues * sona-cloud issues * resolving sonar cloud issues * resolving sonar cloud issues * Delta: added Batch size fix in workers * comments addressed * celery configurational changes for new workers * fiixes in callback regaurding the pipline type check * change internal url registry logic * gitignore changes * gitignore changes * addressng pr cmmnets and cleanup the codes * adding missed profiles for v2 * sonal cloud blocker issues resolved * imlement otel * Commit uv.lock changes * handle execution time and some cleanup * adding user_data in metadata Pr: https://github.com/Zipstack/unstract/pull/1544 * scheduler backward compatibitlity * replace user_data with custom_data * Commit uv.lock changes * celery worker command issue resolved * enhance package imports in connectors by changing to lazy imports * Update runner.py by removing the otel from it Update runner.py by removing the otel from it Signed-off-by: ali <117142933+muhammad-ali-e@users.noreply.github.com> * added delta changes * handle erro to destination db * resolve tool instances id validation and hitl queu name in API * handled direct execution from workflow page to worker and logs * handle cost logs * Update health.py Signed-off-by: Ritwik G <100672805+ritwik-g@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor log changes * introducing log consumer scheduler to bulk create, and socket .emit from worker for ws * Commit uv.lock changes * time limit or timeout celery config cleanup * implemented redis client class in worker * pipline status enum mismatch * notification worker fixes * resolve uv lock conflicts * workflow log fixes * ws channel name issue resolved. and handling redis down in status tracker, and removing redis keys * default TTL changed for unified logs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: ali <117142933+muhammad-ali-e@users.noreply.github.com> Signed-off-by: Ritwik G <100672805+ritwik-g@users.noreply.github.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Ritwik G <100672805+ritwik-g@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
10
.gitignore
vendored
10
.gitignore
vendored
@@ -1,6 +1,14 @@
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/windows,macos,linux,pycharm,pycharm+all,pycharm+iml,python,visualstudiocode,react,django
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=windows,macos,linux,pycharm,pycharm+all,pycharm+iml,python,visualstudiocode,react,django
|
||||
|
||||
# Development helper scripts
|
||||
*.sh
|
||||
# list Exceptional files with ! like !fix-and-test.sh
|
||||
!run-platform.sh
|
||||
!workers/run-worker.sh
|
||||
!workers/run-worker-docker.sh
|
||||
!workers/log_consumer/scheduler.sh
|
||||
|
||||
### Django ###
|
||||
*.log
|
||||
*.pot
|
||||
@@ -622,6 +630,7 @@ backend/plugins/processor/*
|
||||
# Subscription Plugins
|
||||
backend/plugins/subscription/*
|
||||
|
||||
|
||||
# API Deployment Plugins
|
||||
backend/plugins/api/**
|
||||
|
||||
@@ -685,6 +694,7 @@ backend/requirements.txt
|
||||
backend/backend/*_urls.py
|
||||
!backend/backend/base_urls.py
|
||||
!backend/backend/public_urls.py
|
||||
!backend/backend/internal_base_urls.py
|
||||
# TODO: Remove after v2 migration is completed
|
||||
backend/backend/*_urls_v2.py
|
||||
!backend/backend/public_urls_v2.py
|
||||
|
||||
@@ -8,6 +8,7 @@ from account_v2.authentication_plugin_registry import AuthenticationPluginRegist
|
||||
from account_v2.authentication_service import AuthenticationService
|
||||
from account_v2.constants import Common
|
||||
from backend.constants import RequestHeader
|
||||
from backend.internal_api_constants import INTERNAL_API_PREFIX
|
||||
|
||||
|
||||
class CustomAuthMiddleware:
|
||||
@@ -22,6 +23,10 @@ class CustomAuthMiddleware:
|
||||
if any(request.path.startswith(path) for path in settings.WHITELISTED_PATHS):
|
||||
return self.get_response(request)
|
||||
|
||||
# Skip internal API paths - they are handled by InternalAPIAuthMiddleware
|
||||
if request.path.startswith(f"{INTERNAL_API_PREFIX}/"):
|
||||
return self.get_response(request)
|
||||
|
||||
# Authenticating With API_KEY
|
||||
x_api_key = request.headers.get(RequestHeader.X_API_KEY)
|
||||
if (
|
||||
|
||||
15
backend/account_v2/internal_serializers.py
Normal file
15
backend/account_v2/internal_serializers.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Account Internal API Serializers
|
||||
Handles serialization for organization context related endpoints.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
class OrganizationContextSerializer(serializers.Serializer):
|
||||
"""Serializer for organization context information."""
|
||||
|
||||
organization_id = serializers.CharField()
|
||||
organization_name = serializers.CharField()
|
||||
organization_slug = serializers.CharField(required=False, allow_blank=True)
|
||||
created_at = serializers.CharField(required=False, allow_blank=True)
|
||||
settings = serializers.DictField(required=False)
|
||||
20
backend/account_v2/internal_urls.py
Normal file
20
backend/account_v2/internal_urls.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Internal API URLs for Organization Context
|
||||
URL patterns for organization-related internal APIs.
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from .internal_views import OrganizationContextAPIView
|
||||
|
||||
urlpatterns = [
|
||||
# Organization context endpoint (backward compatibility)
|
||||
path(
|
||||
"<str:org_id>/", OrganizationContextAPIView.as_view(), name="organization-context"
|
||||
),
|
||||
# Organization context endpoint (explicit path)
|
||||
path(
|
||||
"<str:org_id>/context/",
|
||||
OrganizationContextAPIView.as_view(),
|
||||
name="organization-context-explicit",
|
||||
),
|
||||
]
|
||||
40
backend/account_v2/internal_views.py
Normal file
40
backend/account_v2/internal_views.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Account Internal API Views
|
||||
Handles organization context related endpoints for internal services.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from rest_framework import status
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
from utils.organization_utils import get_organization_context, resolve_organization
|
||||
|
||||
from .internal_serializers import OrganizationContextSerializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrganizationContextAPIView(APIView):
|
||||
"""Internal API endpoint for getting organization context."""
|
||||
|
||||
def get(self, request, org_id):
|
||||
"""Get organization context information."""
|
||||
try:
|
||||
# Use shared utility to resolve organization
|
||||
organization = resolve_organization(org_id, raise_on_not_found=True)
|
||||
|
||||
# Use shared utility to get context data
|
||||
context_data = get_organization_context(organization)
|
||||
|
||||
serializer = OrganizationContextSerializer(context_data)
|
||||
|
||||
logger.info(f"Retrieved organization context for {org_id}")
|
||||
|
||||
return Response(serializer.data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get organization context for {org_id}: {str(e)}")
|
||||
return Response(
|
||||
{"error": "Failed to get organization context", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
16
backend/account_v2/organization_internal_urls.py
Normal file
16
backend/account_v2/organization_internal_urls.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Account Internal API URLs
|
||||
Defines internal API endpoints for organization operations.
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from .internal_views import OrganizationContextAPIView
|
||||
|
||||
urlpatterns = [
|
||||
# Organization context API
|
||||
path(
|
||||
"<str:org_id>/context/",
|
||||
OrganizationContextAPIView.as_view(),
|
||||
name="organization-context",
|
||||
),
|
||||
]
|
||||
@@ -21,8 +21,19 @@ class SubscriptionConfig:
|
||||
METADATA_IS_ACTIVE = "is_active"
|
||||
|
||||
|
||||
# Cache for loaded plugins to avoid repeated loading
|
||||
_subscription_plugins_cache: list[Any] = []
|
||||
_plugins_loaded = False
|
||||
|
||||
|
||||
def load_plugins() -> list[Any]:
|
||||
"""Iterate through the subscription plugins and register them."""
|
||||
global _subscription_plugins_cache, _plugins_loaded
|
||||
|
||||
# Return cached plugins if already loaded
|
||||
if _plugins_loaded:
|
||||
return _subscription_plugins_cache
|
||||
|
||||
plugins_app = apps.get_app_config(SubscriptionConfig.PLUGINS_APP)
|
||||
package_path = plugins_app.module.__package__
|
||||
subscription_dir = os.path.join(plugins_app.path, SubscriptionConfig.PLUGIN_DIR)
|
||||
@@ -30,6 +41,8 @@ def load_plugins() -> list[Any]:
|
||||
subscription_plugins: list[Any] = []
|
||||
|
||||
if not os.path.exists(subscription_dir):
|
||||
_subscription_plugins_cache = subscription_plugins
|
||||
_plugins_loaded = True
|
||||
return subscription_plugins
|
||||
|
||||
for item in os.listdir(subscription_dir):
|
||||
@@ -56,10 +69,13 @@ def load_plugins() -> list[Any]:
|
||||
SubscriptionConfig.METADATA: module.metadata,
|
||||
}
|
||||
)
|
||||
name = metadata.get(
|
||||
SubscriptionConfig.METADATA_NAME,
|
||||
getattr(module, "__name__", "unknown"),
|
||||
)
|
||||
is_active = metadata.get(SubscriptionConfig.METADATA_IS_ACTIVE, False)
|
||||
logger.info(
|
||||
"Loaded subscription plugin: %s, is_active: %s",
|
||||
module.metadata[SubscriptionConfig.METADATA_NAME],
|
||||
module.metadata[SubscriptionConfig.METADATA_IS_ACTIVE],
|
||||
"Loaded subscription plugin: %s, is_active: %s", name, is_active
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
@@ -75,6 +91,10 @@ def load_plugins() -> list[Any]:
|
||||
if len(subscription_plugins) == 0:
|
||||
logger.info("No subscription plugins found.")
|
||||
|
||||
# Cache the results for future requests
|
||||
_subscription_plugins_cache = subscription_plugins
|
||||
_plugins_loaded = True
|
||||
|
||||
return subscription_plugins
|
||||
|
||||
|
||||
|
||||
74
backend/api_v2/internal_api_views.py
Normal file
74
backend/api_v2/internal_api_views.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Internal API Views for API v2
|
||||
|
||||
This module provides internal API endpoints for worker communication,
|
||||
specifically optimized for type-aware pipeline data fetching.
|
||||
|
||||
Since we know the context from worker function calls:
|
||||
- process_batch_callback_api -> APIDeployment model
|
||||
- process_batch_callback -> Pipeline model (handled in workflow_manager)
|
||||
|
||||
This provides direct access to APIDeployment model data without
|
||||
the overhead of checking both Pipeline and APIDeployment models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from rest_framework import status
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from api_v2.models import APIDeployment
|
||||
from api_v2.serializers import APIDeploymentSerializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIDeploymentDataView(APIView):
|
||||
"""Internal API endpoint for fetching APIDeployment data.
|
||||
|
||||
This endpoint is optimized for callback workers that know they're dealing
|
||||
with API deployments. It directly queries the APIDeployment model without
|
||||
checking the Pipeline model, improving performance.
|
||||
|
||||
Endpoint: GET /v2/api-deployments/{api_id}/data/
|
||||
"""
|
||||
|
||||
def get(self, request, api_id):
|
||||
"""Get APIDeployment model data by API ID.
|
||||
|
||||
Args:
|
||||
request: HTTP request object
|
||||
api_id: APIDeployment UUID
|
||||
|
||||
Returns:
|
||||
Response with APIDeployment model data
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Fetching APIDeployment data for ID: {api_id}")
|
||||
|
||||
# Query APIDeployment model directly (organization-scoped via DefaultOrganizationMixin)
|
||||
api_deployment = APIDeployment.objects.get(id=api_id)
|
||||
|
||||
# Serialize the APIDeployment model
|
||||
serializer = APIDeploymentSerializer(api_deployment)
|
||||
|
||||
# Use consistent response format with pipeline endpoint
|
||||
response_data = {"status": "success", "pipeline": serializer.data}
|
||||
|
||||
logger.info(
|
||||
f"Found APIDeployment {api_id}: name='{api_deployment.api_name}', display_name='{api_deployment.display_name}'"
|
||||
)
|
||||
return Response(response_data, status=status.HTTP_200_OK)
|
||||
|
||||
except APIDeployment.DoesNotExist:
|
||||
logger.warning(f"APIDeployment not found for ID: {api_id}")
|
||||
return Response(
|
||||
{"error": f"APIDeployment with ID {api_id} not found"},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching APIDeployment data for {api_id}: {str(e)}")
|
||||
return Response(
|
||||
{"error": f"Failed to fetch APIDeployment data: {str(e)}"},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
20
backend/api_v2/internal_urls.py
Normal file
20
backend/api_v2/internal_urls.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Internal API URLs for API v2
|
||||
|
||||
Internal endpoints for worker communication, specifically optimized
|
||||
for type-aware pipeline data fetching.
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
from rest_framework.urlpatterns import format_suffix_patterns
|
||||
|
||||
from api_v2.internal_api_views import APIDeploymentDataView
|
||||
|
||||
urlpatterns = format_suffix_patterns(
|
||||
[
|
||||
path(
|
||||
"<uuid:api_id>/",
|
||||
APIDeploymentDataView.as_view(),
|
||||
name="api_deployment_data_internal",
|
||||
),
|
||||
]
|
||||
)
|
||||
@@ -23,4 +23,6 @@ urlpatterns = [
|
||||
include("pipeline_v2.public_api_urls"),
|
||||
),
|
||||
path("", include("health.urls")),
|
||||
# Internal API for worker communication
|
||||
path("internal/", include("backend.internal_base_urls")),
|
||||
]
|
||||
|
||||
100
backend/backend/internal_api_constants.py
Normal file
100
backend/backend/internal_api_constants.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Internal API Constants
|
||||
|
||||
Centralized constants for internal API paths, versions, and configuration.
|
||||
These constants can be overridden via environment variables for flexibility.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# Default constant for SonarCloud compliance
|
||||
DEFAULT_INTERNAL_PREFIX = "/internal"
|
||||
|
||||
# Internal API Configuration
|
||||
INTERNAL_API_PREFIX = os.getenv("INTERNAL_API_PREFIX", DEFAULT_INTERNAL_PREFIX)
|
||||
INTERNAL_API_VERSION = os.getenv("INTERNAL_API_VERSION", "v1")
|
||||
|
||||
# Computed full prefix
|
||||
INTERNAL_API_BASE_PATH = f"{INTERNAL_API_PREFIX}/{INTERNAL_API_VERSION}"
|
||||
|
||||
|
||||
def build_internal_endpoint(path: str) -> str:
|
||||
"""Build a complete internal API endpoint path.
|
||||
|
||||
Args:
|
||||
path: The endpoint path without the internal prefix (e.g., "health/")
|
||||
|
||||
Returns:
|
||||
Complete internal API path (e.g., "/internal/v1/health/")
|
||||
"""
|
||||
# Ensure path starts and ends with /
|
||||
if not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
if not path.endswith("/"):
|
||||
path = f"{path}/"
|
||||
|
||||
return f"{INTERNAL_API_BASE_PATH}{path}"
|
||||
|
||||
|
||||
# Common endpoint builder shortcuts
|
||||
class InternalEndpoints:
|
||||
"""Convenience class for building internal API endpoints."""
|
||||
|
||||
@staticmethod
|
||||
def health() -> str:
|
||||
"""Health check endpoint."""
|
||||
return build_internal_endpoint("health")
|
||||
|
||||
@staticmethod
|
||||
def workflow(workflow_id: str = "{id}") -> str:
|
||||
"""Workflow endpoint."""
|
||||
return build_internal_endpoint(f"workflow/{workflow_id}")
|
||||
|
||||
@staticmethod
|
||||
def workflow_status(workflow_id: str = "{id}") -> str:
|
||||
"""Workflow status endpoint."""
|
||||
return build_internal_endpoint(f"workflow/{workflow_id}/status")
|
||||
|
||||
@staticmethod
|
||||
def file_execution(file_execution_id: str = "{id}") -> str:
|
||||
"""File execution endpoint."""
|
||||
return build_internal_endpoint(f"file-execution/{file_execution_id}")
|
||||
|
||||
@staticmethod
|
||||
def file_execution_status(file_execution_id: str = "{id}") -> str:
|
||||
"""File execution status endpoint."""
|
||||
return build_internal_endpoint(f"file-execution/{file_execution_id}/status")
|
||||
|
||||
@staticmethod
|
||||
def webhook_send() -> str:
|
||||
"""Webhook send endpoint."""
|
||||
return build_internal_endpoint("webhook/send")
|
||||
|
||||
@staticmethod
|
||||
def organization(org_id: str = "{org_id}") -> str:
|
||||
"""Organization endpoint."""
|
||||
return build_internal_endpoint(f"organization/{org_id}")
|
||||
|
||||
|
||||
# Environment variable documentation
|
||||
ENVIRONMENT_VARIABLES = {
|
||||
"INTERNAL_API_PREFIX": {
|
||||
"description": "Base prefix for internal API endpoints",
|
||||
"default": DEFAULT_INTERNAL_PREFIX,
|
||||
"example": DEFAULT_INTERNAL_PREFIX,
|
||||
},
|
||||
"INTERNAL_API_VERSION": {
|
||||
"description": "API version for internal endpoints",
|
||||
"default": "v1",
|
||||
"example": "v1",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_api_info() -> dict:
|
||||
"""Get current internal API configuration info."""
|
||||
return {
|
||||
"prefix": INTERNAL_API_PREFIX,
|
||||
"version": INTERNAL_API_VERSION,
|
||||
"base_path": INTERNAL_API_BASE_PATH,
|
||||
"environment_variables": ENVIRONMENT_VARIABLES,
|
||||
}
|
||||
266
backend/backend/internal_base_urls.py
Normal file
266
backend/backend/internal_base_urls.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""Internal API URL Configuration - OSS Base.
|
||||
|
||||
Base internal URL patterns for OSS deployment. This file contains
|
||||
the foundational internal APIs available in all deployments.
|
||||
|
||||
Cloud deployments extend this via cloud_internal_urls.py following
|
||||
the same pattern as base_urls.py / cloud_base_urls.py.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
from django.conf import settings
|
||||
from django.http import Http404, JsonResponse
|
||||
from django.urls import include, path
|
||||
from django.views.decorators.http import require_http_methods
|
||||
from utils.websocket_views import emit_websocket
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@require_http_methods(["GET"])
|
||||
def internal_api_root(request):
|
||||
"""Internal API root endpoint with comprehensive documentation."""
|
||||
return JsonResponse(
|
||||
{
|
||||
"message": "Unstract Internal API",
|
||||
"version": "1.0.0",
|
||||
"description": "Internal service-to-service API for Celery workers",
|
||||
"documentation": "https://docs.unstract.com/internal-api",
|
||||
"endpoints": {
|
||||
"description": "Various v1 endpoints for workflow execution, pipeline, organization, and other services",
|
||||
"base_path": "/internal/v1/",
|
||||
},
|
||||
"authentication": {
|
||||
"type": "Bearer Token",
|
||||
"header": "Authorization: Bearer <internal_service_api_key>",
|
||||
"organization": "X-Organization-ID header (optional for scoped requests)",
|
||||
"requirements": [
|
||||
"All requests must include Authorization header",
|
||||
"API key must match INTERNAL_SERVICE_API_KEY setting",
|
||||
"Organization ID header required for org-scoped operations",
|
||||
],
|
||||
},
|
||||
"response_format": {
|
||||
"success": {"status": "success", "data": "..."},
|
||||
"error": {"error": "Error message", "detail": "Additional details"},
|
||||
},
|
||||
"rate_limits": {
|
||||
"default": "No rate limits for internal services",
|
||||
"note": "Monitor usage through application logs",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@require_http_methods(["GET"])
|
||||
def internal_health_check(request):
|
||||
"""Health check endpoint for internal API."""
|
||||
try:
|
||||
# Debug information (sanitized for security)
|
||||
debug_info = {
|
||||
"has_internal_service": hasattr(request, "internal_service"),
|
||||
"internal_service_value": getattr(request, "internal_service", None),
|
||||
"auth_header_present": bool(request.META.get("HTTP_AUTHORIZATION")),
|
||||
"auth_scheme": (
|
||||
request.META.get("HTTP_AUTHORIZATION", "").split()[0]
|
||||
if request.META.get("HTTP_AUTHORIZATION", "").strip()
|
||||
else "None"
|
||||
),
|
||||
"path": request.path,
|
||||
"method": request.method,
|
||||
}
|
||||
|
||||
# Check authentication - first check middleware, then fallback to direct key check
|
||||
authenticated = False
|
||||
|
||||
if hasattr(request, "internal_service") and request.internal_service:
|
||||
authenticated = True
|
||||
else:
|
||||
# Fallback: check API key directly if middleware didn't run
|
||||
auth_header = request.META.get("HTTP_AUTHORIZATION", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
api_key = auth_header[7:] # Remove 'Bearer ' prefix
|
||||
internal_api_key = getattr(settings, "INTERNAL_SERVICE_API_KEY", None)
|
||||
if internal_api_key and secrets.compare_digest(api_key, internal_api_key):
|
||||
authenticated = True
|
||||
# Set the flag manually since middleware didn't run
|
||||
request.internal_service = True
|
||||
elif internal_api_key:
|
||||
# Log authentication failure (without exposing the key)
|
||||
logger.warning(
|
||||
"Internal API authentication failed",
|
||||
extra={
|
||||
"path": request.path,
|
||||
"method": request.method,
|
||||
"remote_addr": request.META.get("REMOTE_ADDR"),
|
||||
},
|
||||
)
|
||||
|
||||
if not authenticated:
|
||||
return JsonResponse(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Not authenticated as internal service",
|
||||
"debug": debug_info,
|
||||
},
|
||||
status=401,
|
||||
)
|
||||
|
||||
# Basic health checks
|
||||
health_data = {
|
||||
"status": "healthy",
|
||||
"service": "internal_api",
|
||||
"version": "1.0.0",
|
||||
"timestamp": request.META.get("HTTP_DATE"),
|
||||
"authenticated": True,
|
||||
"organization_id": getattr(request, "organization_id", None),
|
||||
"debug": debug_info,
|
||||
}
|
||||
|
||||
return JsonResponse(health_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("internal_health_check failed")
|
||||
return JsonResponse(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Health check failed",
|
||||
"error": str(e),
|
||||
"debug": {
|
||||
"has_internal_service": hasattr(request, "internal_service"),
|
||||
"auth_header_present": bool(request.META.get("HTTP_AUTHORIZATION")),
|
||||
"auth_scheme": (
|
||||
request.META.get("HTTP_AUTHORIZATION", "").split()[0]
|
||||
if request.META.get("HTTP_AUTHORIZATION", "").strip()
|
||||
else "None"
|
||||
),
|
||||
"path": request.path,
|
||||
},
|
||||
},
|
||||
status=500,
|
||||
)
|
||||
|
||||
|
||||
# Test endpoint to debug middleware (only available in DEBUG mode)
|
||||
@require_http_methods(["GET"])
|
||||
def test_middleware_debug(request):
|
||||
"""Debug endpoint to check middleware execution - only in DEBUG mode."""
|
||||
# Only available in DEBUG mode or with explicit flag
|
||||
if not (settings.DEBUG or getattr(settings, "INTERNAL_API_DEBUG", False)):
|
||||
raise Http404("Debug endpoint not available")
|
||||
|
||||
return JsonResponse(
|
||||
{
|
||||
"middleware_debug": {
|
||||
"path": request.path,
|
||||
"method": request.method,
|
||||
"auth_header_present": bool(request.META.get("HTTP_AUTHORIZATION")),
|
||||
"auth_scheme": (
|
||||
request.META.get("HTTP_AUTHORIZATION", "").split()[0]
|
||||
if request.META.get("HTTP_AUTHORIZATION", "").strip()
|
||||
else "None"
|
||||
),
|
||||
"has_internal_service": hasattr(request, "internal_service"),
|
||||
"internal_service_value": getattr(request, "internal_service", None),
|
||||
"authenticated_via": getattr(request, "authenticated_via", None),
|
||||
"organization_id": getattr(request, "organization_id", None),
|
||||
"internal_api_key_configured": bool(
|
||||
getattr(settings, "INTERNAL_SERVICE_API_KEY", None)
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Internal API URL patterns - OSS Base
|
||||
urlpatterns = [
|
||||
# Internal API root and utilities
|
||||
path("", internal_api_root, name="internal_api_root"),
|
||||
path("debug/", test_middleware_debug, name="test_middleware_debug"),
|
||||
path("v1/health/", internal_health_check, name="internal_health"),
|
||||
# WebSocket emission endpoint for workers
|
||||
path("emit-websocket/", emit_websocket, name="emit_websocket"),
|
||||
# ========================================
|
||||
# CORE OSS INTERNAL API MODULES
|
||||
# ========================================
|
||||
# Workflow execution management APIs
|
||||
path(
|
||||
"v1/workflow-execution/",
|
||||
include("workflow_manager.workflow_execution_internal_urls"),
|
||||
name="workflow_execution_internal",
|
||||
),
|
||||
# Workflow management and pipeline APIs
|
||||
path(
|
||||
"v1/workflow-manager/",
|
||||
include("workflow_manager.internal_urls"),
|
||||
name="workflow_manager_internal",
|
||||
),
|
||||
# Pipeline APIs
|
||||
path(
|
||||
"v1/pipeline/",
|
||||
include("pipeline_v2.internal_urls"),
|
||||
name="pipeline_internal",
|
||||
),
|
||||
# Organization context and management APIs
|
||||
path(
|
||||
"v1/organization/",
|
||||
include("account_v2.organization_internal_urls"),
|
||||
name="organization_internal",
|
||||
),
|
||||
# File execution and batch processing APIs
|
||||
path(
|
||||
"v1/file-execution/",
|
||||
include("workflow_manager.file_execution.internal_urls"),
|
||||
name="file_execution_internal",
|
||||
),
|
||||
# Tool instance execution APIs
|
||||
path(
|
||||
"v1/tool-execution/",
|
||||
include("tool_instance_v2.internal_urls"),
|
||||
name="tool_execution_internal",
|
||||
),
|
||||
# File processing history and caching APIs
|
||||
path(
|
||||
"v1/file-history/",
|
||||
include("workflow_manager.workflow_v2.file_history_internal_urls"),
|
||||
name="file_history_internal",
|
||||
),
|
||||
# Webhook notification APIs
|
||||
path(
|
||||
"v1/webhook/",
|
||||
include("notification_v2.internal_urls"),
|
||||
name="webhook_internal",
|
||||
),
|
||||
# API deployment data APIs for type-aware worker optimization
|
||||
path(
|
||||
"v1/api-deployments/",
|
||||
include("api_v2.internal_urls"),
|
||||
name="api_deployments_internal",
|
||||
),
|
||||
# Platform configuration and settings APIs
|
||||
path(
|
||||
"v1/platform-settings/",
|
||||
include("platform_settings_v2.internal_urls"),
|
||||
name="platform_settings_internal",
|
||||
),
|
||||
# Execution log management and cache operations APIs
|
||||
path(
|
||||
"v1/execution-logs/",
|
||||
include("workflow_manager.workflow_v2.execution_log_internal_urls"),
|
||||
name="execution_logs_internal",
|
||||
),
|
||||
# Organization configuration management APIs
|
||||
path(
|
||||
"v1/configuration/",
|
||||
include("configuration.internal_urls"),
|
||||
name="configuration_internal",
|
||||
),
|
||||
# Usage data and token count APIs
|
||||
path(
|
||||
"v1/usage/",
|
||||
include("usage_v2.internal_urls"),
|
||||
name="usage_internal",
|
||||
),
|
||||
]
|
||||
@@ -562,7 +562,6 @@ SOCIAL_AUTH_GOOGLE_OAUTH2_AUTH_EXTRA_ARGUMENTS = {
|
||||
}
|
||||
SOCIAL_AUTH_GOOGLE_OAUTH2_USE_UNIQUE_USER_ID = True
|
||||
|
||||
|
||||
# Always keep this line at the bottom of the file.
|
||||
if missing_settings:
|
||||
ERROR_MESSAGE = "Below required settings are missing.\n" + ",\n".join(
|
||||
|
||||
@@ -61,4 +61,5 @@ urlpatterns = [
|
||||
include("prompt_studio.prompt_studio_index_manager.urls"),
|
||||
),
|
||||
path("notifications/", include("notification.urls")),
|
||||
path("internal/", include("backend.internal_base_urls")),
|
||||
]
|
||||
|
||||
15
backend/configuration/internal_urls.py
Normal file
15
backend/configuration/internal_urls.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Internal API URLs for Configuration access by workers."""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from . import internal_views
|
||||
|
||||
app_name = "configuration_internal"
|
||||
|
||||
urlpatterns = [
|
||||
path(
|
||||
"<str:config_key>/",
|
||||
internal_views.ConfigurationInternalView.as_view(),
|
||||
name="configuration-detail",
|
||||
),
|
||||
]
|
||||
122
backend/configuration/internal_views.py
Normal file
122
backend/configuration/internal_views.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Internal API views for Configuration access by workers."""
|
||||
|
||||
import logging
|
||||
|
||||
from account_v2.models import Organization
|
||||
from django.http import JsonResponse
|
||||
from rest_framework import status
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from .models import Configuration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConfigurationInternalView(APIView):
|
||||
"""Internal API view for workers to access organization configurations.
|
||||
|
||||
This endpoint allows workers to get organization-specific configuration
|
||||
values without direct database access, maintaining the same logic as
|
||||
Configuration.get_value_by_organization() but over HTTP.
|
||||
|
||||
Workers can call this to get configs like MAX_PARALLEL_FILE_BATCHES
|
||||
with proper organization-specific overrides and fallbacks.
|
||||
"""
|
||||
|
||||
def get(self, request: Request, config_key: str) -> JsonResponse:
|
||||
"""Get configuration value for an organization.
|
||||
|
||||
Args:
|
||||
request: HTTP request with organization_id parameter
|
||||
config_key: Configuration key name (e.g., "MAX_PARALLEL_FILE_BATCHES")
|
||||
|
||||
Returns:
|
||||
JSON response with configuration value and metadata
|
||||
"""
|
||||
try:
|
||||
organization_id = request.query_params.get("organization_id")
|
||||
|
||||
if not organization_id:
|
||||
return JsonResponse(
|
||||
{
|
||||
"success": False,
|
||||
"error": "organization_id parameter is required",
|
||||
"config_key": config_key,
|
||||
},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Get the organization - handle both ID (int) and organization_id (string)
|
||||
try:
|
||||
# Try to get organization by primary key ID first (for backward compatibility)
|
||||
if organization_id.isdigit():
|
||||
organization = Organization.objects.get(id=int(organization_id))
|
||||
else:
|
||||
# Otherwise, lookup by organization_id field (string identifier)
|
||||
organization = Organization.objects.get(
|
||||
organization_id=organization_id
|
||||
)
|
||||
except (Organization.DoesNotExist, ValueError):
|
||||
return JsonResponse(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Organization {organization_id} not found",
|
||||
"config_key": config_key,
|
||||
},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
# Get the configuration value using the same logic as the backend
|
||||
try:
|
||||
config_value = Configuration.get_value_by_organization(
|
||||
config_key=config_key, organization=organization
|
||||
)
|
||||
|
||||
# Check if we found an organization-specific override
|
||||
has_override = False
|
||||
try:
|
||||
Configuration.objects.get(
|
||||
organization=organization, key=config_key, enabled=True
|
||||
)
|
||||
has_override = True
|
||||
except Configuration.DoesNotExist:
|
||||
has_override = False
|
||||
|
||||
return JsonResponse(
|
||||
{
|
||||
"success": True,
|
||||
"data": {
|
||||
"config_key": config_key,
|
||||
"value": config_value,
|
||||
"organization_id": organization_id,
|
||||
"has_organization_override": has_override,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
# Configuration key not found in registry
|
||||
return JsonResponse(
|
||||
{
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"config_key": config_key,
|
||||
"organization_id": organization_id,
|
||||
},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting configuration {config_key} for organization {organization_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return JsonResponse(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Internal server error",
|
||||
"config_key": config_key,
|
||||
},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
@@ -42,10 +42,22 @@ def _group_connectors(
|
||||
) -> dict[tuple[Any, str, str | None], list[Any]]:
|
||||
"""Group connectors by organization, connector type, and metadata hash."""
|
||||
connector_groups = {}
|
||||
skipped_connectors = 0
|
||||
|
||||
for connector in connector_instances:
|
||||
try:
|
||||
metadata_hash = _compute_metadata_hash(connector.connector_metadata)
|
||||
# Try to access connector_metadata - this may fail due to encryption key mismatch
|
||||
try:
|
||||
metadata_hash = _compute_metadata_hash(connector.connector_metadata)
|
||||
except Exception as decrypt_error:
|
||||
# Log the encryption error and skip this connector
|
||||
logger.warning(
|
||||
f"Skipping connector {connector.id} due to encryption error: {str(decrypt_error)}. "
|
||||
f"This is likely due to a changed ENCRYPTION_KEY."
|
||||
)
|
||||
skipped_connectors += 1
|
||||
continue
|
||||
|
||||
connector_sys_name = _extract_connector_sys_name(connector.connector_id)
|
||||
|
||||
group_key = (
|
||||
@@ -62,6 +74,11 @@ def _group_connectors(
|
||||
logger.error(f"Error processing connector {connector.id}: {str(e)}")
|
||||
raise
|
||||
|
||||
if skipped_connectors > 0:
|
||||
logger.warning(
|
||||
f"Skipped {skipped_connectors} connectors due to encryption key issues"
|
||||
)
|
||||
|
||||
return connector_groups
|
||||
|
||||
|
||||
@@ -70,9 +87,16 @@ def _process_single_connector(
|
||||
processed_groups: int,
|
||||
total_groups: int,
|
||||
short_group_key: tuple[Any, str, str],
|
||||
connector_instance_model: Any,
|
||||
) -> None:
|
||||
"""Process a group with only one connector."""
|
||||
connector.connector_name = f"{connector.connector_name}-{uuid.uuid4().hex[:8]}"
|
||||
base_name = connector.connector_name
|
||||
new_name = f"{base_name}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# For performance with large datasets, UUID collisions are extremely rare
|
||||
# If uniqueness becomes critical, we can add collision detection later
|
||||
|
||||
connector.connector_name = new_name
|
||||
logger.info(
|
||||
f"[Group {processed_groups}/{total_groups}] {short_group_key}: "
|
||||
f"Only 1 connector present, renaming to '{connector.connector_name}'"
|
||||
@@ -85,6 +109,7 @@ def _centralize_connector_group(
|
||||
processed_groups: int,
|
||||
total_groups: int,
|
||||
short_group_key: tuple[Any, str, str],
|
||||
connector_instance_model: Any,
|
||||
) -> tuple[Any, dict[Any, Any], set[Any]]:
|
||||
"""Centralize a group of multiple connectors."""
|
||||
logger.info(
|
||||
@@ -95,7 +120,12 @@ def _centralize_connector_group(
|
||||
# First connector becomes the centralized one
|
||||
centralized_connector = connectors[0]
|
||||
original_name = centralized_connector.connector_name
|
||||
centralized_connector.connector_name = f"{original_name}-{uuid.uuid4().hex[:8]}"
|
||||
new_name = f"{original_name}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# For performance with large datasets, UUID collisions are extremely rare
|
||||
# If uniqueness becomes critical, we can add collision detection later
|
||||
|
||||
centralized_connector.connector_name = new_name
|
||||
|
||||
logger.info(
|
||||
f"[Group {processed_groups}/{total_groups}] {short_group_key}: "
|
||||
@@ -164,6 +194,88 @@ def _delete_redundant_connectors(
|
||||
raise
|
||||
|
||||
|
||||
def _fix_remaining_duplicate_names(connector_instance_model: Any) -> int:
|
||||
"""Fix any remaining duplicate connector names within organizations."""
|
||||
from django.db.models import Count
|
||||
|
||||
# Find all organizations with duplicate connector names (optimized query)
|
||||
duplicates = list(
|
||||
connector_instance_model.objects.values("connector_name", "organization_id")
|
||||
.annotate(count=Count("id"))
|
||||
.filter(count__gt=1)
|
||||
.order_by("organization_id", "connector_name")
|
||||
)
|
||||
|
||||
total_duplicates = len(duplicates)
|
||||
if total_duplicates == 0:
|
||||
logger.info("No duplicate connector names found after migration")
|
||||
return 0
|
||||
|
||||
logger.info(
|
||||
f"Found {total_duplicates} groups with duplicate connector names - fixing"
|
||||
)
|
||||
fixed_count = 0
|
||||
|
||||
# Process in batches to avoid memory issues
|
||||
batch_size = 20
|
||||
for i in range(0, len(duplicates), batch_size):
|
||||
batch = duplicates[i : i + batch_size]
|
||||
logger.info(
|
||||
f"Processing batch {i//batch_size + 1}/{(len(duplicates)-1)//batch_size + 1}"
|
||||
)
|
||||
|
||||
for dup_info in batch:
|
||||
connector_name = dup_info["connector_name"]
|
||||
org_id = dup_info["organization_id"]
|
||||
|
||||
# Get all connectors with this name in this organization (select only needed fields)
|
||||
duplicate_connectors = list(
|
||||
connector_instance_model.objects.filter(
|
||||
connector_name=connector_name, organization_id=org_id
|
||||
)
|
||||
.only("id", "connector_name", "organization_id")
|
||||
.order_by("id")
|
||||
)
|
||||
|
||||
if len(duplicate_connectors) <= 1:
|
||||
continue # Skip if no longer duplicates
|
||||
|
||||
# Prepare batch updates (keep first, rename others)
|
||||
updates = []
|
||||
existing_names_in_org = set(
|
||||
connector_instance_model.objects.filter(
|
||||
organization_id=org_id
|
||||
).values_list("connector_name", flat=True)
|
||||
)
|
||||
|
||||
for j, connector in enumerate(duplicate_connectors[1:], 1): # Skip first
|
||||
base_name = connector_name
|
||||
new_name = f"{base_name}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Simple collision check against existing names in this org
|
||||
attempt = 0
|
||||
while new_name in existing_names_in_org and attempt < 5:
|
||||
new_name = f"{base_name}-{uuid.uuid4().hex[:8]}"
|
||||
attempt += 1
|
||||
|
||||
existing_names_in_org.add(new_name) # Track new names
|
||||
connector.connector_name = new_name
|
||||
updates.append(connector)
|
||||
fixed_count += 1
|
||||
|
||||
# Bulk update for better performance
|
||||
if updates:
|
||||
connector_instance_model.objects.bulk_update(
|
||||
updates, ["connector_name"], batch_size=100
|
||||
)
|
||||
logger.info(
|
||||
f" Fixed {len(updates)} duplicates of '{connector_name}' in org {org_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Fixed {fixed_count} duplicate connector names")
|
||||
return fixed_count
|
||||
|
||||
|
||||
def migrate_to_centralized_connectors(apps, schema_editor): # noqa: ARG001
|
||||
"""Migrate existing workflow-specific connectors to centralized connectors.
|
||||
|
||||
@@ -176,10 +288,15 @@ def migrate_to_centralized_connectors(apps, schema_editor): # noqa: ARG001
|
||||
ConnectorInstance = apps.get_model("connector_v2", "ConnectorInstance") # NOSONAR
|
||||
WorkflowEndpoint = apps.get_model("endpoint_v2", "WorkflowEndpoint") # NOSONAR
|
||||
|
||||
# Get all connector instances with select_related for performance
|
||||
connector_instances = ConnectorInstance.objects.select_related(
|
||||
"organization", "created_by", "modified_by"
|
||||
).all()
|
||||
# Get all connector instances, but defer the encrypted metadata field to avoid
|
||||
# automatic decryption failures when the encryption key has changed
|
||||
connector_instances = (
|
||||
ConnectorInstance.objects.select_related(
|
||||
"organization", "created_by", "modified_by"
|
||||
)
|
||||
.defer("connector_metadata")
|
||||
.all()
|
||||
)
|
||||
|
||||
total_connectors = connector_instances.count()
|
||||
logger.info(f"Processing {total_connectors} connector instances for centralization")
|
||||
@@ -187,6 +304,17 @@ def migrate_to_centralized_connectors(apps, schema_editor): # noqa: ARG001
|
||||
# Group connectors by organization and unique credential fingerprint
|
||||
connector_groups = _group_connectors(connector_instances)
|
||||
|
||||
# Safety check: If we have connectors but all were skipped, this indicates a serious issue
|
||||
if total_connectors > 0 and len(connector_groups) == 0:
|
||||
error_msg = (
|
||||
f"CRITICAL: All {total_connectors} connectors were skipped due to encryption errors. "
|
||||
f"This likely means the ENCRYPTION_KEY has changed. Please restore the correct "
|
||||
f"ENCRYPTION_KEY and retry the migration. The migration has been aborted to prevent "
|
||||
f"data loss."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Process each group and centralize connectors
|
||||
processed_groups = 0
|
||||
centralized_count = 0
|
||||
@@ -202,13 +330,21 @@ def migrate_to_centralized_connectors(apps, schema_editor): # noqa: ARG001
|
||||
# Process single connector groups differently
|
||||
if len(connectors) == 1:
|
||||
_process_single_connector(
|
||||
connectors[0], processed_groups, total_groups, short_group_key
|
||||
connectors[0],
|
||||
processed_groups,
|
||||
total_groups,
|
||||
short_group_key,
|
||||
ConnectorInstance,
|
||||
)
|
||||
continue
|
||||
|
||||
# Centralize multiple connectors
|
||||
_, connector_mapping, connectors_to_delete = _centralize_connector_group(
|
||||
connectors, processed_groups, total_groups, short_group_key
|
||||
connectors,
|
||||
processed_groups,
|
||||
total_groups,
|
||||
short_group_key,
|
||||
ConnectorInstance,
|
||||
)
|
||||
|
||||
centralized_count += 1
|
||||
@@ -232,6 +368,9 @@ def migrate_to_centralized_connectors(apps, schema_editor): # noqa: ARG001
|
||||
# Delete redundant connectors
|
||||
_delete_redundant_connectors(all_connectors_to_delete, ConnectorInstance)
|
||||
|
||||
# Final cleanup: Fix any remaining duplicate names within organizations
|
||||
_fix_remaining_duplicate_names(ConnectorInstance)
|
||||
|
||||
logger.info(
|
||||
f"Migration completed: {centralized_count} centralized connectors created"
|
||||
)
|
||||
@@ -273,19 +412,28 @@ def _create_workflow_specific_connector(
|
||||
connector_instance_model: Any,
|
||||
) -> Any:
|
||||
"""Create a new workflow-specific connector from a centralized one."""
|
||||
return connector_instance_model.objects.create(
|
||||
connector_name=centralized_connector.connector_name,
|
||||
connector_id=centralized_connector.connector_id,
|
||||
connector_metadata=centralized_connector.connector_metadata,
|
||||
connector_version=centralized_connector.connector_version,
|
||||
connector_type=connector_type,
|
||||
connector_auth=centralized_connector.connector_auth,
|
||||
connector_mode=centralized_connector.connector_mode,
|
||||
workflow=workflow,
|
||||
organization=centralized_connector.organization,
|
||||
created_by=centralized_connector.created_by,
|
||||
modified_by=centralized_connector.modified_by,
|
||||
)
|
||||
try:
|
||||
# Try to access connector_metadata to ensure it's readable
|
||||
metadata = centralized_connector.connector_metadata
|
||||
return connector_instance_model.objects.create(
|
||||
connector_name=centralized_connector.connector_name,
|
||||
connector_id=centralized_connector.connector_id,
|
||||
connector_metadata=metadata,
|
||||
connector_version=centralized_connector.connector_version,
|
||||
connector_type=connector_type,
|
||||
connector_auth=centralized_connector.connector_auth,
|
||||
connector_mode=centralized_connector.connector_mode,
|
||||
workflow=workflow,
|
||||
organization=centralized_connector.organization,
|
||||
created_by=centralized_connector.created_by,
|
||||
modified_by=centralized_connector.modified_by,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Skipping creation of workflow-specific connector from {centralized_connector.id} "
|
||||
f"due to encryption error: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _process_connector_endpoints(
|
||||
@@ -359,10 +507,13 @@ def reverse_centralized_connectors(apps, schema_editor): # noqa: ARG001
|
||||
ConnectorInstance = apps.get_model("connector_v2", "ConnectorInstance") # NOSONAR
|
||||
WorkflowEndpoint = apps.get_model("endpoint_v2", "WorkflowEndpoint") # NOSONAR
|
||||
|
||||
# Get all centralized connectors with prefetch for better performance
|
||||
centralized_connectors = ConnectorInstance.objects.prefetch_related(
|
||||
"workflow_endpoints"
|
||||
).all()
|
||||
# Get all centralized connectors, but defer the encrypted metadata field to avoid
|
||||
# automatic decryption failures when the encryption key has changed
|
||||
centralized_connectors = (
|
||||
ConnectorInstance.objects.prefetch_related("workflow_endpoints")
|
||||
.defer("connector_metadata")
|
||||
.all()
|
||||
)
|
||||
|
||||
total_connectors = centralized_connectors.count()
|
||||
logger.info(f"Processing {total_connectors} centralized connectors for reversal")
|
||||
@@ -375,6 +526,7 @@ def reverse_centralized_connectors(apps, schema_editor): # noqa: ARG001
|
||||
# Process connectors with endpoints to create workflow-specific copies
|
||||
added_connector_count = 0
|
||||
processed_connectors = 0
|
||||
skipped_reverse_connectors = 0
|
||||
|
||||
for centralized_connector in centralized_connectors:
|
||||
processed_connectors += 1
|
||||
@@ -384,6 +536,17 @@ def reverse_centralized_connectors(apps, schema_editor): # noqa: ARG001
|
||||
continue
|
||||
|
||||
try:
|
||||
# Test if we can access encrypted fields before processing
|
||||
try:
|
||||
_ = centralized_connector.connector_metadata
|
||||
except Exception as decrypt_error:
|
||||
logger.warning(
|
||||
f"Skipping reverse migration for connector {centralized_connector.id} "
|
||||
f"due to encryption error: {str(decrypt_error)}"
|
||||
)
|
||||
skipped_reverse_connectors += 1
|
||||
continue
|
||||
|
||||
endpoints = WorkflowEndpoint.objects.filter(
|
||||
connector_instance=centralized_connector
|
||||
)
|
||||
@@ -404,6 +567,22 @@ def reverse_centralized_connectors(apps, schema_editor): # noqa: ARG001
|
||||
)
|
||||
raise
|
||||
|
||||
if skipped_reverse_connectors > 0:
|
||||
logger.warning(
|
||||
f"Skipped {skipped_reverse_connectors} connectors during reverse migration due to encryption issues"
|
||||
)
|
||||
|
||||
# Safety check for reverse migration: if we skipped everything, abort
|
||||
if skipped_reverse_connectors == total_connectors and total_connectors > 0:
|
||||
error_msg = (
|
||||
f"CRITICAL: All {total_connectors} connectors were skipped during reverse migration "
|
||||
f"due to encryption errors. This likely means the ENCRYPTION_KEY has changed. "
|
||||
f"Please restore the correct ENCRYPTION_KEY and retry the reverse migration. "
|
||||
f"The reverse migration has been aborted to prevent data loss."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Delete unused centralized connectors
|
||||
_delete_unused_centralized_connectors(unused_connectors, ConnectorInstance)
|
||||
|
||||
|
||||
241
backend/middleware/internal_api_auth.py
Normal file
241
backend/middleware/internal_api_auth.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Internal API Service Authentication Middleware
|
||||
Handles service-to-service authentication for internal APIs.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.http import HttpRequest, HttpResponse, JsonResponse
|
||||
from django.utils.deprecation import MiddlewareMixin
|
||||
from utils.constants import Account
|
||||
from utils.local_context import StateStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InternalAPIAuthMiddleware(MiddlewareMixin):
|
||||
"""Middleware for authenticating internal service API requests.
|
||||
|
||||
This middleware:
|
||||
1. Checks for internal service API key in Authorization header
|
||||
2. Validates the key against INTERNAL_SERVICE_API_KEY setting
|
||||
3. Sets up organization context for requests
|
||||
4. Bypasses normal user authentication for internal services
|
||||
"""
|
||||
|
||||
def process_request(self, request: HttpRequest) -> HttpResponse | None:
|
||||
"""Enhanced request processing with improved debugging and organization context handling."""
|
||||
# Enhanced request logging with more context
|
||||
request_info = {
|
||||
"path": request.path,
|
||||
"method": request.method,
|
||||
"content_type": request.META.get("CONTENT_TYPE", "unknown"),
|
||||
"user_agent": request.META.get("HTTP_USER_AGENT", "unknown")[:100],
|
||||
"remote_addr": request.META.get("REMOTE_ADDR", "unknown"),
|
||||
"auth_header_present": bool(request.META.get("HTTP_AUTHORIZATION")),
|
||||
"org_header_present": bool(request.headers.get("X-Organization-ID")),
|
||||
}
|
||||
|
||||
logger.debug(f"InternalAPIAuthMiddleware processing request: {request_info}")
|
||||
|
||||
# Only apply to internal API endpoints
|
||||
if not request.path.startswith("/internal/"):
|
||||
logger.debug(f"Skipping middleware for non-internal path: {request.path}")
|
||||
return None
|
||||
|
||||
logger.info(f"Processing internal API request: {request.method} {request.path}")
|
||||
|
||||
# Enhanced authentication handling
|
||||
auth_result = self._authenticate_request(request)
|
||||
if auth_result["error"]:
|
||||
logger.warning(
|
||||
f"Authentication failed for {request.path}: {auth_result['message']}"
|
||||
)
|
||||
return JsonResponse(
|
||||
{
|
||||
"error": auth_result["message"],
|
||||
"detail": auth_result["detail"],
|
||||
"debug_info": auth_result.get("debug_info", {})
|
||||
if settings.DEBUG
|
||||
else {},
|
||||
},
|
||||
status=auth_result["status"],
|
||||
)
|
||||
|
||||
# Enhanced organization context handling
|
||||
org_result = self._setup_organization_context(request)
|
||||
if org_result["warning"]:
|
||||
logger.warning(
|
||||
f"Organization context issue for {request.path}: {org_result['warning']}"
|
||||
)
|
||||
|
||||
# Mark request as authenticated
|
||||
request.internal_service = True
|
||||
request.authenticated_via = "internal_service_api_key"
|
||||
|
||||
# Enhanced organization context logging
|
||||
final_context = {
|
||||
"path": request.path,
|
||||
"request_org_id": getattr(request, "organization_id", "None"),
|
||||
"statestore_org_id": StateStore.get(Account.ORGANIZATION_ID),
|
||||
"org_context_set": org_result["context_set"],
|
||||
"org_validated": org_result.get("organization_validated", False),
|
||||
}
|
||||
logger.info(f"Internal API request authenticated successfully: {final_context}")
|
||||
return None # Continue with request processing
|
||||
|
||||
def _authenticate_request(self, request: HttpRequest) -> dict[str, Any]:
|
||||
"""Enhanced authentication with detailed error reporting."""
|
||||
auth_header = request.META.get("HTTP_AUTHORIZATION", "")
|
||||
|
||||
if not auth_header:
|
||||
return {
|
||||
"error": True,
|
||||
"status": 401,
|
||||
"message": "Authorization header required for internal APIs",
|
||||
"detail": "Missing Authorization header",
|
||||
"debug_info": {
|
||||
"headers_present": list(request.META.keys()),
|
||||
"expected_format": "Authorization: Bearer <api_key>",
|
||||
},
|
||||
}
|
||||
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return {
|
||||
"error": True,
|
||||
"status": 401,
|
||||
"message": "Bearer token required for internal APIs",
|
||||
"detail": f"Invalid authorization format: {auth_header[:20]}...",
|
||||
"debug_info": {
|
||||
"provided_format": auth_header.split(" ")[0]
|
||||
if " " in auth_header
|
||||
else auth_header[:10],
|
||||
"expected_format": "Bearer <api_key>",
|
||||
},
|
||||
}
|
||||
|
||||
# Extract and validate API key
|
||||
api_key = auth_header[7:] # Remove 'Bearer ' prefix
|
||||
internal_api_key = getattr(settings, "INTERNAL_SERVICE_API_KEY", None)
|
||||
|
||||
if not internal_api_key:
|
||||
logger.error("INTERNAL_SERVICE_API_KEY not configured in Django settings")
|
||||
return {
|
||||
"error": True,
|
||||
"status": 500,
|
||||
"message": "Internal API authentication not configured",
|
||||
"detail": "INTERNAL_SERVICE_API_KEY setting missing",
|
||||
}
|
||||
|
||||
if api_key != internal_api_key:
|
||||
# Enhanced logging for key mismatch debugging
|
||||
key_comparison = {
|
||||
"provided_key_length": len(api_key),
|
||||
"expected_key_length": len(internal_api_key),
|
||||
"keys_match": api_key == internal_api_key,
|
||||
"provided_key_prefix": api_key[:8] + "..."
|
||||
if len(api_key) > 8
|
||||
else api_key,
|
||||
"expected_key_prefix": internal_api_key[:8] + "..."
|
||||
if len(internal_api_key) > 8
|
||||
else internal_api_key,
|
||||
}
|
||||
logger.warning(f"API key validation failed: {key_comparison}")
|
||||
|
||||
return {
|
||||
"error": True,
|
||||
"status": 401,
|
||||
"message": "Invalid internal service API key",
|
||||
"detail": "API key does not match configured value",
|
||||
"debug_info": key_comparison if settings.DEBUG else {},
|
||||
}
|
||||
|
||||
return {"error": False, "message": "Authentication successful"}
|
||||
|
||||
def _setup_organization_context(self, request: HttpRequest) -> dict[str, Any]:
|
||||
"""Enhanced organization context setup with validation."""
|
||||
org_id = request.headers.get("X-Organization-ID")
|
||||
|
||||
if not org_id:
|
||||
return {
|
||||
"warning": "No organization ID provided in X-Organization-ID header",
|
||||
"context_set": False,
|
||||
}
|
||||
|
||||
try:
|
||||
# Validate organization ID format
|
||||
if not org_id.strip():
|
||||
return {"warning": "Empty organization ID provided", "context_set": False}
|
||||
|
||||
# Enhanced organization context validation
|
||||
from utils.organization_utils import resolve_organization
|
||||
|
||||
try:
|
||||
organization = resolve_organization(org_id, raise_on_not_found=False)
|
||||
if organization:
|
||||
# Use organization.organization_id (string field) for StateStore consistency
|
||||
# This ensures UserContext.get_organization() can properly retrieve the organization
|
||||
request.organization_id = organization.organization_id
|
||||
request.organization_context = {
|
||||
"id": str(organization.id),
|
||||
"organization_id": organization.organization_id,
|
||||
"name": organization.display_name,
|
||||
"validated": True,
|
||||
}
|
||||
# Store the organization_id string field in StateStore for UserContext compatibility
|
||||
StateStore.set(Account.ORGANIZATION_ID, organization.organization_id)
|
||||
|
||||
logger.debug(
|
||||
f"Organization context validated and set: {organization.display_name} (org_id: {organization.organization_id}, pk: {organization.id})"
|
||||
)
|
||||
return {
|
||||
"warning": None,
|
||||
"context_set": True,
|
||||
"organization_validated": True,
|
||||
}
|
||||
else:
|
||||
logger.warning(f"Organization {org_id} not found in database")
|
||||
# Still set the context for backward compatibility
|
||||
request.organization_id = org_id
|
||||
StateStore.set(Account.ORGANIZATION_ID, org_id)
|
||||
return {
|
||||
"warning": f"Organization {org_id} not found in database, using raw value",
|
||||
"context_set": True,
|
||||
"organization_validated": False,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to validate organization {org_id}: {str(e)}")
|
||||
# Fallback to raw organization ID
|
||||
request.organization_id = org_id
|
||||
StateStore.set(Account.ORGANIZATION_ID, org_id)
|
||||
return {
|
||||
"warning": f"Organization validation failed: {str(e)}, using raw value",
|
||||
"context_set": True,
|
||||
"organization_validated": False,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error setting organization context: {str(e)}")
|
||||
return {
|
||||
"warning": f"Failed to set organization context: {str(e)}",
|
||||
"context_set": False,
|
||||
}
|
||||
|
||||
def process_response(
|
||||
self, request: HttpRequest, response: HttpResponse
|
||||
) -> HttpResponse:
|
||||
# Clean up organization context if we set it
|
||||
if hasattr(request, "internal_service") and request.internal_service:
|
||||
try:
|
||||
org_id_before_clear = StateStore.get(Account.ORGANIZATION_ID)
|
||||
if org_id_before_clear is not None:
|
||||
StateStore.clear(Account.ORGANIZATION_ID)
|
||||
logger.debug(
|
||||
f"Cleaned up organization context for {request.path}: {org_id_before_clear}"
|
||||
)
|
||||
except AttributeError:
|
||||
# StateStore key doesn't exist, which is fine
|
||||
logger.debug(f"No organization context to clean up for {request.path}")
|
||||
return response
|
||||
252
backend/notification_v2/internal_api_views.py
Normal file
252
backend/notification_v2/internal_api_views.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""Internal API views for notification data access by workers.
|
||||
|
||||
These endpoints provide notification configuration data to workers
|
||||
without exposing full Django models or requiring Django dependencies.
|
||||
|
||||
Security Note:
|
||||
- CSRF protection is disabled for internal service-to-service communication
|
||||
- Authentication is handled by InternalAPIAuthMiddleware using Bearer tokens
|
||||
- These endpoints are not accessible from browsers and don't use session cookies
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from api_v2.models import APIDeployment
|
||||
from django.http import JsonResponse
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from django.views.decorators.http import require_http_methods
|
||||
from pipeline_v2.models import Pipeline
|
||||
from utils.organization_utils import filter_queryset_by_organization
|
||||
|
||||
from notification_v2.models import Notification
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants for error messages
|
||||
INTERNAL_SERVER_ERROR_MSG = "Internal server error"
|
||||
|
||||
|
||||
@csrf_exempt # Safe: Internal API with Bearer token auth, service-to-service only
|
||||
@require_http_methods(["GET"])
|
||||
def get_pipeline_notifications(request, pipeline_id):
|
||||
"""Get active notifications for a pipeline or API deployment.
|
||||
|
||||
Used by callback worker to fetch notification configuration.
|
||||
"""
|
||||
try:
|
||||
# Try to find the pipeline ID in Pipeline model first
|
||||
pipeline_queryset = Pipeline.objects.filter(id=pipeline_id)
|
||||
pipeline_queryset = filter_queryset_by_organization(
|
||||
pipeline_queryset, request, "organization"
|
||||
)
|
||||
|
||||
if pipeline_queryset.exists():
|
||||
pipeline = pipeline_queryset.first()
|
||||
|
||||
# Get active notifications for this pipeline
|
||||
notifications = Notification.objects.filter(pipeline=pipeline, is_active=True)
|
||||
|
||||
notifications_data = []
|
||||
for notification in notifications:
|
||||
notifications_data.append(
|
||||
{
|
||||
"id": str(notification.id),
|
||||
"notification_type": notification.notification_type,
|
||||
"platform": notification.platform,
|
||||
"url": notification.url,
|
||||
"authorization_type": notification.authorization_type,
|
||||
"authorization_key": notification.authorization_key,
|
||||
"authorization_header": notification.authorization_header,
|
||||
"max_retries": notification.max_retries,
|
||||
"is_active": notification.is_active,
|
||||
}
|
||||
)
|
||||
|
||||
return JsonResponse(
|
||||
{
|
||||
"status": "success",
|
||||
"pipeline_id": str(pipeline.id),
|
||||
"pipeline_name": pipeline.pipeline_name,
|
||||
"pipeline_type": pipeline.pipeline_type,
|
||||
"notifications": notifications_data,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# If not found in Pipeline, try APIDeployment model
|
||||
api_queryset = APIDeployment.objects.filter(id=pipeline_id)
|
||||
api_queryset = filter_queryset_by_organization(
|
||||
api_queryset, request, "organization"
|
||||
)
|
||||
|
||||
if api_queryset.exists():
|
||||
api = api_queryset.first()
|
||||
|
||||
# Get active notifications for this API deployment
|
||||
notifications = Notification.objects.filter(api=api, is_active=True)
|
||||
|
||||
notifications_data = []
|
||||
for notification in notifications:
|
||||
notifications_data.append(
|
||||
{
|
||||
"id": str(notification.id),
|
||||
"notification_type": notification.notification_type,
|
||||
"platform": notification.platform,
|
||||
"url": notification.url,
|
||||
"authorization_type": notification.authorization_type,
|
||||
"authorization_key": notification.authorization_key,
|
||||
"authorization_header": notification.authorization_header,
|
||||
"max_retries": notification.max_retries,
|
||||
"is_active": notification.is_active,
|
||||
}
|
||||
)
|
||||
|
||||
return JsonResponse(
|
||||
{
|
||||
"status": "success",
|
||||
"pipeline_id": str(api.id),
|
||||
"pipeline_name": api.api_name,
|
||||
"pipeline_type": "API",
|
||||
"notifications": notifications_data,
|
||||
}
|
||||
)
|
||||
else:
|
||||
return JsonResponse(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Pipeline or API deployment not found",
|
||||
},
|
||||
status=404,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pipeline notifications for {pipeline_id}: {e}")
|
||||
return JsonResponse(
|
||||
{"status": "error", "message": INTERNAL_SERVER_ERROR_MSG}, status=500
|
||||
)
|
||||
|
||||
|
||||
@csrf_exempt # Safe: Internal API with Bearer token auth, service-to-service only
|
||||
@require_http_methods(["GET"])
|
||||
def get_api_notifications(request, api_id):
|
||||
"""Get active notifications for an API deployment.
|
||||
|
||||
Used by callback worker to fetch notification configuration.
|
||||
"""
|
||||
try:
|
||||
# Get API deployment with organization filtering
|
||||
api_queryset = APIDeployment.objects.filter(id=api_id)
|
||||
api_queryset = filter_queryset_by_organization(
|
||||
api_queryset, request, "organization"
|
||||
)
|
||||
api = get_object_or_404(api_queryset)
|
||||
|
||||
# Get active notifications for this API
|
||||
notifications = Notification.objects.filter(api=api, is_active=True)
|
||||
|
||||
notifications_data = []
|
||||
for notification in notifications:
|
||||
notifications_data.append(
|
||||
{
|
||||
"id": str(notification.id),
|
||||
"notification_type": notification.notification_type,
|
||||
"platform": notification.platform,
|
||||
"url": notification.url,
|
||||
"authorization_type": notification.authorization_type,
|
||||
"authorization_key": notification.authorization_key,
|
||||
"authorization_header": notification.authorization_header,
|
||||
"max_retries": notification.max_retries,
|
||||
"is_active": notification.is_active,
|
||||
}
|
||||
)
|
||||
|
||||
return JsonResponse(
|
||||
{
|
||||
"status": "success",
|
||||
"api_id": str(api.id),
|
||||
"api_name": api.api_name,
|
||||
"display_name": api.display_name,
|
||||
"notifications": notifications_data,
|
||||
}
|
||||
)
|
||||
|
||||
except APIDeployment.DoesNotExist:
|
||||
return JsonResponse(
|
||||
{"status": "error", "message": "API deployment not found"}, status=404
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API notifications for {api_id}: {e}")
|
||||
return JsonResponse(
|
||||
{"status": "error", "message": INTERNAL_SERVER_ERROR_MSG}, status=500
|
||||
)
|
||||
|
||||
|
||||
@csrf_exempt # Safe: Internal API with Bearer token auth, service-to-service only
|
||||
@require_http_methods(["GET"])
|
||||
def get_pipeline_data(request, pipeline_id):
|
||||
"""Get basic pipeline data for notification purposes.
|
||||
|
||||
Used by callback worker to determine pipeline type and name.
|
||||
"""
|
||||
try:
|
||||
# Get pipeline with organization filtering
|
||||
pipeline_queryset = Pipeline.objects.filter(id=pipeline_id)
|
||||
pipeline_queryset = filter_queryset_by_organization(
|
||||
pipeline_queryset, request, "organization"
|
||||
)
|
||||
pipeline = get_object_or_404(pipeline_queryset)
|
||||
|
||||
return JsonResponse(
|
||||
{
|
||||
"status": "success",
|
||||
"pipeline_id": str(pipeline.id),
|
||||
"pipeline_name": pipeline.pipeline_name,
|
||||
"pipeline_type": pipeline.pipeline_type,
|
||||
"last_run_status": pipeline.last_run_status,
|
||||
}
|
||||
)
|
||||
|
||||
except Pipeline.DoesNotExist:
|
||||
return JsonResponse(
|
||||
{"status": "error", "message": "Pipeline not found"}, status=404
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pipeline data for {pipeline_id}: {e}")
|
||||
return JsonResponse(
|
||||
{"status": "error", "message": INTERNAL_SERVER_ERROR_MSG}, status=500
|
||||
)
|
||||
|
||||
|
||||
@csrf_exempt # Safe: Internal API with Bearer token auth, service-to-service only
|
||||
@require_http_methods(["GET"])
|
||||
def get_api_data(request, api_id):
|
||||
"""Get basic API deployment data for notification purposes.
|
||||
|
||||
Used by callback worker to determine API name and details.
|
||||
"""
|
||||
try:
|
||||
# Get API deployment with organization filtering
|
||||
api_queryset = APIDeployment.objects.filter(id=api_id)
|
||||
api_queryset = filter_queryset_by_organization(
|
||||
api_queryset, request, "organization"
|
||||
)
|
||||
api = get_object_or_404(api_queryset)
|
||||
|
||||
return JsonResponse(
|
||||
{
|
||||
"status": "success",
|
||||
"api_id": str(api.id),
|
||||
"api_name": api.api_name,
|
||||
"display_name": api.display_name,
|
||||
"is_active": api.is_active,
|
||||
}
|
||||
)
|
||||
|
||||
except APIDeployment.DoesNotExist:
|
||||
return JsonResponse(
|
||||
{"status": "error", "message": "API deployment not found"}, status=404
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API data for {api_id}: {e}")
|
||||
return JsonResponse(
|
||||
{"status": "error", "message": INTERNAL_SERVER_ERROR_MSG}, status=500
|
||||
)
|
||||
128
backend/notification_v2/internal_serializers.py
Normal file
128
backend/notification_v2/internal_serializers.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Internal API Serializers for Notification/Webhook Operations
|
||||
Used by Celery workers for service-to-service communication.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from notification_v2.enums import AuthorizationType, NotificationType, PlatformType
|
||||
from notification_v2.models import Notification
|
||||
|
||||
|
||||
class NotificationSerializer(serializers.ModelSerializer):
|
||||
"""Serializer for Notification model."""
|
||||
|
||||
class Meta:
|
||||
model = Notification
|
||||
fields = [
|
||||
"id",
|
||||
"url",
|
||||
"authorization_type",
|
||||
"authorization_key",
|
||||
"authorization_header",
|
||||
"notification_type",
|
||||
"platform",
|
||||
"max_retries",
|
||||
"is_active",
|
||||
"created_at",
|
||||
"modified_at",
|
||||
"pipeline",
|
||||
"api",
|
||||
]
|
||||
|
||||
|
||||
class WebhookNotificationRequestSerializer(serializers.Serializer):
|
||||
"""Serializer for webhook notification requests."""
|
||||
|
||||
notification_id = serializers.UUIDField(required=False)
|
||||
url = serializers.URLField(required=True)
|
||||
payload = serializers.JSONField(required=True)
|
||||
authorization_type = serializers.ChoiceField(
|
||||
choices=AuthorizationType.choices(), default=AuthorizationType.NONE.value
|
||||
)
|
||||
authorization_key = serializers.CharField(required=False, allow_blank=True)
|
||||
authorization_header = serializers.CharField(required=False, allow_blank=True)
|
||||
headers = serializers.DictField(required=False, default=dict)
|
||||
timeout = serializers.IntegerField(default=30, min_value=1, max_value=300)
|
||||
max_retries = serializers.IntegerField(default=3, min_value=0, max_value=10)
|
||||
retry_delay = serializers.IntegerField(default=60, min_value=1, max_value=3600)
|
||||
|
||||
|
||||
class WebhookNotificationResponseSerializer(serializers.Serializer):
|
||||
"""Serializer for webhook notification responses."""
|
||||
|
||||
task_id = serializers.CharField()
|
||||
notification_id = serializers.UUIDField(required=False)
|
||||
url = serializers.URLField()
|
||||
status = serializers.CharField()
|
||||
queued_at = serializers.DateTimeField()
|
||||
|
||||
|
||||
class WebhookStatusSerializer(serializers.Serializer):
|
||||
"""Serializer for webhook delivery status."""
|
||||
|
||||
task_id = serializers.CharField()
|
||||
status = serializers.CharField()
|
||||
url = serializers.CharField()
|
||||
attempts = serializers.IntegerField()
|
||||
success = serializers.BooleanField()
|
||||
error_message = serializers.CharField(required=False, allow_null=True)
|
||||
|
||||
|
||||
class WebhookBatchRequestSerializer(serializers.Serializer):
|
||||
"""Serializer for batch webhook requests."""
|
||||
|
||||
batch_name = serializers.CharField(required=False, max_length=255)
|
||||
webhooks = serializers.ListField(
|
||||
child=WebhookNotificationRequestSerializer(), min_length=1, max_length=100
|
||||
)
|
||||
delay_between_requests = serializers.IntegerField(
|
||||
default=0, min_value=0, max_value=60
|
||||
)
|
||||
|
||||
|
||||
class WebhookBatchResponseSerializer(serializers.Serializer):
|
||||
"""Serializer for batch webhook responses."""
|
||||
|
||||
batch_id = serializers.CharField()
|
||||
batch_name = serializers.CharField()
|
||||
total_webhooks = serializers.IntegerField()
|
||||
queued_webhooks = serializers.ListField(child=WebhookNotificationResponseSerializer())
|
||||
failed_webhooks = serializers.ListField(child=serializers.DictField())
|
||||
|
||||
|
||||
class WebhookConfigurationSerializer(serializers.Serializer):
|
||||
"""Serializer for webhook configuration."""
|
||||
|
||||
notification_id = serializers.UUIDField()
|
||||
url = serializers.URLField()
|
||||
authorization_type = serializers.ChoiceField(choices=AuthorizationType.choices())
|
||||
authorization_key = serializers.CharField(required=False, allow_blank=True)
|
||||
authorization_header = serializers.CharField(required=False, allow_blank=True)
|
||||
max_retries = serializers.IntegerField()
|
||||
is_active = serializers.BooleanField()
|
||||
|
||||
|
||||
class NotificationListSerializer(serializers.Serializer):
|
||||
"""Serializer for notification list filters."""
|
||||
|
||||
pipeline_id = serializers.UUIDField(required=False)
|
||||
api_deployment_id = serializers.UUIDField(required=False)
|
||||
notification_type = serializers.ChoiceField(
|
||||
choices=NotificationType.choices(), required=False
|
||||
)
|
||||
platform = serializers.ChoiceField(choices=PlatformType.choices(), required=False)
|
||||
is_active = serializers.BooleanField(required=False)
|
||||
|
||||
|
||||
class WebhookTestSerializer(serializers.Serializer):
|
||||
"""Serializer for webhook testing."""
|
||||
|
||||
url = serializers.URLField(required=True)
|
||||
payload = serializers.JSONField(required=True)
|
||||
authorization_type = serializers.ChoiceField(
|
||||
choices=AuthorizationType.choices(), default=AuthorizationType.NONE.value
|
||||
)
|
||||
authorization_key = serializers.CharField(required=False, allow_blank=True)
|
||||
authorization_header = serializers.CharField(required=False, allow_blank=True)
|
||||
headers = serializers.DictField(required=False, default=dict)
|
||||
timeout = serializers.IntegerField(default=30, min_value=1, max_value=300)
|
||||
56
backend/notification_v2/internal_urls.py
Normal file
56
backend/notification_v2/internal_urls.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Internal API URLs for Notification/Webhook Operations
|
||||
URL patterns for webhook notification internal APIs.
|
||||
"""
|
||||
|
||||
from django.urls import include, path
|
||||
from rest_framework.routers import DefaultRouter
|
||||
|
||||
from . import internal_api_views
|
||||
from .internal_views import (
|
||||
WebhookBatchAPIView,
|
||||
WebhookBatchStatusAPIView,
|
||||
WebhookInternalViewSet,
|
||||
WebhookMetricsAPIView,
|
||||
WebhookSendAPIView,
|
||||
WebhookStatusAPIView,
|
||||
WebhookTestAPIView,
|
||||
)
|
||||
|
||||
# Create router for webhook viewsets
|
||||
router = DefaultRouter()
|
||||
router.register(r"", WebhookInternalViewSet, basename="webhook-internal")
|
||||
|
||||
urlpatterns = [
|
||||
# Notification data endpoints for workers
|
||||
path(
|
||||
"pipeline/<str:pipeline_id>/notifications/",
|
||||
internal_api_views.get_pipeline_notifications,
|
||||
name="get_pipeline_notifications",
|
||||
),
|
||||
path(
|
||||
"pipeline/<str:pipeline_id>/",
|
||||
internal_api_views.get_pipeline_data,
|
||||
name="get_pipeline_data",
|
||||
),
|
||||
path(
|
||||
"api/<str:api_id>/notifications/",
|
||||
internal_api_views.get_api_notifications,
|
||||
name="get_api_notifications",
|
||||
),
|
||||
path(
|
||||
"api/<str:api_id>/",
|
||||
internal_api_views.get_api_data,
|
||||
name="get_api_data",
|
||||
),
|
||||
# Webhook operation endpoints
|
||||
path("send/", WebhookSendAPIView.as_view(), name="webhook-send"),
|
||||
path("batch/", WebhookBatchAPIView.as_view(), name="webhook-batch"),
|
||||
path("test/", WebhookTestAPIView.as_view(), name="webhook-test"),
|
||||
path("status/<str:task_id>/", WebhookStatusAPIView.as_view(), name="webhook-status"),
|
||||
path(
|
||||
"batch-status/", WebhookBatchStatusAPIView.as_view(), name="webhook-batch-status"
|
||||
),
|
||||
path("metrics/", WebhookMetricsAPIView.as_view(), name="webhook-metrics"),
|
||||
# Webhook configuration CRUD (via router)
|
||||
path("", include(router.urls)),
|
||||
]
|
||||
559
backend/notification_v2/internal_views.py
Normal file
559
backend/notification_v2/internal_views.py
Normal file
@@ -0,0 +1,559 @@
|
||||
"""Internal API Views for Webhook Operations
|
||||
Handles webhook notification related endpoints for internal services.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from celery import current_app as celery_app
|
||||
from celery.result import AsyncResult
|
||||
from django.utils import timezone
|
||||
from rest_framework import status, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
from utils.organization_utils import filter_queryset_by_organization
|
||||
|
||||
from notification_v2.enums import AuthorizationType, NotificationType, PlatformType
|
||||
|
||||
# Import serializers from notification_v2 internal API
|
||||
from notification_v2.internal_serializers import (
|
||||
NotificationListSerializer,
|
||||
NotificationSerializer,
|
||||
WebhookBatchRequestSerializer,
|
||||
WebhookBatchResponseSerializer,
|
||||
WebhookConfigurationSerializer,
|
||||
WebhookNotificationRequestSerializer,
|
||||
WebhookNotificationResponseSerializer,
|
||||
WebhookStatusSerializer,
|
||||
WebhookTestSerializer,
|
||||
)
|
||||
from notification_v2.models import Notification
|
||||
from notification_v2.provider.webhook.webhook import send_webhook_notification
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
APPLICATION_JSON = "application/json"
|
||||
|
||||
|
||||
class WebhookInternalViewSet(viewsets.ReadOnlyModelViewSet):
|
||||
"""Internal API ViewSet for Webhook/Notification operations."""
|
||||
|
||||
serializer_class = NotificationSerializer
|
||||
lookup_field = "id"
|
||||
|
||||
def get_queryset(self):
|
||||
"""Get notifications filtered by organization context."""
|
||||
queryset = Notification.objects.all()
|
||||
return filter_queryset_by_organization(queryset, self.request)
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
"""List notifications with filtering options."""
|
||||
try:
|
||||
serializer = NotificationListSerializer(data=request.query_params)
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
filters = serializer.validated_data
|
||||
queryset = self.get_queryset()
|
||||
|
||||
# Apply filters
|
||||
if filters.get("pipeline_id"):
|
||||
queryset = queryset.filter(pipeline_id=filters["pipeline_id"])
|
||||
if filters.get("api_deployment_id"):
|
||||
queryset = queryset.filter(api_id=filters["api_deployment_id"])
|
||||
if filters.get("notification_type"):
|
||||
queryset = queryset.filter(notification_type=filters["notification_type"])
|
||||
if filters.get("platform"):
|
||||
queryset = queryset.filter(platform=filters["platform"])
|
||||
if filters.get("is_active") is not None:
|
||||
queryset = queryset.filter(is_active=filters["is_active"])
|
||||
|
||||
notifications = NotificationSerializer(queryset, many=True).data
|
||||
|
||||
return Response({"count": len(notifications), "notifications": notifications})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list notifications: {str(e)}")
|
||||
return Response(
|
||||
{"error": "Failed to list notifications", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
@action(detail=True, methods=["get"])
|
||||
def configuration(self, request, id=None):
|
||||
"""Get webhook configuration for a notification."""
|
||||
try:
|
||||
notification = self.get_object()
|
||||
|
||||
config_data = {
|
||||
"notification_id": notification.id,
|
||||
"url": notification.url,
|
||||
"authorization_type": notification.authorization_type,
|
||||
"authorization_key": notification.authorization_key,
|
||||
"authorization_header": notification.authorization_header,
|
||||
"max_retries": notification.max_retries,
|
||||
"is_active": notification.is_active,
|
||||
}
|
||||
|
||||
serializer = WebhookConfigurationSerializer(config_data)
|
||||
return Response(serializer.data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get webhook configuration {id}: {str(e)}")
|
||||
return Response(
|
||||
{"error": "Failed to get webhook configuration", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
class WebhookSendAPIView(APIView):
|
||||
"""Internal API endpoint for sending webhook notifications."""
|
||||
|
||||
def post(self, request):
|
||||
"""Send a webhook notification."""
|
||||
try:
|
||||
serializer = WebhookNotificationRequestSerializer(data=request.data)
|
||||
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
validated_data = serializer.validated_data
|
||||
|
||||
# Build headers based on authorization type
|
||||
headers = self._build_headers(validated_data)
|
||||
|
||||
# Send webhook notification task
|
||||
task = send_webhook_notification.delay(
|
||||
url=validated_data["url"],
|
||||
payload=validated_data["payload"],
|
||||
headers=headers,
|
||||
timeout=validated_data["timeout"],
|
||||
max_retries=validated_data["max_retries"],
|
||||
retry_delay=validated_data["retry_delay"],
|
||||
)
|
||||
|
||||
# Prepare response
|
||||
response_data = {
|
||||
"task_id": task.id,
|
||||
"notification_id": validated_data.get("notification_id"),
|
||||
"url": validated_data["url"],
|
||||
"status": "queued",
|
||||
"queued_at": timezone.now(),
|
||||
}
|
||||
|
||||
response_serializer = WebhookNotificationResponseSerializer(response_data)
|
||||
|
||||
logger.info(
|
||||
f"Queued webhook notification task {task.id} for URL {validated_data['url']}"
|
||||
)
|
||||
|
||||
return Response(response_serializer.data, status=status.HTTP_202_ACCEPTED)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send webhook notification: {str(e)}")
|
||||
return Response(
|
||||
{"error": "Failed to send webhook notification", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
def _build_headers(self, validated_data: dict[str, Any]) -> dict[str, str]:
|
||||
"""Build headers based on authorization configuration."""
|
||||
headers = {"Content-Type": APPLICATION_JSON}
|
||||
|
||||
auth_type = validated_data.get("authorization_type", AuthorizationType.NONE.value)
|
||||
auth_key = validated_data.get("authorization_key")
|
||||
auth_header = validated_data.get("authorization_header")
|
||||
|
||||
if validated_data.get("headers"):
|
||||
headers.update(validated_data["headers"])
|
||||
|
||||
if auth_type == AuthorizationType.BEARER.value and auth_key:
|
||||
headers["Authorization"] = f"Bearer {auth_key}"
|
||||
elif auth_type == AuthorizationType.API_KEY.value and auth_key:
|
||||
headers["Authorization"] = auth_key
|
||||
elif (
|
||||
auth_type == AuthorizationType.CUSTOM_HEADER.value
|
||||
and auth_header
|
||||
and auth_key
|
||||
):
|
||||
headers[auth_header] = auth_key
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
class WebhookStatusAPIView(APIView):
|
||||
"""Internal API endpoint for checking webhook delivery status."""
|
||||
|
||||
def get(self, request, task_id):
|
||||
"""Get webhook delivery status by task ID."""
|
||||
try:
|
||||
task_result = AsyncResult(task_id, app=celery_app)
|
||||
|
||||
status_data = {
|
||||
"task_id": task_id,
|
||||
"status": task_result.status,
|
||||
"url": "unknown",
|
||||
"attempts": 0,
|
||||
"success": task_result.successful(),
|
||||
"error_message": None,
|
||||
}
|
||||
|
||||
if task_result.failed():
|
||||
status_data["error_message"] = str(task_result.result)
|
||||
elif task_result.successful():
|
||||
status_data["attempts"] = getattr(task_result.result, "attempts", 1)
|
||||
|
||||
serializer = WebhookStatusSerializer(status_data)
|
||||
return Response(serializer.data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get webhook status for task {task_id}: {str(e)}")
|
||||
return Response(
|
||||
{"error": "Failed to get webhook status", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
class WebhookBatchAPIView(APIView):
|
||||
"""Internal API endpoint for sending batch webhook notifications."""
|
||||
|
||||
def post(self, request):
|
||||
"""Send multiple webhook notifications in batch."""
|
||||
try:
|
||||
serializer = WebhookBatchRequestSerializer(data=request.data)
|
||||
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
validated_data = serializer.validated_data
|
||||
webhooks = validated_data["webhooks"]
|
||||
delay_between = validated_data.get("delay_between_requests", 0)
|
||||
|
||||
batch_id = str(uuid.uuid4())
|
||||
queued_webhooks = []
|
||||
failed_webhooks = []
|
||||
|
||||
for i, webhook_data in enumerate(webhooks):
|
||||
try:
|
||||
headers = self._build_headers(webhook_data)
|
||||
countdown = i * delay_between if delay_between > 0 else 0
|
||||
|
||||
task = send_webhook_notification.apply_async(
|
||||
args=[
|
||||
webhook_data["url"],
|
||||
webhook_data["payload"],
|
||||
headers,
|
||||
webhook_data["timeout"],
|
||||
],
|
||||
kwargs={
|
||||
"max_retries": webhook_data["max_retries"],
|
||||
"retry_delay": webhook_data["retry_delay"],
|
||||
},
|
||||
countdown=countdown,
|
||||
)
|
||||
|
||||
queued_webhooks.append(
|
||||
{
|
||||
"task_id": task.id,
|
||||
"notification_id": webhook_data.get("notification_id"),
|
||||
"url": webhook_data["url"],
|
||||
"status": "queued",
|
||||
"queued_at": timezone.now(),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_webhooks.append({"url": webhook_data["url"], "error": str(e)})
|
||||
|
||||
response_data = {
|
||||
"batch_id": batch_id,
|
||||
"batch_name": validated_data.get("batch_name", f"Batch-{batch_id[:8]}"),
|
||||
"total_webhooks": len(webhooks),
|
||||
"queued_webhooks": queued_webhooks,
|
||||
"failed_webhooks": failed_webhooks,
|
||||
}
|
||||
|
||||
response_serializer = WebhookBatchResponseSerializer(response_data)
|
||||
|
||||
logger.info(
|
||||
f"Queued batch {batch_id} with {len(queued_webhooks)} webhooks, {len(failed_webhooks)} failed"
|
||||
)
|
||||
|
||||
return Response(response_serializer.data, status=status.HTTP_202_ACCEPTED)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send webhook batch: {str(e)}")
|
||||
return Response(
|
||||
{"error": "Failed to send webhook batch", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
def _build_headers(self, webhook_data: dict[str, Any]) -> dict[str, str]:
|
||||
"""Build headers for webhook request."""
|
||||
headers = {"Content-Type": APPLICATION_JSON}
|
||||
|
||||
auth_type = webhook_data.get("authorization_type", AuthorizationType.NONE.value)
|
||||
auth_key = webhook_data.get("authorization_key")
|
||||
auth_header = webhook_data.get("authorization_header")
|
||||
|
||||
if webhook_data.get("headers"):
|
||||
headers.update(webhook_data["headers"])
|
||||
|
||||
if auth_type == AuthorizationType.BEARER.value and auth_key:
|
||||
headers["Authorization"] = f"Bearer {auth_key}"
|
||||
elif auth_type == AuthorizationType.API_KEY.value and auth_key:
|
||||
headers["Authorization"] = auth_key
|
||||
elif (
|
||||
auth_type == AuthorizationType.CUSTOM_HEADER.value
|
||||
and auth_header
|
||||
and auth_key
|
||||
):
|
||||
headers[auth_header] = auth_key
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
class WebhookTestAPIView(APIView):
|
||||
"""Internal API endpoint for testing webhook configurations."""
|
||||
|
||||
def post(self, request):
|
||||
"""Test a webhook configuration without queuing."""
|
||||
try:
|
||||
serializer = WebhookTestSerializer(data=request.data)
|
||||
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
validated_data = serializer.validated_data
|
||||
headers = self._build_headers(validated_data)
|
||||
|
||||
import requests
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url=validated_data["url"],
|
||||
json=validated_data["payload"],
|
||||
headers=headers,
|
||||
timeout=validated_data["timeout"],
|
||||
)
|
||||
|
||||
test_result = {
|
||||
"success": response.status_code < 400,
|
||||
"status_code": response.status_code,
|
||||
"response_headers": dict(response.headers),
|
||||
"response_body": response.text[:1000],
|
||||
"url": validated_data["url"],
|
||||
"request_headers": headers,
|
||||
"request_payload": validated_data["payload"],
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Webhook test to {validated_data['url']} completed with status {response.status_code}"
|
||||
)
|
||||
|
||||
return Response(test_result)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
test_result = {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"url": validated_data["url"],
|
||||
"request_headers": headers,
|
||||
"request_payload": validated_data["payload"],
|
||||
}
|
||||
|
||||
return Response(test_result, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test webhook: {str(e)}")
|
||||
return Response(
|
||||
{"error": "Failed to test webhook", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
def _build_headers(self, validated_data: dict[str, Any]) -> dict[str, str]:
|
||||
"""Build headers for webhook test."""
|
||||
headers = {"Content-Type": APPLICATION_JSON}
|
||||
|
||||
auth_type = validated_data.get("authorization_type", AuthorizationType.NONE.value)
|
||||
auth_key = validated_data.get("authorization_key")
|
||||
auth_header = validated_data.get("authorization_header")
|
||||
|
||||
if validated_data.get("headers"):
|
||||
headers.update(validated_data["headers"])
|
||||
|
||||
if auth_type == AuthorizationType.BEARER.value and auth_key:
|
||||
headers["Authorization"] = f"Bearer {auth_key}"
|
||||
elif auth_type == AuthorizationType.API_KEY.value and auth_key:
|
||||
headers["Authorization"] = auth_key
|
||||
elif (
|
||||
auth_type == AuthorizationType.CUSTOM_HEADER.value
|
||||
and auth_header
|
||||
and auth_key
|
||||
):
|
||||
headers[auth_header] = auth_key
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
class WebhookBatchStatusAPIView(APIView):
|
||||
"""Internal API endpoint for checking batch webhook delivery status."""
|
||||
|
||||
def get(self, request):
|
||||
"""Get batch webhook delivery status."""
|
||||
try:
|
||||
batch_id = request.query_params.get("batch_id")
|
||||
task_ids = request.query_params.get("task_ids", "").split(",")
|
||||
|
||||
if not batch_id and not task_ids:
|
||||
return Response(
|
||||
{"error": "Either batch_id or task_ids parameter is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
batch_results = []
|
||||
|
||||
if task_ids and task_ids[0]: # task_ids is not empty
|
||||
for task_id in task_ids:
|
||||
if task_id.strip():
|
||||
try:
|
||||
task_result = AsyncResult(task_id.strip(), app=celery_app)
|
||||
|
||||
batch_results.append(
|
||||
{
|
||||
"task_id": task_id.strip(),
|
||||
"status": task_result.status,
|
||||
"success": task_result.successful(),
|
||||
"error_message": str(task_result.result)
|
||||
if task_result.failed()
|
||||
else None,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
batch_results.append(
|
||||
{
|
||||
"task_id": task_id.strip(),
|
||||
"status": "ERROR",
|
||||
"success": False,
|
||||
"error_message": f"Failed to get task status: {str(e)}",
|
||||
}
|
||||
)
|
||||
|
||||
response_data = {
|
||||
"batch_id": batch_id,
|
||||
"total_tasks": len(batch_results),
|
||||
"results": batch_results,
|
||||
"summary": {
|
||||
"completed": sum(
|
||||
1 for r in batch_results if r["status"] == "SUCCESS"
|
||||
),
|
||||
"failed": sum(1 for r in batch_results if r["status"] == "FAILURE"),
|
||||
"pending": sum(1 for r in batch_results if r["status"] == "PENDING"),
|
||||
"running": sum(1 for r in batch_results if r["status"] == "STARTED"),
|
||||
},
|
||||
}
|
||||
|
||||
return Response(response_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get batch webhook status: {str(e)}")
|
||||
return Response(
|
||||
{"error": "Failed to get batch webhook status", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
class WebhookMetricsAPIView(APIView):
|
||||
"""Internal API endpoint for webhook delivery metrics."""
|
||||
|
||||
def get(self, request):
|
||||
"""Get webhook delivery metrics."""
|
||||
try:
|
||||
# Get query parameters
|
||||
organization_id = request.query_params.get("organization_id")
|
||||
start_date = request.query_params.get("start_date")
|
||||
end_date = request.query_params.get("end_date")
|
||||
|
||||
# Get base queryset
|
||||
queryset = Notification.objects.all()
|
||||
queryset = filter_queryset_by_organization(queryset, request)
|
||||
|
||||
# Apply filters
|
||||
if organization_id:
|
||||
queryset = queryset.filter(organization_id=organization_id)
|
||||
|
||||
if start_date:
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
start_dt = datetime.fromisoformat(start_date.replace("Z", "+00:00"))
|
||||
queryset = queryset.filter(created_at__gte=start_dt)
|
||||
except ValueError:
|
||||
return Response(
|
||||
{"error": "Invalid start_date format. Use ISO format."},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
if end_date:
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
end_dt = datetime.fromisoformat(end_date.replace("Z", "+00:00"))
|
||||
queryset = queryset.filter(created_at__lte=end_dt)
|
||||
except ValueError:
|
||||
return Response(
|
||||
{"error": "Invalid end_date format. Use ISO format."},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Calculate metrics
|
||||
total_webhooks = queryset.count()
|
||||
active_webhooks = queryset.filter(is_active=True).count()
|
||||
inactive_webhooks = queryset.filter(is_active=False).count()
|
||||
|
||||
# Group by notification type
|
||||
type_breakdown = {}
|
||||
for notification_type in NotificationType:
|
||||
count = queryset.filter(notification_type=notification_type.value).count()
|
||||
if count > 0:
|
||||
type_breakdown[notification_type.value] = count
|
||||
|
||||
# Group by platform
|
||||
platform_breakdown = {}
|
||||
for platform_type in PlatformType:
|
||||
count = queryset.filter(platform=platform_type.value).count()
|
||||
if count > 0:
|
||||
platform_breakdown[platform_type.value] = count
|
||||
|
||||
# Group by authorization type
|
||||
auth_breakdown = {}
|
||||
for auth_type in AuthorizationType:
|
||||
count = queryset.filter(authorization_type=auth_type.value).count()
|
||||
if count > 0:
|
||||
auth_breakdown[auth_type.value] = count
|
||||
|
||||
metrics = {
|
||||
"total_webhooks": total_webhooks,
|
||||
"active_webhooks": active_webhooks,
|
||||
"inactive_webhooks": inactive_webhooks,
|
||||
"type_breakdown": type_breakdown,
|
||||
"platform_breakdown": platform_breakdown,
|
||||
"authorization_breakdown": auth_breakdown,
|
||||
"filters_applied": {
|
||||
"organization_id": organization_id,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
},
|
||||
}
|
||||
|
||||
return Response(metrics)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get webhook metrics: {str(e)}")
|
||||
return Response(
|
||||
{"error": "Failed to get webhook metrics", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
167
backend/pipeline_v2/internal_api_views.py
Normal file
167
backend/pipeline_v2/internal_api_views.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import logging
|
||||
|
||||
from api_v2.models import APIDeployment
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ViewSet
|
||||
from utils.organization_utils import filter_queryset_by_organization
|
||||
|
||||
from pipeline_v2.models import Pipeline
|
||||
|
||||
from .serializers.internal import APIDeploymentSerializer, PipelineSerializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineInternalViewSet(ViewSet):
|
||||
def retrieve(self, request, pk=None):
|
||||
logger.info(f"[PipelineInternalViewSet] Retrieving data for ID: {pk}")
|
||||
|
||||
try:
|
||||
# 1️⃣ Try in Pipeline
|
||||
pipeline_data = self._fetch_single_record(
|
||||
pk,
|
||||
request,
|
||||
Pipeline.objects.filter(id=pk),
|
||||
PipelineSerializer,
|
||||
"Pipeline",
|
||||
)
|
||||
if isinstance(pipeline_data, dict): # Found successfully
|
||||
return Response({"status": "success", "pipeline": pipeline_data})
|
||||
elif isinstance(pipeline_data, Response): # Integrity error
|
||||
return pipeline_data
|
||||
|
||||
# 2️⃣ Try in APIDeployment
|
||||
api_data = self._fetch_single_record(
|
||||
pk,
|
||||
request,
|
||||
APIDeployment.objects.filter(id=pk),
|
||||
APIDeploymentSerializer,
|
||||
"APIDeployment",
|
||||
)
|
||||
if isinstance(api_data, dict):
|
||||
return Response({"status": "success", "pipeline": api_data})
|
||||
elif isinstance(api_data, Response):
|
||||
return api_data
|
||||
|
||||
# 3️⃣ Not found anywhere
|
||||
logger.warning(f"⚠️ No Pipeline or APIDeployment found for {pk}")
|
||||
return Response(
|
||||
{"status": "error", "message": "Pipeline not found"}, status=404
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"💥 Error retrieving pipeline or deployment for {pk}")
|
||||
return Response(
|
||||
{"status": "error", "message": "Internal server error"}, status=500
|
||||
)
|
||||
|
||||
# Helper function for DRY logic
|
||||
def _fetch_single_record(self, pk, request, qs, serializer_cls, model_name):
|
||||
qs = filter_queryset_by_organization(qs, request, "organization")
|
||||
count = qs.count()
|
||||
|
||||
if count == 1:
|
||||
obj = qs.first()
|
||||
logger.info(f"✅ Found {model_name} entry: {obj}")
|
||||
return serializer_cls(obj).data
|
||||
elif count > 1:
|
||||
logger.error(f"❌ Multiple {model_name} entries found for {pk}")
|
||||
return Response(
|
||||
{
|
||||
"status": "error",
|
||||
"message": f"Data integrity error: multiple {model_name} entries found",
|
||||
},
|
||||
status=500,
|
||||
)
|
||||
|
||||
return None # Not found in this model
|
||||
|
||||
def update(self, request, pk=None):
|
||||
"""Update pipeline status with support for completion states."""
|
||||
try:
|
||||
new_status = request.data.get("status")
|
||||
if not new_status:
|
||||
return Response(
|
||||
{"status": "error", "message": "Status is required"}, status=400
|
||||
)
|
||||
|
||||
# Extract additional parameters for completion states
|
||||
is_end = request.data.get("is_end", False)
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from pipeline_v2.pipeline_processor import PipelineProcessor
|
||||
|
||||
# Try to update pipeline first
|
||||
try:
|
||||
pipeline_qs = Pipeline.objects.filter(id=pk)
|
||||
pipeline_qs = filter_queryset_by_organization(
|
||||
pipeline_qs, request, "organization"
|
||||
)
|
||||
pipeline = pipeline_qs.first()
|
||||
|
||||
if pipeline:
|
||||
# Use PipelineProcessor.update_pipeline() without execution_id and error_message
|
||||
# This will update status but skip notifications (since execution_id=None)
|
||||
PipelineProcessor.update_pipeline(
|
||||
pipeline_guid=pk,
|
||||
status=new_status,
|
||||
is_end=is_end,
|
||||
)
|
||||
|
||||
return Response(
|
||||
{
|
||||
"status": "success",
|
||||
"pipeline_id": pk,
|
||||
"new_status": new_status,
|
||||
"is_end": is_end,
|
||||
"message": "Pipeline status updated successfully",
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating pipeline status: {e}")
|
||||
return Response(
|
||||
{"status": "error", "message": f"Failed to update pipeline: {e}"},
|
||||
status=500,
|
||||
)
|
||||
|
||||
# Try API deployment if pipeline not found
|
||||
try:
|
||||
api_qs = APIDeployment.objects.filter(id=pk)
|
||||
api_qs = filter_queryset_by_organization(api_qs, request, "organization")
|
||||
api_deployment = api_qs.first()
|
||||
|
||||
if api_deployment:
|
||||
# For API deployments, log the status update
|
||||
logger.info(f"Updated API deployment {pk} status to {new_status}")
|
||||
|
||||
return Response(
|
||||
{
|
||||
"status": "success",
|
||||
"pipeline_id": pk,
|
||||
"new_status": new_status,
|
||||
"message": "API deployment status updated successfully",
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating API deployment status: {e}")
|
||||
return Response(
|
||||
{
|
||||
"status": "error",
|
||||
"message": f"Failed to update API deployment: {e}",
|
||||
},
|
||||
status=500,
|
||||
)
|
||||
|
||||
# Not found in either model
|
||||
return Response(
|
||||
{"status": "error", "message": "Pipeline or API deployment not found"},
|
||||
status=404,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating pipeline/API deployment status for {pk}: {e}")
|
||||
return Response(
|
||||
{"status": "error", "message": "Internal server error"}, status=500
|
||||
)
|
||||
17
backend/pipeline_v2/internal_urls.py
Normal file
17
backend/pipeline_v2/internal_urls.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Internal API URLs for Pipeline Operations"""
|
||||
|
||||
from django.urls import include, path
|
||||
from rest_framework.routers import DefaultRouter
|
||||
|
||||
from .internal_api_views import (
|
||||
PipelineInternalViewSet,
|
||||
)
|
||||
|
||||
# Create router for pipeline viewsets
|
||||
router = DefaultRouter()
|
||||
router.register(r"", PipelineInternalViewSet, basename="pipeline-internal")
|
||||
|
||||
urlpatterns = [
|
||||
# Pipeline internal APIs
|
||||
path("", include(router.urls)),
|
||||
]
|
||||
@@ -113,7 +113,12 @@ class PipelineProcessor:
|
||||
pipeline = PipelineProcessor._update_pipeline_status(
|
||||
pipeline=pipeline, is_end=is_end, status=status, is_active=is_active
|
||||
)
|
||||
PipelineProcessor._send_notification(
|
||||
pipeline=pipeline, execution_id=execution_id, error_message=error_message
|
||||
)
|
||||
|
||||
# Only send notifications if execution_id is provided
|
||||
# This avoids duplicate notifications when called from workers (who handle notifications separately)
|
||||
if execution_id:
|
||||
PipelineProcessor._send_notification(
|
||||
pipeline=pipeline, execution_id=execution_id, error_message=error_message
|
||||
)
|
||||
|
||||
logger.info(f"Updated pipeline {pipeline_guid} status: {status}")
|
||||
|
||||
59
backend/pipeline_v2/serializers/internal.py
Normal file
59
backend/pipeline_v2/serializers/internal.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from api_v2.models import APIDeployment
|
||||
from pipeline_v2.models import Pipeline
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
class PipelineSerializer(serializers.ModelSerializer):
|
||||
# Add computed fields for callback worker
|
||||
is_api = serializers.SerializerMethodField()
|
||||
resolved_pipeline_type = serializers.SerializerMethodField()
|
||||
resolved_pipeline_name = serializers.SerializerMethodField()
|
||||
pipeline_name = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = Pipeline
|
||||
fields = "__all__"
|
||||
|
||||
def get_is_api(self, obj):
|
||||
"""Returns False for Pipeline model entries."""
|
||||
return False
|
||||
|
||||
def get_resolved_pipeline_type(self, obj):
|
||||
"""Returns the pipeline type from the Pipeline model."""
|
||||
return obj.pipeline_type
|
||||
|
||||
def get_resolved_pipeline_name(self, obj):
|
||||
"""Returns the pipeline name from the Pipeline model."""
|
||||
return obj.pipeline_name
|
||||
|
||||
def get_pipeline_name(self, obj):
|
||||
"""Returns the pipeline name for callback worker compatibility."""
|
||||
return obj.pipeline_name
|
||||
|
||||
|
||||
class APIDeploymentSerializer(serializers.ModelSerializer):
|
||||
# Add computed fields for callback worker
|
||||
is_api = serializers.SerializerMethodField()
|
||||
resolved_pipeline_type = serializers.SerializerMethodField()
|
||||
resolved_pipeline_name = serializers.SerializerMethodField()
|
||||
pipeline_name = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = APIDeployment
|
||||
fields = "__all__"
|
||||
|
||||
def get_is_api(self, obj):
|
||||
"""Returns True for APIDeployment model entries."""
|
||||
return True
|
||||
|
||||
def get_resolved_pipeline_type(self, obj):
|
||||
"""Returns 'API' for APIDeployment model entries."""
|
||||
return "API"
|
||||
|
||||
def get_resolved_pipeline_name(self, obj):
|
||||
"""Returns the api_name from the APIDeployment model."""
|
||||
return obj.api_name
|
||||
|
||||
def get_pipeline_name(self, obj):
|
||||
"""Returns the api_name for callback worker compatibility."""
|
||||
return obj.api_name
|
||||
18
backend/platform_settings_v2/internal_urls.py
Normal file
18
backend/platform_settings_v2/internal_urls.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Internal URLs for platform settings
|
||||
|
||||
Routes for internal API endpoints used by workers.
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from .internal_views import InternalPlatformKeyView
|
||||
|
||||
app_name = "platform_settings_internal"
|
||||
|
||||
urlpatterns = [
|
||||
path(
|
||||
"platform-key/",
|
||||
InternalPlatformKeyView.as_view(),
|
||||
name="platform_key",
|
||||
),
|
||||
]
|
||||
76
backend/platform_settings_v2/internal_views.py
Normal file
76
backend/platform_settings_v2/internal_views.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Internal API views for platform settings
|
||||
|
||||
Provides internal endpoints for workers to access platform settings
|
||||
without direct database access.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from account_v2.models import PlatformKey
|
||||
from account_v2.organization import OrganizationService
|
||||
from rest_framework import status
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from platform_settings_v2.platform_auth_service import PlatformAuthenticationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InternalPlatformKeyView(APIView):
|
||||
"""Internal API to get active platform key for an organization."""
|
||||
|
||||
def get(self, request):
|
||||
"""Get active platform key for organization.
|
||||
|
||||
Uses X-Organization-ID header to identify the organization.
|
||||
|
||||
Args:
|
||||
request: HTTP request with X-Organization-ID header
|
||||
|
||||
Returns:
|
||||
Response with platform key
|
||||
"""
|
||||
try:
|
||||
# Get organization ID from header
|
||||
org_id = request.headers.get("X-Organization-ID")
|
||||
if not org_id:
|
||||
return Response(
|
||||
{"error": "X-Organization-ID header is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Get organization
|
||||
organization = OrganizationService.get_organization_by_org_id(org_id=org_id)
|
||||
|
||||
if not organization:
|
||||
return Response(
|
||||
{"error": f"Organization {org_id} not found"},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
# Get active platform key
|
||||
platform_key = PlatformAuthenticationService.get_active_platform_key(
|
||||
organization_id=org_id
|
||||
)
|
||||
|
||||
return Response(
|
||||
{
|
||||
"platform_key": str(platform_key.key),
|
||||
"key_name": platform_key.key_name,
|
||||
"organization_id": org_id,
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
except PlatformKey.DoesNotExist:
|
||||
return Response(
|
||||
{"error": f"No active platform key found for organization {org_id}"},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting platform key for org {org_id}: {str(e)}")
|
||||
return Response(
|
||||
{"error": "Internal server error"},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
@@ -20,8 +20,19 @@ class ModifierConfig:
|
||||
METADATA_IS_ACTIVE = "is_active"
|
||||
|
||||
|
||||
# Cache for loaded plugins to avoid repeated loading
|
||||
_modifier_plugins_cache: list[Any] = []
|
||||
_plugins_loaded = False
|
||||
|
||||
|
||||
def load_plugins() -> list[Any]:
|
||||
"""Iterate through the extraction plugins and register them."""
|
||||
global _modifier_plugins_cache, _plugins_loaded
|
||||
|
||||
# Return cached plugins if already loaded
|
||||
if _plugins_loaded:
|
||||
return _modifier_plugins_cache
|
||||
|
||||
plugins_app = apps.get_app_config(ModifierConfig.PLUGINS_APP)
|
||||
package_path = plugins_app.module.__package__
|
||||
modifier_dir = os.path.join(plugins_app.path, ModifierConfig.PLUGIN_DIR)
|
||||
@@ -29,6 +40,8 @@ def load_plugins() -> list[Any]:
|
||||
modifier_plugins: list[Any] = []
|
||||
|
||||
if not os.path.exists(modifier_dir):
|
||||
_modifier_plugins_cache = modifier_plugins
|
||||
_plugins_loaded = True
|
||||
return modifier_plugins
|
||||
|
||||
for item in os.listdir(modifier_dir):
|
||||
@@ -69,4 +82,8 @@ def load_plugins() -> list[Any]:
|
||||
if len(modifier_plugins) == 0:
|
||||
logger.info("No modifier plugins found.")
|
||||
|
||||
# Cache the results for future requests
|
||||
_modifier_plugins_cache = modifier_plugins
|
||||
_plugins_loaded = True
|
||||
|
||||
return modifier_plugins
|
||||
|
||||
@@ -20,14 +20,29 @@ class ProcessorConfig:
|
||||
METADATA_IS_ACTIVE = "is_active"
|
||||
|
||||
|
||||
# Cache for loaded plugins to avoid repeated loading
|
||||
_processor_plugins_cache: list[Any] = []
|
||||
_plugins_loaded = False
|
||||
|
||||
|
||||
def load_plugins() -> list[Any]:
|
||||
"""Iterate through the processor plugins and register them."""
|
||||
global _processor_plugins_cache, _plugins_loaded
|
||||
|
||||
# Return cached plugins if already loaded
|
||||
if _plugins_loaded:
|
||||
return _processor_plugins_cache
|
||||
|
||||
plugins_app = apps.get_app_config(ProcessorConfig.PLUGINS_APP)
|
||||
package_path = plugins_app.module.__package__
|
||||
processor_dir = os.path.join(plugins_app.path, ProcessorConfig.PLUGIN_DIR)
|
||||
processor_package_path = f"{package_path}.{ProcessorConfig.PLUGIN_DIR}"
|
||||
processor_plugins: list[Any] = []
|
||||
|
||||
if not os.path.exists(processor_dir):
|
||||
logger.info("No processor directory found at %s.", processor_dir)
|
||||
return []
|
||||
|
||||
for item in os.listdir(processor_dir):
|
||||
# Loads a plugin if it is in a directory.
|
||||
if os.path.isdir(os.path.join(processor_dir, item)):
|
||||
@@ -71,6 +86,10 @@ def load_plugins() -> list[Any]:
|
||||
if len(processor_plugins) == 0:
|
||||
logger.info("No processor plugins found.")
|
||||
|
||||
# Cache the results for future requests
|
||||
_processor_plugins_cache = processor_plugins
|
||||
_plugins_loaded = True
|
||||
|
||||
return processor_plugins
|
||||
|
||||
|
||||
|
||||
@@ -72,7 +72,10 @@ dev = [
|
||||
# For file watching
|
||||
"inotify>=0.2.10",
|
||||
"poethepoet>=0.33.1",
|
||||
"debugpy>=1.8.14"
|
||||
"debugpy>=1.8.14",
|
||||
"pytest>=8.3.5",
|
||||
"responses>=0.25.7",
|
||||
"psutil>=7.0.0",
|
||||
]
|
||||
test = ["pytest>=8.0.1", "pytest-dotenv==0.5.2"]
|
||||
deploy = [
|
||||
|
||||
@@ -65,7 +65,7 @@ PLATFORM_SERVICE_PORT=3001
|
||||
# Tool Runner
|
||||
UNSTRACT_RUNNER_HOST=http://unstract-runner
|
||||
UNSTRACT_RUNNER_PORT=5002
|
||||
UNSTRACT_RUNNER_API_TIMEOUT=120 # (in seconds) 2 mins
|
||||
UNSTRACT_RUNNER_API_TIMEOUT=240 # (in seconds) 2 mins
|
||||
UNSTRACT_RUNNER_API_RETRY_COUNT=5 # Number of retries for failed requests
|
||||
UNSTRACT_RUNNER_API_BACKOFF_FACTOR=3 # Exponential backoff factor for retries
|
||||
|
||||
|
||||
16
backend/tool_instance_v2/internal_urls.py
Normal file
16
backend/tool_instance_v2/internal_urls.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Internal API URLs for tool instance operations."""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from .internal_views import tool_by_id_internal, validate_tool_instances_internal
|
||||
|
||||
urlpatterns = [
|
||||
# Tool by ID endpoint - critical for worker functionality
|
||||
path("tool/<str:tool_id>/", tool_by_id_internal, name="tool-by-id-internal"),
|
||||
# Tool instance validation endpoint - used by workers before execution
|
||||
path(
|
||||
"validate/",
|
||||
validate_tool_instances_internal,
|
||||
name="validate-tool-instances-internal",
|
||||
),
|
||||
]
|
||||
403
backend/tool_instance_v2/internal_views.py
Normal file
403
backend/tool_instance_v2/internal_views.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""Internal API Views for Tool Instance Operations
|
||||
|
||||
This module contains internal API endpoints used by workers for tool execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from rest_framework import status, viewsets
|
||||
from rest_framework.decorators import api_view
|
||||
from rest_framework.response import Response
|
||||
from utils.organization_utils import filter_queryset_by_organization
|
||||
|
||||
from tool_instance_v2.models import ToolInstance
|
||||
from tool_instance_v2.serializers import ToolInstanceSerializer
|
||||
from tool_instance_v2.tool_instance_helper import ToolInstanceHelper
|
||||
from tool_instance_v2.tool_processor import ToolProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolExecutionInternalViewSet(viewsets.ModelViewSet):
|
||||
"""Internal API for tool execution operations used by lightweight workers."""
|
||||
|
||||
serializer_class = ToolInstanceSerializer
|
||||
|
||||
def get_queryset(self):
|
||||
# Filter by organization context set by internal API middleware
|
||||
# Use relationship path: ToolInstance -> Workflow -> Organization
|
||||
queryset = ToolInstance.objects.all()
|
||||
return filter_queryset_by_organization(
|
||||
queryset, self.request, "workflow__organization"
|
||||
)
|
||||
|
||||
def execute_tool(self, request, pk=None):
|
||||
"""Execute a specific tool with provided input data.
|
||||
|
||||
This replaces the direct tool execution that was previously done
|
||||
in the heavy Django workers.
|
||||
"""
|
||||
try:
|
||||
tool_instance = self.get_object()
|
||||
|
||||
# Extract execution parameters from request
|
||||
input_data = request.data.get("input_data", {})
|
||||
file_data = request.data.get("file_data", {})
|
||||
execution_context = request.data.get("execution_context", {})
|
||||
|
||||
# Execute tool using existing tool processor
|
||||
execution_result = ToolProcessor.execute_tool(
|
||||
tool_instance=tool_instance,
|
||||
input_data=input_data,
|
||||
file_data=file_data,
|
||||
context=execution_context,
|
||||
user=request.user,
|
||||
)
|
||||
|
||||
return Response(
|
||||
{
|
||||
"status": "success",
|
||||
"tool_instance_id": str(tool_instance.id),
|
||||
"execution_result": execution_result,
|
||||
"tool_function": tool_instance.tool_function,
|
||||
"step": tool_instance.step,
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution failed for tool {pk}: {e}")
|
||||
return Response(
|
||||
{
|
||||
"status": "error",
|
||||
"error_message": str(e),
|
||||
"tool_instance_id": str(pk) if pk else None,
|
||||
},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
|
||||
@api_view(["GET"])
|
||||
def tool_execution_status_internal(request, execution_id):
|
||||
"""Get tool execution status for internal API calls."""
|
||||
try:
|
||||
# This would track tool execution status
|
||||
# For now, return a basic status structure
|
||||
return Response(
|
||||
{
|
||||
"execution_id": execution_id,
|
||||
"status": "completed", # Could be: pending, running, completed, failed
|
||||
"progress": 100,
|
||||
"results": [],
|
||||
"error_message": None,
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get tool execution status for {execution_id}: {e}")
|
||||
return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
|
||||
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
|
||||
@api_view(["GET"])
|
||||
def tool_by_id_internal(request, tool_id):
|
||||
"""Get tool information by tool ID for internal API calls."""
|
||||
try:
|
||||
logger.info(f"Getting tool information for tool ID: {tool_id}")
|
||||
|
||||
# Get tool from registry using ToolProcessor
|
||||
try:
|
||||
tool = ToolProcessor.get_tool_by_uid(tool_id)
|
||||
logger.info(f"Successfully retrieved tool from ToolProcessor: {tool_id}")
|
||||
except Exception as tool_fetch_error:
|
||||
logger.error(
|
||||
f"Failed to fetch tool {tool_id} from ToolProcessor: {tool_fetch_error}"
|
||||
)
|
||||
# Return fallback using Structure Tool image (which actually exists)
|
||||
from django.conf import settings
|
||||
|
||||
return Response(
|
||||
{
|
||||
"tool": {
|
||||
"tool_id": tool_id,
|
||||
"properties": {
|
||||
"displayName": f"Missing Tool ({tool_id[:8]}...)",
|
||||
"functionName": tool_id,
|
||||
"description": "Tool not found in registry or Prompt Studio",
|
||||
"toolVersion": "unknown",
|
||||
},
|
||||
"image_name": settings.STRUCTURE_TOOL_IMAGE_NAME,
|
||||
"image_tag": settings.STRUCTURE_TOOL_IMAGE_TAG,
|
||||
"name": f"Missing Tool ({tool_id[:8]}...)",
|
||||
"description": "Tool not found in registry or Prompt Studio",
|
||||
"version": "unknown",
|
||||
"note": "Fallback data for missing tool",
|
||||
}
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
# Convert Properties object to dict for JSON serialization
|
||||
properties_dict = {}
|
||||
try:
|
||||
if hasattr(tool.properties, "to_dict"):
|
||||
# Use the to_dict method if available (which handles Adapter serialization)
|
||||
properties_dict = tool.properties.to_dict()
|
||||
logger.info(f"Properties serialized using to_dict() for tool {tool_id}")
|
||||
elif hasattr(tool.properties, "dict"):
|
||||
properties_dict = tool.properties.dict()
|
||||
logger.info(f"Properties serialized using dict() for tool {tool_id}")
|
||||
elif hasattr(tool.properties, "__dict__"):
|
||||
properties_dict = tool.properties.__dict__
|
||||
logger.info(f"Properties serialized using __dict__ for tool {tool_id}")
|
||||
else:
|
||||
# Try to convert to dict if it's iterable
|
||||
try:
|
||||
properties_dict = dict(tool.properties)
|
||||
logger.info(
|
||||
f"Properties serialized using dict conversion for tool {tool_id}"
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
properties_dict = {"default": "true"} # Fallback
|
||||
logger.warning(f"Using fallback properties for tool {tool_id}")
|
||||
except Exception as props_error:
|
||||
logger.error(
|
||||
f"Failed to serialize properties for tool {tool_id}: {props_error}"
|
||||
)
|
||||
properties_dict = {"error": "serialization_failed"}
|
||||
|
||||
# Handle spec serialization if needed
|
||||
if hasattr(tool, "spec") and tool.spec:
|
||||
if hasattr(tool.spec, "to_dict"):
|
||||
tool.spec.to_dict()
|
||||
elif hasattr(tool.spec, "__dict__"):
|
||||
pass
|
||||
|
||||
# Return tool information with essential fields only to avoid serialization issues
|
||||
return Response(
|
||||
{
|
||||
"tool": {
|
||||
"tool_id": tool_id,
|
||||
"properties": properties_dict,
|
||||
"image_name": str(tool.image_name)
|
||||
if tool.image_name
|
||||
else "default-tool",
|
||||
"image_tag": str(tool.image_tag) if tool.image_tag else "latest",
|
||||
"name": getattr(tool, "name", tool_id),
|
||||
"description": getattr(tool, "description", ""),
|
||||
"version": getattr(tool, "version", "latest"),
|
||||
}
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get tool information for {tool_id}: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
|
||||
# Always return fallback data instead of error to allow workflow to continue
|
||||
from django.conf import settings
|
||||
|
||||
return Response(
|
||||
{
|
||||
"tool": {
|
||||
"tool_id": tool_id,
|
||||
"properties": {
|
||||
"displayName": f"Error Tool ({tool_id[:8]}...)",
|
||||
"functionName": tool_id,
|
||||
"description": f"Error processing tool: {str(e)[:100]}",
|
||||
"toolVersion": "error",
|
||||
},
|
||||
"image_name": settings.STRUCTURE_TOOL_IMAGE_NAME,
|
||||
"image_tag": settings.STRUCTURE_TOOL_IMAGE_TAG,
|
||||
"name": f"Error Tool ({tool_id[:8]}...)",
|
||||
"description": f"Error: {str(e)[:100]}",
|
||||
"version": "error",
|
||||
"error": str(e),
|
||||
"note": "Fallback data for tool processing error",
|
||||
}
|
||||
},
|
||||
status=status.HTTP_200_OK, # Return 200 to allow workflow to continue
|
||||
)
|
||||
|
||||
|
||||
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
|
||||
@api_view(["GET"])
|
||||
def tool_instances_by_workflow_internal(request, workflow_id):
|
||||
"""Get tool instances for a workflow for internal API calls."""
|
||||
try:
|
||||
from workflow_manager.workflow_v2.models.workflow import Workflow
|
||||
|
||||
logger.info(f"Getting tool instances for workflow: {workflow_id}")
|
||||
|
||||
# Get workflow with organization filtering first (via DefaultOrganizationManagerMixin)
|
||||
try:
|
||||
workflow = Workflow.objects.get(id=workflow_id)
|
||||
logger.info(f"Found workflow: {workflow.id}")
|
||||
except Workflow.DoesNotExist:
|
||||
logger.error(f"Workflow not found: {workflow_id}")
|
||||
return Response(
|
||||
{"error": "Workflow not found or access denied"},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
# Get tool instances for the workflow with organization filtering
|
||||
# Filter through the relationship: ToolInstance -> Workflow -> Organization
|
||||
tool_instances_queryset = ToolInstance.objects.filter(workflow=workflow)
|
||||
tool_instances_queryset = filter_queryset_by_organization(
|
||||
tool_instances_queryset, request, "workflow__organization"
|
||||
)
|
||||
tool_instances = tool_instances_queryset.order_by("step")
|
||||
logger.info(f"Found {len(tool_instances)} tool instances")
|
||||
|
||||
# Serialize the tool instances
|
||||
try:
|
||||
logger.info("Starting serialization of tool instances")
|
||||
serializer = ToolInstanceSerializer(tool_instances, many=True)
|
||||
logger.info("Accessing serializer.data")
|
||||
serializer_data = serializer.data
|
||||
logger.info(f"Serialization completed, got {len(serializer_data)} items")
|
||||
except Exception as serializer_error:
|
||||
logger.error(f"Serialization error: {serializer_error}")
|
||||
# Try to return basic data without enhanced tool information
|
||||
basic_data = []
|
||||
for instance in tool_instances:
|
||||
basic_data.append(
|
||||
{
|
||||
"id": str(instance.id),
|
||||
"tool_id": instance.tool_id,
|
||||
"step": instance.step,
|
||||
"metadata": instance.metadata,
|
||||
}
|
||||
)
|
||||
logger.info(f"Returning {len(basic_data)} basic tool instances")
|
||||
return Response(
|
||||
{
|
||||
"workflow_id": workflow_id,
|
||||
"tool_instances": basic_data,
|
||||
"total_count": len(tool_instances),
|
||||
"note": "Basic data returned due to serialization error",
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
return Response(
|
||||
{
|
||||
"workflow_id": workflow_id,
|
||||
"tool_instances": serializer_data,
|
||||
"total_count": len(tool_instances),
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get tool instances for workflow {workflow_id}: {e}")
|
||||
return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
|
||||
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
|
||||
@api_view(["POST"])
|
||||
def validate_tool_instances_internal(request):
|
||||
"""Validate tool instances and ensure adapter IDs are migrated.
|
||||
|
||||
This internal endpoint validates tool instances for a workflow, ensuring:
|
||||
1. Adapter names are migrated to IDs
|
||||
2. User has permissions to access tools and adapters
|
||||
3. Tool settings match JSON schema requirements
|
||||
|
||||
Used by workers to validate tools before execution.
|
||||
|
||||
Args:
|
||||
request: Request containing:
|
||||
- workflow_id: ID of the workflow
|
||||
- tool_instances: List of tool instance IDs
|
||||
|
||||
Returns:
|
||||
Response with validation results and migrated metadata
|
||||
"""
|
||||
workflow_id = request.data.get("workflow_id")
|
||||
tool_instance_ids = request.data.get("tool_instances", [])
|
||||
|
||||
if not workflow_id:
|
||||
return Response(
|
||||
{"error": "workflow_id is required"}, status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
validated_instances = []
|
||||
validation_errors = []
|
||||
|
||||
try:
|
||||
# Get tool instances from database with organization filtering
|
||||
tool_instances_queryset = ToolInstance.objects.filter(
|
||||
workflow_id=workflow_id, id__in=tool_instance_ids
|
||||
).select_related("workflow", "workflow__created_by")
|
||||
|
||||
# Apply organization filtering
|
||||
tool_instances_queryset = filter_queryset_by_organization(
|
||||
tool_instances_queryset, request, "workflow__organization"
|
||||
)
|
||||
tool_instances = list(tool_instances_queryset)
|
||||
|
||||
# Validate each tool instance
|
||||
for tool in tool_instances:
|
||||
try:
|
||||
# Get the user who created the workflow
|
||||
user = tool.workflow.created_by
|
||||
|
||||
# Ensure adapter IDs are migrated from names to IDs
|
||||
migrated_metadata = ToolInstanceHelper.ensure_adapter_ids_in_metadata(
|
||||
tool, user=user
|
||||
)
|
||||
|
||||
# Validate tool settings
|
||||
ToolInstanceHelper.validate_tool_settings(
|
||||
user=user,
|
||||
tool_uid=tool.tool_id,
|
||||
tool_meta=migrated_metadata,
|
||||
)
|
||||
|
||||
# Add to validated list with migrated metadata
|
||||
validated_instances.append(
|
||||
{
|
||||
"id": str(tool.id),
|
||||
"tool_id": tool.tool_id,
|
||||
"metadata": migrated_metadata,
|
||||
"step": tool.step,
|
||||
"status": "valid",
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
validation_errors.append(
|
||||
{
|
||||
"tool_id": tool.tool_id,
|
||||
"tool_instance_id": str(tool.id),
|
||||
"error": str(e),
|
||||
}
|
||||
)
|
||||
logger.error(f"Tool validation failed for {tool.tool_id}: {e}")
|
||||
|
||||
# Return validation results
|
||||
response_data = {
|
||||
"success": len(validation_errors) == 0,
|
||||
"validated_instances": validated_instances,
|
||||
"errors": validation_errors,
|
||||
"workflow_id": workflow_id,
|
||||
}
|
||||
|
||||
if validation_errors:
|
||||
return Response(response_data, status=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
|
||||
return Response(response_data, status=status.HTTP_200_OK)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool validation failed: {e}", exc_info=True)
|
||||
return Response(
|
||||
{"error": f"Tool validation failed: {str(e)}"},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
@@ -63,7 +63,6 @@ class ToolInstanceSerializer(AuditSerializer):
|
||||
rep[TIKey.METADATA] = self._transform_adapter_ids_to_names_for_display(
|
||||
metadata, tool_function
|
||||
)
|
||||
|
||||
return rep
|
||||
|
||||
def _transform_adapter_ids_to_names_for_display(
|
||||
|
||||
15
backend/usage_v2/internal_urls.py
Normal file
15
backend/usage_v2/internal_urls.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Internal API URLs for Usage access by workers."""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from . import internal_views
|
||||
|
||||
app_name = "usage_internal"
|
||||
|
||||
urlpatterns = [
|
||||
path(
|
||||
"aggregated-token-count/<str:file_execution_id>/",
|
||||
internal_views.UsageInternalView.as_view(),
|
||||
name="aggregated-token-count",
|
||||
),
|
||||
]
|
||||
79
backend/usage_v2/internal_views.py
Normal file
79
backend/usage_v2/internal_views.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Internal API views for Usage access by workers."""
|
||||
|
||||
import logging
|
||||
|
||||
from django.http import JsonResponse
|
||||
from rest_framework import status
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from unstract.core.data_models import UsageResponseData
|
||||
|
||||
from .helper import UsageHelper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UsageInternalView(APIView):
|
||||
"""Internal API view for workers to access usage data.
|
||||
|
||||
This endpoint allows workers to get aggregated token usage data
|
||||
for a specific file execution without direct database access.
|
||||
"""
|
||||
|
||||
def get(self, request: Request, file_execution_id: str) -> JsonResponse:
|
||||
"""Get aggregated token usage for a file execution.
|
||||
|
||||
Args:
|
||||
request: HTTP request (no additional parameters needed)
|
||||
file_execution_id: File execution ID to get usage data for
|
||||
|
||||
Returns:
|
||||
JSON response with aggregated usage data using core data models
|
||||
"""
|
||||
try:
|
||||
if not file_execution_id:
|
||||
return JsonResponse(
|
||||
{
|
||||
"success": False,
|
||||
"error": "file_execution_id parameter is required",
|
||||
},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Get aggregated token count using the existing helper
|
||||
result = UsageHelper.get_aggregated_token_count(run_id=file_execution_id)
|
||||
|
||||
# Create UsageResponseData for type safety and consistency
|
||||
usage_data = UsageResponseData(
|
||||
file_execution_id=file_execution_id,
|
||||
embedding_tokens=result.get("embedding_tokens"),
|
||||
prompt_tokens=result.get("prompt_tokens"),
|
||||
completion_tokens=result.get("completion_tokens"),
|
||||
total_tokens=result.get("total_tokens"),
|
||||
cost_in_dollars=result.get("cost_in_dollars"),
|
||||
)
|
||||
|
||||
return JsonResponse(
|
||||
{
|
||||
"success": True,
|
||||
"data": {
|
||||
"file_execution_id": file_execution_id,
|
||||
"usage": usage_data.to_dict(),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting usage data for file_execution_id {file_execution_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return JsonResponse(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Internal server error",
|
||||
"file_execution_id": file_execution_id,
|
||||
},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
@@ -7,6 +9,8 @@ from django_redis import get_redis_connection
|
||||
|
||||
redis_cache = get_redis_connection("default")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheService:
|
||||
@staticmethod
|
||||
@@ -38,6 +42,72 @@ class CacheService:
|
||||
"""Delete keys in bulk based on the key pattern."""
|
||||
cache.delete_pattern(key_pattern)
|
||||
|
||||
@staticmethod
|
||||
def clear_cache_optimized(key_pattern: str) -> Any:
|
||||
"""Delete keys in bulk using optimized SCAN approach for large datasets.
|
||||
|
||||
Uses Redis SCAN instead of KEYS to avoid blocking Redis during deletion.
|
||||
Safe for production with large key sets. Use this for heavy operations
|
||||
like workflow history clearing.
|
||||
"""
|
||||
TIMEOUT_SECONDS = 90 # Generous but bounded timeout
|
||||
BATCH_SIZE = 1000
|
||||
|
||||
start_time = time.time()
|
||||
deleted_count = 0
|
||||
cursor = 0
|
||||
completed_naturally = False
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Check timeout first
|
||||
if time.time() - start_time > TIMEOUT_SECONDS:
|
||||
logger.warning(
|
||||
f"Cache clearing timed out after {TIMEOUT_SECONDS}s, "
|
||||
f"deleted {deleted_count} keys matching '{key_pattern}'"
|
||||
)
|
||||
break
|
||||
|
||||
# SCAN returns (next_cursor, keys_list)
|
||||
cursor, keys = redis_cache.scan(
|
||||
cursor=cursor, match=key_pattern, count=BATCH_SIZE
|
||||
)
|
||||
|
||||
if keys:
|
||||
# Delete keys in pipeline for efficiency
|
||||
pipe = redis_cache.pipeline()
|
||||
for key in keys:
|
||||
pipe.delete(key)
|
||||
pipe.execute()
|
||||
deleted_count += len(keys)
|
||||
|
||||
# SCAN is complete when cursor returns to 0
|
||||
if cursor == 0:
|
||||
completed_naturally = True
|
||||
break
|
||||
|
||||
# Log completion status
|
||||
if completed_naturally:
|
||||
logger.info(
|
||||
f"Cache clearing completed: deleted {deleted_count} keys matching '{key_pattern}'"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Cache clearing incomplete: deleted {deleted_count} keys before timeout"
|
||||
)
|
||||
|
||||
except (ConnectionError, TimeoutError, OSError) as e:
|
||||
logger.error(f"Failed to clear cache pattern '{key_pattern}': {str(e)}")
|
||||
# Fallback to old method for backward compatibility
|
||||
try:
|
||||
cache.delete_pattern(key_pattern)
|
||||
logger.warning(f"Used fallback delete_pattern for '{key_pattern}'")
|
||||
except (ConnectionError, TimeoutError, OSError) as fallback_error:
|
||||
logger.error(
|
||||
f"Fallback cache clearing also failed: {str(fallback_error)}"
|
||||
)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def check_a_key_exist(key: str, version: Any = None) -> bool:
|
||||
data: bool = cache.has_key(key, version)
|
||||
@@ -70,6 +140,10 @@ class CacheService:
|
||||
def lpop(key: str) -> Any:
|
||||
return redis_cache.lpop(key)
|
||||
|
||||
@staticmethod
|
||||
def llen(key: str) -> int:
|
||||
return redis_cache.llen(key)
|
||||
|
||||
@staticmethod
|
||||
def lrem(key: str, value: str) -> None:
|
||||
redis_cache.lrem(key, value)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import http
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
@@ -9,10 +8,9 @@ import socketio
|
||||
from django.conf import settings
|
||||
from django.core.wsgi import WSGIHandler
|
||||
|
||||
from unstract.core.constants import LogFieldName
|
||||
from unstract.workflow_execution.enums import LogType
|
||||
from unstract.core.data_models import LogDataDTO
|
||||
from unstract.core.log_utils import get_validated_log_data, store_execution_log
|
||||
from utils.constants import ExecutionLogConstants
|
||||
from utils.dto import LogDataDTO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -79,71 +77,23 @@ def _get_user_session_id_from_cookies(sid: str, environ: Any) -> str | None:
|
||||
return session_id.value
|
||||
|
||||
|
||||
# Functions moved to unstract.core.log_utils for sharing with workers
|
||||
# Keep these as wrapper functions for backward compatibility
|
||||
|
||||
|
||||
def _get_validated_log_data(json_data: Any) -> LogDataDTO | None:
|
||||
"""Validate log data to persist history. This function takes log data in
|
||||
JSON format, validates it, and returns a `LogDataDTO` object if the data is
|
||||
valid. The validation process includes decoding bytes to string, parsing
|
||||
the string as JSON, and checking for required fields and log type.
|
||||
|
||||
Args:
|
||||
json_data (Any): Log data in JSON format
|
||||
Returns:
|
||||
Optional[LogDataDTO]: Log data DTO object
|
||||
"""
|
||||
if isinstance(json_data, bytes):
|
||||
json_data = json_data.decode("utf-8")
|
||||
|
||||
if isinstance(json_data, str):
|
||||
try:
|
||||
# Parse the string as JSON
|
||||
json_data = json.loads(json_data)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Error decoding JSON data while validating {json_data}")
|
||||
return
|
||||
|
||||
if not isinstance(json_data, dict):
|
||||
logger.warning(f"Getting invalid data type while validating {json_data}")
|
||||
return
|
||||
|
||||
# Extract required fields from the JSON data
|
||||
execution_id = json_data.get(LogFieldName.EXECUTION_ID)
|
||||
organization_id = json_data.get(LogFieldName.ORGANIZATION_ID)
|
||||
timestamp = json_data.get(LogFieldName.TIMESTAMP)
|
||||
log_type = json_data.get(LogFieldName.TYPE)
|
||||
file_execution_id = json_data.get(LogFieldName.FILE_EXECUTION_ID)
|
||||
|
||||
# Ensure the log type is LogType.LOG
|
||||
if log_type != LogType.LOG.value:
|
||||
return
|
||||
|
||||
# Check if all required fields are present
|
||||
if not all((execution_id, organization_id, timestamp)):
|
||||
logger.debug(f"Missing required fields while validating {json_data}")
|
||||
return
|
||||
|
||||
return LogDataDTO(
|
||||
execution_id=execution_id,
|
||||
file_execution_id=file_execution_id,
|
||||
organization_id=organization_id,
|
||||
timestamp=timestamp,
|
||||
log_type=log_type,
|
||||
data=json_data,
|
||||
)
|
||||
"""Validate log data to persist history (backward compatibility wrapper)."""
|
||||
return get_validated_log_data(json_data)
|
||||
|
||||
|
||||
def _store_execution_log(data: dict[str, Any]) -> None:
|
||||
"""Store execution log in database
|
||||
Args:
|
||||
data (dict[str, Any]): Execution log data
|
||||
"""
|
||||
if not ExecutionLogConstants.IS_ENABLED:
|
||||
return
|
||||
try:
|
||||
log_data = _get_validated_log_data(json_data=data)
|
||||
if log_data:
|
||||
redis_conn.rpush(ExecutionLogConstants.LOG_QUEUE_NAME, log_data.to_json())
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing execution log: {e}")
|
||||
"""Store execution log in database (backward compatibility wrapper)."""
|
||||
store_execution_log(
|
||||
data=data,
|
||||
redis_client=redis_conn,
|
||||
log_queue_name=ExecutionLogConstants.LOG_QUEUE_NAME,
|
||||
is_enabled=ExecutionLogConstants.IS_ENABLED,
|
||||
)
|
||||
|
||||
|
||||
def _emit_websocket_event(room: str, event: str, data: dict[str, Any]) -> None:
|
||||
|
||||
95
backend/utils/organization_utils.py
Normal file
95
backend/utils/organization_utils.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Organization utilities for internal APIs.
|
||||
Provides shared functions for organization context resolution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from account_v2.models import Organization
|
||||
from django.shortcuts import get_object_or_404
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def resolve_organization(
|
||||
org_id: str, raise_on_not_found: bool = False
|
||||
) -> Organization | None:
|
||||
"""Resolve organization by either organization.id (int) or organization.organization_id (string).
|
||||
|
||||
Args:
|
||||
org_id: Organization identifier - can be either the primary key (numeric string)
|
||||
or the organization_id field (string)
|
||||
raise_on_not_found: If True, raises Http404 on not found. If False, returns None.
|
||||
|
||||
Returns:
|
||||
Organization instance if found, None if not found and raise_on_not_found=False
|
||||
|
||||
Raises:
|
||||
Http404: If organization not found and raise_on_not_found=True
|
||||
"""
|
||||
try:
|
||||
if org_id.isdigit():
|
||||
# If it's numeric, treat as primary key
|
||||
if raise_on_not_found:
|
||||
return get_object_or_404(Organization, id=org_id)
|
||||
else:
|
||||
return Organization.objects.get(id=org_id)
|
||||
else:
|
||||
# If it's string, treat as organization_id field
|
||||
if raise_on_not_found:
|
||||
return get_object_or_404(Organization, organization_id=org_id)
|
||||
else:
|
||||
return Organization.objects.get(organization_id=org_id)
|
||||
except Organization.DoesNotExist:
|
||||
if raise_on_not_found:
|
||||
raise
|
||||
logger.warning(f"Organization {org_id} not found")
|
||||
return None
|
||||
|
||||
|
||||
def get_organization_context(organization: Organization) -> dict[str, Any]:
|
||||
"""Get standardized organization context data.
|
||||
|
||||
Args:
|
||||
organization: Organization instance
|
||||
|
||||
Returns:
|
||||
Dictionary with organization context information
|
||||
"""
|
||||
return {
|
||||
"organization_id": str(organization.id),
|
||||
"organization_name": organization.display_name,
|
||||
"organization_slug": getattr(organization, "slug", ""),
|
||||
"created_at": organization.created_at.isoformat()
|
||||
if hasattr(organization, "created_at")
|
||||
else None,
|
||||
"settings": {
|
||||
# Add organization-specific settings here
|
||||
"subscription_active": True, # This would come from subscription model
|
||||
"features_enabled": [], # This would come from feature flags
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def filter_queryset_by_organization(queryset, request, organization_field="organization"):
|
||||
"""Filter a Django queryset by organization context from request.
|
||||
|
||||
Args:
|
||||
queryset: Django QuerySet to filter
|
||||
request: HTTP request object with organization_id attribute
|
||||
organization_field: Field name for organization relationship (default: 'organization')
|
||||
|
||||
Returns:
|
||||
Filtered queryset or empty queryset if organization not found
|
||||
"""
|
||||
org_id = getattr(request, "organization_id", None)
|
||||
if org_id:
|
||||
organization = resolve_organization(org_id, raise_on_not_found=False)
|
||||
if organization:
|
||||
# Use dynamic field lookup
|
||||
filter_kwargs = {organization_field: organization}
|
||||
return queryset.filter(**filter_kwargs)
|
||||
else:
|
||||
# Return empty queryset if organization not found
|
||||
return queryset.none()
|
||||
return queryset
|
||||
87
backend/utils/websocket_views.py
Normal file
87
backend/utils/websocket_views.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""WebSocket emission views for internal API.
|
||||
|
||||
This module provides endpoints for workers to trigger WebSocket events
|
||||
through the backend's SocketIO server.
|
||||
|
||||
Security Note:
|
||||
- CSRF protection is disabled for internal service-to-service communication
|
||||
- Authentication is handled by InternalAPIAuthMiddleware using Bearer tokens
|
||||
- This endpoint is for worker → backend WebSocket event triggering only
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from django.http import JsonResponse
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from django.views.decorators.http import require_http_methods
|
||||
|
||||
from utils.log_events import _emit_websocket_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# CSRF exemption is safe here because:
|
||||
# 1. Internal service-to-service communication (workers → backend)
|
||||
# 2. Protected by InternalAPIAuthMiddleware Bearer token authentication
|
||||
# 3. No browser sessions or cookies involved
|
||||
# 4. Used for WebSocket event triggering, not state modification
|
||||
@csrf_exempt
|
||||
@require_http_methods(["POST"])
|
||||
def emit_websocket(request):
|
||||
"""Internal API endpoint for workers to emit WebSocket events.
|
||||
|
||||
Expected payload:
|
||||
{
|
||||
"room": "session_id",
|
||||
"event": "logs:session_id",
|
||||
"data": {...}
|
||||
}
|
||||
|
||||
Returns:
|
||||
JSON response with success/error status
|
||||
"""
|
||||
try:
|
||||
# Parse request data (standard Django view)
|
||||
data = json.loads(request.body.decode("utf-8"))
|
||||
|
||||
# Extract required fields
|
||||
room = data.get("room")
|
||||
event = data.get("event")
|
||||
message_data = data.get("data", {})
|
||||
|
||||
# Validate required fields
|
||||
if not room or not event:
|
||||
return JsonResponse(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Missing required fields: room and event are required",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
|
||||
# Emit the WebSocket event
|
||||
_emit_websocket_event(room=room, event=event, data=message_data)
|
||||
|
||||
logger.debug(f"WebSocket event emitted: room={room}, event={event}")
|
||||
|
||||
return JsonResponse(
|
||||
{
|
||||
"status": "success",
|
||||
"message": "WebSocket event emitted successfully",
|
||||
"room": room,
|
||||
"event": event,
|
||||
}
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Invalid JSON in WebSocket emission request: {e}")
|
||||
return JsonResponse(
|
||||
{"status": "error", "message": "Invalid JSON payload"}, status=400
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting WebSocket event: {e}")
|
||||
return JsonResponse(
|
||||
{"status": "error", "message": f"Failed to emit WebSocket event: {str(e)}"},
|
||||
status=500,
|
||||
)
|
||||
35
backend/uv.lock
generated
35
backend/uv.lock
generated
@@ -2975,6 +2975,21 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/c1/6aece0ab5209981a70cd186f164c133fdba2f51e124ff92b73de7fd24d78/protobuf-4.25.8-py3-none-any.whl", hash = "sha256:15a0af558aa3b13efef102ae6e4f3efac06f1eea11afb3a57db2901447d9fb59", size = 156757, upload-time = "2025-05-28T14:22:24.135Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "psutil"
|
||||
version = "7.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/2a/80/336820c1ad9286a4ded7e845b2eccfcb27851ab8ac6abece774a6ff4d3de/psutil-7.0.0.tar.gz", hash = "sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456", size = 497003 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/e6/2d26234410f8b8abdbf891c9da62bee396583f713fb9f3325a4760875d22/psutil-7.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25", size = 238051 },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/8b/30f930733afe425e3cbfc0e1468a30a18942350c1a8816acfade80c005c4/psutil-7.0.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da", size = 239535 },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/ed/d362e84620dd22876b55389248e522338ed1bf134a5edd3b8231d7207f6d/psutil-7.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91", size = 275004 },
|
||||
{ url = "https://files.pythonhosted.org/packages/bf/b9/b0eb3f3cbcb734d930fdf839431606844a825b23eaf9a6ab371edac8162c/psutil-7.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34", size = 277986 },
|
||||
{ url = "https://files.pythonhosted.org/packages/eb/a2/709e0fe2f093556c17fbafda93ac032257242cabcc7ff3369e2cb76a97aa/psutil-7.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993", size = 279544 },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/e6/eecf58810b9d12e6427369784efe814a1eec0f492084ce8eb8f4d89d6d61/psutil-7.0.0-cp37-abi3-win32.whl", hash = "sha256:ba3fcef7523064a6c9da440fc4d6bd07da93ac726b5733c29027d7dc95b39d99", size = 241053 },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/1b/6921afe68c74868b4c9fa424dad3be35b095e16687989ebbb50ce4fceb7c/psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553", size = 244885 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "psycopg2-binary"
|
||||
version = "2.9.9"
|
||||
@@ -3502,6 +3517,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06", size = 54481, upload-time = "2023-05-01T04:11:28.427Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "responses"
|
||||
version = "0.25.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pyyaml" },
|
||||
{ name = "requests" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/81/7e/2345ac3299bd62bd7163216702bbc88976c099cfceba5b889f2a457727a1/responses-0.25.7.tar.gz", hash = "sha256:8ebae11405d7a5df79ab6fd54277f6f2bc29b2d002d0dd2d5c632594d1ddcedb", size = 79203 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e4/fc/1d20b64fa90e81e4fa0a34c9b0240a6cfb1326b7e06d18a5432a9917c316/responses-0.25.7-py3-none-any.whl", hash = "sha256:92ca17416c90fe6b35921f52179bff29332076bb32694c0df02dcac2c6bc043c", size = 34732 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rpds-py"
|
||||
version = "0.27.1"
|
||||
@@ -4042,6 +4071,9 @@ dev = [
|
||||
{ name = "debugpy" },
|
||||
{ name = "inotify" },
|
||||
{ name = "poethepoet" },
|
||||
{ name = "psutil" },
|
||||
{ name = "pytest" },
|
||||
{ name = "responses" },
|
||||
{ name = "unstract-connectors" },
|
||||
{ name = "unstract-core" },
|
||||
{ name = "unstract-filesystem" },
|
||||
@@ -4109,6 +4141,9 @@ dev = [
|
||||
{ name = "debugpy", specifier = ">=1.8.14" },
|
||||
{ name = "inotify", specifier = ">=0.2.10" },
|
||||
{ name = "poethepoet", specifier = ">=0.33.1" },
|
||||
{ name = "psutil", specifier = ">=7.0.0" },
|
||||
{ name = "pytest", specifier = ">=8.3.5" },
|
||||
{ name = "responses", specifier = ">=0.25.7" },
|
||||
{ name = "unstract-connectors", editable = "../unstract/connectors" },
|
||||
{ name = "unstract-core", editable = "../unstract/core" },
|
||||
{ name = "unstract-filesystem", editable = "../unstract/filesystem" },
|
||||
|
||||
@@ -871,7 +871,6 @@ class DestinationConnector(BaseConnector):
|
||||
).to_dict()
|
||||
|
||||
queue_result_json = json.dumps(queue_result)
|
||||
|
||||
conn = QueueUtils.get_queue_inst()
|
||||
conn.enqueue(queue_name=q_name, message=queue_result_json)
|
||||
logger.info(f"Pushed {file_name} to queue {q_name} with file content")
|
||||
@@ -891,11 +890,13 @@ class DestinationConnector(BaseConnector):
|
||||
q_name = self._get_review_queue_name()
|
||||
if meta_data:
|
||||
whisper_hash = meta_data.get("whisper-hash")
|
||||
extracted_text = meta_data.get("extracted_text")
|
||||
else:
|
||||
whisper_hash = None
|
||||
extracted_text = None
|
||||
|
||||
# Get extracted text from metadata (added by structure tool)
|
||||
extracted_text = meta_data.get("extracted_text") if meta_data else None
|
||||
# Get TTL from workflow settings
|
||||
ttl_seconds = WorkflowUtil.get_hitl_ttl_seconds(workflow)
|
||||
|
||||
# Create QueueResult with TTL metadata
|
||||
queue_result_obj = QueueResult(
|
||||
@@ -907,8 +908,8 @@ class DestinationConnector(BaseConnector):
|
||||
whisper_hash=whisper_hash,
|
||||
file_execution_id=file_execution_id,
|
||||
extracted_text=extracted_text,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
|
||||
# Add TTL metadata based on HITLSettings
|
||||
queue_result_obj.ttl_seconds = WorkflowUtil.get_hitl_ttl_seconds(workflow)
|
||||
|
||||
|
||||
@@ -28,3 +28,27 @@ class WorkflowEndpointUtils:
|
||||
workflow=workflow
|
||||
)
|
||||
return endpoints
|
||||
|
||||
@staticmethod
|
||||
def get_endpoint_for_workflow_by_type(
|
||||
workflow_id: str, endpoint_type: WorkflowEndpoint.EndpointType
|
||||
) -> WorkflowEndpoint:
|
||||
"""Get endpoint for a given workflow by type.
|
||||
|
||||
Args:
|
||||
workflow_id (str): The ID of the workflow.
|
||||
endpoint_type (WorkflowEndpoint.EndpointType): The type of the endpoint.
|
||||
|
||||
Returns:
|
||||
WorkflowEndpoint: The endpoint for the given workflow and type.
|
||||
"""
|
||||
workflow = WorkflowHelper.get_workflow_by_id(workflow_id)
|
||||
endpoint: WorkflowEndpoint = WorkflowEndpoint.objects.get(
|
||||
workflow=workflow,
|
||||
endpoint_type=endpoint_type,
|
||||
)
|
||||
if endpoint.connector_instance:
|
||||
endpoint.connector_instance.connector_metadata = (
|
||||
endpoint.connector_instance.metadata
|
||||
)
|
||||
return endpoint
|
||||
|
||||
@@ -578,7 +578,7 @@ class SourceConnector(BaseConnector):
|
||||
return WorkflowExecution.objects.filter(
|
||||
workflow=self.workflow,
|
||||
workflow__organization_id=organization.id, # Security: Organization isolation
|
||||
status__in=[ExecutionStatus.EXECUTING, ExecutionStatus.PENDING],
|
||||
status__in=[ExecutionStatus.EXECUTING.value, ExecutionStatus.PENDING.value],
|
||||
)
|
||||
|
||||
def _has_blocking_file_execution(self, execution, file_hash: FileHash) -> bool:
|
||||
@@ -616,7 +616,10 @@ class SourceConnector(BaseConnector):
|
||||
workflow_execution=execution,
|
||||
file_hash=file_hash.file_hash,
|
||||
file_path=file_hash.file_path,
|
||||
status__in=ExecutionStatus.get_skip_processing_statuses(),
|
||||
status__in=[
|
||||
status.value
|
||||
for status in ExecutionStatus.get_skip_processing_statuses()
|
||||
],
|
||||
)
|
||||
except WorkflowFileExecution.DoesNotExist:
|
||||
return None
|
||||
@@ -633,7 +636,10 @@ class SourceConnector(BaseConnector):
|
||||
workflow_execution=execution,
|
||||
provider_file_uuid=file_hash.provider_file_uuid,
|
||||
file_path=file_hash.file_path,
|
||||
status__in=ExecutionStatus.get_skip_processing_statuses(),
|
||||
status__in=[
|
||||
status.value
|
||||
for status in ExecutionStatus.get_skip_processing_statuses()
|
||||
],
|
||||
)
|
||||
except WorkflowFileExecution.DoesNotExist:
|
||||
return None
|
||||
|
||||
@@ -26,8 +26,8 @@ class ExecutionSerializer(serializers.ModelSerializer):
|
||||
|
||||
def get_successful_files(self, obj: WorkflowExecution) -> int:
|
||||
"""Return the count of successfully executed files"""
|
||||
return obj.file_executions.filter(status=ExecutionStatus.COMPLETED).count()
|
||||
return obj.file_executions.filter(status=ExecutionStatus.COMPLETED.value).count()
|
||||
|
||||
def get_failed_files(self, obj: WorkflowExecution) -> int:
|
||||
"""Return the count of failed executed files"""
|
||||
return obj.file_executions.filter(status=ExecutionStatus.ERROR).count()
|
||||
return obj.file_executions.filter(status=ExecutionStatus.ERROR.value).count()
|
||||
|
||||
42
backend/workflow_manager/file_execution/internal_urls.py
Normal file
42
backend/workflow_manager/file_execution/internal_urls.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Internal API URLs for File Execution
|
||||
URL patterns for file execution internal APIs.
|
||||
"""
|
||||
|
||||
from django.urls import include, path
|
||||
from rest_framework.routers import DefaultRouter
|
||||
|
||||
from .internal_views import (
|
||||
FileExecutionBatchCreateAPIView,
|
||||
FileExecutionBatchHashUpdateAPIView,
|
||||
FileExecutionBatchStatusUpdateAPIView,
|
||||
FileExecutionInternalViewSet,
|
||||
FileExecutionMetricsAPIView,
|
||||
)
|
||||
|
||||
# Create router for file execution viewsets
|
||||
router = DefaultRouter()
|
||||
router.register(r"", FileExecutionInternalViewSet, basename="file-execution-internal")
|
||||
|
||||
urlpatterns = [
|
||||
# Batch operations
|
||||
path(
|
||||
"batch-create/",
|
||||
FileExecutionBatchCreateAPIView.as_view(),
|
||||
name="file-execution-batch-create",
|
||||
),
|
||||
path(
|
||||
"batch-status-update/",
|
||||
FileExecutionBatchStatusUpdateAPIView.as_view(),
|
||||
name="file-execution-batch-status-update",
|
||||
),
|
||||
path(
|
||||
"batch-hash-update/",
|
||||
FileExecutionBatchHashUpdateAPIView.as_view(),
|
||||
name="file-execution-batch-hash-update",
|
||||
),
|
||||
path(
|
||||
"metrics/", FileExecutionMetricsAPIView.as_view(), name="file-execution-metrics"
|
||||
),
|
||||
# File execution CRUD (via router)
|
||||
path("", include(router.urls)),
|
||||
]
|
||||
777
backend/workflow_manager/file_execution/internal_views.py
Normal file
777
backend/workflow_manager/file_execution/internal_views.py
Normal file
@@ -0,0 +1,777 @@
|
||||
"""Internal API Views for File Execution
|
||||
Handles file execution related endpoints for internal services.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from django.db import transaction
|
||||
from rest_framework import status, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
from utils.organization_utils import filter_queryset_by_organization
|
||||
|
||||
from workflow_manager.endpoint_v2.dto import FileHash
|
||||
from workflow_manager.file_execution.models import WorkflowFileExecution
|
||||
|
||||
# Import serializers from workflow_manager internal API
|
||||
from workflow_manager.internal_serializers import (
|
||||
FileExecutionStatusUpdateSerializer,
|
||||
WorkflowFileExecutionSerializer,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileExecutionInternalViewSet(viewsets.ModelViewSet):
|
||||
"""Internal API ViewSet for File Execution operations."""
|
||||
|
||||
serializer_class = WorkflowFileExecutionSerializer
|
||||
lookup_field = "id"
|
||||
|
||||
def get_queryset(self):
|
||||
"""Get file executions filtered by organization context and query parameters."""
|
||||
queryset = WorkflowFileExecution.objects.all()
|
||||
|
||||
# Filter through the relationship: WorkflowFileExecution -> WorkflowExecution -> Workflow -> Organization
|
||||
queryset = filter_queryset_by_organization(
|
||||
queryset, self.request, "workflow_execution__workflow__organization"
|
||||
)
|
||||
|
||||
# Debug: Log initial queryset count after organization filtering
|
||||
org_filtered_count = queryset.count()
|
||||
logger.debug(
|
||||
f"After organization filtering: {org_filtered_count} file executions"
|
||||
)
|
||||
|
||||
# Support filtering by query parameters for get-or-create operations
|
||||
execution_id = self.request.query_params.get("execution_id")
|
||||
file_hash = self.request.query_params.get("file_hash")
|
||||
provider_file_uuid = self.request.query_params.get("provider_file_uuid")
|
||||
workflow_id = self.request.query_params.get("workflow_id")
|
||||
file_path = self.request.query_params.get(
|
||||
"file_path"
|
||||
) # CRITICAL: Add file_path parameter
|
||||
|
||||
logger.debug(
|
||||
f"Query parameters: execution_id={execution_id}, file_hash={file_hash}, provider_file_uuid={provider_file_uuid}, workflow_id={workflow_id}, file_path={file_path}"
|
||||
)
|
||||
|
||||
# Apply filters step by step with debugging
|
||||
if execution_id:
|
||||
queryset = queryset.filter(workflow_execution_id=execution_id)
|
||||
logger.info(
|
||||
f"DEBUG: After execution_id filter: {queryset.count()} file executions"
|
||||
)
|
||||
|
||||
# CRITICAL FIX: Include file_path filter to match unique constraints
|
||||
if file_path:
|
||||
queryset = queryset.filter(file_path=file_path)
|
||||
logger.debug(f"After file_path filter: {queryset.count()} file executions")
|
||||
|
||||
# CRITICAL FIX: Match backend manager logic - use file_hash OR provider_file_uuid (not both)
|
||||
if file_hash:
|
||||
queryset = queryset.filter(file_hash=file_hash)
|
||||
logger.info(
|
||||
f"DEBUG: After file_hash filter: {queryset.count()} file executions"
|
||||
)
|
||||
elif provider_file_uuid:
|
||||
queryset = queryset.filter(provider_file_uuid=provider_file_uuid)
|
||||
logger.info(
|
||||
f"DEBUG: After provider_file_uuid filter: {queryset.count()} file executions"
|
||||
)
|
||||
|
||||
if workflow_id:
|
||||
queryset = queryset.filter(workflow_execution__workflow_id=workflow_id)
|
||||
logger.info(
|
||||
f"DEBUG: After workflow_id filter: {queryset.count()} file executions"
|
||||
)
|
||||
|
||||
final_count = queryset.count()
|
||||
logger.info(
|
||||
f"Final queryset count: {final_count} file executions for params: execution_id={execution_id}, file_hash={file_hash}, provider_file_uuid={provider_file_uuid}, workflow_id={workflow_id}, file_path={file_path}"
|
||||
)
|
||||
|
||||
# If we still have too many results, something is wrong with the filtering
|
||||
if final_count > 10: # Reasonable threshold
|
||||
logger.warning(
|
||||
f"Query returned {final_count} file executions - filtering may not be working correctly"
|
||||
)
|
||||
logger.warning(
|
||||
f"Query params: execution_id={execution_id}, file_hash={file_hash}, workflow_id={workflow_id}"
|
||||
)
|
||||
|
||||
return queryset
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
"""List file executions with enhanced filtering validation."""
|
||||
queryset = self.get_queryset()
|
||||
count = queryset.count()
|
||||
|
||||
# If we get too many results, it means the filtering failed
|
||||
if count > 50: # Conservative threshold
|
||||
logger.error(
|
||||
f"GET request returned {count} file executions - this suggests broken query parameter filtering"
|
||||
)
|
||||
logger.error(f"Request query params: {dict(request.query_params)}")
|
||||
|
||||
# For debugging, show a sample of what we're returning
|
||||
sample_ids = list(queryset.values_list("id", flat=True)[:5])
|
||||
logger.error(f"Sample file execution IDs: {sample_ids}")
|
||||
|
||||
# Return error response instead of broken list
|
||||
return Response(
|
||||
{
|
||||
"error": "Query returned too many results",
|
||||
"detail": f"Expected 0-1 file executions but got {count}. Check query parameters.",
|
||||
"count": count,
|
||||
"query_params": dict(request.query_params),
|
||||
},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Continue with normal list behavior for reasonable result counts
|
||||
logger.info(f"GET request successfully filtered to {count} file executions")
|
||||
return super().list(request, *args, **kwargs)
|
||||
|
||||
@action(detail=True, methods=["post"])
|
||||
def status(self, request, id=None):
|
||||
"""Update file execution status."""
|
||||
try:
|
||||
# Get file execution by ID with organization filtering
|
||||
# Don't use self.get_object() as it applies query parameter filtering
|
||||
base_queryset = WorkflowFileExecution.objects.all()
|
||||
base_queryset = filter_queryset_by_organization(
|
||||
base_queryset, request, "workflow_execution__workflow__organization"
|
||||
)
|
||||
|
||||
try:
|
||||
file_execution = base_queryset.get(id=id)
|
||||
except WorkflowFileExecution.DoesNotExist:
|
||||
logger.warning(f"WorkflowFileExecution {id} not found for status update")
|
||||
return Response(
|
||||
{
|
||||
"error": "WorkflowFileExecution not found",
|
||||
"detail": f"No file execution record found with ID {id}",
|
||||
},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
serializer = FileExecutionStatusUpdateSerializer(data=request.data)
|
||||
|
||||
if serializer.is_valid():
|
||||
validated_data = serializer.validated_data
|
||||
|
||||
# Update file execution using the model's update_status method
|
||||
file_execution.update_status(
|
||||
status=validated_data["status"],
|
||||
execution_error=validated_data.get("error_message"),
|
||||
execution_time=validated_data.get("execution_time"),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Updated file execution {id} status to {validated_data['status']}"
|
||||
)
|
||||
|
||||
# Return consistent dataclass response
|
||||
from unstract.core.data_models import FileExecutionStatusUpdateRequest
|
||||
|
||||
response_data = FileExecutionStatusUpdateRequest(
|
||||
status=file_execution.status,
|
||||
error_message=file_execution.execution_error,
|
||||
result=getattr(file_execution, "result", None),
|
||||
)
|
||||
|
||||
return Response(
|
||||
{
|
||||
"status": "updated",
|
||||
"file_execution_id": str(file_execution.id),
|
||||
"data": response_data.to_dict(),
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update file execution status {id}: {str(e)}")
|
||||
return Response(
|
||||
{"error": "Failed to update file execution status", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
"""Create or get existing workflow file execution using existing manager method."""
|
||||
try:
|
||||
from workflow_manager.workflow_v2.models.execution import WorkflowExecution
|
||||
|
||||
data = request.data
|
||||
execution_id = data.get("execution_id")
|
||||
file_hash_data = data.get("file_hash", {})
|
||||
workflow_id = data.get("workflow_id")
|
||||
|
||||
if not execution_id:
|
||||
return Response(
|
||||
{"error": "execution_id is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Get workflow execution with organization filtering
|
||||
try:
|
||||
workflow_execution = WorkflowExecution.objects.get(id=execution_id)
|
||||
# Verify organization access
|
||||
filter_queryset_by_organization(
|
||||
WorkflowExecution.objects.filter(id=execution_id),
|
||||
request,
|
||||
"workflow__organization",
|
||||
).get()
|
||||
except WorkflowExecution.DoesNotExist:
|
||||
return Response(
|
||||
{"error": "WorkflowExecution not found or access denied"},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
# Convert request data to FileHash object that the manager expects
|
||||
file_hash = FileHash(
|
||||
file_path=file_hash_data.get("file_path", ""),
|
||||
file_name=file_hash_data.get("file_name", ""),
|
||||
source_connection_type=file_hash_data.get("source_connection_type", ""),
|
||||
file_hash=file_hash_data.get("file_hash"),
|
||||
file_size=file_hash_data.get("file_size"),
|
||||
provider_file_uuid=file_hash_data.get("provider_file_uuid"),
|
||||
mime_type=file_hash_data.get("mime_type"),
|
||||
fs_metadata=file_hash_data.get("fs_metadata"),
|
||||
file_destination=file_hash_data.get("file_destination"),
|
||||
is_executed=file_hash_data.get("is_executed", False),
|
||||
file_number=file_hash_data.get("file_number"),
|
||||
)
|
||||
|
||||
# Determine if this is an API request (affects file_path handling in manager)
|
||||
is_api = file_hash_data.get("source_connection_type", "") == "API"
|
||||
|
||||
# Use existing manager method - this handles get_or_create logic properly
|
||||
file_execution = WorkflowFileExecution.objects.get_or_create_file_execution(
|
||||
workflow_execution=workflow_execution, file_hash=file_hash, is_api=is_api
|
||||
)
|
||||
|
||||
# Return single object (not list!) using serializer
|
||||
serializer = self.get_serializer(file_execution)
|
||||
response_data = serializer.data
|
||||
|
||||
# ROOT CAUSE FIX: Ensure file_path is always present in API response
|
||||
# The backend model sets file_path to None for API files, but workers require it
|
||||
if not response_data.get("file_path") and file_hash.file_path:
|
||||
logger.info(
|
||||
f"Backend stored null file_path for API file, including original: {file_hash.file_path}"
|
||||
)
|
||||
response_data["file_path"] = file_hash.file_path
|
||||
|
||||
logger.info(
|
||||
f"Retrieved/created file execution {file_execution.id} for workflow {workflow_id}"
|
||||
)
|
||||
logger.debug(f"Response data: {response_data}")
|
||||
|
||||
# Determine status code based on whether it was created or retrieved
|
||||
# Note: We can't easily tell if it was created or retrieved from the manager,
|
||||
# but 201 is fine for both cases in this context
|
||||
return Response(response_data, status=status.HTTP_201_CREATED)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get/create file execution: {str(e)}")
|
||||
return Response(
|
||||
{"error": "Failed to get/create file execution", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
@action(detail=True, methods=["patch"])
|
||||
def update_hash(self, request, id=None):
|
||||
"""Update file execution with computed file hash."""
|
||||
try:
|
||||
# Get file execution by ID with organization filtering
|
||||
base_queryset = WorkflowFileExecution.objects.all()
|
||||
base_queryset = filter_queryset_by_organization(
|
||||
base_queryset, request, "workflow_execution__workflow__organization"
|
||||
)
|
||||
|
||||
try:
|
||||
file_execution = base_queryset.get(id=id)
|
||||
except WorkflowFileExecution.DoesNotExist:
|
||||
logger.warning(f"WorkflowFileExecution {id} not found for hash update")
|
||||
return Response(
|
||||
{
|
||||
"error": "WorkflowFileExecution not found",
|
||||
"detail": f"No file execution record found with ID {id}",
|
||||
},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
# Extract update data
|
||||
file_hash = request.data.get("file_hash")
|
||||
fs_metadata = request.data.get("fs_metadata")
|
||||
mime_type = request.data.get("mime_type")
|
||||
|
||||
if not file_hash and not fs_metadata and not mime_type:
|
||||
return Response(
|
||||
{"error": "file_hash, fs_metadata, or mime_type is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Use the model's update method for efficient field-specific updates
|
||||
file_execution.update(
|
||||
file_hash=file_hash, fs_metadata=fs_metadata, mime_type=mime_type
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Updated file execution {id} with file_hash: {file_hash[:16] if file_hash else 'none'}..."
|
||||
)
|
||||
|
||||
# Return updated record
|
||||
serializer = self.get_serializer(file_execution)
|
||||
return Response(
|
||||
{
|
||||
"status": "updated",
|
||||
"file_execution_id": str(file_execution.id),
|
||||
"data": serializer.data,
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update file execution hash {id}: {str(e)}")
|
||||
return Response(
|
||||
{"error": "Failed to update file execution hash", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
class FileExecutionBatchCreateAPIView(APIView):
|
||||
"""Internal API endpoint for creating multiple file executions in a single batch."""
|
||||
|
||||
def post(self, request):
|
||||
"""Create multiple file executions in a single batch request."""
|
||||
try:
|
||||
file_executions = request.data.get("file_executions", [])
|
||||
|
||||
if not file_executions:
|
||||
return Response(
|
||||
{"error": "file_executions list is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
successful_creations = []
|
||||
failed_creations = []
|
||||
|
||||
with transaction.atomic():
|
||||
for file_execution_data in file_executions:
|
||||
try:
|
||||
from workflow_manager.workflow_v2.models.execution import (
|
||||
WorkflowExecution,
|
||||
)
|
||||
|
||||
execution_id = file_execution_data.get("execution_id")
|
||||
file_hash_data = file_execution_data.get("file_hash", {})
|
||||
|
||||
if not execution_id:
|
||||
failed_creations.append(
|
||||
{
|
||||
"file_name": file_hash_data.get(
|
||||
"file_name", "unknown"
|
||||
),
|
||||
"error": "execution_id is required",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Get workflow execution with organization filtering
|
||||
try:
|
||||
workflow_execution = WorkflowExecution.objects.get(
|
||||
id=execution_id
|
||||
)
|
||||
# Verify organization access
|
||||
filter_queryset_by_organization(
|
||||
WorkflowExecution.objects.filter(id=execution_id),
|
||||
request,
|
||||
"workflow__organization",
|
||||
).get()
|
||||
except WorkflowExecution.DoesNotExist:
|
||||
failed_creations.append(
|
||||
{
|
||||
"file_name": file_hash_data.get(
|
||||
"file_name", "unknown"
|
||||
),
|
||||
"error": "WorkflowExecution not found or access denied",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert request data to FileHash object
|
||||
file_hash = FileHash(
|
||||
file_path=file_hash_data.get("file_path", ""),
|
||||
file_name=file_hash_data.get("file_name", ""),
|
||||
source_connection_type=file_hash_data.get(
|
||||
"source_connection_type", ""
|
||||
),
|
||||
file_hash=file_hash_data.get("file_hash"),
|
||||
file_size=file_hash_data.get("file_size"),
|
||||
provider_file_uuid=file_hash_data.get("provider_file_uuid"),
|
||||
mime_type=file_hash_data.get("mime_type"),
|
||||
fs_metadata=file_hash_data.get("fs_metadata"),
|
||||
file_destination=file_hash_data.get("file_destination"),
|
||||
is_executed=file_hash_data.get("is_executed", False),
|
||||
file_number=file_hash_data.get("file_number"),
|
||||
)
|
||||
|
||||
# Determine if this is an API request
|
||||
is_api = file_hash_data.get("source_connection_type", "") == "API"
|
||||
|
||||
# Use existing manager method
|
||||
file_execution = (
|
||||
WorkflowFileExecution.objects.get_or_create_file_execution(
|
||||
workflow_execution=workflow_execution,
|
||||
file_hash=file_hash,
|
||||
is_api=is_api,
|
||||
)
|
||||
)
|
||||
|
||||
# ROOT CAUSE FIX: Ensure file_path is always present in batch response
|
||||
# The backend model sets file_path to None for API files, but workers require it
|
||||
response_file_path = file_execution.file_path
|
||||
if not response_file_path and file_hash.file_path:
|
||||
response_file_path = file_hash.file_path
|
||||
|
||||
successful_creations.append(
|
||||
{
|
||||
"id": str(file_execution.id),
|
||||
"file_name": file_execution.file_name,
|
||||
"file_path": response_file_path,
|
||||
"status": file_execution.status,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_creations.append(
|
||||
{
|
||||
"file_name": file_execution_data.get("file_hash", {}).get(
|
||||
"file_name", "unknown"
|
||||
),
|
||||
"error": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Batch file execution creation: {len(successful_creations)} successful, {len(failed_creations)} failed"
|
||||
)
|
||||
|
||||
return Response(
|
||||
{
|
||||
"successful_creations": successful_creations,
|
||||
"failed_creations": failed_creations,
|
||||
"total_processed": len(file_executions),
|
||||
},
|
||||
status=status.HTTP_201_CREATED,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process batch file execution creation: {str(e)}")
|
||||
return Response(
|
||||
{
|
||||
"error": "Failed to process batch file execution creation",
|
||||
"detail": str(e),
|
||||
},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
class FileExecutionBatchStatusUpdateAPIView(APIView):
|
||||
"""Internal API endpoint for updating multiple file execution statuses in a single batch."""
|
||||
|
||||
def post(self, request):
|
||||
"""Update multiple file execution statuses in a single batch request."""
|
||||
try:
|
||||
status_updates = request.data.get("status_updates", [])
|
||||
|
||||
if not status_updates:
|
||||
return Response(
|
||||
{"error": "status_updates list is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
successful_updates = []
|
||||
failed_updates = []
|
||||
|
||||
with transaction.atomic():
|
||||
for update_data in status_updates:
|
||||
try:
|
||||
file_execution_id = update_data.get("file_execution_id")
|
||||
status_value = update_data.get("status")
|
||||
|
||||
if not file_execution_id or not status_value:
|
||||
failed_updates.append(
|
||||
{
|
||||
"file_execution_id": file_execution_id,
|
||||
"error": "file_execution_id and status are required",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Get file execution with organization filtering
|
||||
base_queryset = WorkflowFileExecution.objects.all()
|
||||
base_queryset = filter_queryset_by_organization(
|
||||
base_queryset,
|
||||
request,
|
||||
"workflow_execution__workflow__organization",
|
||||
)
|
||||
|
||||
try:
|
||||
file_execution = base_queryset.get(id=file_execution_id)
|
||||
except WorkflowFileExecution.DoesNotExist:
|
||||
failed_updates.append(
|
||||
{
|
||||
"file_execution_id": file_execution_id,
|
||||
"error": "WorkflowFileExecution not found",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Update file execution using the model's update_status method
|
||||
file_execution.update_status(
|
||||
status=status_value,
|
||||
execution_error=update_data.get("error_message"),
|
||||
execution_time=update_data.get("execution_time"),
|
||||
)
|
||||
|
||||
successful_updates.append(
|
||||
{
|
||||
"file_execution_id": str(file_execution.id),
|
||||
"status": file_execution.status,
|
||||
"file_name": file_execution.file_name,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_updates.append(
|
||||
{"file_execution_id": file_execution_id, "error": str(e)}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Batch file execution status update: {len(successful_updates)} successful, {len(failed_updates)} failed"
|
||||
)
|
||||
|
||||
return Response(
|
||||
{
|
||||
"successful_updates": successful_updates,
|
||||
"failed_updates": failed_updates,
|
||||
"total_processed": len(status_updates),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process batch file execution status update: {str(e)}"
|
||||
)
|
||||
return Response(
|
||||
{
|
||||
"error": "Failed to process batch file execution status update",
|
||||
"detail": str(e),
|
||||
},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
class FileExecutionBatchHashUpdateAPIView(APIView):
|
||||
"""Internal API endpoint for updating multiple file execution hashes in a single batch."""
|
||||
|
||||
def post(self, request):
|
||||
"""Update multiple file execution hashes in a single batch request."""
|
||||
try:
|
||||
hash_updates = request.data.get("hash_updates", [])
|
||||
|
||||
if not hash_updates:
|
||||
return Response(
|
||||
{"error": "hash_updates list is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
successful_updates = []
|
||||
failed_updates = []
|
||||
|
||||
with transaction.atomic():
|
||||
for update_data in hash_updates:
|
||||
try:
|
||||
file_execution_id = update_data.get("file_execution_id")
|
||||
file_hash = update_data.get("file_hash")
|
||||
|
||||
if not file_execution_id or not file_hash:
|
||||
failed_updates.append(
|
||||
{
|
||||
"file_execution_id": file_execution_id,
|
||||
"error": "file_execution_id and file_hash are required",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Get file execution with organization filtering
|
||||
base_queryset = WorkflowFileExecution.objects.all()
|
||||
base_queryset = filter_queryset_by_organization(
|
||||
base_queryset,
|
||||
request,
|
||||
"workflow_execution__workflow__organization",
|
||||
)
|
||||
|
||||
try:
|
||||
file_execution = base_queryset.get(id=file_execution_id)
|
||||
except WorkflowFileExecution.DoesNotExist:
|
||||
failed_updates.append(
|
||||
{
|
||||
"file_execution_id": file_execution_id,
|
||||
"error": "WorkflowFileExecution not found",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Update file execution hash using the model's update method
|
||||
file_execution.update(
|
||||
file_hash=file_hash,
|
||||
fs_metadata=update_data.get("fs_metadata"),
|
||||
)
|
||||
|
||||
successful_updates.append(
|
||||
{
|
||||
"file_execution_id": str(file_execution.id),
|
||||
"file_hash": file_hash[:16] + "..."
|
||||
if file_hash
|
||||
else None,
|
||||
"file_name": file_execution.file_name,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_updates.append(
|
||||
{"file_execution_id": file_execution_id, "error": str(e)}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Batch file execution hash update: {len(successful_updates)} successful, {len(failed_updates)} failed"
|
||||
)
|
||||
|
||||
return Response(
|
||||
{
|
||||
"successful_updates": successful_updates,
|
||||
"failed_updates": failed_updates,
|
||||
"total_processed": len(hash_updates),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process batch file execution hash update: {str(e)}")
|
||||
return Response(
|
||||
{
|
||||
"error": "Failed to process batch file execution hash update",
|
||||
"detail": str(e),
|
||||
},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
class FileExecutionMetricsAPIView(APIView):
|
||||
"""Internal API endpoint for getting file execution metrics."""
|
||||
|
||||
def get(self, request):
|
||||
"""Get file execution metrics with optional filtering."""
|
||||
try:
|
||||
# Get query parameters
|
||||
start_date = request.query_params.get("start_date")
|
||||
end_date = request.query_params.get("end_date")
|
||||
workflow_id = request.query_params.get("workflow_id")
|
||||
execution_id = request.query_params.get("execution_id")
|
||||
status = request.query_params.get("status")
|
||||
|
||||
# Build base queryset with organization filtering
|
||||
file_executions = WorkflowFileExecution.objects.all()
|
||||
file_executions = filter_queryset_by_organization(
|
||||
file_executions, request, "workflow_execution__workflow__organization"
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if start_date:
|
||||
from datetime import datetime
|
||||
|
||||
file_executions = file_executions.filter(
|
||||
created_at__gte=datetime.fromisoformat(start_date)
|
||||
)
|
||||
if end_date:
|
||||
from datetime import datetime
|
||||
|
||||
file_executions = file_executions.filter(
|
||||
created_at__lte=datetime.fromisoformat(end_date)
|
||||
)
|
||||
if workflow_id:
|
||||
file_executions = file_executions.filter(
|
||||
workflow_execution__workflow_id=workflow_id
|
||||
)
|
||||
if execution_id:
|
||||
file_executions = file_executions.filter(
|
||||
workflow_execution_id=execution_id
|
||||
)
|
||||
if status:
|
||||
file_executions = file_executions.filter(status=status)
|
||||
|
||||
# Calculate metrics
|
||||
from django.db.models import Avg, Count, Sum
|
||||
|
||||
total_file_executions = file_executions.count()
|
||||
|
||||
# Status breakdown
|
||||
status_counts = file_executions.values("status").annotate(count=Count("id"))
|
||||
status_breakdown = {item["status"]: item["count"] for item in status_counts}
|
||||
|
||||
# Success rate
|
||||
completed_count = status_breakdown.get("COMPLETED", 0)
|
||||
success_rate = (
|
||||
(completed_count / total_file_executions)
|
||||
if total_file_executions > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
# Average execution time
|
||||
avg_execution_time = (
|
||||
file_executions.aggregate(avg_time=Avg("execution_time"))["avg_time"] or 0
|
||||
)
|
||||
|
||||
# File size statistics
|
||||
total_file_size = (
|
||||
file_executions.aggregate(total_size=Sum("file_size"))["total_size"] or 0
|
||||
)
|
||||
|
||||
avg_file_size = (
|
||||
file_executions.aggregate(avg_size=Avg("file_size"))["avg_size"] or 0
|
||||
)
|
||||
|
||||
metrics = {
|
||||
"total_file_executions": total_file_executions,
|
||||
"status_breakdown": status_breakdown,
|
||||
"success_rate": success_rate,
|
||||
"average_execution_time": avg_execution_time,
|
||||
"total_file_size": total_file_size,
|
||||
"average_file_size": avg_file_size,
|
||||
"filters_applied": {
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"workflow_id": workflow_id,
|
||||
"execution_id": execution_id,
|
||||
"status": status,
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Generated file execution metrics: {total_file_executions} executions, {success_rate:.2%} success rate"
|
||||
)
|
||||
|
||||
return Response(metrics)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file execution metrics: {str(e)}")
|
||||
return Response(
|
||||
{"error": "Failed to get file execution metrics", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
@@ -120,32 +120,30 @@ class WorkflowFileExecution(BaseModel):
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
status: ExecutionStatus,
|
||||
status: ExecutionStatus | str,
|
||||
execution_error: str = None,
|
||||
execution_time: float = None,
|
||||
) -> None:
|
||||
"""Updates the status and execution details of an input file.
|
||||
|
||||
Args:
|
||||
execution_file: The `WorkflowExecutionFile` object to update
|
||||
status: The new status of the file
|
||||
execution_time: The execution time for processing the file
|
||||
status: The new status of the file (ExecutionStatus enum or string)
|
||||
execution_time: The execution time for processing the file (optional)
|
||||
execution_error: (Optional) Error message if processing failed
|
||||
|
||||
Return:
|
||||
The updated `WorkflowExecutionInputFile` object
|
||||
"""
|
||||
self.status = status
|
||||
|
||||
if (
|
||||
status
|
||||
in [
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.ERROR,
|
||||
ExecutionStatus.STOPPED,
|
||||
]
|
||||
and not self.execution_time
|
||||
):
|
||||
self.execution_time = CommonUtils.time_since(self.created_at)
|
||||
# Set execution_time if provided, otherwise calculate it for final states
|
||||
status = ExecutionStatus(status)
|
||||
self.status = status.value
|
||||
if status in [
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.ERROR,
|
||||
ExecutionStatus.STOPPED,
|
||||
]:
|
||||
self.execution_time = CommonUtils.time_since(self.created_at, 3)
|
||||
|
||||
self.execution_error = execution_error
|
||||
self.save()
|
||||
@@ -210,6 +208,13 @@ class WorkflowFileExecution(BaseModel):
|
||||
fields=["workflow_execution", "provider_file_uuid", "file_path"],
|
||||
name="unique_workflow_provider_uuid_path",
|
||||
),
|
||||
# CRITICAL FIX: Add constraint for API files where file_path is None
|
||||
# This prevents duplicate entries for same file_hash
|
||||
models.UniqueConstraint(
|
||||
fields=["workflow_execution", "file_hash"],
|
||||
condition=models.Q(file_path__isnull=True),
|
||||
name="unique_workflow_api_file_hash",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -219,17 +224,20 @@ class WorkflowFileExecution(BaseModel):
|
||||
Returns:
|
||||
bool: True if the execution status is completed, False otherwise.
|
||||
"""
|
||||
return self.status is not None and self.status == ExecutionStatus.COMPLETED
|
||||
return self.status is not None and self.status == ExecutionStatus.COMPLETED.value
|
||||
|
||||
def update(
|
||||
self,
|
||||
file_hash: str = None,
|
||||
fs_metadata: dict[str, Any] = None,
|
||||
mime_type: str = None,
|
||||
) -> None:
|
||||
"""Updates the file execution details.
|
||||
|
||||
Args:
|
||||
file_hash: (Optional) Hash of the file content
|
||||
fs_metadata: (Optional) File system metadata
|
||||
mime_type: (Optional) MIME type of the file
|
||||
|
||||
Returns:
|
||||
None
|
||||
@@ -242,5 +250,8 @@ class WorkflowFileExecution(BaseModel):
|
||||
if fs_metadata is not None:
|
||||
self.fs_metadata = fs_metadata
|
||||
update_fields.append("fs_metadata")
|
||||
if mime_type is not None:
|
||||
self.mime_type = mime_type
|
||||
update_fields.append("mime_type")
|
||||
if update_fields: # Save only if there's an actual update
|
||||
self.save(update_fields=update_fields)
|
||||
|
||||
@@ -30,9 +30,9 @@ class FileCentricExecutionSerializer(serializers.ModelSerializer):
|
||||
exclude = ["file_hash"]
|
||||
|
||||
def get_status_msg(self, obj: FileExecution) -> dict[str, any] | None:
|
||||
if obj.status in [ExecutionStatus.PENDING]:
|
||||
if obj.status in [ExecutionStatus.PENDING.value]:
|
||||
return self.INIT_STATUS_MSG
|
||||
elif obj.status == ExecutionStatus.ERROR:
|
||||
elif obj.status == ExecutionStatus.ERROR.value:
|
||||
return obj.execution_error
|
||||
|
||||
latest_log = (
|
||||
|
||||
447
backend/workflow_manager/internal_api_views.py
Normal file
447
backend/workflow_manager/internal_api_views.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""Internal API Views for Worker Communication
|
||||
|
||||
This module provides internal API endpoints that workers use to communicate
|
||||
with Django backend for database operations only. All business logic has been
|
||||
moved to workers.
|
||||
|
||||
NOTE: Many sophisticated endpoints are now implemented in internal_views.py
|
||||
using class-based views. This file contains simpler function-based views
|
||||
for basic operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from account_v2.models import Organization
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import api_view
|
||||
from rest_framework.response import Response
|
||||
from tool_instance_v2.models import ToolInstance
|
||||
|
||||
from workflow_manager.workflow_v2.enums import ExecutionStatus
|
||||
from workflow_manager.workflow_v2.models import Workflow, WorkflowExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
|
||||
@api_view(["GET"])
|
||||
def get_workflow_execution_data(request, execution_id: str):
|
||||
"""Get workflow execution data for workers.
|
||||
|
||||
Args:
|
||||
execution_id: Workflow execution ID
|
||||
|
||||
Returns:
|
||||
JSON response with workflow and execution data
|
||||
"""
|
||||
try:
|
||||
# Get organization from header
|
||||
org_id = request.headers.get("X-Organization-ID")
|
||||
if not org_id:
|
||||
return Response(
|
||||
{"error": "X-Organization-ID header is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Get execution with organization filtering
|
||||
execution = WorkflowExecution.objects.select_related("workflow").get(
|
||||
id=execution_id, workflow__organization_id=org_id
|
||||
)
|
||||
|
||||
workflow = execution.workflow
|
||||
|
||||
# Prepare workflow data
|
||||
workflow_data = {
|
||||
"id": str(workflow.id),
|
||||
"workflow_name": workflow.workflow_name,
|
||||
"execution_details": workflow.execution_details,
|
||||
"organization_id": workflow.organization_id,
|
||||
}
|
||||
|
||||
# Prepare execution data
|
||||
execution_data = {
|
||||
"id": str(execution.id),
|
||||
"status": execution.status,
|
||||
"execution_mode": execution.execution_mode,
|
||||
"execution_method": execution.execution_method,
|
||||
"execution_type": execution.execution_type,
|
||||
"pipeline_id": execution.pipeline_id,
|
||||
"total_files": execution.total_files,
|
||||
"completed_files": execution.completed_files,
|
||||
"failed_files": execution.failed_files,
|
||||
"execution_log_id": execution.execution_log_id, # Include for WebSocket messaging
|
||||
}
|
||||
|
||||
return Response(
|
||||
{
|
||||
"workflow": workflow_data,
|
||||
"execution": execution_data,
|
||||
}
|
||||
)
|
||||
|
||||
except WorkflowExecution.DoesNotExist:
|
||||
return Response(
|
||||
{"error": f"Workflow execution {execution_id} not found"},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflow execution data: {e}")
|
||||
return Response(
|
||||
{"error": "Internal server error"},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
|
||||
@api_view(["GET"])
|
||||
def get_tool_instances_by_workflow(request, workflow_id: str):
|
||||
"""Get tool instances for a workflow.
|
||||
|
||||
Args:
|
||||
workflow_id: Workflow ID
|
||||
|
||||
Returns:
|
||||
JSON response with tool instances data
|
||||
"""
|
||||
try:
|
||||
# Get organization from header
|
||||
org_id = request.headers.get("X-Organization-ID")
|
||||
if not org_id:
|
||||
logger.error(f"Missing X-Organization-ID header for workflow {workflow_id}")
|
||||
return Response(
|
||||
{"error": "X-Organization-ID header is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
logger.info(f"Getting tool instances for workflow {workflow_id}, org {org_id}")
|
||||
|
||||
# Get tool instances with organization filtering
|
||||
# First check if workflow exists and belongs to organization
|
||||
try:
|
||||
# Get organization object first (org_id is the organization_id string field)
|
||||
logger.info(f"Looking up organization with organization_id: {org_id}")
|
||||
organization = Organization.objects.get(organization_id=org_id)
|
||||
logger.info(
|
||||
f"Found organization: {organization.id} - {organization.display_name}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Looking up workflow {workflow_id} for organization {organization.id}"
|
||||
)
|
||||
workflow = Workflow.objects.get(id=workflow_id, organization=organization)
|
||||
logger.info(f"Found workflow: {workflow.workflow_name}")
|
||||
|
||||
except Organization.DoesNotExist:
|
||||
logger.error(f"Organization not found: {org_id}")
|
||||
return Response(
|
||||
{"error": f"Organization {org_id} not found"},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
except Workflow.DoesNotExist:
|
||||
logger.error(f"Workflow {workflow_id} not found for organization {org_id}")
|
||||
return Response(
|
||||
{"error": f"Workflow {workflow_id} not found for organization {org_id}"},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error during organization/workflow lookup: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return Response(
|
||||
{"error": "Database lookup error", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
# Get tool instances for the workflow
|
||||
tool_instances = ToolInstance.objects.filter(workflow=workflow).order_by("step")
|
||||
|
||||
# Prepare tool instances data
|
||||
instances_data = []
|
||||
for instance in tool_instances:
|
||||
instance_data = {
|
||||
"id": str(instance.id),
|
||||
"tool_id": instance.tool_id,
|
||||
"step": instance.step,
|
||||
"status": instance.status,
|
||||
"version": instance.version,
|
||||
"metadata": instance.metadata,
|
||||
"input": instance.input,
|
||||
"output": instance.output,
|
||||
}
|
||||
instances_data.append(instance_data)
|
||||
|
||||
return Response(
|
||||
{
|
||||
"tool_instances": instances_data,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting tool instances for workflow {workflow_id}: {e}", exc_info=True
|
||||
)
|
||||
return Response(
|
||||
{"error": "Internal server error", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
|
||||
@api_view(["POST"])
|
||||
def create_file_execution_batch(request):
|
||||
"""Create a batch of file executions for workers.
|
||||
|
||||
Returns:
|
||||
JSON response with batch creation result
|
||||
"""
|
||||
try:
|
||||
# Get organization from header
|
||||
org_id = request.headers.get("X-Organization-ID")
|
||||
if not org_id:
|
||||
logger.error(
|
||||
"Missing X-Organization-ID header for file execution batch creation"
|
||||
)
|
||||
return Response(
|
||||
{"error": "X-Organization-ID header is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# For now, return a simple response indicating batch creation
|
||||
# This would be expanded based on actual requirements
|
||||
return Response(
|
||||
{
|
||||
"batch_id": "temp-batch-id",
|
||||
"status": "created",
|
||||
"organization_id": org_id,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating file execution batch: {e}", exc_info=True)
|
||||
return Response(
|
||||
{"error": "Internal server error", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
|
||||
@api_view(["POST"])
|
||||
def update_file_execution_batch_status(request):
|
||||
"""Update file execution batch status for workers.
|
||||
|
||||
Returns:
|
||||
JSON response with batch status update result
|
||||
"""
|
||||
try:
|
||||
# Get organization from header
|
||||
org_id = request.headers.get("X-Organization-ID")
|
||||
if not org_id:
|
||||
logger.error(
|
||||
"Missing X-Organization-ID header for file execution batch status update"
|
||||
)
|
||||
return Response(
|
||||
{"error": "X-Organization-ID header is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# For now, return a simple response indicating status update
|
||||
# This would be expanded based on actual requirements
|
||||
return Response(
|
||||
{
|
||||
"status": "updated",
|
||||
"organization_id": org_id,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating file execution batch status: {e}", exc_info=True)
|
||||
return Response(
|
||||
{"error": "Internal server error", "detail": str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
|
||||
@api_view(["POST"])
|
||||
def create_workflow_execution(request):
|
||||
"""Create a new workflow execution.
|
||||
|
||||
Returns:
|
||||
JSON response with execution ID
|
||||
"""
|
||||
try:
|
||||
data = request.data
|
||||
|
||||
# Get organization from header
|
||||
org_id = request.headers.get("X-Organization-ID")
|
||||
if not org_id:
|
||||
return Response(
|
||||
{"error": "X-Organization-ID header is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Get workflow with organization filtering
|
||||
# First get organization object, then lookup workflow
|
||||
try:
|
||||
organization = Organization.objects.get(organization_id=org_id)
|
||||
workflow = Workflow.objects.get(
|
||||
id=data["workflow_id"], organization=organization
|
||||
)
|
||||
except Organization.DoesNotExist:
|
||||
return Response(
|
||||
{"error": f"Organization {org_id} not found"},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
# Create execution with log_events_id for WebSocket messaging
|
||||
log_events_id = data.get("log_events_id")
|
||||
# If log_events_id not provided, fall back to pipeline_id for backward compatibility
|
||||
execution_log_id = log_events_id if log_events_id else data.get("pipeline_id")
|
||||
|
||||
execution = WorkflowExecution.objects.create(
|
||||
workflow=workflow,
|
||||
pipeline_id=data.get("pipeline_id"),
|
||||
execution_mode=data.get("mode", WorkflowExecution.Mode.INSTANT),
|
||||
execution_method=WorkflowExecution.Method.SCHEDULED
|
||||
if data.get("scheduled")
|
||||
else WorkflowExecution.Method.DIRECT,
|
||||
execution_type=WorkflowExecution.Type.STEP
|
||||
if data.get("single_step")
|
||||
else WorkflowExecution.Type.COMPLETE,
|
||||
status=ExecutionStatus.PENDING.value,
|
||||
total_files=data.get("total_files", 0),
|
||||
execution_log_id=execution_log_id, # Set execution_log_id for WebSocket messaging
|
||||
)
|
||||
|
||||
# Set tags if provided
|
||||
if data.get("tags"):
|
||||
# Handle tags logic if needed
|
||||
pass
|
||||
|
||||
return Response(
|
||||
{
|
||||
"execution_id": str(execution.id),
|
||||
"status": execution.status,
|
||||
"execution_log_id": execution.execution_log_id, # Return for workers to use
|
||||
}
|
||||
)
|
||||
|
||||
except Workflow.DoesNotExist:
|
||||
return Response({"error": "Workflow not found"}, status=status.HTTP_404_NOT_FOUND)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating workflow execution: {e}")
|
||||
return Response(
|
||||
{"error": "Internal server error"},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
|
||||
@api_view(["POST"])
|
||||
def compile_workflow(request):
|
||||
"""Compile workflow for workers.
|
||||
|
||||
This is a database-only operation that workers need.
|
||||
|
||||
Returns:
|
||||
JSON response with compilation result
|
||||
"""
|
||||
try:
|
||||
data = request.data
|
||||
workflow_id = data.get("workflow_id")
|
||||
|
||||
# Get organization from header
|
||||
org_id = request.headers.get("X-Organization-ID")
|
||||
if not org_id:
|
||||
return Response(
|
||||
{"error": "X-Organization-ID header is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# For now, return success since compilation logic needs to be migrated
|
||||
# TODO: Implement actual compilation logic in workers
|
||||
|
||||
return Response(
|
||||
{
|
||||
"success": True,
|
||||
"workflow_id": workflow_id,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error compiling workflow: {e}")
|
||||
return Response(
|
||||
{"error": "Internal server error"},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@csrf_exempt # Safe: Internal API with Bearer token auth, no session/cookies
|
||||
@api_view(["POST"])
|
||||
def submit_file_batch_for_processing(request):
|
||||
"""Submit file batch for processing by workers.
|
||||
|
||||
This endpoint receives batch data and returns immediately,
|
||||
as actual processing is handled by Celery workers.
|
||||
|
||||
Returns:
|
||||
JSON response with batch submission status
|
||||
"""
|
||||
try:
|
||||
batch_data = request.data
|
||||
|
||||
# Get organization from header
|
||||
org_id = request.headers.get("X-Organization-ID")
|
||||
if not org_id:
|
||||
return Response(
|
||||
{"error": "X-Organization-ID header is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Add organization ID to file_data where WorkerFileData expects it
|
||||
if "file_data" in batch_data:
|
||||
batch_data["file_data"]["organization_id"] = org_id
|
||||
else:
|
||||
# Fallback: add at top level for backward compatibility
|
||||
batch_data["organization_id"] = org_id
|
||||
|
||||
# Submit to file processing worker queue using Celery
|
||||
try:
|
||||
from backend.celery_service import app as celery_app
|
||||
|
||||
# Submit the batch data to the file processing worker using send_task
|
||||
# This calls the task by name without needing to import it
|
||||
task_result = celery_app.send_task(
|
||||
"process_file_batch", # Task name as defined in workers/file_processing/tasks.py
|
||||
args=[batch_data], # Pass batch_data as first argument
|
||||
queue="file_processing", # Send to file processing queue
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully submitted file batch {batch_data.get('batch_id')} to worker queue (task: {task_result.id})"
|
||||
)
|
||||
|
||||
return Response(
|
||||
{
|
||||
"success": True,
|
||||
"batch_id": batch_data.get("batch_id"),
|
||||
"celery_task_id": task_result.id,
|
||||
"message": "Batch submitted for processing",
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to submit batch to worker queue: {e}")
|
||||
return Response(
|
||||
{"error": f"Failed to submit batch for processing: {str(e)}"},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error submitting file batch: {e}")
|
||||
return Response(
|
||||
{"error": "Internal server error"},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
220
backend/workflow_manager/internal_serializers.py
Normal file
220
backend/workflow_manager/internal_serializers.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Workflow Manager Internal API Serializers
|
||||
Handles serialization for workflow execution related internal endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from pipeline_v2.models import Pipeline
|
||||
from rest_framework import serializers
|
||||
|
||||
# Import shared dataclasses for type safety and consistency
|
||||
from unstract.core.data_models import (
|
||||
FileExecutionStatusUpdateRequest,
|
||||
WorkflowFileExecutionData,
|
||||
)
|
||||
from workflow_manager.file_execution.models import WorkflowFileExecution
|
||||
from workflow_manager.workflow_v2.enums import ExecutionStatus
|
||||
from workflow_manager.workflow_v2.models.execution import WorkflowExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowExecutionSerializer(serializers.ModelSerializer):
|
||||
"""Serializer for WorkflowExecution model for internal API."""
|
||||
|
||||
workflow_id = serializers.CharField(source="workflow.id", read_only=True)
|
||||
workflow_name = serializers.CharField(source="workflow.workflow_name", read_only=True)
|
||||
pipeline_id = serializers.SerializerMethodField()
|
||||
tags = serializers.SerializerMethodField()
|
||||
|
||||
def get_pipeline_id(self, obj):
|
||||
"""ROOT CAUSE FIX: Return None for pipeline_id if the referenced pipeline doesn't exist.
|
||||
This prevents callback workers from attempting to update deleted pipelines.
|
||||
PERFORMANCE: Cache pipeline existence to avoid repeated DB queries.
|
||||
"""
|
||||
if not obj.pipeline_id:
|
||||
return None
|
||||
|
||||
# Use instance-level cache to avoid repeated DB queries within same request
|
||||
cache_key = f"_pipeline_exists_{obj.pipeline_id}"
|
||||
if hasattr(self, cache_key):
|
||||
return getattr(self, cache_key)
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from api_v2.models import APIDeployment
|
||||
|
||||
try:
|
||||
# First check if it's a Pipeline
|
||||
Pipeline.objects.get(id=obj.pipeline_id)
|
||||
result = str(obj.pipeline_id)
|
||||
setattr(self, cache_key, result)
|
||||
return result
|
||||
except Pipeline.DoesNotExist:
|
||||
# Not a Pipeline, check if it's an APIDeployment
|
||||
try:
|
||||
APIDeployment.objects.get(id=obj.pipeline_id)
|
||||
result = str(obj.pipeline_id)
|
||||
setattr(self, cache_key, result)
|
||||
return result
|
||||
except APIDeployment.DoesNotExist:
|
||||
# Neither Pipeline nor APIDeployment exists - return None to prevent stale reference usage
|
||||
setattr(self, cache_key, None)
|
||||
return None
|
||||
|
||||
def get_tags(self, obj):
|
||||
"""Serialize tags as full objects with id, name, and description.
|
||||
|
||||
This method ensures tags are serialized as:
|
||||
[{"id": "uuid", "name": "tag_name", "description": "..."}, ...]
|
||||
instead of just ["uuid1", "uuid2", ...]
|
||||
"""
|
||||
try:
|
||||
return [
|
||||
{
|
||||
"id": str(tag.id),
|
||||
"name": tag.name,
|
||||
"description": tag.description or "",
|
||||
}
|
||||
for tag in obj.tags.all()
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to serialize tags for execution {obj.id}: {str(e)}")
|
||||
return []
|
||||
|
||||
class Meta:
|
||||
model = WorkflowExecution
|
||||
fields = [
|
||||
"id",
|
||||
"workflow_id",
|
||||
"workflow_name",
|
||||
"pipeline_id",
|
||||
"task_id",
|
||||
"execution_mode",
|
||||
"execution_method",
|
||||
"execution_type",
|
||||
"execution_log_id",
|
||||
"status",
|
||||
"result_acknowledged",
|
||||
"total_files",
|
||||
"error_message",
|
||||
"attempts",
|
||||
"execution_time",
|
||||
"created_at",
|
||||
"modified_at",
|
||||
"tags",
|
||||
]
|
||||
read_only_fields = ["id", "created_at", "modified_at"]
|
||||
|
||||
|
||||
class WorkflowFileExecutionSerializer(serializers.ModelSerializer):
|
||||
"""Serializer for WorkflowFileExecution model for internal API.
|
||||
Enhanced with shared dataclass integration for type safety.
|
||||
"""
|
||||
|
||||
workflow_execution_id = serializers.CharField(
|
||||
source="workflow_execution.id", read_only=True
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = WorkflowFileExecution
|
||||
fields = [
|
||||
"id",
|
||||
"workflow_execution_id",
|
||||
"file_name",
|
||||
"file_path",
|
||||
"file_size",
|
||||
"file_hash",
|
||||
"provider_file_uuid",
|
||||
"mime_type",
|
||||
"fs_metadata",
|
||||
"status",
|
||||
"execution_error",
|
||||
"created_at",
|
||||
"modified_at",
|
||||
]
|
||||
read_only_fields = ["id", "created_at", "modified_at"]
|
||||
|
||||
def to_dataclass(self, instance=None) -> WorkflowFileExecutionData:
|
||||
"""Convert serialized data to shared dataclass."""
|
||||
if instance is None:
|
||||
instance = self.instance
|
||||
return WorkflowFileExecutionData.from_dict(self.to_representation(instance))
|
||||
|
||||
@classmethod
|
||||
def from_dataclass(cls, data: WorkflowFileExecutionData) -> dict:
|
||||
"""Convert shared dataclass to serializer-compatible dict."""
|
||||
return data.to_dict()
|
||||
|
||||
|
||||
class FileExecutionStatusUpdateSerializer(serializers.Serializer):
|
||||
"""Serializer for updating file execution status.
|
||||
Enhanced with shared dataclass integration for type safety.
|
||||
"""
|
||||
|
||||
status = serializers.ChoiceField(choices=ExecutionStatus.choices)
|
||||
error_message = serializers.CharField(required=False, allow_blank=True)
|
||||
result = serializers.CharField(required=False, allow_blank=True)
|
||||
execution_time = serializers.FloatField(required=False, min_value=0)
|
||||
|
||||
def to_dataclass(self) -> FileExecutionStatusUpdateRequest:
|
||||
"""Convert validated data to shared dataclass."""
|
||||
return FileExecutionStatusUpdateRequest(
|
||||
status=self.validated_data["status"],
|
||||
error_message=self.validated_data.get("error_message"),
|
||||
result=self.validated_data.get("result"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dataclass(cls, data: FileExecutionStatusUpdateRequest):
|
||||
"""Create serializer from shared dataclass."""
|
||||
return cls(data=data.to_dict())
|
||||
|
||||
|
||||
class WorkflowExecutionStatusUpdateSerializer(serializers.Serializer):
|
||||
"""Serializer for updating workflow execution status."""
|
||||
|
||||
status = serializers.ChoiceField(choices=ExecutionStatus.choices)
|
||||
error_message = serializers.CharField(required=False, allow_blank=True)
|
||||
total_files = serializers.IntegerField(
|
||||
required=False, min_value=0
|
||||
) # Allow 0 but backend will only update if > 0
|
||||
attempts = serializers.IntegerField(required=False, min_value=0)
|
||||
execution_time = serializers.FloatField(required=False, min_value=0)
|
||||
|
||||
|
||||
class OrganizationContextSerializer(serializers.Serializer):
|
||||
"""Serializer for organization context information."""
|
||||
|
||||
organization_id = serializers.CharField(allow_null=True, required=False)
|
||||
organization_name = serializers.CharField(required=False, allow_blank=True)
|
||||
settings = serializers.DictField(required=False)
|
||||
|
||||
|
||||
class WorkflowExecutionContextSerializer(serializers.Serializer):
|
||||
"""Serializer for complete workflow execution context."""
|
||||
|
||||
execution = WorkflowExecutionSerializer()
|
||||
workflow_definition = serializers.DictField()
|
||||
source_config = serializers.DictField()
|
||||
destination_config = serializers.DictField(required=False)
|
||||
organization_context = OrganizationContextSerializer()
|
||||
file_executions = serializers.ListField(required=False)
|
||||
aggregated_usage_cost = serializers.FloatField(required=False, allow_null=True)
|
||||
|
||||
|
||||
class FileBatchCreateSerializer(serializers.Serializer):
|
||||
"""Serializer for creating file execution batches."""
|
||||
|
||||
workflow_execution_id = serializers.UUIDField()
|
||||
files = serializers.ListField(child=serializers.DictField(), allow_empty=False)
|
||||
is_api = serializers.BooleanField(default=False)
|
||||
|
||||
|
||||
class FileBatchResponseSerializer(serializers.Serializer):
|
||||
"""Serializer for file batch creation response."""
|
||||
|
||||
batch_id = serializers.CharField()
|
||||
workflow_execution_id = serializers.CharField()
|
||||
total_files = serializers.IntegerField()
|
||||
created_file_executions = serializers.ListField()
|
||||
skipped_files = serializers.ListField(required=False)
|
||||
139
backend/workflow_manager/internal_urls.py
Normal file
139
backend/workflow_manager/internal_urls.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Internal API URLs for Workflow Manager
|
||||
|
||||
URLs for internal APIs that workers use to communicate with Django backend.
|
||||
These handle only database operations while business logic remains in workers.
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from . import internal_api_views, internal_views
|
||||
|
||||
app_name = "workflow_manager_internal"
|
||||
|
||||
urlpatterns = [
|
||||
# Workflow execution endpoints - specific paths first
|
||||
path(
|
||||
"execution/create/",
|
||||
internal_api_views.create_workflow_execution,
|
||||
name="create_workflow_execution",
|
||||
),
|
||||
path(
|
||||
"execution/<str:execution_id>/",
|
||||
internal_api_views.get_workflow_execution_data,
|
||||
name="get_workflow_execution_data",
|
||||
),
|
||||
# Tool instance endpoints
|
||||
path(
|
||||
"workflow/<str:workflow_id>/tool-instances/",
|
||||
internal_api_views.get_tool_instances_by_workflow,
|
||||
name="get_tool_instances_by_workflow",
|
||||
),
|
||||
# Workflow compilation
|
||||
path(
|
||||
"workflow/compile/",
|
||||
internal_api_views.compile_workflow,
|
||||
name="compile_workflow",
|
||||
),
|
||||
# File batch processing
|
||||
path(
|
||||
"file-batch/submit/",
|
||||
internal_api_views.submit_file_batch_for_processing,
|
||||
name="submit_file_batch_for_processing",
|
||||
),
|
||||
# Workflow definition and type detection (using sophisticated class-based views)
|
||||
path(
|
||||
"workflow/<str:workflow_id>/",
|
||||
internal_views.WorkflowDefinitionAPIView.as_view(),
|
||||
name="get_workflow_definition",
|
||||
),
|
||||
path(
|
||||
"<str:workflow_id>/endpoint/",
|
||||
internal_views.WorkflowEndpointAPIView.as_view(),
|
||||
name="get_workflow_endpoints",
|
||||
),
|
||||
path(
|
||||
"pipeline-type/<str:pipeline_id>/",
|
||||
internal_views.PipelineTypeAPIView.as_view(),
|
||||
name="get_pipeline_type",
|
||||
),
|
||||
path(
|
||||
"pipeline-name/<str:pipeline_id>/",
|
||||
internal_views.PipelineNameAPIView.as_view(),
|
||||
name="get_pipeline_name",
|
||||
),
|
||||
# Batch operations (using sophisticated class-based views)
|
||||
path(
|
||||
"batch-status-update/",
|
||||
internal_views.BatchStatusUpdateAPIView.as_view(),
|
||||
name="batch_update_execution_status",
|
||||
),
|
||||
path(
|
||||
"file-batch/",
|
||||
internal_views.FileBatchCreateAPIView.as_view(),
|
||||
name="create_file_batch",
|
||||
),
|
||||
# File management (using sophisticated class-based views)
|
||||
path(
|
||||
"increment-files/",
|
||||
internal_views.FileCountIncrementAPIView.as_view(),
|
||||
name="increment_files",
|
||||
),
|
||||
path(
|
||||
"file-history/create/",
|
||||
internal_views.FileHistoryCreateView.as_view(),
|
||||
name="create_file_history_entry",
|
||||
),
|
||||
path(
|
||||
"file-history/check-batch/",
|
||||
internal_views.FileHistoryBatchCheckView.as_view(),
|
||||
name="check_file_history_batch",
|
||||
),
|
||||
# Additional endpoints available in internal_views.py
|
||||
path(
|
||||
"source-files/<str:workflow_id>/",
|
||||
internal_views.WorkflowSourceFilesAPIView.as_view(),
|
||||
name="get_workflow_source_files",
|
||||
),
|
||||
# path("execution/finalize/<str:execution_id>/", removed - ExecutionFinalizationAPIView was unused dead code
|
||||
path(
|
||||
"execution/cleanup/",
|
||||
internal_views.WorkflowExecutionCleanupAPIView.as_view(),
|
||||
name="cleanup_executions",
|
||||
),
|
||||
path(
|
||||
"execution/metrics/",
|
||||
internal_views.WorkflowExecutionMetricsAPIView.as_view(),
|
||||
name="get_execution_metrics",
|
||||
),
|
||||
path(
|
||||
"file-execution/",
|
||||
internal_views.WorkflowFileExecutionAPIView.as_view(),
|
||||
name="workflow_file_execution",
|
||||
),
|
||||
path(
|
||||
"file-execution/check-active",
|
||||
internal_views.WorkflowFileExecutionCheckActiveAPIView.as_view(),
|
||||
name="workflow_file_execution_check_active",
|
||||
),
|
||||
path(
|
||||
"execute-file/",
|
||||
internal_views.WorkflowExecuteFileAPIView.as_view(),
|
||||
name="execute_workflow_file",
|
||||
),
|
||||
path(
|
||||
"pipeline/<str:pipeline_id>/status/",
|
||||
internal_views.PipelineStatusUpdateAPIView.as_view(),
|
||||
name="update_pipeline_status",
|
||||
),
|
||||
# File execution batch operations (using simple function views for now)
|
||||
path(
|
||||
"file-execution/batch-create/",
|
||||
internal_api_views.create_file_execution_batch,
|
||||
name="file_execution_batch_create",
|
||||
),
|
||||
path(
|
||||
"file-execution/batch-status-update/",
|
||||
internal_api_views.update_file_execution_batch_status,
|
||||
name="file_execution_batch_status_update",
|
||||
),
|
||||
]
|
||||
2633
backend/workflow_manager/internal_views.py
Normal file
2633
backend/workflow_manager/internal_views.py
Normal file
File diff suppressed because it is too large
Load Diff
27
backend/workflow_manager/workflow_execution_internal_urls.py
Normal file
27
backend/workflow_manager/workflow_execution_internal_urls.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Workflow Manager Internal API URLs
|
||||
Defines internal API endpoints for workflow execution operations.
|
||||
"""
|
||||
|
||||
from django.urls import include, path
|
||||
from rest_framework.routers import DefaultRouter
|
||||
|
||||
from .internal_views import FileBatchCreateAPIView, WorkflowExecutionInternalViewSet
|
||||
|
||||
# Create router for internal API viewsets
|
||||
router = DefaultRouter()
|
||||
router.register(
|
||||
r"",
|
||||
WorkflowExecutionInternalViewSet,
|
||||
basename="workflow-execution-internal",
|
||||
)
|
||||
|
||||
urlpatterns = [
|
||||
# Workflow execution internal APIs
|
||||
path(
|
||||
"create-file-batch/",
|
||||
FileBatchCreateAPIView.as_view(),
|
||||
name="create-file-batch",
|
||||
),
|
||||
# Include router URLs for viewsets (this creates the CRUD endpoints)
|
||||
path("", include(router.urls)),
|
||||
]
|
||||
@@ -2,7 +2,10 @@ from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from django.db.models import TextChoices
|
||||
from unstract.core.data_models import ExecutionStatus as SharedExecutionStatus
|
||||
|
||||
# Alias shared ExecutionStatus to ensure consistency between backend and workers
|
||||
ExecutionStatus = SharedExecutionStatus
|
||||
|
||||
|
||||
class WorkflowExecutionMethod(Enum):
|
||||
@@ -10,105 +13,6 @@ class WorkflowExecutionMethod(Enum):
|
||||
QUEUED = "QUEUED"
|
||||
|
||||
|
||||
class ExecutionStatus(TextChoices):
|
||||
"""An enumeration representing the various statuses of an execution
|
||||
process.
|
||||
|
||||
Statuses:
|
||||
PENDING: The execution's entry has been created in the database.
|
||||
EXECUTING: The execution is currently in progress.
|
||||
COMPLETED: The execution has been successfully completed.
|
||||
STOPPED: The execution was stopped by the user
|
||||
(applicable to step executions).
|
||||
ERROR: An error occurred during the execution process.
|
||||
|
||||
Note:
|
||||
Intermediate statuses might not be experienced due to
|
||||
Django's query triggering once all processes are completed.
|
||||
"""
|
||||
|
||||
PENDING = "PENDING"
|
||||
EXECUTING = "EXECUTING"
|
||||
COMPLETED = "COMPLETED"
|
||||
STOPPED = "STOPPED"
|
||||
ERROR = "ERROR"
|
||||
|
||||
@classmethod
|
||||
def is_completed(cls, status: str | ExecutionStatus) -> bool:
|
||||
"""Check if the execution status is completed."""
|
||||
try:
|
||||
status_enum = cls(status)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid status: {status}. Must be a valid ExecutionStatus."
|
||||
)
|
||||
return status_enum in [cls.COMPLETED, cls.STOPPED, cls.ERROR]
|
||||
|
||||
@classmethod
|
||||
def is_active(cls, status: str | ExecutionStatus) -> bool:
|
||||
"""Check if the workflow execution status is active (in progress)."""
|
||||
try:
|
||||
status_enum = cls(status)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid status: {status}. Must be a valid ExecutionStatus."
|
||||
)
|
||||
return status_enum in [cls.PENDING, cls.EXECUTING]
|
||||
|
||||
@classmethod
|
||||
def get_skip_processing_statuses(cls) -> list[ExecutionStatus]:
|
||||
"""Get list of statuses that should skip file processing.
|
||||
|
||||
Skip processing if:
|
||||
- EXECUTING: File is currently being processed
|
||||
- PENDING: File is queued to be processed
|
||||
- COMPLETED: File has already been successfully processed
|
||||
|
||||
Returns:
|
||||
list[ExecutionStatus]: List of statuses where file processing should be skipped
|
||||
"""
|
||||
return [cls.EXECUTING, cls.PENDING, cls.COMPLETED]
|
||||
|
||||
@classmethod
|
||||
def should_skip_file_processing(cls, status: str | ExecutionStatus) -> bool:
|
||||
"""Check if file processing should be skipped based on status.
|
||||
|
||||
Allow processing (retry) if:
|
||||
- STOPPED: Processing was stopped, can retry
|
||||
- ERROR: Processing failed, can retry
|
||||
"""
|
||||
try:
|
||||
status_enum = cls(status)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid status: {status}. Must be a valid ExecutionStatus."
|
||||
)
|
||||
return status_enum in cls.get_skip_processing_statuses()
|
||||
|
||||
@classmethod
|
||||
def can_update_to_pending(cls, status: str | ExecutionStatus) -> bool:
|
||||
"""Check if a status can be updated to PENDING.
|
||||
|
||||
Allow updating to PENDING if:
|
||||
- Status is STOPPED or ERROR (can retry)
|
||||
- Status is None (new record)
|
||||
|
||||
Don't allow updating to PENDING if:
|
||||
- Status is EXECUTING (currently processing)
|
||||
- Status is COMPLETED (already done)
|
||||
- Status is already PENDING (no change needed)
|
||||
"""
|
||||
if status is None:
|
||||
return True
|
||||
|
||||
try:
|
||||
status_enum = cls(status)
|
||||
except ValueError:
|
||||
return True # Invalid status, allow update
|
||||
|
||||
return status_enum in [cls.STOPPED, cls.ERROR]
|
||||
|
||||
|
||||
class SchemaType(Enum):
|
||||
"""Possible types for workflow module's JSON schema.
|
||||
|
||||
|
||||
@@ -78,12 +78,12 @@ class WorkflowExecutionServiceHelper(WorkflowExecutionService):
|
||||
log_events_id = StateStore.get(Common.LOG_EVENTS_ID)
|
||||
self.execution_log_id = log_events_id if log_events_id else pipeline_id
|
||||
self.execution_mode = mode
|
||||
self.execution_method: tuple[str, str] = (
|
||||
self.execution_method = (
|
||||
WorkflowExecution.Method.SCHEDULED
|
||||
if scheduled
|
||||
else WorkflowExecution.Method.DIRECT
|
||||
)
|
||||
self.execution_type: tuple[str, str] = (
|
||||
self.execution_type = (
|
||||
WorkflowExecution.Type.STEP
|
||||
if single_step
|
||||
else WorkflowExecution.Type.COMPLETE
|
||||
@@ -94,7 +94,7 @@ class WorkflowExecutionServiceHelper(WorkflowExecutionService):
|
||||
execution_mode=mode,
|
||||
execution_method=self.execution_method,
|
||||
execution_type=self.execution_type,
|
||||
status=ExecutionStatus.EXECUTING,
|
||||
status=ExecutionStatus.EXECUTING.value,
|
||||
execution_log_id=self.execution_log_id,
|
||||
)
|
||||
workflow_execution.save()
|
||||
@@ -140,12 +140,12 @@ class WorkflowExecutionServiceHelper(WorkflowExecutionService):
|
||||
if existing_execution:
|
||||
return existing_execution
|
||||
|
||||
execution_method: tuple[str, str] = (
|
||||
execution_method = (
|
||||
WorkflowExecution.Method.SCHEDULED
|
||||
if scheduled
|
||||
else WorkflowExecution.Method.DIRECT
|
||||
)
|
||||
execution_type: tuple[str, str] = (
|
||||
execution_type = (
|
||||
WorkflowExecution.Type.STEP
|
||||
if single_step
|
||||
else WorkflowExecution.Type.COMPLETE
|
||||
@@ -159,7 +159,7 @@ class WorkflowExecutionServiceHelper(WorkflowExecutionService):
|
||||
execution_mode=mode,
|
||||
execution_method=execution_method,
|
||||
execution_type=execution_type,
|
||||
status=ExecutionStatus.PENDING,
|
||||
status=ExecutionStatus.PENDING.value,
|
||||
execution_log_id=execution_log_id,
|
||||
total_files=total_files,
|
||||
)
|
||||
@@ -396,7 +396,7 @@ class WorkflowExecutionServiceHelper(WorkflowExecutionService):
|
||||
def update_execution_err(execution_id: str, err_msg: str = "") -> WorkflowExecution:
|
||||
try:
|
||||
execution = WorkflowExecution.objects.get(pk=execution_id)
|
||||
execution.status = ExecutionStatus.ERROR
|
||||
execution.status = ExecutionStatus.ERROR.value
|
||||
execution.error_message = err_msg[:EXECUTION_ERROR_LENGTH]
|
||||
execution.save()
|
||||
return execution
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Internal API URLs for Execution Log Operations
|
||||
|
||||
URLs for internal APIs that workers use to communicate with Django backend
|
||||
for execution log operations. These handle database operations while business
|
||||
logic remains in workers.
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from . import execution_log_internal_views
|
||||
|
||||
app_name = "execution_log_internal"
|
||||
|
||||
urlpatterns = [
|
||||
# Execution log management endpoints
|
||||
path(
|
||||
"workflow-executions/by-ids/",
|
||||
execution_log_internal_views.GetWorkflowExecutionsByIdsAPIView.as_view(),
|
||||
name="get_workflow_executions_by_ids",
|
||||
),
|
||||
path(
|
||||
"file-executions/by-ids/",
|
||||
execution_log_internal_views.GetFileExecutionsByIdsAPIView.as_view(),
|
||||
name="get_file_executions_by_ids",
|
||||
),
|
||||
path(
|
||||
"executions/validate/",
|
||||
execution_log_internal_views.ValidateExecutionReferencesAPIView.as_view(),
|
||||
name="validate_execution_references",
|
||||
),
|
||||
path(
|
||||
"process-log-history/",
|
||||
execution_log_internal_views.ProcessLogHistoryAPIView.as_view(),
|
||||
name="process_log_history",
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,171 @@
|
||||
"""Internal API Views for Execution Log Operations
|
||||
|
||||
These views handle internal API requests from workers for execution log operations.
|
||||
They provide database access while keeping business logic in workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from rest_framework import status
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from workflow_manager.file_execution.models import WorkflowFileExecution
|
||||
from workflow_manager.workflow_v2.execution_log_utils import (
|
||||
process_log_history_from_cache,
|
||||
)
|
||||
from workflow_manager.workflow_v2.models import WorkflowExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GetWorkflowExecutionsByIdsAPIView(APIView):
|
||||
"""API view for getting workflow executions by IDs."""
|
||||
|
||||
def post(self, request: Request) -> Response:
|
||||
"""Get workflow execution data for given IDs.
|
||||
|
||||
Args:
|
||||
request: HTTP request containing execution IDs
|
||||
|
||||
Returns:
|
||||
JSON response with execution data
|
||||
"""
|
||||
try:
|
||||
execution_ids = request.data.get("execution_ids", [])
|
||||
|
||||
executions = WorkflowExecution.objects.filter(id__in=execution_ids)
|
||||
execution_data = {}
|
||||
|
||||
for execution in executions:
|
||||
execution_data[str(execution.id)] = {
|
||||
"id": str(execution.id),
|
||||
"workflow_id": str(execution.workflow.id)
|
||||
if execution.workflow
|
||||
else None,
|
||||
"status": execution.status,
|
||||
"created_at": execution.created_at.isoformat()
|
||||
if execution.created_at
|
||||
else None,
|
||||
}
|
||||
|
||||
return Response({"executions": execution_data})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflow executions: {e}", exc_info=True)
|
||||
return Response(
|
||||
{"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
class GetFileExecutionsByIdsAPIView(APIView):
|
||||
"""API view for getting file executions by IDs."""
|
||||
|
||||
def post(self, request: Request) -> Response:
|
||||
"""Get file execution data for given IDs.
|
||||
|
||||
Args:
|
||||
request: HTTP request containing file execution IDs
|
||||
|
||||
Returns:
|
||||
JSON response with file execution data
|
||||
"""
|
||||
try:
|
||||
file_execution_ids = request.data.get("file_execution_ids", [])
|
||||
|
||||
file_executions = WorkflowFileExecution.objects.filter(
|
||||
id__in=file_execution_ids
|
||||
)
|
||||
file_execution_data = {}
|
||||
|
||||
for file_execution in file_executions:
|
||||
file_execution_data[str(file_execution.id)] = {
|
||||
"id": str(file_execution.id),
|
||||
"workflow_execution_id": str(file_execution.workflow_execution.id)
|
||||
if file_execution.workflow_execution
|
||||
else None,
|
||||
"status": file_execution.status,
|
||||
"created_at": file_execution.created_at.isoformat()
|
||||
if file_execution.created_at
|
||||
else None,
|
||||
}
|
||||
|
||||
return Response({"file_executions": file_execution_data})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting file executions: {e}", exc_info=True)
|
||||
return Response(
|
||||
{"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
class ValidateExecutionReferencesAPIView(APIView):
|
||||
"""API view for validating execution references exist."""
|
||||
|
||||
def post(self, request: Request) -> Response:
|
||||
"""Validate that execution references exist.
|
||||
|
||||
Args:
|
||||
request: HTTP request containing execution and file execution IDs
|
||||
|
||||
Returns:
|
||||
JSON response with validation results
|
||||
"""
|
||||
try:
|
||||
execution_ids = request.data.get("execution_ids", [])
|
||||
file_execution_ids = request.data.get("file_execution_ids", [])
|
||||
|
||||
# Check which executions exist
|
||||
existing_executions = {
|
||||
str(obj.id)
|
||||
for obj in WorkflowExecution.objects.filter(id__in=execution_ids)
|
||||
}
|
||||
|
||||
# Check which file executions exist
|
||||
existing_file_executions = {
|
||||
str(obj.id)
|
||||
for obj in WorkflowFileExecution.objects.filter(id__in=file_execution_ids)
|
||||
}
|
||||
|
||||
return Response(
|
||||
{
|
||||
"valid_executions": list(existing_executions),
|
||||
"valid_file_executions": list(existing_file_executions),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating execution references: {e}", exc_info=True)
|
||||
return Response(
|
||||
{"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
class ProcessLogHistoryAPIView(APIView):
|
||||
"""API view for processing log history from scheduler.
|
||||
|
||||
This endpoint is called by the log history scheduler when logs exist in Redis queue.
|
||||
It reuses the existing business logic from execution_log_utils.process_log_history_from_cache().
|
||||
"""
|
||||
|
||||
def post(self, request: Request) -> Response:
|
||||
"""Process log history batch from Redis cache.
|
||||
|
||||
Args:
|
||||
request: HTTP request (no parameters needed)
|
||||
|
||||
Returns:
|
||||
JSON response with processing results
|
||||
"""
|
||||
try:
|
||||
# Reuse existing business logic (uses ExecutionLogConstants for config)
|
||||
result = process_log_history_from_cache()
|
||||
|
||||
return Response(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing log history: {e}", exc_info=True)
|
||||
return Response(
|
||||
{"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
@@ -16,15 +16,35 @@ from workflow_manager.workflow_v2.models import ExecutionLog, WorkflowExecution
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(name=ExecutionLogConstants.TASK_V2)
|
||||
def consume_log_history() -> None:
|
||||
def process_log_history_from_cache(
|
||||
queue_name: str = ExecutionLogConstants.LOG_QUEUE_NAME,
|
||||
batch_limit: int = ExecutionLogConstants.LOGS_BATCH_LIMIT,
|
||||
) -> dict:
|
||||
"""Process log history from Redis cache.
|
||||
|
||||
This function contains the core business logic for processing execution logs
|
||||
from Redis cache to database. It can be called by both the Celery task and
|
||||
internal API endpoints.
|
||||
|
||||
Args:
|
||||
queue_name: Redis queue name to process logs from
|
||||
batch_limit: Maximum number of logs to process in one batch
|
||||
|
||||
Returns:
|
||||
Dictionary with processing results:
|
||||
- processed_count: Number of logs successfully stored
|
||||
- skipped_count: Number of logs skipped (invalid references)
|
||||
- total_logs: Total number of logs retrieved from cache
|
||||
- organizations_processed: Number of organizations affected
|
||||
"""
|
||||
organization_logs = defaultdict(list)
|
||||
logs_count = 0
|
||||
logs_to_process = []
|
||||
skipped_count = 0
|
||||
|
||||
# Collect logs from cache (batch retrieval)
|
||||
while logs_count < ExecutionLogConstants.LOGS_BATCH_LIMIT:
|
||||
log = CacheService.lpop(ExecutionLogConstants.LOG_QUEUE_NAME)
|
||||
while logs_count < batch_limit:
|
||||
log = CacheService.lpop(queue_name)
|
||||
if not log:
|
||||
break
|
||||
|
||||
@@ -34,9 +54,14 @@ def consume_log_history() -> None:
|
||||
logs_count += 1
|
||||
|
||||
if not logs_to_process:
|
||||
return # No logs to process
|
||||
return {
|
||||
"processed_count": 0,
|
||||
"skipped_count": 0,
|
||||
"total_logs": 0,
|
||||
"organizations_processed": 0,
|
||||
}
|
||||
|
||||
logger.info(f"Logs count: {logs_count}")
|
||||
logger.info(f"Processing {logs_count} logs from queue '{queue_name}'")
|
||||
|
||||
# Preload required WorkflowExecution and WorkflowFileExecution objects
|
||||
execution_ids = {log.execution_id for log in logs_to_process}
|
||||
@@ -60,7 +85,8 @@ def consume_log_history() -> None:
|
||||
f"Execution not found for execution_id: {log_data.execution_id}, "
|
||||
"skipping log push"
|
||||
)
|
||||
continue # Skip logs with missing execution reference
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
execution_log = ExecutionLog(
|
||||
wf_execution=execution,
|
||||
@@ -69,16 +95,42 @@ def consume_log_history() -> None:
|
||||
)
|
||||
|
||||
if log_data.file_execution_id:
|
||||
execution_log.file_execution = file_execution_map.get(
|
||||
log_data.file_execution_id
|
||||
)
|
||||
file_execution = file_execution_map.get(log_data.file_execution_id)
|
||||
if file_execution:
|
||||
execution_log.file_execution = file_execution
|
||||
else:
|
||||
logger.warning(
|
||||
f"File execution not found for file_execution_id: {log_data.file_execution_id}, "
|
||||
"skipping log push"
|
||||
)
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
organization_logs[log_data.organization_id].append(execution_log)
|
||||
|
||||
# Bulk insert logs for each organization
|
||||
processed_count = 0
|
||||
for organization_id, logs in organization_logs.items():
|
||||
logger.info(f"Storing '{len(logs)}' logs for org: {organization_id}")
|
||||
logger.info(f"Storing {len(logs)} logs for org: {organization_id}")
|
||||
ExecutionLog.objects.bulk_create(objs=logs, ignore_conflicts=True)
|
||||
processed_count += len(logs)
|
||||
|
||||
return {
|
||||
"processed_count": processed_count,
|
||||
"skipped_count": skipped_count,
|
||||
"total_logs": logs_count,
|
||||
"organizations_processed": len(organization_logs),
|
||||
}
|
||||
|
||||
|
||||
@shared_task(name=ExecutionLogConstants.TASK_V2)
|
||||
def consume_log_history() -> None:
|
||||
"""Celery task to consume log history from Redis cache.
|
||||
|
||||
This task is a thin wrapper around process_log_history_from_cache() for
|
||||
backward compatibility with existing Celery Beat schedules.
|
||||
"""
|
||||
process_log_history_from_cache()
|
||||
|
||||
|
||||
def create_log_consumer_scheduler_if_not_exists() -> None:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any
|
||||
from django.db.models import Q
|
||||
from django.db.utils import IntegrityError
|
||||
from django.utils import timezone
|
||||
from utils.cache_service import CacheService
|
||||
|
||||
from workflow_manager.endpoint_v2.dto import FileHash
|
||||
from workflow_manager.endpoint_v2.models import WorkflowEndpoint
|
||||
@@ -245,8 +246,8 @@ class FileHistoryHelper:
|
||||
metadata: str | None,
|
||||
error: str | None = None,
|
||||
is_api: bool = False,
|
||||
) -> None:
|
||||
"""Create a new file history record.
|
||||
) -> FileHistory:
|
||||
"""Create a new file history record or return existing one.
|
||||
|
||||
Args:
|
||||
file_hash (FileHash): The file hash for the file.
|
||||
@@ -255,35 +256,90 @@ class FileHistoryHelper:
|
||||
result (Any): The result from the execution.
|
||||
metadata (str | None): The metadata from the execution.
|
||||
error (str | None): The error from the execution.
|
||||
is_api (bool): Whether this is an API call.
|
||||
"""
|
||||
try:
|
||||
file_path = file_hash.file_path if not is_api else None
|
||||
is_api (bool): Whether this is for API workflow (affects file_path handling).
|
||||
|
||||
FileHistory.objects.create(
|
||||
Returns:
|
||||
FileHistory: Either newly created or existing file history record.
|
||||
"""
|
||||
file_path = file_hash.file_path if not is_api else None
|
||||
|
||||
# Prepare data for creation
|
||||
create_data = {
|
||||
"workflow": workflow,
|
||||
"cache_key": file_hash.file_hash,
|
||||
"provider_file_uuid": file_hash.provider_file_uuid,
|
||||
"status": status,
|
||||
"result": str(result),
|
||||
"metadata": str(metadata) if metadata else "",
|
||||
"error": str(error) if error else "",
|
||||
"file_path": file_path,
|
||||
}
|
||||
|
||||
try:
|
||||
# Try to create the file history record
|
||||
file_history = FileHistory.objects.create(**create_data)
|
||||
logger.info(
|
||||
f"Created new FileHistory record - "
|
||||
f"file_name='{file_hash.file_name}', file_path='{file_hash.file_path}', "
|
||||
f"file_hash='{file_hash.file_hash[:16] if file_hash.file_hash else 'None'}', "
|
||||
f"workflow={workflow}"
|
||||
)
|
||||
return file_history
|
||||
|
||||
except IntegrityError as e:
|
||||
# Race condition detected - another worker created the record
|
||||
# Try to retrieve the existing record
|
||||
logger.info(
|
||||
f"FileHistory constraint violation (expected in concurrent environment) - "
|
||||
f"file_name='{file_hash.file_name}', file_path='{file_hash.file_path}', "
|
||||
f"file_hash='{file_hash.file_hash[:16] if file_hash.file_hash else 'None'}', "
|
||||
f"workflow={workflow}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
# Use the existing get_file_history method to retrieve the record
|
||||
existing_record = FileHistoryHelper.get_file_history(
|
||||
workflow=workflow,
|
||||
cache_key=file_hash.file_hash,
|
||||
provider_file_uuid=file_hash.provider_file_uuid,
|
||||
status=status,
|
||||
result=str(result),
|
||||
metadata=str(metadata) if metadata else "",
|
||||
error=str(error) if error else "",
|
||||
file_path=file_path,
|
||||
)
|
||||
except IntegrityError as e:
|
||||
# TODO: Need to find why duplicate insert is coming
|
||||
logger.warning(
|
||||
f"Trying to insert duplication data for filename {file_hash.file_name} "
|
||||
f"for workflow {workflow}. Error: {str(e)} with metadata {metadata}",
|
||||
)
|
||||
|
||||
if existing_record:
|
||||
logger.info(
|
||||
f"Retrieved existing FileHistory record after constraint violation - "
|
||||
f"ID: {existing_record.id}, workflow={workflow}"
|
||||
)
|
||||
return existing_record
|
||||
else:
|
||||
# This should rarely happen, but if we can't find the existing record,
|
||||
# log the issue and re-raise the original exception
|
||||
logger.error(
|
||||
f"Failed to retrieve existing FileHistory record after constraint violation - "
|
||||
f"file_name='{file_hash.file_name}', workflow={workflow}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def clear_history_for_workflow(
|
||||
workflow: Workflow,
|
||||
) -> None:
|
||||
"""Clear all file history records associated with a workflow.
|
||||
"""Clear all file history records and Redis caches associated with a workflow.
|
||||
|
||||
Args:
|
||||
workflow (Workflow): The workflow to clear the history for.
|
||||
"""
|
||||
# Clear database records
|
||||
FileHistory.objects.filter(workflow=workflow).delete()
|
||||
logger.info(f"Cleared database records for workflow {workflow.id}")
|
||||
|
||||
# Clear Redis caches for file_active entries
|
||||
pattern = f"file_active:{workflow.id}:*"
|
||||
|
||||
try:
|
||||
CacheService.clear_cache_optimized(pattern)
|
||||
logger.info(
|
||||
f"Cleared Redis cache entries for workflow {workflow.id} with pattern: {pattern}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clear Redis caches for workflow {workflow.id}: {str(e)}"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Internal API URLs for file history operations."""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from .views import (
|
||||
create_file_history_internal,
|
||||
file_history_batch_lookup_internal,
|
||||
file_history_by_cache_key_internal,
|
||||
file_history_status_internal,
|
||||
get_file_history_internal,
|
||||
reserve_file_processing_internal,
|
||||
)
|
||||
|
||||
urlpatterns = [
|
||||
# File history endpoints
|
||||
path(
|
||||
"cache-key/<str:cache_key>/",
|
||||
file_history_by_cache_key_internal,
|
||||
name="file-history-by-cache-key-internal",
|
||||
),
|
||||
# Flexible lookup endpoint (supports both cache_key and provider_file_uuid)
|
||||
path(
|
||||
"lookup/",
|
||||
file_history_by_cache_key_internal,
|
||||
name="file-history-lookup-internal",
|
||||
),
|
||||
# Batch lookup endpoint for multiple files
|
||||
path(
|
||||
"batch-lookup/",
|
||||
file_history_batch_lookup_internal,
|
||||
name="file-history-batch-lookup-internal",
|
||||
),
|
||||
path("create/", create_file_history_internal, name="create-file-history-internal"),
|
||||
path(
|
||||
"status/<str:file_history_id>/",
|
||||
file_history_status_internal,
|
||||
name="file-history-status-internal",
|
||||
),
|
||||
path(
|
||||
"reserve/",
|
||||
reserve_file_processing_internal,
|
||||
name="reserve-file-processing-internal",
|
||||
),
|
||||
path("get/", get_file_history_internal, name="get-file-history-internal"),
|
||||
]
|
||||
@@ -6,7 +6,6 @@ from api_v2.models import APIDeployment
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet, Sum
|
||||
from django.utils import timezone
|
||||
from pipeline_v2.models import Pipeline
|
||||
from tags.models import Tag
|
||||
from usage_v2.constants import UsageKeys
|
||||
@@ -226,6 +225,17 @@ class WorkflowExecution(BaseModel):
|
||||
def is_completed(self) -> bool:
|
||||
return ExecutionStatus.is_completed(self.status)
|
||||
|
||||
@property
|
||||
def organization_id(self) -> str | None:
|
||||
"""Get the organization ID from the associated workflow."""
|
||||
if (
|
||||
self.workflow
|
||||
and hasattr(self.workflow, "organization")
|
||||
and self.workflow.organization
|
||||
):
|
||||
return str(self.workflow.organization.organization_id)
|
||||
return None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"Workflow execution: {self.id} ("
|
||||
@@ -250,19 +260,15 @@ class WorkflowExecution(BaseModel):
|
||||
increment_attempt (bool, optional): Whether to increment attempt counter. Defaults to False.
|
||||
"""
|
||||
if status is not None:
|
||||
status = ExecutionStatus(status)
|
||||
self.status = status.value
|
||||
if (
|
||||
status
|
||||
in [
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.ERROR,
|
||||
ExecutionStatus.STOPPED,
|
||||
]
|
||||
and not self.execution_time
|
||||
):
|
||||
self.execution_time = round(
|
||||
(timezone.now() - self.created_at).total_seconds(), 3
|
||||
)
|
||||
if status in [
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.ERROR,
|
||||
ExecutionStatus.STOPPED,
|
||||
]:
|
||||
self.execution_time = CommonUtils.time_since(self.created_at, 3)
|
||||
|
||||
if error:
|
||||
self.error_message = error[:EXECUTION_ERROR_LENGTH]
|
||||
if increment_attempt:
|
||||
|
||||
@@ -18,7 +18,7 @@ class FileHistory(BaseModel):
|
||||
Returns:
|
||||
bool: True if the execution status is completed, False otherwise.
|
||||
"""
|
||||
return self.status is not None and self.status == ExecutionStatus.COMPLETED
|
||||
return self.status is not None and self.status == ExecutionStatus.COMPLETED.value
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
||||
cache_key = models.CharField(
|
||||
|
||||
@@ -21,6 +21,7 @@ from backend.serializers import AuditSerializer
|
||||
from workflow_manager.workflow_v2.constants import WorkflowExecutionKey, WorkflowKey
|
||||
from workflow_manager.workflow_v2.models.execution import WorkflowExecution
|
||||
from workflow_manager.workflow_v2.models.execution_log import ExecutionLog
|
||||
from workflow_manager.workflow_v2.models.file_history import FileHistory
|
||||
from workflow_manager.workflow_v2.models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -129,6 +130,12 @@ class WorkflowExecutionLogSerializer(ModelSerializer):
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class FileHistorySerializer(ModelSerializer):
|
||||
class Meta:
|
||||
model = FileHistory
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class SharedUserListSerializer(ModelSerializer):
|
||||
"""Serializer for returning workflow with shared user details."""
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -483,18 +483,13 @@ class WorkflowHelper:
|
||||
},
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
# Log task_id for debugging
|
||||
logger.info(
|
||||
f"[{org_schema}] AsyncResult created with task_id: '{async_execution.id}' "
|
||||
f"(type: {type(async_execution.id).__name__})"
|
||||
f"[{org_schema}] Job '{async_execution}' has been enqueued for "
|
||||
f"execution_id '{execution_id}', '{len(hash_values_of_files)}' files"
|
||||
)
|
||||
|
||||
workflow_execution: WorkflowExecution = WorkflowExecution.objects.get(
|
||||
id=execution_id
|
||||
)
|
||||
|
||||
# Handle empty task_id gracefully using existing validation logic
|
||||
if not async_execution.id:
|
||||
logger.warning(
|
||||
f"[{org_schema}] Celery returned empty task_id for execution_id '{execution_id}'. "
|
||||
@@ -509,6 +504,7 @@ class WorkflowHelper:
|
||||
f"[{org_schema}] Job '{async_execution.id}' has been enqueued for "
|
||||
f"execution_id '{execution_id}', '{len(hash_values_of_files)}' files"
|
||||
)
|
||||
|
||||
execution_status = workflow_execution.status
|
||||
if timeout > -1:
|
||||
while not ExecutionStatus.is_completed(execution_status) and timeout > 0:
|
||||
@@ -779,17 +775,16 @@ class WorkflowHelper:
|
||||
# Normal Workflow page execution
|
||||
workflow_execution = WorkflowExecution.objects.get(pk=execution_id)
|
||||
if (
|
||||
workflow_execution.status != ExecutionStatus.PENDING
|
||||
workflow_execution.status != ExecutionStatus.PENDING.value
|
||||
or workflow_execution.execution_type != WorkflowExecution.Type.COMPLETE
|
||||
):
|
||||
raise InvalidRequest(WorkflowErrors.INVALID_EXECUTION_ID)
|
||||
organization_identifier = UserContext.get_organization_identifier()
|
||||
result: ExecutionResponse = WorkflowHelper.run_workflow(
|
||||
workflow=workflow,
|
||||
workflow_execution=workflow_execution,
|
||||
result: ExecutionResponse = WorkflowHelper.execute_workflow_async(
|
||||
workflow_id=str(workflow.id) if workflow else None,
|
||||
pipeline_id=str(pipeline_id) if pipeline_id else None,
|
||||
execution_id=str(execution_id) if execution_id else None,
|
||||
hash_values_of_files=hash_values_of_files,
|
||||
use_file_history=use_file_history,
|
||||
organization_id=organization_identifier,
|
||||
)
|
||||
result = WorkflowHelper.wait_for_execution(result, timeout=timeout)
|
||||
return result
|
||||
@@ -813,7 +808,8 @@ class WorkflowHelper:
|
||||
ExecutionResponse: The execution response.
|
||||
"""
|
||||
if (
|
||||
result.execution_status in [ExecutionStatus.COMPLETED, ExecutionStatus.ERROR]
|
||||
result.execution_status
|
||||
in [ExecutionStatus.COMPLETED.value, ExecutionStatus.ERROR.value]
|
||||
or not timeout
|
||||
):
|
||||
return result
|
||||
@@ -879,7 +875,7 @@ class WorkflowHelper:
|
||||
except WorkflowExecution.DoesNotExist:
|
||||
raise WorkflowExecutionNotExist(WorkflowErrors.INVALID_EXECUTION_ID)
|
||||
if (
|
||||
workflow_execution.status != ExecutionStatus.PENDING
|
||||
workflow_execution.status != ExecutionStatus.PENDING.value
|
||||
or workflow_execution.execution_type != WorkflowExecution.Type.STEP
|
||||
):
|
||||
raise InvalidRequest(WorkflowErrors.INVALID_EXECUTION_ID)
|
||||
|
||||
@@ -30,6 +30,23 @@ VERSION=dev docker compose -f docker-compose.yaml --profile optional up -d
|
||||
|
||||
Now access frontend at http://frontend.unstract.localhost
|
||||
|
||||
## V2 Workers (Optional)
|
||||
|
||||
V2 workers use a unified container architecture and are **disabled by default**.
|
||||
|
||||
```bash
|
||||
# Default: Run with legacy workers only
|
||||
VERSION=dev docker compose -f docker-compose.yaml up -d
|
||||
|
||||
# Enable V2 workers (unified container)
|
||||
VERSION=dev docker compose -f docker-compose.yaml --profile workers-v2 up -d
|
||||
|
||||
# Or use the platform script
|
||||
./run-platform.sh --workers-v2
|
||||
```
|
||||
|
||||
V2 workers available: `api-deployment`, `callback`, `file-processing`, `general`, `notification`, `log-consumer`, `scheduler`
|
||||
|
||||
## Overriding a service's config
|
||||
|
||||
By making use of the [merge compose files](https://docs.docker.com/compose/how-tos/multiple-compose-files/merge/) feature its possible to override some configuration that's used by the services.
|
||||
|
||||
@@ -50,3 +50,11 @@ services:
|
||||
build:
|
||||
dockerfile: docker/dockerfiles/x2text.Dockerfile
|
||||
context: ..
|
||||
# Unified worker image (replaces all individual worker images)
|
||||
worker-unified:
|
||||
image: unstract/worker-unified:${VERSION}
|
||||
build:
|
||||
dockerfile: docker/dockerfiles/worker-unified.Dockerfile
|
||||
context: ..
|
||||
args:
|
||||
MINIMAL_BUILD: ${MINIMAL_BUILD:-0} # Set to 1 for faster dev builds
|
||||
|
||||
@@ -239,6 +239,313 @@ services:
|
||||
labels:
|
||||
- traefik.enable=false
|
||||
|
||||
# ====================================================================
|
||||
# V2 DEDICATED WORKER SERVICES (opt-in with --workers-v2 flag)
|
||||
# ====================================================================
|
||||
|
||||
worker-api-deployment-v2:
|
||||
image: unstract/worker-unified:${VERSION}
|
||||
container_name: unstract-worker-api-deployment-v2
|
||||
restart: unless-stopped
|
||||
command: ["api-deployment"]
|
||||
ports:
|
||||
- "8085:8090"
|
||||
env_file:
|
||||
- ../workers/.env
|
||||
- ./essentials.env
|
||||
depends_on:
|
||||
- db
|
||||
- redis
|
||||
- rabbitmq
|
||||
environment:
|
||||
- ENVIRONMENT=development
|
||||
- APPLICATION_NAME=unstract-worker-api-deployment-v2
|
||||
- WORKER_TYPE=api_deployment
|
||||
- CELERY_QUEUES_API_DEPLOYMENT=${CELERY_QUEUES_API_DEPLOYMENT:-celery_api_deployments}
|
||||
- CELERY_POOL=${WORKER_API_DEPLOYMENT_POOL:-threads}
|
||||
- CELERY_PREFETCH_MULTIPLIER=${WORKER_API_DEPLOYMENT_PREFETCH_MULTIPLIER:-1}
|
||||
- CELERY_CONCURRENCY=${WORKER_API_DEPLOYMENT_CONCURRENCY:-4}
|
||||
- CELERY_EXTRA_ARGS=${WORKER_API_DEPLOYMENT_EXTRA_ARGS:-}
|
||||
- WORKER_NAME=api-deployment-worker-v2
|
||||
- API_DEPLOYMENT_METRICS_PORT=8090
|
||||
- HEALTH_PORT=8090
|
||||
labels:
|
||||
- traefik.enable=false
|
||||
volumes:
|
||||
- ./workflow_data:/data
|
||||
- ${TOOL_REGISTRY_CONFIG_SRC_PATH}:/data/tool_registry_config
|
||||
profiles:
|
||||
- workers-v2
|
||||
|
||||
worker-callback-v2:
|
||||
image: unstract/worker-unified:${VERSION}
|
||||
container_name: unstract-worker-callback-v2
|
||||
restart: unless-stopped
|
||||
command: ["callback"]
|
||||
ports:
|
||||
- "8086:8083"
|
||||
env_file:
|
||||
- ../workers/.env
|
||||
- ./essentials.env
|
||||
depends_on:
|
||||
- db
|
||||
- redis
|
||||
- rabbitmq
|
||||
environment:
|
||||
- ENVIRONMENT=development
|
||||
- APPLICATION_NAME=unstract-worker-callback-v2
|
||||
- WORKER_TYPE=callback
|
||||
- WORKER_NAME=callback-worker-v2
|
||||
- CALLBACK_METRICS_PORT=8083
|
||||
labels:
|
||||
- traefik.enable=false
|
||||
volumes:
|
||||
- ./workflow_data:/data
|
||||
- ${TOOL_REGISTRY_CONFIG_SRC_PATH}:/data/tool_registry_config
|
||||
profiles:
|
||||
- workers-v2
|
||||
|
||||
worker-file-processing-v2:
|
||||
image: unstract/worker-unified:${VERSION}
|
||||
container_name: unstract-worker-file-processing-v2
|
||||
restart: unless-stopped
|
||||
# command: ["file-processing"]
|
||||
command: [".venv/bin/celery", "-A", "worker", "worker", "--queues=file_processing,api_file_processing,file_processing_priority", "--loglevel=INFO", "--pool=prefork", "--concurrency=4", "--prefetch-multiplier=1", "--without-gossip", "--without-mingle", "--without-heartbeat"]
|
||||
ports:
|
||||
- "8087:8082"
|
||||
env_file:
|
||||
- ../workers/.env
|
||||
- ./essentials.env
|
||||
depends_on:
|
||||
- db
|
||||
- redis
|
||||
- rabbitmq
|
||||
environment:
|
||||
- ENVIRONMENT=development
|
||||
- APPLICATION_NAME=unstract-worker-file-processing-v2
|
||||
- WORKER_TYPE=file_processing
|
||||
- WORKER_MODE=oss
|
||||
- WORKER_NAME=file-processing-worker-v2
|
||||
- FILE_PROCESSING_METRICS_PORT=8082
|
||||
# OSS Configuration - Enterprise features disabled
|
||||
- MANUAL_REVIEW_ENABLED=false
|
||||
- ENTERPRISE_FEATURES_ENABLED=false
|
||||
- PLUGIN_REGISTRY_MODE=oss
|
||||
# Configurable Celery options
|
||||
- CELERY_QUEUES_FILE_PROCESSING=${CELERY_QUEUES_FILE_PROCESSING:-file_processing,api_file_processing}
|
||||
- CELERY_POOL=${WORKER_FILE_PROCESSING_POOL:-threads}
|
||||
- CELERY_PREFETCH_MULTIPLIER=${WORKER_FILE_PROCESSING_PREFETCH_MULTIPLIER:-1}
|
||||
- CELERY_CONCURRENCY=${WORKER_FILE_PROCESSING_CONCURRENCY:-4}
|
||||
- CELERY_EXTRA_ARGS=${WORKER_FILE_PROCESSING_EXTRA_ARGS:-}
|
||||
labels:
|
||||
- traefik.enable=false
|
||||
volumes:
|
||||
- ./workflow_data:/data
|
||||
- ${TOOL_REGISTRY_CONFIG_SRC_PATH}:/data/tool_registry_config
|
||||
profiles:
|
||||
- workers-v2
|
||||
|
||||
worker-general-v2:
|
||||
image: unstract/worker-unified:${VERSION}
|
||||
container_name: unstract-worker-general-v2
|
||||
restart: unless-stopped
|
||||
command: ["general"]
|
||||
ports:
|
||||
- "8088:8082"
|
||||
env_file:
|
||||
- ../workers/.env
|
||||
- ./essentials.env
|
||||
depends_on:
|
||||
- db
|
||||
- redis
|
||||
- rabbitmq
|
||||
environment:
|
||||
- ENVIRONMENT=development
|
||||
- APPLICATION_NAME=unstract-worker-general-v2
|
||||
- WORKER_TYPE=general
|
||||
- WORKER_NAME=general-worker-v2
|
||||
- GENERAL_METRICS_PORT=8081
|
||||
- HEALTH_PORT=8082
|
||||
labels:
|
||||
- traefik.enable=false
|
||||
volumes:
|
||||
- ./workflow_data:/data
|
||||
- ${TOOL_REGISTRY_CONFIG_SRC_PATH}:/data/tool_registry_config
|
||||
profiles:
|
||||
- workers-v2
|
||||
|
||||
worker-notification-v2:
|
||||
image: unstract/worker-unified:${VERSION}
|
||||
container_name: unstract-worker-notification-v2
|
||||
restart: unless-stopped
|
||||
command: ["notification"]
|
||||
ports:
|
||||
- "8089:8085"
|
||||
env_file:
|
||||
- ../workers/.env
|
||||
- ./essentials.env
|
||||
depends_on:
|
||||
- db
|
||||
- redis
|
||||
- rabbitmq
|
||||
environment:
|
||||
- ENVIRONMENT=development
|
||||
- APPLICATION_NAME=unstract-worker-notification-v2
|
||||
- WORKER_TYPE=notification
|
||||
- WORKER_NAME=notification-worker-v2
|
||||
- NOTIFICATION_METRICS_PORT=8085
|
||||
- HEALTH_PORT=8085
|
||||
# Notification specific configs
|
||||
- NOTIFICATION_QUEUE_NAME=notifications
|
||||
- WEBHOOK_QUEUE_NAME=notifications_webhook
|
||||
- EMAIL_QUEUE_NAME=notifications_email
|
||||
- SMS_QUEUE_NAME=notifications_sms
|
||||
- PRIORITY_QUEUE_NAME=notifications_priority
|
||||
# Configurable Celery options
|
||||
- CELERY_QUEUES_NOTIFICATION=${CELERY_QUEUES_NOTIFICATION:-notifications,notifications_webhook,notifications_email,notifications_sms,notifications_priority}
|
||||
- CELERY_POOL=${WORKER_NOTIFICATION_POOL:-prefork}
|
||||
- CELERY_PREFETCH_MULTIPLIER=${WORKER_NOTIFICATION_PREFETCH_MULTIPLIER:-1}
|
||||
- CELERY_CONCURRENCY=${WORKER_NOTIFICATION_CONCURRENCY:-4}
|
||||
- CELERY_EXTRA_ARGS=${WORKER_NOTIFICATION_EXTRA_ARGS:-}
|
||||
# Complete command override (if set, ignores all other options)
|
||||
- CELERY_COMMAND_OVERRIDE=${WORKER_NOTIFICATION_COMMAND_OVERRIDE:-}
|
||||
# Individual argument overrides
|
||||
- CELERY_APP_MODULE=${WORKER_NOTIFICATION_APP_MODULE:-worker}
|
||||
- CELERY_LOG_LEVEL=${WORKER_NOTIFICATION_LOG_LEVEL:-INFO}
|
||||
- CELERY_HOSTNAME=${WORKER_NOTIFICATION_HOSTNAME:-}
|
||||
- CELERY_MAX_TASKS_PER_CHILD=${WORKER_NOTIFICATION_MAX_TASKS_PER_CHILD:-}
|
||||
- CELERY_TIME_LIMIT=${WORKER_NOTIFICATION_TIME_LIMIT:-}
|
||||
- CELERY_SOFT_TIME_LIMIT=${WORKER_NOTIFICATION_SOFT_TIME_LIMIT:-}
|
||||
labels:
|
||||
- traefik.enable=false
|
||||
volumes:
|
||||
- ./workflow_data:/data
|
||||
- ${TOOL_REGISTRY_CONFIG_SRC_PATH}:/data/tool_registry_config
|
||||
profiles:
|
||||
- workers-v2
|
||||
|
||||
worker-log-consumer-v2:
|
||||
image: unstract/worker-unified:${VERSION}
|
||||
container_name: unstract-worker-log-consumer-v2
|
||||
restart: unless-stopped
|
||||
command: ["log-consumer"]
|
||||
ports:
|
||||
- "8090:8084"
|
||||
env_file:
|
||||
- ../workers/.env
|
||||
- ./essentials.env
|
||||
depends_on:
|
||||
- db
|
||||
- redis
|
||||
- rabbitmq
|
||||
environment:
|
||||
- ENVIRONMENT=development
|
||||
- APPLICATION_NAME=unstract-worker-log-consumer-v2
|
||||
- WORKER_TYPE=log_consumer
|
||||
- WORKER_NAME=log-consumer-worker-v2
|
||||
- LOG_CONSUMER_METRICS_PORT=8084
|
||||
- HEALTH_PORT=8084
|
||||
# Log consumer specific configs
|
||||
- LOG_CONSUMER_QUEUE_NAME=celery_log_task_queue
|
||||
# Multiple queue support - supports comma-separated queue names
|
||||
- CELERY_QUEUES_LOG_CONSUMER=${CELERY_QUEUES_LOG_CONSUMER:-celery_log_task_queue,celery_periodic_logs}
|
||||
- PERIODIC_LOGS_QUEUE_NAME=${PERIODIC_LOGS_QUEUE_NAME:-celery_periodic_logs}
|
||||
# Log history configuration
|
||||
- LOG_HISTORY_QUEUE_NAME=${LOG_HISTORY_QUEUE_NAME:-log_history_queue}
|
||||
- LOGS_BATCH_LIMIT=${LOGS_BATCH_LIMIT:-100}
|
||||
- ENABLE_LOG_HISTORY=${ENABLE_LOG_HISTORY:-true}
|
||||
- CELERY_POOL=${WORKER_LOG_CONSUMER_POOL:-prefork}
|
||||
- CELERY_PREFETCH_MULTIPLIER=${WORKER_LOG_CONSUMER_PREFETCH_MULTIPLIER:-1}
|
||||
- CELERY_CONCURRENCY=${WORKER_LOG_CONSUMER_CONCURRENCY:-2}
|
||||
- CELERY_EXTRA_ARGS=${WORKER_LOG_CONSUMER_EXTRA_ARGS:-}
|
||||
# Complete command override (if set, ignores all other options)
|
||||
- CELERY_COMMAND_OVERRIDE=${WORKER_LOG_CONSUMER_COMMAND_OVERRIDE:-}
|
||||
# Individual argument overrides
|
||||
- CELERY_APP_MODULE=${WORKER_LOG_CONSUMER_APP_MODULE:-worker}
|
||||
- CELERY_LOG_LEVEL=${WORKER_LOG_CONSUMER_LOG_LEVEL:-INFO}
|
||||
- CELERY_HOSTNAME=${WORKER_LOG_CONSUMER_HOSTNAME:-}
|
||||
- CELERY_MAX_TASKS_PER_CHILD=${WORKER_LOG_CONSUMER_MAX_TASKS_PER_CHILD:-}
|
||||
- CELERY_TIME_LIMIT=${WORKER_LOG_CONSUMER_TIME_LIMIT:-}
|
||||
- CELERY_SOFT_TIME_LIMIT=${WORKER_LOG_CONSUMER_SOFT_TIME_LIMIT:-}
|
||||
labels:
|
||||
- traefik.enable=false
|
||||
volumes:
|
||||
- ./workflow_data:/data
|
||||
- ${TOOL_REGISTRY_CONFIG_SRC_PATH}:/data/tool_registry_config
|
||||
profiles:
|
||||
- workers-v2
|
||||
|
||||
worker-log-history-scheduler-v2:
|
||||
image: unstract/worker-unified:${VERSION}
|
||||
container_name: unstract-worker-log-history-scheduler-v2
|
||||
restart: unless-stopped
|
||||
entrypoint: ["/bin/bash"]
|
||||
command: ["/app/log_consumer/scheduler.sh"]
|
||||
env_file:
|
||||
- ../workers/.env
|
||||
- ./essentials.env
|
||||
depends_on:
|
||||
- db
|
||||
- redis
|
||||
- rabbitmq
|
||||
environment:
|
||||
- ENVIRONMENT=development
|
||||
- APPLICATION_NAME=unstract-worker-log-history-scheduler-v2
|
||||
# Scheduler interval in seconds
|
||||
- LOG_HISTORY_CONSUMER_INTERVAL=${LOG_HISTORY_CONSUMER_INTERVAL:-5}
|
||||
# Override example: TASK_TRIGGER_COMMAND=/custom/trigger/script.sh
|
||||
- TASK_TRIGGER_COMMAND=${TASK_TRIGGER_COMMAND:-}
|
||||
labels:
|
||||
- traefik.enable=false
|
||||
profiles:
|
||||
- workers-v2
|
||||
|
||||
worker-scheduler-v2:
|
||||
image: unstract/worker-unified:${VERSION}
|
||||
container_name: unstract-worker-scheduler-v2
|
||||
restart: unless-stopped
|
||||
command: ["scheduler"]
|
||||
ports:
|
||||
- "8091:8087"
|
||||
env_file:
|
||||
- ../workers/.env
|
||||
- ./essentials.env
|
||||
depends_on:
|
||||
- db
|
||||
- redis
|
||||
- rabbitmq
|
||||
environment:
|
||||
- ENVIRONMENT=development
|
||||
- APPLICATION_NAME=unstract-worker-scheduler-v2
|
||||
- WORKER_TYPE=scheduler
|
||||
- WORKER_NAME=scheduler-worker-v2
|
||||
- SCHEDULER_METRICS_PORT=8087
|
||||
- HEALTH_PORT=8087
|
||||
# Scheduler specific configs
|
||||
- SCHEDULER_QUEUE_NAME=scheduler
|
||||
# Configurable Celery options
|
||||
- CELERY_QUEUES_SCHEDULER=${CELERY_QUEUES_SCHEDULER:-scheduler}
|
||||
- CELERY_POOL=${WORKER_SCHEDULER_POOL:-prefork}
|
||||
- CELERY_PREFETCH_MULTIPLIER=${WORKER_SCHEDULER_PREFETCH_MULTIPLIER:-1}
|
||||
- CELERY_CONCURRENCY=${WORKER_SCHEDULER_CONCURRENCY:-2}
|
||||
- CELERY_EXTRA_ARGS=${WORKER_SCHEDULER_EXTRA_ARGS:-}
|
||||
# Complete command override (if set, ignores all other options)
|
||||
- CELERY_COMMAND_OVERRIDE=${WORKER_SCHEDULER_COMMAND_OVERRIDE:-}
|
||||
# Individual argument overrides
|
||||
- CELERY_APP_MODULE=${WORKER_SCHEDULER_APP_MODULE:-worker}
|
||||
- CELERY_LOG_LEVEL=${WORKER_SCHEDULER_LOG_LEVEL:-INFO}
|
||||
- CELERY_HOSTNAME=${WORKER_SCHEDULER_HOSTNAME:-}
|
||||
- CELERY_MAX_TASKS_PER_CHILD=${WORKER_SCHEDULER_MAX_TASKS_PER_CHILD:-}
|
||||
- CELERY_TIME_LIMIT=${WORKER_SCHEDULER_TIME_LIMIT:-}
|
||||
- CELERY_SOFT_TIME_LIMIT=${WORKER_SCHEDULER_SOFT_TIME_LIMIT:-}
|
||||
labels:
|
||||
- traefik.enable=false
|
||||
volumes:
|
||||
- ./workflow_data:/data
|
||||
- ${TOOL_REGISTRY_CONFIG_SRC_PATH}:/data/tool_registry_config
|
||||
profiles:
|
||||
- workers-v2
|
||||
|
||||
volumes:
|
||||
prompt_studio_data:
|
||||
unstract_data:
|
||||
|
||||
85
docker/dockerfiles/worker-unified.Dockerfile
Normal file
85
docker/dockerfiles/worker-unified.Dockerfile
Normal file
@@ -0,0 +1,85 @@
|
||||
# Unified Worker Dockerfile - Optimized for fast builds
|
||||
FROM python:3.12.9-slim AS base
|
||||
|
||||
ARG VERSION=dev
|
||||
LABEL maintainer="Zipstack Inc." \
|
||||
description="Unified Worker Container for All Worker Types" \
|
||||
version="${VERSION}"
|
||||
|
||||
# Set environment variables (CRITICAL: PYTHONPATH makes paths work!)
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONPATH=/app:/unstract \
|
||||
BUILD_CONTEXT_PATH=workers \
|
||||
BUILD_PACKAGES_PATH=unstract \
|
||||
APP_HOME=/app \
|
||||
# OpenTelemetry configuration (disabled by default, enable in docker-compose)
|
||||
OTEL_TRACES_EXPORTER=none \
|
||||
OTEL_LOGS_EXPORTER=none \
|
||||
OTEL_SERVICE_NAME=unstract_workers
|
||||
|
||||
# Install system dependencies (minimal for workers)
|
||||
RUN apt-get update \
|
||||
&& apt-get --no-install-recommends install -y \
|
||||
curl \
|
||||
gcc \
|
||||
libmagic-dev \
|
||||
libssl-dev \
|
||||
pkg-config \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||
|
||||
# Install uv package manager
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.6.14 /uv /uvx /bin/
|
||||
|
||||
# Create non-root user early to avoid ownership issues
|
||||
RUN groupadd -r worker && useradd -r -g worker worker && \
|
||||
mkdir -p /home/worker && chown -R worker:worker /home/worker
|
||||
|
||||
# Create working directory
|
||||
WORKDIR ${APP_HOME}
|
||||
|
||||
# -----------------------------------------------
|
||||
# EXTERNAL DEPENDENCIES STAGE - This layer gets cached
|
||||
# -----------------------------------------------
|
||||
FROM base AS ext-dependencies
|
||||
|
||||
# Copy dependency files (including README.md like backend)
|
||||
COPY ${BUILD_CONTEXT_PATH}/pyproject.toml ${BUILD_CONTEXT_PATH}/uv.lock ./
|
||||
# Create empty README.md if it doesn't exist in the copy
|
||||
RUN touch README.md
|
||||
|
||||
# Copy local package dependencies to /unstract directory
|
||||
# This provides the unstract packages for imports
|
||||
COPY ${BUILD_PACKAGES_PATH}/ /unstract/
|
||||
|
||||
# Install external dependencies with --locked for FAST builds
|
||||
# No symlinks needed - PYTHONPATH handles the paths
|
||||
RUN uv sync --group deploy --locked --no-install-project --no-dev
|
||||
|
||||
# -----------------------------------------------
|
||||
# FINAL STAGE - Minimal image for production
|
||||
# -----------------------------------------------
|
||||
FROM ext-dependencies AS production
|
||||
|
||||
# Copy application code (this layer changes most frequently)
|
||||
COPY ${BUILD_CONTEXT_PATH}/ ./
|
||||
|
||||
# Set shell with pipefail for proper error handling in pipes
|
||||
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
|
||||
|
||||
# Install project and OpenTelemetry instrumentation (as root to avoid permission issues)
|
||||
# No symlinks needed - PYTHONPATH handles the paths correctly
|
||||
RUN uv sync --group deploy --locked && \
|
||||
uv run opentelemetry-bootstrap -a requirements | uv pip install --requirement - && \
|
||||
{ chmod +x ./run-worker.sh ./run-worker-docker.sh 2>/dev/null || true; } && \
|
||||
touch requirements.txt && \
|
||||
{ chown -R worker:worker ./run-worker.sh ./run-worker-docker.sh 2>/dev/null || true; }
|
||||
|
||||
# Switch to worker user
|
||||
USER worker
|
||||
|
||||
|
||||
# Default command - runs the Docker-optimized worker script
|
||||
ENTRYPOINT ["/app/run-worker-docker.sh"]
|
||||
CMD ["general"]
|
||||
74
docker/dockerfiles/worker-unified.Dockerfile.dockerignore
Normal file
74
docker/dockerfiles/worker-unified.Dockerfile.dockerignore
Normal file
@@ -0,0 +1,74 @@
|
||||
# Unified Worker Docker ignore file
|
||||
# Based on worker-base.dockerignore but unified for all worker types
|
||||
|
||||
# Virtual environments
|
||||
**/.venv/
|
||||
**/venv/
|
||||
**/__pycache__/
|
||||
**/.pytest_cache/
|
||||
**/.mypy_cache/
|
||||
|
||||
# IDE and editor files
|
||||
**/.vscode/
|
||||
**/.idea/
|
||||
**/*.swp
|
||||
**/*.swo
|
||||
**/*~
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Git
|
||||
.git/
|
||||
.gitignore
|
||||
|
||||
# Docker files (avoid recursion)
|
||||
**/Dockerfile*
|
||||
**/*.dockerignore
|
||||
|
||||
# Build artifacts
|
||||
**/dist/
|
||||
**/build/
|
||||
**/*.egg-info/
|
||||
|
||||
# Logs
|
||||
**/*.log
|
||||
**/logs/
|
||||
|
||||
# Test files
|
||||
**/tests/
|
||||
**/test_*.py
|
||||
**/*_test.py
|
||||
|
||||
# Development files
|
||||
**/dev-*
|
||||
**/sample.*
|
||||
**/example.*
|
||||
|
||||
# Node modules (if any)
|
||||
**/node_modules/
|
||||
|
||||
# Documentation
|
||||
**/docs/
|
||||
**/*.md
|
||||
!README.md
|
||||
|
||||
# Configuration that shouldn't be in containers
|
||||
**/.env*
|
||||
**/local_settings.py
|
||||
|
||||
# Coverage reports
|
||||
**/htmlcov/
|
||||
**/.coverage
|
||||
**/coverage.xml
|
||||
|
||||
# Backup files
|
||||
**/*.bak
|
||||
**/*.backup
|
||||
**/*.orig
|
||||
|
||||
# Temporary files
|
||||
**/tmp/
|
||||
**/temp/
|
||||
**/.tmp/
|
||||
@@ -11,3 +11,103 @@ WORKER_LOGGING_AUTOSCALE=4,1
|
||||
WORKER_AUTOSCALE=4,1
|
||||
WORKER_FILE_PROCESSING_AUTOSCALE=4,1
|
||||
WORKER_FILE_PROCESSING_CALLBACK_AUTOSCALE=4,1
|
||||
|
||||
# New unified worker autoscaling (matches hierarchical configuration below)
|
||||
WORKER_API_DEPLOYMENT_AUTOSCALE=4,1 # API deployment worker autoscale
|
||||
WORKER_CALLBACK_AUTOSCALE=4,1 # Callback worker autoscale
|
||||
WORKER_GENERAL_AUTOSCALE=6,2 # General worker autoscale (enhanced)
|
||||
WORKER_FILE_PROCESSING_NEW_AUTOSCALE=8,2 # File processing unified worker autoscale
|
||||
WORKER_NOTIFICATION_AUTOSCALE=4,1 # Notification worker autoscale
|
||||
WORKER_LOG_CONSUMER_AUTOSCALE=2,1 # Log consumer worker autoscale
|
||||
WORKER_SCHEDULER_AUTOSCALE=2,1 # Scheduler worker autoscale
|
||||
|
||||
# Worker-specific configurations
|
||||
API_DEPLOYMENT_WORKER_NAME=api-deployment-worker
|
||||
API_DEPLOYMENT_HEALTH_PORT=8080
|
||||
API_DEPLOYMENT_MAX_CONCURRENT_TASKS=5
|
||||
|
||||
CALLBACK_WORKER_NAME=callback-worker
|
||||
CALLBACK_HEALTH_PORT=8083
|
||||
CALLBACK_MAX_CONCURRENT_TASKS=3
|
||||
|
||||
FILE_PROCESSING_WORKER_NAME=file-processing-worker
|
||||
FILE_PROCESSING_HEALTH_PORT=8082
|
||||
FILE_PROCESSING_MAX_CONCURRENT_TASKS=4
|
||||
|
||||
GENERAL_WORKER_NAME=general-worker
|
||||
GENERAL_HEALTH_PORT=8081
|
||||
GENERAL_MAX_CONCURRENT_TASKS=10
|
||||
|
||||
# =============================================================================
|
||||
# HIERARCHICAL CELERY CONFIGURATION SYSTEM
|
||||
# =============================================================================
|
||||
#
|
||||
# This system uses a 3-tier hierarchy for all Celery settings (most specific wins):
|
||||
# 1. {WORKER_TYPE}_{SETTING_NAME} - Worker-specific override (highest priority)
|
||||
# 2. CELERY_{SETTING_NAME} - Global override (medium priority)
|
||||
# 3. Code default - Celery standard default (lowest priority)
|
||||
#
|
||||
# Examples:
|
||||
# - CALLBACK_TASK_TIME_LIMIT=3600 (callback worker only)
|
||||
# - CELERY_TASK_TIME_LIMIT=300 (all workers)
|
||||
# - Code provides default if neither is set
|
||||
#
|
||||
# Worker types: API_DEPLOYMENT, GENERAL, FILE_PROCESSING, CALLBACK,
|
||||
# NOTIFICATION, LOG_CONSUMER, SCHEDULER
|
||||
# =============================================================================
|
||||
|
||||
# Global Celery Configuration (applies to all workers unless overridden)
|
||||
CELERY_RESULT_CHORD_RETRY_INTERVAL=3 # Global chord unlock retry interval
|
||||
CELERY_TASK_TIME_LIMIT=7200 # Global task timeout (2 hours)
|
||||
CELERY_TASK_SOFT_TIME_LIMIT=6300 # Global soft timeout (1h 45m)
|
||||
CELERY_PREFETCH_MULTIPLIER=1 # Global prefetch multiplier
|
||||
CELERY_MAX_TASKS_PER_CHILD=1000 # Global max tasks per child process
|
||||
CELERY_TASK_ACKS_LATE=true # Global acks late setting
|
||||
CELERY_TASK_DEFAULT_RETRY_DELAY=60 # Global retry delay (1 minute)
|
||||
CELERY_TASK_MAX_RETRIES=3 # Global max retries
|
||||
|
||||
# Worker-Specific Configuration Examples
|
||||
# Callback Worker - Chord settings and extended timeouts
|
||||
CALLBACK_RESULT_CHORD_RETRY_INTERVAL=3 # Callback-specific chord retry interval
|
||||
CALLBACK_TASK_TIME_LIMIT=7200 # Callback tasks need more time (2 hours)
|
||||
CALLBACK_TASK_SOFT_TIME_LIMIT=6300 # Callback soft timeout (1h 45m)
|
||||
|
||||
# File Processing Worker - Thread pool and optimized settings
|
||||
FILE_PROCESSING_POOL_TYPE=threads # Use threads instead of prefork
|
||||
FILE_PROCESSING_CONCURRENCY=4 # Fixed concurrency for file processing
|
||||
FILE_PROCESSING_TASK_TIME_LIMIT=10800 # File processing timeout (3 hours)
|
||||
|
||||
# API Deployment Worker - Autoscaling and timeout configuration
|
||||
API_DEPLOYMENT_AUTOSCALE=4,1 # Max 4, min 1 workers
|
||||
API_DEPLOYMENT_TASK_TIME_LIMIT=3600 # API deployment timeout (1 hour)
|
||||
|
||||
# General Worker - Enhanced scaling for high-throughput tasks
|
||||
GENERAL_AUTOSCALE=6,2 # Max 6, min 2 workers
|
||||
|
||||
# Docker Worker-Specific Concurrency Settings (for docker-compose.yaml)
|
||||
WORKER_API_DEPLOYMENT_CONCURRENCY=4 # API deployment fixed concurrency
|
||||
WORKER_FILE_PROCESSING_CONCURRENCY=8 # File processing fixed concurrency
|
||||
WORKER_NOTIFICATION_CONCURRENCY=4 # Notification worker concurrency
|
||||
WORKER_LOG_CONSUMER_CONCURRENCY=2 # Log consumer worker concurrency
|
||||
WORKER_SCHEDULER_CONCURRENCY=2 # Scheduler worker concurrency
|
||||
|
||||
# Notification Worker - Optimized for quick message processing
|
||||
NOTIFICATION_AUTOSCALE=4,1 # Max 4, min 1 workers
|
||||
NOTIFICATION_TASK_TIME_LIMIT=120 # Quick timeout for notifications
|
||||
|
||||
# Scheduler Worker - Conservative settings for scheduled tasks
|
||||
SCHEDULER_AUTOSCALE=2,1 # Max 2, min 1 workers
|
||||
SCHEDULER_TASK_TIME_LIMIT=1800 # Scheduler timeout (30 minutes)
|
||||
|
||||
# Log Consumer Worker - Optimized for log processing
|
||||
LOG_CONSUMER_AUTOSCALE=2,1 # Max 2, min 1 workers
|
||||
LOG_CONSUMER_TASK_TIME_LIMIT=600 # Log processing timeout (10 minutes)
|
||||
|
||||
# Worker Circuit Breaker Settings
|
||||
CIRCUIT_BREAKER_FAILURE_THRESHOLD=5
|
||||
CIRCUIT_BREAKER_RECOVERY_TIMEOUT=60
|
||||
|
||||
# Worker Health Check Settings
|
||||
HEALTH_CHECK_INTERVAL=30
|
||||
HEALTH_CHECK_TIMEOUT=10
|
||||
ENABLE_METRICS=true
|
||||
|
||||
@@ -22,6 +22,10 @@ dev = [
|
||||
"types-tzlocal~=5.1.0.1",
|
||||
]
|
||||
|
||||
workers = [
|
||||
"unstract-workers",
|
||||
]
|
||||
|
||||
hook-check-django-migrations = [
|
||||
"celery>=5.3.4",
|
||||
"cron-descriptor==1.4.0",
|
||||
@@ -55,6 +59,8 @@ unstract-tool-registry = { path = "./unstract/tool-registry", editable = true }
|
||||
unstract-flags = { path = "./unstract/flags", editable = true }
|
||||
unstract-core = { path = "./unstract/core", editable = true }
|
||||
unstract-connectors = { path = "./unstract/connectors", editable = true }
|
||||
# Workers
|
||||
unstract-workers = { path = "./workers", editable = true }
|
||||
# === Development tool configurations ===
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
@@ -71,6 +71,7 @@ display_help() {
|
||||
echo -e " -p, --only-pull Only do docker images pull"
|
||||
echo -e " -b, --build-local Build docker images locally"
|
||||
echo -e " -u, --update Update services version"
|
||||
echo -e " -w, --workers-v2 Use v2 dedicated worker containers"
|
||||
echo -e " -x, --trace Enables trace mode"
|
||||
echo -e " -V, --verbose Print verbose logs"
|
||||
echo -e " -v, --version Docker images version tag (default \"latest\")"
|
||||
@@ -97,6 +98,9 @@ parse_args() {
|
||||
-u | --update)
|
||||
opt_update=true
|
||||
;;
|
||||
-w | --workers-v2)
|
||||
opt_workers_v2=true
|
||||
;;
|
||||
-x | --trace)
|
||||
set -o xtrace # display every line before execution; enables PS4
|
||||
;;
|
||||
@@ -128,6 +132,7 @@ parse_args() {
|
||||
debug "OPTION only_pull: $opt_only_pull"
|
||||
debug "OPTION build_local: $opt_build_local"
|
||||
debug "OPTION upgrade: $opt_update"
|
||||
debug "OPTION workers_v2: $opt_workers_v2"
|
||||
debug "OPTION verbose: $opt_verbose"
|
||||
debug "OPTION version: $opt_version"
|
||||
}
|
||||
@@ -280,8 +285,13 @@ build_services() {
|
||||
run_services() {
|
||||
pushd "$script_dir/docker" 1>/dev/null
|
||||
|
||||
echo -e "$blue_text""Starting docker containers in detached mode""$default_text"
|
||||
VERSION=$opt_version $docker_compose_cmd up -d
|
||||
if [ "$opt_workers_v2" = true ]; then
|
||||
echo -e "$blue_text""Starting docker containers with V2 dedicated workers in detached mode""$default_text"
|
||||
VERSION=$opt_version $docker_compose_cmd --profile workers-v2 up -d
|
||||
else
|
||||
echo -e "$blue_text""Starting docker containers with existing backend-based workers in detached mode""$default_text"
|
||||
VERSION=$opt_version $docker_compose_cmd up -d
|
||||
fi
|
||||
|
||||
if [ "$opt_update" = true ]; then
|
||||
echo ""
|
||||
@@ -324,6 +334,7 @@ opt_only_env=false
|
||||
opt_only_pull=false
|
||||
opt_build_local=false
|
||||
opt_update=false
|
||||
opt_workers_v2=false
|
||||
opt_verbose=false
|
||||
opt_version="latest"
|
||||
|
||||
@@ -331,6 +342,8 @@ script_dir=$(dirname "$(readlink -f "$BASH_SOURCE")")
|
||||
first_setup=false
|
||||
# Extract service names from docker compose file
|
||||
services=($(VERSION=$opt_version $docker_compose_cmd -f "$script_dir/docker/docker-compose.build.yaml" config --services))
|
||||
# Add workers manually for env setup
|
||||
services+=("workers")
|
||||
spawned_services=("tool-structure" "tool-sidecar")
|
||||
current_version=""
|
||||
target_branch=""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from flask import Blueprint, Response, jsonify
|
||||
from flask import Blueprint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -9,7 +9,6 @@ health_bp = Blueprint("health", __name__)
|
||||
|
||||
|
||||
# Define a route to ping test
|
||||
@health_bp.route("/ping", methods=["GET"])
|
||||
def ping() -> Response:
|
||||
logger.info("Ping request received")
|
||||
return jsonify({"message": "pong!!!"})
|
||||
@health_bp.route("/health", methods=["GET"])
|
||||
def health_check() -> str:
|
||||
return "OK"
|
||||
|
||||
@@ -296,7 +296,7 @@ class UnstractRunner:
|
||||
settings_json = json.dumps(settings).replace("'", "\\'")
|
||||
# Prepare the tool execution command
|
||||
tool_cmd = (
|
||||
f"opentelemetry-instrument python main.py --command RUN "
|
||||
f"python main.py --command RUN "
|
||||
f"--settings '{settings_json}' --log-level DEBUG"
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ ENV \
|
||||
OTEL_METRICS_EXPORTER=none \
|
||||
OTEL_LOGS_EXPORTER=none \
|
||||
# Enable context propagation
|
||||
OTEL_PROPAGATORS="tracecontext"
|
||||
OTEL_PROPAGATORS="tracecontext" \
|
||||
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
# Install system dependencies in one layer
|
||||
RUN apt-get update && \
|
||||
|
||||
2
unstract/connectors/src/unstract/__init__.py
Normal file
2
unstract/connectors/src/unstract/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
|
||||
# Unstract namespace package
|
||||
@@ -1,7 +1,10 @@
|
||||
import logging
|
||||
from logging import NullHandler
|
||||
from typing import Any
|
||||
|
||||
from unstract.connectors.connection_types import ConnectionType
|
||||
|
||||
logging.getLogger(__name__).addHandler(NullHandler())
|
||||
|
||||
ConnectorDict = dict[str, dict[str, Any]]
|
||||
__all__ = [
|
||||
"ConnectionType",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Unified Connection Types for Unstract Platform
|
||||
|
||||
This module provides a centralized definition of connection types used across
|
||||
the entire Unstract platform to ensure consistency and prevent duplication.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ConnectionType(str, Enum):
|
||||
"""Core connection types for workflow endpoints and connectors.
|
||||
|
||||
This enum provides the fundamental connection types used across:
|
||||
- workers/shared/enums.py
|
||||
- workers/shared/workflow/source_connector.py
|
||||
- workers/shared/workflow/destination_connector.py
|
||||
- unstract/core/src/unstract/core/data_models.py
|
||||
- unstract/core/src/unstract/core/workflow_utils.py
|
||||
"""
|
||||
|
||||
FILESYSTEM = "FILESYSTEM"
|
||||
DATABASE = "DATABASE"
|
||||
API = "API"
|
||||
MANUALREVIEW = "MANUALREVIEW"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
@property
|
||||
def is_filesystem(self) -> bool:
|
||||
"""Check if this is a filesystem connection type."""
|
||||
return self == ConnectionType.FILESYSTEM
|
||||
|
||||
@property
|
||||
def is_database(self) -> bool:
|
||||
"""Check if this is a database connection type."""
|
||||
return self == ConnectionType.DATABASE
|
||||
|
||||
@property
|
||||
def is_api(self) -> bool:
|
||||
"""Check if this is an API connection type."""
|
||||
return self == ConnectionType.API
|
||||
|
||||
@property
|
||||
def is_manual_review(self) -> bool:
|
||||
"""Check if this is a manual review connection type."""
|
||||
return self == ConnectionType.MANUALREVIEW
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, connection_type: str) -> "ConnectionType":
|
||||
"""Create ConnectionType from string, with validation.
|
||||
|
||||
Args:
|
||||
connection_type: Connection type string
|
||||
|
||||
Returns:
|
||||
ConnectionType enum value
|
||||
|
||||
Raises:
|
||||
ValueError: If connection type is not recognized or is empty
|
||||
"""
|
||||
if not connection_type:
|
||||
raise ValueError("Connection type cannot be empty")
|
||||
|
||||
connection_type_upper = connection_type.upper()
|
||||
|
||||
try:
|
||||
return cls(connection_type_upper)
|
||||
except ValueError:
|
||||
raise ValueError(f"Unknown connection type: {connection_type}")
|
||||
@@ -3,9 +3,8 @@ from typing import Any
|
||||
|
||||
from singleton_decorator import singleton
|
||||
|
||||
from unstract.connectors import ConnectorDict # type: ignore
|
||||
from unstract.connectors.base import UnstractConnector
|
||||
from unstract.connectors.constants import Common
|
||||
from unstract.connectors.constants import Common, ConnectorDict
|
||||
from unstract.connectors.databases import connectors as db_connectors
|
||||
from unstract.connectors.enums import ConnectorMode
|
||||
from unstract.connectors.filesystems import connectors as fs_connectors
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Common:
|
||||
METADATA = "metadata"
|
||||
MODULE = "module"
|
||||
CONNECTOR = "connector"
|
||||
|
||||
|
||||
# Type definitions
|
||||
ConnectorDict = dict[str, dict[str, Any]]
|
||||
|
||||
|
||||
class DatabaseTypeConstants:
|
||||
"""Central location for all database-specific type constants."""
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from unstract.connectors import ConnectorDict # type: ignore
|
||||
from unstract.connectors.constants import ConnectorDict
|
||||
from unstract.connectors.databases.register import register_connectors
|
||||
|
||||
connectors: ConnectorDict = {}
|
||||
|
||||
@@ -7,8 +7,6 @@ from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import google.api_core.exceptions
|
||||
from google.cloud import bigquery
|
||||
from google.cloud.bigquery import Client
|
||||
|
||||
from unstract.connectors.constants import DatabaseTypeConstants
|
||||
from unstract.connectors.databases.exceptions import (
|
||||
@@ -28,6 +26,9 @@ BIG_QUERY_TABLE_SIZE = 3
|
||||
class BigQuery(UnstractDB):
|
||||
def __init__(self, settings: dict[str, Any]):
|
||||
super().__init__("BigQuery")
|
||||
from google.cloud import bigquery
|
||||
|
||||
self.bigquery = bigquery
|
||||
self.json_credentials = json.loads(settings.get("json_credentials", "{}"))
|
||||
self.big_query_table_size = BIG_QUERY_TABLE_SIZE
|
||||
|
||||
@@ -62,8 +63,8 @@ class BigQuery(UnstractDB):
|
||||
def can_read() -> bool:
|
||||
return True
|
||||
|
||||
def get_engine(self) -> Client:
|
||||
return bigquery.Client.from_service_account_info( # type: ignore
|
||||
def get_engine(self) -> Any:
|
||||
return self.bigquery.Client.from_service_account_info( # type: ignore
|
||||
info=self.json_credentials
|
||||
)
|
||||
|
||||
@@ -208,21 +209,23 @@ class BigQuery(UnstractDB):
|
||||
f"@`{key}`", f"PARSE_JSON(@`{key}`)"
|
||||
)
|
||||
query_parameters.append(
|
||||
bigquery.ScalarQueryParameter(key, "STRING", json_str)
|
||||
self.bigquery.ScalarQueryParameter(key, "STRING", json_str)
|
||||
)
|
||||
elif isinstance(value, (dict, list)):
|
||||
# For dict/list values in STRING columns, serialize to JSON string
|
||||
json_str = json.dumps(value) if value else None
|
||||
query_parameters.append(
|
||||
bigquery.ScalarQueryParameter(key, "STRING", json_str)
|
||||
self.bigquery.ScalarQueryParameter(key, "STRING", json_str)
|
||||
)
|
||||
else:
|
||||
# For other values, use STRING as before
|
||||
query_parameters.append(
|
||||
bigquery.ScalarQueryParameter(key, "STRING", value)
|
||||
self.bigquery.ScalarQueryParameter(key, "STRING", value)
|
||||
)
|
||||
|
||||
query_params = bigquery.QueryJobConfig(query_parameters=query_parameters)
|
||||
query_params = self.bigquery.QueryJobConfig(
|
||||
query_parameters=query_parameters
|
||||
)
|
||||
query_job = engine.query(modified_sql, job_config=query_params)
|
||||
else:
|
||||
query_job = engine.query(sql_query)
|
||||
|
||||
@@ -144,8 +144,9 @@ class PostgreSQL(UnstractDB, PsycoPgHandler):
|
||||
Returns:
|
||||
str: generates a create sql base query with the constant columns
|
||||
"""
|
||||
quoted_table = self._quote_identifier(table)
|
||||
sql_query = (
|
||||
f"CREATE TABLE IF NOT EXISTS {table} "
|
||||
f"CREATE TABLE IF NOT EXISTS {quoted_table} "
|
||||
f"(id TEXT, "
|
||||
f"created_by TEXT, created_at TIMESTAMP, "
|
||||
f"metadata JSONB, "
|
||||
@@ -158,8 +159,9 @@ class PostgreSQL(UnstractDB, PsycoPgHandler):
|
||||
return sql_query
|
||||
|
||||
def prepare_multi_column_migration(self, table_name: str, column_name: str) -> str:
|
||||
quoted_table = self._quote_identifier(table_name)
|
||||
sql_query = (
|
||||
f"ALTER TABLE {table_name} "
|
||||
f"ALTER TABLE {quoted_table} "
|
||||
f"ADD COLUMN {column_name}_v2 JSONB, "
|
||||
f"ADD COLUMN metadata JSONB, "
|
||||
f"ADD COLUMN user_field_1 BOOLEAN DEFAULT FALSE, "
|
||||
@@ -182,3 +184,41 @@ class PostgreSQL(UnstractDB, PsycoPgHandler):
|
||||
schema=self.schema,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _quote_identifier(identifier: str) -> str:
|
||||
"""Quote PostgreSQL identifier to handle special characters like hyphens.
|
||||
|
||||
PostgreSQL identifiers with special characters must be enclosed in double quotes.
|
||||
This method adds proper quoting for table names containing hyphens, spaces,
|
||||
or other special characters.
|
||||
|
||||
Args:
|
||||
identifier (str): Table name or column name to quote
|
||||
|
||||
Returns:
|
||||
str: Properly quoted identifier safe for PostgreSQL
|
||||
"""
|
||||
# Always quote the identifier to handle special characters like hyphens
|
||||
# This is safe even for valid identifiers and prevents SQL injection
|
||||
return f'"{identifier}"'
|
||||
|
||||
def get_sql_insert_query(
|
||||
self, table_name: str, sql_keys: list[str], sql_values: list[str] | None = None
|
||||
) -> str:
|
||||
"""Override base method to add PostgreSQL-specific table name quoting.
|
||||
|
||||
Generates INSERT query with properly quoted table name for PostgreSQL.
|
||||
|
||||
Args:
|
||||
table_name (str): Name of the table
|
||||
sql_keys (list[str]): List of column names
|
||||
sql_values (list[str], optional): SQL values for database-specific handling (ignored for PostgreSQL)
|
||||
|
||||
Returns:
|
||||
str: INSERT query with properly quoted table name
|
||||
"""
|
||||
quoted_table = self._quote_identifier(table_name)
|
||||
keys_str = ", ".join(sql_keys)
|
||||
values_placeholder = ", ".join(["%s"] * len(sql_keys))
|
||||
return f"INSERT INTO {quoted_table} ({keys_str}) VALUES ({values_placeholder})"
|
||||
|
||||
@@ -33,7 +33,10 @@ def register_connectors(connectors: dict[str, Any]) -> None:
|
||||
Common.METADATA: metadata,
|
||||
}
|
||||
except ModuleNotFoundError as exception:
|
||||
logger.error(f"Error while importing connectors : {exception}")
|
||||
logger.error(
|
||||
f"Error while importing connectors {connector} : {exception}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if len(connectors) == 0:
|
||||
logger.warning("No connector found.")
|
||||
|
||||
@@ -6,10 +6,6 @@ import uuid
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import snowflake.connector
|
||||
import snowflake.connector.errors as SnowflakeError
|
||||
from snowflake.connector.connection import SnowflakeConnection
|
||||
|
||||
from unstract.connectors.constants import DatabaseTypeConstants
|
||||
from unstract.connectors.databases.exceptions import SnowflakeProgrammingException
|
||||
from unstract.connectors.databases.unstract_db import UnstractDB
|
||||
@@ -88,8 +84,10 @@ class SnowflakeDB(UnstractDB):
|
||||
}
|
||||
return str(mapping.get(data_type, DatabaseTypeConstants.SNOWFLAKE_TEXT))
|
||||
|
||||
def get_engine(self) -> SnowflakeConnection:
|
||||
con = snowflake.connector.connect(
|
||||
def get_engine(self) -> Any:
|
||||
from snowflake.connector import connect
|
||||
|
||||
con = connect(
|
||||
user=self.user,
|
||||
password=self.password,
|
||||
account=self.account,
|
||||
@@ -134,6 +132,8 @@ class SnowflakeDB(UnstractDB):
|
||||
def execute_query(
|
||||
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
|
||||
) -> None:
|
||||
import snowflake.connector.errors as SnowflakeError
|
||||
|
||||
table_name = kwargs.get("table_name", None)
|
||||
logger.debug(f"Snowflake execute_query called with sql_query: {sql_query}")
|
||||
logger.debug(f"sql_values: {sql_values}")
|
||||
@@ -169,6 +169,8 @@ class SnowflakeDB(UnstractDB):
|
||||
) from e
|
||||
|
||||
def get_information_schema(self, table_name: str) -> dict[str, str]:
|
||||
import snowflake.connector.errors as SnowflakeError
|
||||
|
||||
query = f"describe table {table_name}"
|
||||
column_types: dict[str, str] = {}
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from unstract.connectors import ConnectorDict # type: ignore
|
||||
from unstract.connectors.constants import ConnectorDict
|
||||
from unstract.connectors.filesystems.register import register_connectors
|
||||
|
||||
from .local_storage.local_storage import * # noqa: F401, F403
|
||||
|
||||
@@ -5,7 +5,7 @@ from email.utils import parsedate_to_datetime
|
||||
from typing import Any
|
||||
|
||||
import azure.core.exceptions as AzureException
|
||||
from adlfs import AzureBlobFileSystem
|
||||
from fsspec import AbstractFileSystem
|
||||
|
||||
from unstract.connectors.exceptions import AzureHttpError
|
||||
from unstract.connectors.filesystems.azure_cloud_storage.exceptions import (
|
||||
@@ -14,7 +14,17 @@ from unstract.connectors.filesystems.azure_cloud_storage.exceptions import (
|
||||
from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem
|
||||
from unstract.filesystem import FileStorageType, FileSystem
|
||||
|
||||
# Suppress verbose Azure SDK HTTP request/response logging
|
||||
logging.getLogger("azurefs").setLevel(logging.ERROR)
|
||||
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(
|
||||
logging.WARNING
|
||||
)
|
||||
logging.getLogger("azure.storage.blob").setLevel(logging.WARNING)
|
||||
logging.getLogger("azure.storage").setLevel(logging.WARNING)
|
||||
logging.getLogger("azure.core").setLevel(logging.WARNING)
|
||||
# Keep ADLFS filesystem errors visible but suppress HTTP noise
|
||||
logging.getLogger("adlfs").setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -23,6 +33,8 @@ class AzureCloudStorageFS(UnstractFileSystem):
|
||||
INVALID_PATH = "The specifed resource name contains invalid characters."
|
||||
|
||||
def __init__(self, settings: dict[str, Any]):
|
||||
from adlfs import AzureBlobFileSystem
|
||||
|
||||
super().__init__("AzureCloudStorageFS")
|
||||
account_name = settings.get("account_name", "")
|
||||
access_key = settings.get("access_key", "")
|
||||
@@ -70,7 +82,7 @@ class AzureCloudStorageFS(UnstractFileSystem):
|
||||
def can_read() -> bool:
|
||||
return True
|
||||
|
||||
def get_fsspec_fs(self) -> AzureBlobFileSystem:
|
||||
def get_fsspec_fs(self) -> AbstractFileSystem:
|
||||
return self.azure_fs
|
||||
|
||||
def extract_metadata_file_hash(self, metadata: dict[str, Any]) -> str | None:
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from gcsfs import GCSFileSystem
|
||||
from fsspec import AbstractFileSystem
|
||||
|
||||
from unstract.connectors.exceptions import ConnectorError
|
||||
from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem
|
||||
@@ -26,6 +26,8 @@ class GoogleCloudStorageFS(UnstractFileSystem):
|
||||
project_id = settings.get("project_id", "")
|
||||
json_credentials_str = settings.get("json_credentials", "{}")
|
||||
try:
|
||||
from gcsfs import GCSFileSystem
|
||||
|
||||
json_credentials = json.loads(json_credentials_str)
|
||||
self.gcs_fs = GCSFileSystem(
|
||||
token=json_credentials,
|
||||
@@ -81,7 +83,7 @@ class GoogleCloudStorageFS(UnstractFileSystem):
|
||||
def can_read() -> bool:
|
||||
return True
|
||||
|
||||
def get_fsspec_fs(self) -> GCSFileSystem:
|
||||
def get_fsspec_fs(self) -> AbstractFileSystem:
|
||||
return self.gcs_fs
|
||||
|
||||
def extract_metadata_file_hash(self, metadata: dict[str, Any]) -> str | None:
|
||||
@@ -109,8 +111,46 @@ class GoogleCloudStorageFS(UnstractFileSystem):
|
||||
Returns:
|
||||
bool: True if the path is a directory, False otherwise.
|
||||
"""
|
||||
# Note: Here Metadata type seems to be always "file" even for directories
|
||||
return metadata.get("type") == "directory"
|
||||
# Primary check: Standard directory type
|
||||
if metadata.get("type") == "directory":
|
||||
return True
|
||||
|
||||
# GCS-specific directory detection
|
||||
# In GCS, folders are represented as objects with specific characteristics
|
||||
object_name = metadata.get("name", "")
|
||||
size = metadata.get("size", 0)
|
||||
content_type = metadata.get("contentType", "")
|
||||
|
||||
# GCS folder indicators:
|
||||
# 1. Object name ends with "/" (most reliable indicator)
|
||||
if object_name.endswith("/"):
|
||||
logger.debug(
|
||||
f"[GCS Directory Check] '{object_name}' identified as directory: name ends with '/'"
|
||||
)
|
||||
return True
|
||||
|
||||
# 2. Zero-size object with text/plain content type (common for GCS folders)
|
||||
if size == 0 and content_type == "text/plain":
|
||||
logger.debug(
|
||||
f"[GCS Directory Check] '{object_name}' identified as directory: zero-size with text/plain content type"
|
||||
)
|
||||
return True
|
||||
|
||||
# 3. Check for GCS-specific folder metadata
|
||||
# Some GCS folder objects have no contentType or have application/x-www-form-urlencoded
|
||||
if size == 0 and (
|
||||
not content_type
|
||||
or content_type
|
||||
in ["application/x-www-form-urlencoded", "binary/octet-stream"]
|
||||
):
|
||||
# Additional validation: check if this looks like a folder path
|
||||
if "/" in object_name and not object_name.split("/")[-1]: # Path ends with /
|
||||
logger.debug(
|
||||
f"[GCS Directory Check] '{object_name}' identified as directory: zero-size folder-like object"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def extract_modified_date(self, metadata: dict[str, Any]) -> datetime | None:
|
||||
"""Extract the last modified date from GCS metadata.
|
||||
|
||||
@@ -32,7 +32,10 @@ def register_connectors(connectors: dict[str, Any]) -> None:
|
||||
Common.METADATA: metadata,
|
||||
}
|
||||
except ModuleNotFoundError as exception:
|
||||
logger.error(f"Error while importing connectors : {exception}")
|
||||
logger.error(
|
||||
f"Error while importing connectors {connector} : {exception}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if len(connectors) == 0:
|
||||
logger.warning("No connector found.")
|
||||
|
||||
@@ -3,20 +3,18 @@ import os
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from dropbox.exceptions import ApiError as DropBoxApiError
|
||||
from dropbox.exceptions import DropboxException
|
||||
from dropboxdrivefs import DropboxDriveFileSystem
|
||||
from fsspec import AbstractFileSystem
|
||||
|
||||
from unstract.connectors.exceptions import ConnectorError
|
||||
from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem
|
||||
|
||||
from .exceptions import handle_dropbox_exception
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DropboxFS(UnstractFileSystem):
|
||||
def __init__(self, settings: dict[str, Any]):
|
||||
from dropboxdrivefs import DropboxDriveFileSystem
|
||||
|
||||
super().__init__("Dropbox")
|
||||
self.dropbox_fs = DropboxDriveFileSystem(token=settings["token"])
|
||||
self.path = "///"
|
||||
@@ -61,7 +59,7 @@ class DropboxFS(UnstractFileSystem):
|
||||
def can_read() -> bool:
|
||||
return True
|
||||
|
||||
def get_fsspec_fs(self) -> DropboxDriveFileSystem:
|
||||
def get_fsspec_fs(self) -> AbstractFileSystem:
|
||||
return self.dropbox_fs
|
||||
|
||||
def extract_metadata_file_hash(self, metadata: dict[str, Any]) -> str | None:
|
||||
@@ -132,10 +130,14 @@ class DropboxFS(UnstractFileSystem):
|
||||
|
||||
def test_credentials(self) -> bool:
|
||||
"""To test credentials for Dropbox."""
|
||||
from dropbox.exceptions import DropboxException
|
||||
|
||||
try:
|
||||
# self.get_fsspec_fs().connect()
|
||||
self.get_fsspec_fs().ls("")
|
||||
except DropboxException as e:
|
||||
from .exceptions import handle_dropbox_exception
|
||||
|
||||
raise handle_dropbox_exception(e) from e
|
||||
except Exception as e:
|
||||
raise ConnectorError(f"Error while connecting to Dropbox: {str(e)}") from e
|
||||
@@ -143,11 +145,23 @@ class DropboxFS(UnstractFileSystem):
|
||||
|
||||
@staticmethod
|
||||
def get_connector_root_dir(input_dir: str, **kwargs: Any) -> str:
|
||||
"""Get roor dir of zs dropbox."""
|
||||
return f"/{input_dir.strip('/')}"
|
||||
"""Get root dir of zs dropbox with backward compatibility.
|
||||
|
||||
Dropbox requires leading slashes, so we override the base class behavior.
|
||||
"""
|
||||
# Call base class implementation
|
||||
result = super().get_connector_root_dir(input_dir, **kwargs)
|
||||
|
||||
# Dropbox needs leading slash - ensure it's present
|
||||
if not result.startswith("/"):
|
||||
result = f"/{result}"
|
||||
|
||||
return result
|
||||
|
||||
def create_dir_if_not_exists(self, input_dir: str) -> None:
|
||||
"""Create roor dir of zs dropbox if not exists."""
|
||||
from dropbox.exceptions import ApiError as DropBoxApiError
|
||||
|
||||
fs_fsspec = self.get_fsspec_fs()
|
||||
try:
|
||||
fs_fsspec.isdir(input_dir)
|
||||
|
||||
@@ -4,8 +4,6 @@ import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from google.cloud import secretmanager
|
||||
from google.cloud.storage import Client
|
||||
from google.oauth2 import service_account
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
@@ -20,6 +18,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class GCSHelper:
|
||||
def __init__(self) -> None:
|
||||
from google.cloud.storage import Client
|
||||
|
||||
self.client = Client
|
||||
self.google_service_json = os.environ.get("GDRIVE_GOOGLE_SERVICE_ACCOUNT")
|
||||
self.google_project_id = os.environ.get("GDRIVE_GOOGLE_PROJECT_ID")
|
||||
if self.google_service_json is None:
|
||||
@@ -39,6 +40,8 @@ class GCSHelper:
|
||||
return self.google_credentials
|
||||
|
||||
def get_secret(self, secret_name: str) -> str:
|
||||
from google.cloud import secretmanager
|
||||
|
||||
google_secrets_client = secretmanager.SecretManagerServiceClient(
|
||||
credentials=self.google_credentials
|
||||
)
|
||||
@@ -50,7 +53,7 @@ class GCSHelper:
|
||||
return s.payload.data.decode("UTF-8")
|
||||
|
||||
def get_object_checksum(self, bucket_name: str, object_name: str) -> str:
|
||||
client = Client(credentials=self.google_credentials)
|
||||
client = self.client(credentials=self.google_credentials)
|
||||
bucket = client.bucket(bucket_name)
|
||||
md5_hash_hex = ""
|
||||
try:
|
||||
@@ -62,26 +65,26 @@ class GCSHelper:
|
||||
return md5_hash_hex
|
||||
|
||||
def upload_file(self, bucket_name: str, object_name: str, file_path: str) -> None:
|
||||
client = Client(credentials=self.google_credentials)
|
||||
client = self.client(credentials=self.google_credentials)
|
||||
bucket = client.bucket(bucket_name)
|
||||
blob = bucket.blob(object_name)
|
||||
blob.upload_from_filename(file_path)
|
||||
|
||||
def upload_text(self, bucket_name: str, object_name: str, text: str) -> None:
|
||||
client = Client(credentials=self.google_credentials)
|
||||
client = self.client(credentials=self.google_credentials)
|
||||
bucket = client.bucket(bucket_name)
|
||||
blob = bucket.blob(object_name)
|
||||
blob.upload_from_string(text)
|
||||
|
||||
def upload_object(self, bucket_name: str, object_name: str, object: Any) -> None:
|
||||
client = Client(credentials=self.google_credentials)
|
||||
client = self.client(credentials=self.google_credentials)
|
||||
bucket = client.bucket(bucket_name)
|
||||
blob = bucket.blob(object_name)
|
||||
blob.upload_from_string(object, content_type="application/octet-stream")
|
||||
|
||||
def read_file(self, bucket_name: str, object_name: str) -> Any:
|
||||
logger.info(f"Reading file {object_name} from bucket {bucket_name}")
|
||||
client = Client(credentials=self.google_credentials)
|
||||
client = self.client(credentials=self.google_credentials)
|
||||
bucket = client.bucket(bucket_name)
|
||||
logger.info(f"Reading file {object_name} from bucket {bucket_name}")
|
||||
try:
|
||||
|
||||
146
unstract/connectors/src/unstract/connectors/operations.py
Normal file
146
unstract/connectors/src/unstract/connectors/operations.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Connector Operations for Unstract Platform
|
||||
|
||||
This module provides core connector operations for filesystem connectors
|
||||
and connector health checks.
|
||||
|
||||
Used by:
|
||||
- workers/shared/workflow/connectors/service.py (for worker-native operations)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
# Import internal connector components (no try/catch needed - proper dependencies)
|
||||
from unstract.connectors.constants import Common
|
||||
from unstract.connectors.filesystems import connectors as fs_connectors
|
||||
from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectorOperations:
|
||||
"""Common connector operations shared between backend and workers with strict error handling"""
|
||||
|
||||
@staticmethod
|
||||
def test_connector_connection(
|
||||
connector_id: str, settings: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Test connection to connector before attempting operations.
|
||||
|
||||
Args:
|
||||
connector_id: Connector ID
|
||||
settings: Connector settings
|
||||
|
||||
Returns:
|
||||
Dictionary with connection test results: {'is_connected': bool, 'error': str}
|
||||
"""
|
||||
try:
|
||||
# Get connector instance
|
||||
connector = ConnectorOperations.get_fs_connector(connector_id, settings)
|
||||
|
||||
# Test basic connectivity by getting fsspec filesystem
|
||||
fs = connector.get_fsspec_fs()
|
||||
|
||||
# For filesystem connectors, try to check if root path exists
|
||||
test_path = settings.get("path", "/")
|
||||
try:
|
||||
fs.exists(test_path)
|
||||
return {"is_connected": True, "error": None}
|
||||
except Exception as path_error:
|
||||
return {
|
||||
"is_connected": False,
|
||||
"error": f"Cannot access path '{test_path}': {str(path_error)}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {"is_connected": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
def get_fs_connector(
|
||||
connector_id: str, settings: dict[str, Any]
|
||||
) -> "UnstractFileSystem":
|
||||
"""Get filesystem connector instance using exact backend BaseConnector logic.
|
||||
|
||||
This replicates backend/workflow_manager/endpoint_v2/base_connector.py:get_fs_connector()
|
||||
|
||||
Args:
|
||||
connector_id: Connector ID from the registry
|
||||
settings: Connector-specific settings
|
||||
|
||||
Returns:
|
||||
UnstractFileSystem instance
|
||||
|
||||
Raises:
|
||||
ImportError: If connector registries not available (critical error)
|
||||
ValueError: If connector_id is not supported
|
||||
"""
|
||||
if not fs_connectors:
|
||||
raise RuntimeError("Filesystem connectors registry not initialized")
|
||||
|
||||
if connector_id not in fs_connectors:
|
||||
available_ids = list(fs_connectors.keys())
|
||||
raise ValueError(
|
||||
f"Connector '{connector_id}' is not supported. "
|
||||
f"Available connectors: {available_ids}"
|
||||
)
|
||||
|
||||
if not Common:
|
||||
raise RuntimeError("Common connector constants not initialized")
|
||||
|
||||
# Use exact same pattern as backend BaseConnector
|
||||
connector_class = fs_connectors[connector_id][Common.METADATA][Common.CONNECTOR]
|
||||
return connector_class(settings)
|
||||
|
||||
@staticmethod
|
||||
def get_connector_health(source_config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Get health status of a source connector.
|
||||
|
||||
Args:
|
||||
source_config: Source configuration dictionary
|
||||
|
||||
Returns:
|
||||
Dictionary with health status and metadata
|
||||
"""
|
||||
try:
|
||||
connector_id = source_config.get("connector_id") or source_config.get(
|
||||
"connection_type"
|
||||
)
|
||||
settings = source_config.get("settings", {})
|
||||
|
||||
if not connector_id or not settings:
|
||||
return {
|
||||
"is_healthy": False,
|
||||
"connection_type": connector_id,
|
||||
"errors": ["Missing connector configuration"],
|
||||
"response_time_ms": None,
|
||||
}
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Test connection
|
||||
connection_result = ConnectorOperations.test_connector_connection(
|
||||
connector_id, settings
|
||||
)
|
||||
|
||||
response_time = int(
|
||||
(time.time() - start_time) * 1000
|
||||
) # Convert to milliseconds
|
||||
|
||||
return {
|
||||
"is_healthy": connection_result["is_connected"],
|
||||
"connection_type": connector_id,
|
||||
"errors": [connection_result["error"]]
|
||||
if connection_result["error"]
|
||||
else [],
|
||||
"response_time_ms": response_time,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"is_healthy": False,
|
||||
"connection_type": source_config.get("connector_id", "unknown"),
|
||||
"errors": [str(e)],
|
||||
"response_time_ms": None,
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
from unstract.connectors import ConnectorDict
|
||||
from unstract.connectors.constants import ConnectorDict
|
||||
from unstract.connectors.queues.register import register_connectors
|
||||
|
||||
connectors: ConnectorDict = {}
|
||||
|
||||
@@ -32,7 +32,10 @@ def register_connectors(connectors: dict[str, Any]) -> None:
|
||||
Common.METADATA: metadata,
|
||||
}
|
||||
except ModuleNotFoundError as exception:
|
||||
logger.error(f"Error while importing connectors : {exception}")
|
||||
logger.error(
|
||||
f"Error while importing connectors {connector} : {exception}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if len(connectors) == 0:
|
||||
logger.warning("No connector found.")
|
||||
|
||||
2
unstract/core/src/unstract/__init__.py
Normal file
2
unstract/core/src/unstract/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Unstract namespace package
|
||||
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
|
||||
@@ -0,0 +1,99 @@
|
||||
"""Unstract Core Library
|
||||
|
||||
Core data models, utilities, and base classes for the Unstract platform.
|
||||
Provides shared functionality between backend and worker services.
|
||||
"""
|
||||
|
||||
# Export core data models and enums
|
||||
# Export existing utilities and constants
|
||||
from .constants import LogEventArgument, LogFieldName, LogProcessingTask
|
||||
from .data_models import (
|
||||
ConnectionType,
|
||||
ExecutionStatus,
|
||||
FileHashData,
|
||||
SourceConnectionType,
|
||||
WorkflowExecutionData,
|
||||
WorkflowFileExecutionData,
|
||||
WorkflowType,
|
||||
serialize_dataclass_to_dict,
|
||||
)
|
||||
|
||||
# Export worker base classes
|
||||
from .worker_base import (
|
||||
CallbackTaskBase,
|
||||
FileProcessingTaskBase,
|
||||
WorkerTaskBase,
|
||||
circuit_breaker,
|
||||
create_callback_task,
|
||||
create_file_processing_task,
|
||||
create_task_decorator,
|
||||
monitor_performance,
|
||||
with_task_context,
|
||||
)
|
||||
|
||||
# Note: Worker constants moved to workers/shared/ to remove Django dependency
|
||||
# These are now available directly from workers.shared.constants and workers.shared.worker_patterns
|
||||
# Export worker-specific models and enums
|
||||
from .worker_models import (
|
||||
BatchExecutionResult,
|
||||
CallbackExecutionData,
|
||||
FileExecutionResult,
|
||||
NotificationMethod,
|
||||
NotificationRequest,
|
||||
PipelineStatus,
|
||||
PipelineStatusUpdateRequest,
|
||||
QueueName,
|
||||
StatusMappings,
|
||||
TaskError,
|
||||
TaskExecutionContext,
|
||||
TaskName,
|
||||
WebhookResult,
|
||||
WebhookStatus,
|
||||
WorkerTaskStatus,
|
||||
WorkflowExecutionUpdateRequest,
|
||||
)
|
||||
|
||||
__version__ = "1.0.0"
|
||||
|
||||
__all__ = [
|
||||
# Core data models and enums
|
||||
"ExecutionStatus",
|
||||
"WorkflowType",
|
||||
"ConnectionType",
|
||||
"FileHashData",
|
||||
"WorkflowFileExecutionData",
|
||||
"WorkflowExecutionData",
|
||||
"SourceConnectionType",
|
||||
"serialize_dataclass_to_dict",
|
||||
# Worker models and enums
|
||||
"TaskName",
|
||||
"QueueName",
|
||||
"WorkerTaskStatus",
|
||||
"PipelineStatus",
|
||||
"WebhookStatus",
|
||||
"NotificationMethod",
|
||||
"StatusMappings",
|
||||
"WebhookResult",
|
||||
"FileExecutionResult",
|
||||
"BatchExecutionResult",
|
||||
"CallbackExecutionData",
|
||||
"WorkflowExecutionUpdateRequest",
|
||||
"PipelineStatusUpdateRequest",
|
||||
"NotificationRequest",
|
||||
"TaskExecutionContext",
|
||||
"TaskError",
|
||||
# Worker base classes
|
||||
"WorkerTaskBase",
|
||||
"FileProcessingTaskBase",
|
||||
"CallbackTaskBase",
|
||||
"create_task_decorator",
|
||||
"monitor_performance",
|
||||
"with_task_context",
|
||||
"circuit_breaker",
|
||||
"create_file_processing_task",
|
||||
"create_callback_task",
|
||||
# Existing utilities
|
||||
"LogFieldName",
|
||||
"LogEventArgument",
|
||||
"LogProcessingTask",
|
||||
]
|
||||
|
||||
@@ -18,3 +18,114 @@ class LogEventArgument:
|
||||
class LogProcessingTask:
|
||||
TASK_NAME = "logs_consumer"
|
||||
QUEUE_NAME = "celery_log_task_queue"
|
||||
|
||||
|
||||
class FileProcessingConstants:
|
||||
"""Constants for file processing operations."""
|
||||
|
||||
# File chunk size for reading/writing (4MB default)
|
||||
READ_CHUNK_SIZE = 4194304 # 4MB chunks for file reading
|
||||
|
||||
# Log preview size for truncating file content in logs
|
||||
LOG_PREVIEW_SIZE = 500 # 500 bytes for log preview
|
||||
|
||||
# File processing timeout in seconds
|
||||
DEFAULT_PROCESSING_TIMEOUT = 300 # 5 minutes
|
||||
|
||||
# Maximum file size in bytes for validation
|
||||
MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024 # 100MB
|
||||
|
||||
@classmethod
|
||||
def get_chunk_size(cls) -> int:
|
||||
"""Get the configured chunk size for file operations."""
|
||||
return cls.READ_CHUNK_SIZE
|
||||
|
||||
@classmethod
|
||||
def get_log_preview_size(cls) -> int:
|
||||
"""Get the configured log preview size."""
|
||||
return cls.LOG_PREVIEW_SIZE
|
||||
|
||||
|
||||
class WorkerConstants:
|
||||
"""General worker operation constants."""
|
||||
|
||||
# Default retry attempts for worker operations
|
||||
DEFAULT_RETRY_ATTEMPTS = 3
|
||||
|
||||
# Default timeout for API calls
|
||||
API_TIMEOUT = 30
|
||||
|
||||
# Health check interval
|
||||
HEALTH_CHECK_INTERVAL = 30
|
||||
|
||||
|
||||
class FilePatternConstants:
|
||||
"""Constants for file pattern matching and translation."""
|
||||
|
||||
# Display name to file pattern mappings
|
||||
# Maps UI-friendly display names to actual file matching patterns
|
||||
DISPLAY_NAME_TO_PATTERNS = {
|
||||
"pdf documents": ["*.pdf"],
|
||||
"word documents": ["*.doc", "*.docx"],
|
||||
"excel documents": ["*.xls", "*.xlsx"],
|
||||
"powerpoint documents": ["*.ppt", "*.pptx"],
|
||||
"text files": ["*.txt"],
|
||||
"image files": ["*.jpg", "*.jpeg", "*.png", "*.gif", "*.bmp", "*.tiff", "*.tif"],
|
||||
"csv files": ["*.csv"],
|
||||
"json files": ["*.json"],
|
||||
"xml files": ["*.xml"],
|
||||
"all files": ["*"],
|
||||
"office documents": ["*.doc", "*.docx", "*.xls", "*.xlsx", "*.ppt", "*.pptx"],
|
||||
"document files": ["*.pdf", "*.doc", "*.docx", "*.txt"],
|
||||
"spreadsheet files": ["*.xls", "*.xlsx", "*.csv"],
|
||||
"presentation files": ["*.ppt", "*.pptx"],
|
||||
"archive files": ["*.zip", "*.rar", "*.7z", "*.tar", "*.gz"],
|
||||
"video files": ["*.mp4", "*.avi", "*.mov", "*.wmv", "*.flv", "*.mkv"],
|
||||
"audio files": ["*.mp3", "*.wav", "*.flac", "*.aac", "*.ogg"],
|
||||
}
|
||||
|
||||
# Common file extension categories for inference
|
||||
EXTENSION_CATEGORIES = {
|
||||
"pdf": ["*.pdf"],
|
||||
"doc": ["*.doc", "*.docx"],
|
||||
"excel": ["*.xls", "*.xlsx"],
|
||||
"image": ["*.jpg", "*.jpeg", "*.png", "*.gif", "*.bmp", "*.tiff", "*.tif"],
|
||||
"text": ["*.txt"],
|
||||
"csv": ["*.csv"],
|
||||
"json": ["*.json"],
|
||||
"xml": ["*.xml"],
|
||||
"office": ["*.doc", "*.docx", "*.xls", "*.xlsx", "*.ppt", "*.pptx"],
|
||||
"archive": ["*.zip", "*.rar", "*.7z", "*.tar", "*.gz"],
|
||||
"video": ["*.mp4", "*.avi", "*.mov", "*.wmv", "*.flv", "*.mkv"],
|
||||
"audio": ["*.mp3", "*.wav", "*.flac", "*.aac", "*.ogg"],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_patterns_for_display_name(cls, display_name: str) -> list[str] | None:
|
||||
"""Get file patterns for a given display name.
|
||||
|
||||
Args:
|
||||
display_name: UI display name (e.g., "PDF documents")
|
||||
|
||||
Returns:
|
||||
List of file patterns or None if not found
|
||||
"""
|
||||
return cls.DISPLAY_NAME_TO_PATTERNS.get(display_name.strip().lower())
|
||||
|
||||
@classmethod
|
||||
def infer_patterns_from_keyword(cls, keyword: str) -> list[str] | None:
|
||||
"""Infer file patterns from a keyword.
|
||||
|
||||
Args:
|
||||
keyword: Keyword to search for (e.g., "pdf", "excel")
|
||||
|
||||
Returns:
|
||||
List of file patterns or None if not found
|
||||
"""
|
||||
keyword_lower = keyword.strip().lower()
|
||||
|
||||
for category, patterns in cls.EXTENSION_CATEGORIES.items():
|
||||
if category in keyword_lower:
|
||||
return patterns
|
||||
|
||||
return None
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user