diff --git a/.github/workflows/production-build.yaml b/.github/workflows/production-build.yaml index 07c319bd..d7e2d2df 100644 --- a/.github/workflows/production-build.yaml +++ b/.github/workflows/production-build.yaml @@ -11,7 +11,7 @@ jobs: build-and-push: runs-on: ubuntu-latest strategy: - matrix: + matrix: service_name: [backend, frontend, document-service, platform-service, prompt-service, worker, x2text-service] steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c67d62e4..6ea24139 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,33 +40,33 @@ repos: hooks: - id: yamllint args: ["-d", "relaxed"] - language: system - - repo: https://github.com/rhysd/actionlint - rev: v1.6.27 - hooks: - - id: actionlint-docker - args: [-ignore, 'label ".+" is unknown'] + # language: system + # - repo: https://github.com/rhysd/actionlint + # rev: v1.6.27 + # hooks: + # - id: actionlint-docker + # args: [-ignore, 'label ".+" is unknown'] - repo: https://github.com/psf/black rev: 24.3.0 hooks: - id: black args: [--config=pyproject.toml, -l 80] - language: system + # language: system exclude: | (?x)^( unstract/flags/src/unstract/flags/evaluation_.*\.py| )$ - - repo: https://github.com/pycqa/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - args: [--max-line-length=80] - exclude: | - (?x)^( - .*migrations/.*\.py| - core/tests/.*| - unstract/flags/src/unstract/flags/evaluation_.*\.py| - )$ + # - repo: https://github.com/pycqa/flake8 + # rev: 7.0.0 + # hooks: + # - id: flake8 + # args: [--max-line-length=80] + # exclude: | + # (?x)^( + # .*migrations/.*\.py| + # core/tests/.*| + # unstract/flags/src/unstract/flags/evaluation_.*\.py| + # )$ - repo: https://github.com/pycqa/isort rev: 5.13.2 hooks: @@ -104,35 +104,35 @@ repos: rev: v8.18.2 hooks: - id: gitleaks - - repo: https://github.com/hadolint/hadolint - rev: v2.12.1-beta - hooks: - - id: hadolint-docker - args: - - --ignore=DL3003 - - --ignore=DL3008 - - --ignore=DL3013 - - --ignore=DL3018 - - --ignore=SC1091 - files: Dockerfile$ + # - repo: https://github.com/hadolint/hadolint + # rev: v2.12.1-beta + # hooks: + # - id: hadolint-docker + # args: + # - --ignore=DL3003 + # - --ignore=DL3008 + # - --ignore=DL3013 + # - --ignore=DL3018 + # - --ignore=SC1091 + # files: Dockerfile$ - repo: https://github.com/asottile/yesqa rev: v1.5.0 hooks: - id: yesqa - - repo: https://github.com/pre-commit/mirrors-eslint - rev: "v9.0.0-beta.2" # Use the sha / tag you want to point at - hooks: - - id: eslint - args: [--config=frontend/.eslintrc.json] - files: \.[jt]sx?$ # *.js, *.jsx, *.ts and *.tsx - types: [file] - additional_dependencies: - - eslint@8.41.0 - - eslint-config-google@0.14.0 - - eslint-config-prettier@8.8.0 - - eslint-plugin-prettier@4.2.1 - - eslint-plugin-react@7.32.2 - - eslint-plugin-import@2.25.2 + # - repo: https://github.com/pre-commit/mirrors-eslint + # rev: "v9.0.0-beta.2" # Use the sha / tag you want to point at + # hooks: + # - id: eslint + # args: [--config=frontend/.eslintrc.json] + # files: \.[jt]sx?$ # *.js, *.jsx, *.ts and *.tsx + # types: [file] + # additional_dependencies: + # - eslint@8.41.0 + # - eslint-config-google@0.14.0 + # - eslint-config-prettier@8.8.0 + # - eslint-plugin-prettier@4.2.1 + # - eslint-plugin-react@7.32.2 + # - eslint-plugin-import@2.25.2 - repo: https://github.com/Lucas-C/pre-commit-hooks-nodejs rev: v1.1.2 hooks: @@ -155,16 +155,16 @@ repos: rev: 2.12.4 hooks: - id: pdm-lock-check - - repo: local - hooks: - - id: run-mypy - name: Run mypy - entry: sh -c 'pdm run mypy .' - language: system - pass_filenames: false - - id: check-django-migrations - name: Check django migrations - entry: sh -c 'pdm run docker/scripts/check_django_migrations.sh' - language: system - types: [python] # hook only runs if a python file is staged - pass_filenames: false + # - repo: local + # hooks: + # - id: run-mypy + # name: Run mypy + # entry: sh -c 'pdm run mypy .' + # language: system + # pass_filenames: false + # - id: check-django-migrations + # name: Check django migrations + # entry: sh -c 'pdm run docker/scripts/check_django_migrations.sh' + # language: system + # types: [python] # hook only runs if a python file is staged + # pass_filenames: false diff --git a/README.md b/README.md index b789c35f..090305a3 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,14 @@ ## No-code LLM Platform to launch APIs and ETL Pipelines to structure unstructured documents +[![CLA assistant](https://cla-assistant.io/readme/badge/Zipstack/unstract)](https://cla-assistant.io/Zipstack/unstract) +[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Zipstack/unstract/main.svg)](https://results.pre-commit.ci/latest/github/Zipstack/unstract/main) +[![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=Zipstack_unstract&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=Zipstack_unstract) +[![Bugs](https://sonarcloud.io/api/project_badges/measure?project=Zipstack_unstract&metric=bugs)](https://sonarcloud.io/summary/new_code?id=Zipstack_unstract) +[![Code Smells](https://sonarcloud.io/api/project_badges/measure?project=Zipstack_unstract&metric=code_smells)](https://sonarcloud.io/summary/new_code?id=Zipstack_unstract) +[![Coverage](https://sonarcloud.io/api/project_badges/measure?project=Zipstack_unstract&metric=coverage)](https://sonarcloud.io/summary/new_code?id=Zipstack_unstract) +[![Duplicated Lines (%)](https://sonarcloud.io/api/project_badges/measure?project=Zipstack_unstract&metric=duplicated_lines_density)](https://sonarcloud.io/summary/new_code?id=Zipstack_unstract) + ## 🤖 Go beyond co-pilots @@ -121,4 +129,4 @@ Contributions are welcome! Please read [CONTRIBUTE.md](CONTRIBUTE.md) for furthe ## 👋 Join the LLM-powered automation community -[Join great conversations](https://join-slack.unstract.com) around LLMs, their ecosystem and leveraging them to automate the previously unautomatable! \ No newline at end of file +[Join great conversations](https://join-slack.unstract.com) around LLMs, their ecosystem and leveraging them to automate the previously unautomatable! diff --git a/backend/account/admin.py b/backend/account/admin.py index 6eab2ecf..e0b96cce 100644 --- a/backend/account/admin.py +++ b/backend/account/admin.py @@ -2,4 +2,4 @@ from django.contrib import admin from .models import Organization, User -admin.site.register([Organization, User]) \ No newline at end of file +admin.site.register([Organization, User]) diff --git a/backend/account/authentication_controller.py b/backend/account/authentication_controller.py index 197e82f8..275c5b83 100644 --- a/backend/account/authentication_controller.py +++ b/backend/account/authentication_controller.py @@ -121,10 +121,10 @@ class AuthenticationController: return redirect(f"{settings.ERROR_URL}") if member.organization_id and member.role and len(member.role) > 0: - organization: Optional[ - Organization - ] = OrganizationService.get_organization_by_org_id( - member.organization_id + organization: Optional[Organization] = ( + OrganizationService.get_organization_by_org_id( + member.organization_id + ) ) if organization: try: @@ -192,9 +192,9 @@ class AuthenticationController: new_organization = False organization_ids = CacheService.get_user_organizations(user.user_id) if not organization_ids: - z_organizations: list[ - OrganizationData - ] = self.auth_service.get_organizations_by_user_id(user.user_id) + z_organizations: list[OrganizationData] = ( + self.auth_service.get_organizations_by_user_id(user.user_id) + ) organization_ids = {org.id for org in z_organizations} if organization_id and organization_id in organization_ids: organization = OrganizationService.get_organization_by_org_id( @@ -242,9 +242,9 @@ class AuthenticationController: }, ) # Update user session data in redis - user_session_info: dict[ - str, Any - ] = CacheService.get_user_session_info(user.email) + user_session_info: dict[str, Any] = ( + CacheService.get_user_session_info(user.email) + ) user_session_info["current_org"] = organization_id CacheService.set_user_session_info(user_session_info) response.set_cookie(Cookie.ORG_ID, organization_id) diff --git a/backend/account/migrations/0001_initial.py b/backend/account/migrations/0001_initial.py index a925b7fd..0ff85cfe 100644 --- a/backend/account/migrations/0001_initial.py +++ b/backend/account/migrations/0001_initial.py @@ -29,7 +29,10 @@ class Migration(migrations.Migration): verbose_name="ID", ), ), - ("password", models.CharField(max_length=128, verbose_name="password")), + ( + "password", + models.CharField(max_length=128, verbose_name="password"), + ), ( "last_login", models.DateTimeField( @@ -96,7 +99,8 @@ class Migration(migrations.Migration): ( "date_joined", models.DateTimeField( - default=django.utils.timezone.now, verbose_name="date joined" + default=django.utils.timezone.now, + verbose_name="date joined", ), ), ("user_id", models.CharField()), @@ -218,9 +222,14 @@ class Migration(migrations.Migration): ), ( "domain", - models.CharField(db_index=True, max_length=253, unique=True), + models.CharField( + db_index=True, max_length=253, unique=True + ), + ), + ( + "is_primary", + models.BooleanField(db_index=True, default=True), ), - ("is_primary", models.BooleanField(db_index=True, default=True)), ( "tenant", models.ForeignKey( diff --git a/backend/account/migrations/0003_platformkey.py b/backend/account/migrations/0003_platformkey.py index fc8d8028..80cdc227 100644 --- a/backend/account/migrations/0003_platformkey.py +++ b/backend/account/migrations/0003_platformkey.py @@ -1,9 +1,10 @@ # Generated by Django 4.2.1 on 2023-11-02 05:22 +import uuid + +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion -import uuid class Migration(migrations.Migration): @@ -27,7 +28,9 @@ class Migration(migrations.Migration): ("key", models.UUIDField(default=uuid.uuid4)), ( "key_name", - models.CharField(blank=True, max_length=64, null=True, unique=True), + models.CharField( + blank=True, max_length=64, null=True, unique=True + ), ), ("is_active", models.BooleanField(default=False)), ( diff --git a/backend/account/organization.py b/backend/account/organization.py index b9f599d1..a584bf3a 100644 --- a/backend/account/organization.py +++ b/backend/account/organization.py @@ -8,6 +8,8 @@ from django.db import IntegrityError Logger = logging.getLogger(__name__) subscription_loader = load_plugins() + + class OrganizationService: def __init__(self): # type: ignore pass @@ -36,11 +38,12 @@ class OrganizationService: cls = subscription_plugin[SubscriptionConfig.METADATA][ SubscriptionConfig.METADATA_SERVICE_CLASS ] - cls.add( - organization_id=organization_id) + cls.add(organization_id=organization_id) except IntegrityError as error: - Logger.info(f"[Duplicate Id] Failed to create Organization Error: {error}") + Logger.info( + f"[Duplicate Id] Failed to create Organization Error: {error}" + ) raise error # Add one or more domains for the tenant domain = Domain() diff --git a/backend/account/subscription_loader.py b/backend/account/subscription_loader.py index 9d133619..c516dc3b 100644 --- a/backend/account/subscription_loader.py +++ b/backend/account/subscription_loader.py @@ -24,13 +24,17 @@ def load_plugins() -> list[Any]: """Iterate through the subscription plugins and register them.""" plugins_app = apps.get_app_config(SubscriptionConfig.PLUGINS_APP) package_path = plugins_app.module.__package__ - subscription_dir = os.path.join(plugins_app.path, SubscriptionConfig.PLUGIN_DIR) - subscription_package_path = f"{package_path}.{SubscriptionConfig.PLUGIN_DIR}" + subscription_dir = os.path.join( + plugins_app.path, SubscriptionConfig.PLUGIN_DIR + ) + subscription_package_path = ( + f"{package_path}.{SubscriptionConfig.PLUGIN_DIR}" + ) subscription_plugins: list[Any] = [] if not os.path.exists(subscription_dir): return subscription_plugins - + for item in os.listdir(subscription_dir): # Loads a plugin if it is in a directory. if os.path.isdir(os.path.join(subscription_dir, item)): @@ -76,4 +80,4 @@ def load_plugins() -> list[Any]: if len(subscription_plugins) == 0: logger.info("No subscription plugins found.") - return subscription_plugins \ No newline at end of file + return subscription_plugins diff --git a/backend/account/urls.py b/backend/account/urls.py index 8cc60e65..1987336a 100644 --- a/backend/account/urls.py +++ b/backend/account/urls.py @@ -15,6 +15,10 @@ urlpatterns = [ path("logout", logout, name="logout"), path("callback", callback, name="callback"), path("organization", get_organizations, name="get_organizations"), - path("organization//set", set_organization, name="set_organization"), - path("organization/create", create_organization, name="create_organization"), + path( + "organization//set", set_organization, name="set_organization" + ), + path( + "organization/create", create_organization, name="create_organization" + ), ] diff --git a/backend/adapter_processor/adapter_processor.py b/backend/adapter_processor/adapter_processor.py index 2ef15d47..4a5d1737 100644 --- a/backend/adapter_processor/adapter_processor.py +++ b/backend/adapter_processor/adapter_processor.py @@ -99,18 +99,18 @@ class AdapterProcessor: adapter_metadata.pop(AdapterKeys.ADAPTER_TYPE) == AdapterKeys.X2TEXT ): - adapter_metadata[ - X2TextConstants.X2TEXT_HOST - ] = settings.X2TEXT_HOST - adapter_metadata[ - X2TextConstants.X2TEXT_PORT - ] = settings.X2TEXT_PORT + adapter_metadata[X2TextConstants.X2TEXT_HOST] = ( + settings.X2TEXT_HOST + ) + adapter_metadata[X2TextConstants.X2TEXT_PORT] = ( + settings.X2TEXT_PORT + ) platform_key = ( PlatformAuthenticationService.get_active_platform_key() ) - adapter_metadata[ - X2TextConstants.PLATFORM_SERVICE_API_KEY - ] = str(platform_key.key) + adapter_metadata[X2TextConstants.PLATFORM_SERVICE_API_KEY] = ( + str(platform_key.key) + ) adapter_instance = adapter_class(adapter_metadata) test_result: bool = adapter_instance.test_connection() diff --git a/backend/adapter_processor/views.py b/backend/adapter_processor/views.py index 2ee90c41..fac3ceab 100644 --- a/backend/adapter_processor/views.py +++ b/backend/adapter_processor/views.py @@ -112,9 +112,9 @@ class AdapterViewSet(GenericViewSet): adapter_metadata = serializer.validated_data.get( AdapterKeys.ADAPTER_METADATA ) - adapter_metadata[ - AdapterKeys.ADAPTER_TYPE - ] = serializer.validated_data.get(AdapterKeys.ADAPTER_TYPE) + adapter_metadata[AdapterKeys.ADAPTER_TYPE] = ( + serializer.validated_data.get(AdapterKeys.ADAPTER_TYPE) + ) try: test_result = AdapterProcessor.test_adapter( adapter_id=adapter_id, adapter_metadata=adapter_metadata diff --git a/backend/api/serializers.py b/backend/api/serializers.py index ece1e6de..b1622b43 100644 --- a/backend/api/serializers.py +++ b/backend/api/serializers.py @@ -3,7 +3,6 @@ from typing import Any, Union from api.constants import ApiExecution from api.models import APIDeployment, APIKey -from backend.serializers import AuditSerializer from django.core.validators import RegexValidator from rest_framework.serializers import ( CharField, @@ -14,6 +13,8 @@ from rest_framework.serializers import ( ValidationError, ) +from backend.serializers import AuditSerializer + class APIDeploymentSerializer(AuditSerializer): class Meta: diff --git a/backend/apps/constants.py b/backend/apps/constants.py index 67be60df..15297aef 100644 --- a/backend/apps/constants.py +++ b/backend/apps/constants.py @@ -1,4 +1,2 @@ class AppConstants: """Constants for Apps.""" - - \ No newline at end of file diff --git a/backend/apps/exceptions.py b/backend/apps/exceptions.py index 7836889b..fb1980ae 100644 --- a/backend/apps/exceptions.py +++ b/backend/apps/exceptions.py @@ -3,4 +3,4 @@ from rest_framework.exceptions import APIException class FetchAppListFailed(APIException): status_code = 400 - default_detail = "Failed to fetch App list." \ No newline at end of file + default_detail = "Failed to fetch App list." diff --git a/backend/apps/urls.py b/backend/apps/urls.py index 918cedca..d60c9b2b 100644 --- a/backend/apps/urls.py +++ b/backend/apps/urls.py @@ -1,5 +1,5 @@ -from django.urls import path from apps import views +from django.urls import path from rest_framework.urlpatterns import format_suffix_patterns urlpatterns = format_suffix_patterns( diff --git a/backend/backend/celery.py b/backend/backend/celery.py index 1c790f67..0ff550bb 100644 --- a/backend/backend/celery.py +++ b/backend/backend/celery.py @@ -1,5 +1,4 @@ -"""This module contains the Celery configuration for the backend -project.""" +"""This module contains the Celery configuration for the backend project.""" import os diff --git a/backend/backend/public_urls.py b/backend/backend/public_urls.py index 6f2d7958..db497f5c 100644 --- a/backend/backend/public_urls.py +++ b/backend/backend/public_urls.py @@ -17,9 +17,8 @@ Including another URLconf from account.admin import admin from django.conf import settings from django.conf.urls import * # noqa: F401, F403 -from django.urls import include, path from django.conf.urls.static import static -from django.conf import settings +from django.urls import include, path path_prefix = settings.PATH_PREFIX api_path_prefix = settings.API_DEPLOYMENT_PATH_PREFIX @@ -38,4 +37,4 @@ urlpatterns = [ # Feature flags path(f"{path_prefix}/flags/", include("feature_flag.urls")), ] -urlpatterns += static(settings.STATIC_URL, document_root=settings.STATIC_ROOT) \ No newline at end of file +urlpatterns += static(settings.STATIC_URL, document_root=settings.STATIC_ROOT) diff --git a/backend/backend/serializers.py b/backend/backend/serializers.py index 151a334d..7cb42953 100644 --- a/backend/backend/serializers.py +++ b/backend/backend/serializers.py @@ -1,8 +1,9 @@ from typing import Any -from backend.constants import RequestKey from rest_framework.serializers import ModelSerializer +from backend.constants import RequestKey + class AuditSerializer(ModelSerializer): def create(self, validated_data: dict[str, Any]) -> Any: diff --git a/backend/backend/wsgi.py b/backend/backend/wsgi.py index 7a762dd3..9a654eb5 100644 --- a/backend/backend/wsgi.py +++ b/backend/backend/wsgi.py @@ -11,7 +11,6 @@ import os from django.conf import settings from django.core.wsgi import get_wsgi_application from dotenv import load_dotenv - from utils.log_events import start_server load_dotenv() diff --git a/backend/connector/fields.py b/backend/connector/fields.py index 2ade7a89..b96b206b 100644 --- a/backend/connector/fields.py +++ b/backend/connector/fields.py @@ -1,16 +1,16 @@ +import logging from datetime import datetime from connector_auth.constants import SocialAuthConstants from connector_auth.models import ConnectorAuth from django.db import models -import logging logger = logging.getLogger(__name__) class ConnectorAuthJSONField(models.JSONField): def from_db_value(self, value, expression, connection): # type: ignore - """ Overrding default function. """ + """Overrding default function.""" metadata = super().from_db_value(value, expression, connection) provider = metadata.get(SocialAuthConstants.PROVIDER) uid = metadata.get(SocialAuthConstants.UID) diff --git a/backend/connector/serializers.py b/backend/connector/serializers.py index 820b47d6..992e2cfb 100644 --- a/backend/connector/serializers.py +++ b/backend/connector/serializers.py @@ -76,10 +76,10 @@ class ConnectorInstanceSerializer(AuditSerializer): if SerializerUtils.check_context_for_GET_or_POST(context=self.context): rep.pop(CIKey.CONNECTOR_AUTH) # set icon fields for UI - rep[ - ConnectorKeys.ICON - ] = ConnectorProcessor.get_connector_data_with_key( - instance.connector_id, ConnectorKeys.ICON + rep[ConnectorKeys.ICON] = ( + ConnectorProcessor.get_connector_data_with_key( + instance.connector_id, ConnectorKeys.ICON + ) ) encryption_secret: str = settings.ENCRYPTION_KEY f: Fernet = Fernet(encryption_secret.encode("utf-8")) diff --git a/backend/connector/tests/connector_tests.py b/backend/connector/tests/connector_tests.py index f967a7a4..4e909f68 100644 --- a/backend/connector/tests/connector_tests.py +++ b/backend/connector/tests/connector_tests.py @@ -45,7 +45,10 @@ class TestConnector(APITestCase): "modified_by": 2, "modified_at": "2023-06-14T05:28:47.759Z", "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", - "connector_metadata": {"drive_link": "sample_url", "sharable_link": True}, + "connector_metadata": { + "drive_link": "sample_url", + "sharable_link": True, + }, } response = self.client.post(url, data, format="json") @@ -200,9 +203,9 @@ class TestConnector(APITestCase): }, } response = self.client.put(url, data, format="json") - nested_value = response.data["connector_metadata"]["sample_metadata_json"][ - "key1" - ] + nested_value = response.data["connector_metadata"][ + "sample_metadata_json" + ]["key1"] self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(nested_value, "value1") @@ -226,9 +229,9 @@ class TestConnector(APITestCase): }, } response = self.client.put(url, data, format="json") - nested_value = response.data["connector_metadata"]["sample_metadata_json"][ - "key1" - ] + nested_value = response.data["connector_metadata"][ + "sample_metadata_json" + ]["key1"] nested_list = response.data["connector_metadata"]["file_list"] last_val = nested_list.pop() @@ -293,7 +296,9 @@ class TestConnector(APITestCase): self.assertEqual( connector_id, - ConnectorInstance.objects.get(connector_id=connector_id).connector_id, + ConnectorInstance.objects.get( + connector_id=connector_id + ).connector_id, ) def test_connectors_update_json_field_patch(self) -> None: @@ -304,7 +309,10 @@ class TestConnector(APITestCase): "connector_metadata": { "drive_link": "patch_update_url", "sharable_link": True, - "sample_metadata_json": {"key1": "patch_update1", "key2": "value2"}, + "sample_metadata_json": { + "key1": "patch_update1", + "key2": "value2", + }, } } diff --git a/backend/connector/urls.py b/backend/connector/urls.py index 84afb940..42403352 100644 --- a/backend/connector/urls.py +++ b/backend/connector/urls.py @@ -5,7 +5,12 @@ from .views import ConnectorInstanceViewSet as CIViewSet connector_list = CIViewSet.as_view({"get": "list", "post": "create"}) connector_detail = CIViewSet.as_view( - {"get": "retrieve", "put": "update", "patch": "partial_update", "delete": "destroy"} + { + "get": "retrieve", + "put": "update", + "patch": "partial_update", + "delete": "destroy", + } ) urlpatterns = format_suffix_patterns( diff --git a/backend/connector/views.py b/backend/connector/views.py index ef524493..428a3678 100644 --- a/backend/connector/views.py +++ b/backend/connector/views.py @@ -2,7 +2,6 @@ import logging from typing import Any, Optional from account.custom_exceptions import DuplicateData -from backend.constants import RequestKey from connector.constants import ConnectorInstanceKey as CIKey from connector_auth.constants import ConnectorAuthKey from connector_auth.exceptions import CacheMissException, MissingParamException @@ -15,6 +14,8 @@ from rest_framework.response import Response from rest_framework.versioning import URLPathVersioning from utils.filtering import FilterHelper +from backend.constants import RequestKey + from .models import ConnectorInstance from .serializers import ConnectorInstanceSerializer diff --git a/backend/connector_auth/models.py b/backend/connector_auth/models.py index bc3058a6..a9331e5b 100644 --- a/backend/connector_auth/models.py +++ b/backend/connector_auth/models.py @@ -41,7 +41,10 @@ class ConnectorAuth(AbstractUserSocialAuth): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) user = models.ForeignKey( - User, related_name="connector_auth", on_delete=models.SET_NULL, null=True + User, + related_name="connector_auth", + on_delete=models.SET_NULL, + null=True, ) def __str__(self) -> str: @@ -54,7 +57,10 @@ class ConnectorAuth(AbstractUserSocialAuth): def set_extra_data(self, extra_data=None): # type: ignore ConnectorAuth.check_credential_format(extra_data) - if extra_data[SocialAuthConstants.PROVIDER] == SocialAuthConstants.GOOGLE_OAUTH: + if ( + extra_data[SocialAuthConstants.PROVIDER] + == SocialAuthConstants.GOOGLE_OAUTH + ): extra_data = GoogleAuthHelper.enrich_connector_metadata(extra_data) return super().set_extra_data(extra_data) @@ -67,13 +73,17 @@ class ConnectorAuth(AbstractUserSocialAuth): backend = self.get_backend_instance(strategy) if token and backend and hasattr(backend, "refresh_token"): response = backend.refresh_token(token, *args, **kwargs) - extra_data = backend.extra_data(self, self.uid, response, self.extra_data) + extra_data = backend.extra_data( + self, self.uid, response, self.extra_data + ) extra_data[SocialAuthConstants.PROVIDER] = backend.name extra_data[SocialAuthConstants.UID] = self.uid if self.set_extra_data(extra_data): # type: ignore self.save() - def get_and_refresh_tokens(self, request: Request = None) -> tuple[JSONField, bool]: + def get_and_refresh_tokens( + self, request: Request = None + ) -> tuple[JSONField, bool]: """Uses Social Auth's ability to refresh tokens if necessary. Returns: diff --git a/backend/connector_auth/pipeline/common.py b/backend/connector_auth/pipeline/common.py index 5f5861bb..96d8e365 100644 --- a/backend/connector_auth/pipeline/common.py +++ b/backend/connector_auth/pipeline/common.py @@ -13,7 +13,9 @@ from social_core.backends.oauth import BaseOAuth2 logger = logging.getLogger(__name__) -def check_user_exists(backend: BaseOAuth2, user: User, **kwargs: Any) -> dict[str, str]: +def check_user_exists( + backend: BaseOAuth2, user: User, **kwargs: Any +) -> dict[str, str]: """Checks if user is authenticated (will be handled in auth middleware, present as a fail safe) @@ -46,9 +48,12 @@ def cache_oauth_creds( regarding expiry, uid (unique ID given by provider) and provider. """ cache_key = kwargs.get("cache_key") or backend.strategy.session_get( - settings.SOCIAL_AUTH_FIELDS_STORED_IN_SESSION[0], ConnectorAuthKey.OAUTH_KEY + settings.SOCIAL_AUTH_FIELDS_STORED_IN_SESSION[0], + ConnectorAuthKey.OAUTH_KEY, + ) + extra_data = backend.extra_data( + user, uid, response, details, *args, **kwargs ) - extra_data = backend.extra_data(user, uid, response, details, *args, **kwargs) extra_data[SocialAuthConstants.PROVIDER] = backend.name extra_data[SocialAuthConstants.UID] = uid diff --git a/backend/connector_auth/pipeline/google.py b/backend/connector_auth/pipeline/google.py index 8da2a3ec..6c505a31 100644 --- a/backend/connector_auth/pipeline/google.py +++ b/backend/connector_auth/pipeline/google.py @@ -1,11 +1,13 @@ from datetime import datetime, timedelta -from unstract.connectors.filesystems.google_drive.constants import GDriveConstants - from connector_auth.constants import SocialAuthConstants as AuthConstants from connector_auth.exceptions import EnrichConnectorMetadataException from connector_processor.constants import ConnectorKeys +from unstract.connectors.filesystems.google_drive.constants import ( + GDriveConstants, +) + class GoogleAuthHelper: @staticmethod @@ -24,9 +26,9 @@ class GoogleAuthHelper: ) # Used by Unstract - kwargs[ - ConnectorKeys.PATH - ] = GDriveConstants.ROOT_PREFIX # Acts as a prefix for all paths + kwargs[ConnectorKeys.PATH] = ( + GDriveConstants.ROOT_PREFIX + ) # Acts as a prefix for all paths kwargs[AuthConstants.REFRESH_AFTER] = token_expiry.strftime( AuthConstants.REFRESH_AFTER_FORMAT ) diff --git a/backend/connector_processor/connector_processor.py b/backend/connector_processor/connector_processor.py index cf2f5536..3edd053e 100644 --- a/backend/connector_processor/connector_processor.py +++ b/backend/connector_processor/connector_processor.py @@ -15,6 +15,7 @@ from connector_processor.exceptions import ( TestConnectorException, TestConnectorInputException, ) + from unstract.connectors.base import UnstractConnector from unstract.connectors.connectorkit import Connectorkit from unstract.connectors.enums import ConnectorMode diff --git a/backend/connector_processor/exceptions.py b/backend/connector_processor/exceptions.py index 77ddeea6..1df8079d 100644 --- a/backend/connector_processor/exceptions.py +++ b/backend/connector_processor/exceptions.py @@ -1,5 +1,6 @@ -from backend.exceptions import UnstractBaseException from rest_framework.exceptions import APIException + +from backend.exceptions import UnstractBaseException from unstract.connectors.exceptions import ConnectorError diff --git a/backend/connector_processor/serializers.py b/backend/connector_processor/serializers.py index f1d1a69e..81fcf566 100644 --- a/backend/connector_processor/serializers.py +++ b/backend/connector_processor/serializers.py @@ -1,6 +1,7 @@ -from backend.constants import FieldLengthConstants as FLC from rest_framework import serializers +from backend.constants import FieldLengthConstants as FLC + class TestConnectorSerializer(serializers.Serializer): connector_id = serializers.CharField(max_length=FLC.CONNECTOR_ID_LENGTH) diff --git a/backend/connector_processor/urls.py b/backend/connector_processor/urls.py index 95d69312..790cff1e 100644 --- a/backend/connector_processor/urls.py +++ b/backend/connector_processor/urls.py @@ -6,7 +6,11 @@ from . import views connector_test = ConnectorViewSet.as_view({"post": "test"}) urlpatterns = [ - path("connector_schema/", views.get_connector_schema, name="get_connector_schema"), + path( + "connector_schema/", + views.get_connector_schema, + name="get_connector_schema", + ), path( "supported_connectors/", views.get_supported_connectors, diff --git a/backend/docs/urls.py b/backend/docs/urls.py index ff6071f2..83260b01 100644 --- a/backend/docs/urls.py +++ b/backend/docs/urls.py @@ -12,5 +12,9 @@ schema_view = get_schema_view( ) urlpatterns = [ - path("doc/", schema_view.with_ui("redoc", cache_timeout=0), name="schema-redoc"), + path( + "doc/", + schema_view.with_ui("redoc", cache_timeout=0), + name="schema-redoc", + ), ] diff --git a/backend/feature_flag/urls.py b/backend/feature_flag/urls.py index 288dc152..0bc4c274 100644 --- a/backend/feature_flag/urls.py +++ b/backend/feature_flag/urls.py @@ -2,6 +2,7 @@ This module defines the URL patterns for the feature_flags app. """ + import feature_flag.views as views from django.urls import path diff --git a/backend/feature_flag/views.py b/backend/feature_flag/views.py index 964f0ac4..6e155f5e 100644 --- a/backend/feature_flag/views.py +++ b/backend/feature_flag/views.py @@ -3,12 +3,14 @@ Returns: evaluate response """ + import logging from rest_framework import status from rest_framework.decorators import api_view from rest_framework.request import Request from rest_framework.response import Response + from unstract.flags.client import EvaluationClient logger = logging.getLogger(__name__) diff --git a/backend/file_management/constants.py b/backend/file_management/constants.py index 8ada8331..a17b60b4 100644 --- a/backend/file_management/constants.py +++ b/backend/file_management/constants.py @@ -7,8 +7,8 @@ class FileInformationKey: FILE_UPLOAD_ALLOWED_EXT = ["pdf"] FILE_UPLOAD_ALLOWED_MIME = ["application/pdf"] + class FileViewTypes: ORIGINAL = "ORIGINAL" EXTRACT = "EXTRACT" SUMMARIZE = "SUMMARIZE" - diff --git a/backend/file_management/file_management_helper.py b/backend/file_management/file_management_helper.py index 257d8d83..2e904f94 100644 --- a/backend/file_management/file_management_helper.py +++ b/backend/file_management/file_management_helper.py @@ -126,9 +126,9 @@ class FileManagerHelper: response = StreamingHttpResponse( file, content_type=file_content_type ) - response[ - "Content-Disposition" - ] = f"attachment; filename={base_name}" + response["Content-Disposition"] = ( + f"attachment; filename={base_name}" + ) return response except ApiRequestError as exception: FileManagerHelper.logger.error( @@ -194,8 +194,7 @@ class FileManagerHelper: elif file_content_type == "text/plain": with fs.open(file_path, "r") as file: - FileManagerHelper.logger.info( - f"Reading text file: {file_path}") + FileManagerHelper.logger.info(f"Reading text file: {file_path}") text_content = file.read() return text_content else: diff --git a/backend/file_management/views.py b/backend/file_management/views.py index b494f016..db2a5429 100644 --- a/backend/file_management/views.py +++ b/backend/file_management/views.py @@ -145,12 +145,13 @@ class FileManagementViewSet(viewsets.ModelViewSet): # Create a record in the db for the file document = PromptStudioDocumentHelper.create( - tool_id=tool_id, document_name=file_name) + tool_id=tool_id, document_name=file_name + ) # Create a dictionary to store document data doc = { "document_id": document.document_id, "document_name": document.document_name, - "tool": document.tool.tool_id + "tool": document.tool.tool_id, } # Store file logger.info( @@ -177,7 +178,7 @@ class FileManagementViewSet(viewsets.ModelViewSet): tool_id: str = serializer.validated_data.get("tool_id") view_type: str = serializer.validated_data.get("view_type") - filename_without_extension = file_name.rsplit('.', 1)[0] + filename_without_extension = file_name.rsplit(".", 1)[0] if view_type == FileViewTypes.EXTRACT: file_name = ( f"{FileViewTypes.EXTRACT.lower()}/" @@ -189,20 +190,19 @@ class FileManagementViewSet(viewsets.ModelViewSet): f"{filename_without_extension}.txt" ) - file_path = ( - file_path - ) = FileManagerHelper.handle_sub_directory_for_tenants( - request.org_id, - is_create=True, - user_id=request.user.user_id, - tool_id=tool_id, + file_path = file_path = ( + FileManagerHelper.handle_sub_directory_for_tenants( + request.org_id, + is_create=True, + user_id=request.user.user_id, + tool_id=tool_id, + ) ) file_system = LocalStorageFS(settings={"path": file_path}) if not file_path.endswith("/"): file_path += "/" file_path += file_name - contents = FileManagerHelper.fetch_file_contents( - file_system, file_path) + contents = FileManagerHelper.fetch_file_contents(file_system, file_path) return Response({"data": contents}, status=status.HTTP_200_OK) @action(detail=True, methods=["get"]) diff --git a/backend/init.sql b/backend/init.sql index a4c65998..6ae1249d 100644 --- a/backend/init.sql +++ b/backend/init.sql @@ -3,4 +3,4 @@ ALTER ROLE unstract_dev SET default_transaction_isolation TO 'read committed'; ALTER ROLE unstract_dev SET timezone TO 'UTC'; ALTER USER unstract_dev CREATEDB; GRANT ALL PRIVILEGES ON DATABASE unstract_db TO unstract_dev; -CREATE DATABASE unstract; \ No newline at end of file +CREATE DATABASE unstract; diff --git a/backend/pipeline/manager.py b/backend/pipeline/manager.py index fbbef487..3c6bfbc8 100644 --- a/backend/pipeline/manager.py +++ b/backend/pipeline/manager.py @@ -1,7 +1,6 @@ import logging from typing import Any, Optional -from backend.constants import RequestHeader from django.conf import settings from django.urls import reverse from pipeline.constants import PipelineKey, PipelineURL @@ -11,9 +10,14 @@ from pipeline.pipeline_processor import PipelineProcessor from rest_framework.request import Request from rest_framework.response import Response from utils.request.constants import RequestConstants -from workflow_manager.workflow.constants import WorkflowExecutionKey, WorkflowKey +from workflow_manager.workflow.constants import ( + WorkflowExecutionKey, + WorkflowKey, +) from workflow_manager.workflow.views import WorkflowViewSet +from backend.constants import RequestHeader + logger = logging.getLogger(__name__) diff --git a/backend/pipeline/urls.py b/backend/pipeline/urls.py index 9f2de279..7af3dfca 100644 --- a/backend/pipeline/urls.py +++ b/backend/pipeline/urls.py @@ -10,7 +10,12 @@ pipeline_list = PipelineViewSet.as_view( } ) pipeline_detail = PipelineViewSet.as_view( - {"get": "retrieve", "put": "update", "patch": "partial_update", "delete": "destroy"} + { + "get": "retrieve", + "put": "update", + "patch": "partial_update", + "delete": "destroy", + } ) pipeline_execute = PipelineViewSet.as_view({"post": "execute"}) diff --git a/backend/platform_settings/exceptions.py b/backend/platform_settings/exceptions.py index 009dbf66..4e54b58e 100644 --- a/backend/platform_settings/exceptions.py +++ b/backend/platform_settings/exceptions.py @@ -36,6 +36,7 @@ class InvalidRequest(APIException): status_code = 401 default_detail = "Invalid Request" + class DuplicateData(APIException): status_code = 400 default_detail = "Duplicate Data" diff --git a/backend/platform_settings/platform_auth_helper.py b/backend/platform_settings/platform_auth_helper.py index c3d7d2d4..167db67d 100644 --- a/backend/platform_settings/platform_auth_helper.py +++ b/backend/platform_settings/platform_auth_helper.py @@ -32,7 +32,9 @@ class PlatformAuthHelper: ) raise error if not auth_controller.is_admin_by_role(member.role): - logger.error("User is not having right access to perform this operation.") + logger.error( + "User is not having right access to perform this operation." + ) raise UserForbidden() else: pass diff --git a/backend/platform_settings/serializers.py b/backend/platform_settings/serializers.py index f8e320e1..603883c6 100644 --- a/backend/platform_settings/serializers.py +++ b/backend/platform_settings/serializers.py @@ -1,9 +1,10 @@ # serializers.py from account.models import PlatformKey -from backend.serializers import AuditSerializer from rest_framework import serializers +from backend.serializers import AuditSerializer + class PlatformKeySerializer(AuditSerializer): class Meta: diff --git a/backend/plugins/authentication/auth_sample/auth_service.py b/backend/plugins/authentication/auth_sample/auth_service.py index 996ec59f..3ec4e788 100644 --- a/backend/plugins/authentication/auth_sample/auth_service.py +++ b/backend/plugins/authentication/auth_sample/auth_service.py @@ -7,7 +7,13 @@ from rest_framework.request import Request from rest_framework.response import Response from .auth_helper import AuthHelper -from .dto import AuthOrganization, ResetUserPasswordDto, TokenData, User, UserInfo +from .dto import ( + AuthOrganization, + ResetUserPasswordDto, + TokenData, + User, + UserInfo, +) from .enums import Region from .exceptions import MethodNotImplemented @@ -36,7 +42,10 @@ class AuthService(ABC): self, user: User, token: Optional[dict[str, Any]] = None ) -> Optional[UserInfo]: return UserInfo( - id=user.id, name=user.username, display_name=user.username, email=user.email + id=user.id, + name=user.username, + display_name=user.username, + email=user.email, ) def get_organization_info(self, org_id: str) -> Any: @@ -64,7 +73,9 @@ class AuthService(ABC): def get_user_id_from_token(self, token: dict[str, Any]) -> Response: return token["userinfo"]["sub"] - def get_organization_members_by_org_id(self, organization_id: str) -> Response: + def get_organization_members_by_org_id( + self, organization_id: str + ) -> Response: raise MethodNotImplemented() def reset_user_password(self, user: User) -> ResetUserPasswordDto: diff --git a/backend/project/serializers.py b/backend/project/serializers.py index dea74676..13906954 100644 --- a/backend/project/serializers.py +++ b/backend/project/serializers.py @@ -1,10 +1,11 @@ from typing import Any -from backend.serializers import AuditSerializer from project.models import Project from workflow_manager.workflow.constants import WorkflowKey from workflow_manager.workflow.serializers import WorkflowSerializer +from backend.serializers import AuditSerializer + class ProjectSerializer(AuditSerializer): class Meta: diff --git a/backend/project/tests/project_tests.py b/backend/project/tests/project_tests.py index 891cbd89..40ce8190 100644 --- a/backend/project/tests/project_tests.py +++ b/backend/project/tests/project_tests.py @@ -79,7 +79,8 @@ class TestProjects(APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual( - project_name, Project.objects.get(project_name=project_name).project_name + project_name, + Project.objects.get(project_name=project_name).project_name, ) def test_projects_update_pk(self) -> None: @@ -107,7 +108,8 @@ class TestProjects(APITestCase): project_name = response.data["project_name"] self.assertEqual( - project_name, Project.objects.get(project_name=project_name).project_name + project_name, + Project.objects.get(project_name=project_name).project_name, ) def test_projects_delete(self) -> None: diff --git a/backend/project/urls.py b/backend/project/urls.py index b7805e28..4af97d0d 100644 --- a/backend/project/urls.py +++ b/backend/project/urls.py @@ -5,19 +5,30 @@ from .views import ProjectViewSet project_list = ProjectViewSet.as_view({"get": "list", "post": "create"}) project_detail = ProjectViewSet.as_view( - {"get": "retrieve", "put": "update", "patch": "partial_update", "delete": "destroy"} + { + "get": "retrieve", + "put": "update", + "patch": "partial_update", + "delete": "destroy", + } ) project_settings = ProjectViewSet.as_view( {"get": "project_settings", "put": "project_settings"} ) -project_settings_schema = ProjectViewSet.as_view({"get": "project_settings_schema"}) +project_settings_schema = ProjectViewSet.as_view( + {"get": "project_settings_schema"} +) urlpatterns = format_suffix_patterns( [ path("projects/", project_list, name="projects-list"), path("projects//", project_detail, name="projects-detail"), - path("projects//settings/", project_settings, name="project-settings"), + path( + "projects//settings/", + project_settings, + name="project-settings", + ), path( "projects/settings/", project_settings_schema, diff --git a/backend/prompt/tests/conftest.py b/backend/prompt/tests/conftest.py index 0d9415cc..c64c7226 100644 --- a/backend/prompt/tests/conftest.py +++ b/backend/prompt/tests/conftest.py @@ -1,10 +1,10 @@ import pytest - from django.core.management import call_command -@pytest.fixture(scope='session') + +@pytest.fixture(scope="session") def django_db_setup(django_db_blocker): fixtures = ["./prompt/tests/fixtures/prompts_001.json"] with django_db_blocker.unblock(): - call_command('loaddata', *fixtures) \ No newline at end of file + call_command("loaddata", *fixtures) diff --git a/backend/prompt/tests/test_urls.py b/backend/prompt/tests/test_urls.py index f4f2799f..94469f5c 100644 --- a/backend/prompt/tests/test_urls.py +++ b/backend/prompt/tests/test_urls.py @@ -1,10 +1,9 @@ import pytest from django.urls import reverse +from prompt.models import Prompt from rest_framework import status from rest_framework.test import APITestCase -from prompt.models import Prompt - pytestmark = pytest.mark.django_db @@ -27,8 +26,9 @@ class TestPrompts(APITestCase): def test_prompts_detail_throw_404(self): """Tests whether a 404 error is thrown on retrieving a prompt.""" - url = reverse("prompts-detail", - kwargs={"pk": 200}) # Prompt doesn't exist + url = reverse( + "prompts-detail", kwargs={"pk": 200} + ) # Prompt doesn't exist response = self.client.get(url) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) diff --git a/backend/prompt/urls.py b/backend/prompt/urls.py index 809ac5c4..6d04bca1 100644 --- a/backend/prompt/urls.py +++ b/backend/prompt/urls.py @@ -1,20 +1,21 @@ from django.urls import path from rest_framework.urlpatterns import format_suffix_patterns + from .views import PromptViewSet -prompt_list = PromptViewSet.as_view({ - 'get': 'list', - 'post': 'create' -}) -prompt_detail = PromptViewSet.as_view({ - 'get': 'retrieve', - 'put': 'update', - 'patch': 'partial_update', - 'delete': 'destroy' -}) - -urlpatterns = format_suffix_patterns([ - path('prompt/', prompt_list, name='prompt-list'), - path('prompt//', prompt_detail, name='prompt-detail'), -]) +prompt_list = PromptViewSet.as_view({"get": "list", "post": "create"}) +prompt_detail = PromptViewSet.as_view( + { + "get": "retrieve", + "put": "update", + "patch": "partial_update", + "delete": "destroy", + } +) +urlpatterns = format_suffix_patterns( + [ + path("prompt/", prompt_list, name="prompt-list"), + path("prompt//", prompt_detail, name="prompt-detail"), + ] +) diff --git a/backend/prompt_studio/prompt_profile_manager/migrations/0008_profilemanager_migration.py b/backend/prompt_studio/prompt_profile_manager/migrations/0008_profilemanager_migration.py index 2f2634fc..33f3cca4 100644 --- a/backend/prompt_studio/prompt_profile_manager/migrations/0008_profilemanager_migration.py +++ b/backend/prompt_studio/prompt_profile_manager/migrations/0008_profilemanager_migration.py @@ -18,7 +18,7 @@ class Migration(migrations.Migration): ( "prompt_studio_core", "0007_remove_customtool_default_profile_and_more", - ) + ), ] def MigrateProfileManager(apps: Any, schema_editor: Any) -> None: diff --git a/backend/prompt_studio/prompt_profile_manager/migrations/0009_alter_profilemanager_prompt_studio_tool.py b/backend/prompt_studio/prompt_profile_manager/migrations/0009_alter_profilemanager_prompt_studio_tool.py index e35da9b9..bec9bd15 100644 --- a/backend/prompt_studio/prompt_profile_manager/migrations/0009_alter_profilemanager_prompt_studio_tool.py +++ b/backend/prompt_studio/prompt_profile_manager/migrations/0009_alter_profilemanager_prompt_studio_tool.py @@ -6,7 +6,10 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ("prompt_studio_core", "0008_customtool_exclude_failed_customtool_monitor_llm"), + ( + "prompt_studio_core", + "0008_customtool_exclude_failed_customtool_monitor_llm", + ), ("prompt_profile_manager", "0008_profilemanager_migration"), ] diff --git a/backend/prompt_studio/prompt_profile_manager/serializers.py b/backend/prompt_studio/prompt_profile_manager/serializers.py index a217a436..1f9ef62d 100644 --- a/backend/prompt_studio/prompt_profile_manager/serializers.py +++ b/backend/prompt_studio/prompt_profile_manager/serializers.py @@ -22,19 +22,19 @@ class ProfileManagerSerializer(AuditSerializer): vector_db = rep[ProfileManagerKeys.VECTOR_STORE] x2text = rep[ProfileManagerKeys.X2TEXT] if llm: - rep[ - ProfileManagerKeys.LLM - ] = AdapterProcessor.get_adapter_instance_by_id(llm) + rep[ProfileManagerKeys.LLM] = ( + AdapterProcessor.get_adapter_instance_by_id(llm) + ) if embedding: - rep[ - ProfileManagerKeys.EMBEDDING_MODEL - ] = AdapterProcessor.get_adapter_instance_by_id(embedding) + rep[ProfileManagerKeys.EMBEDDING_MODEL] = ( + AdapterProcessor.get_adapter_instance_by_id(embedding) + ) if vector_db: - rep[ - ProfileManagerKeys.VECTOR_STORE - ] = AdapterProcessor.get_adapter_instance_by_id(vector_db) + rep[ProfileManagerKeys.VECTOR_STORE] = ( + AdapterProcessor.get_adapter_instance_by_id(vector_db) + ) if x2text: - rep[ - ProfileManagerKeys.X2TEXT - ] = AdapterProcessor.get_adapter_instance_by_id(x2text) + rep[ProfileManagerKeys.X2TEXT] = ( + AdapterProcessor.get_adapter_instance_by_id(x2text) + ) return rep diff --git a/backend/prompt_studio/prompt_profile_manager/views.py b/backend/prompt_studio/prompt_profile_manager/views.py index f070f897..e6a8159b 100644 --- a/backend/prompt_studio/prompt_profile_manager/views.py +++ b/backend/prompt_studio/prompt_profile_manager/views.py @@ -45,7 +45,9 @@ class ProfileManagerView(viewsets.ModelViewSet): def create( self, request: HttpRequest, *args: tuple[Any], **kwargs: dict[str, Any] ) -> Response: - serializer: ProfileManagerSerializer = self.get_serializer(data=request.data) + serializer: ProfileManagerSerializer = self.get_serializer( + data=request.data + ) # Overriding default exception behaviour # TO DO : Handle model related exceptions. serializer.is_valid(raise_exception=True) diff --git a/backend/prompt_studio/prompt_studio/migrations/0002_prompt_eval_metrics.py b/backend/prompt_studio/prompt_studio/migrations/0002_prompt_eval_metrics.py index 22f61a29..3f9f7fb8 100644 --- a/backend/prompt_studio/prompt_studio/migrations/0002_prompt_eval_metrics.py +++ b/backend/prompt_studio/prompt_studio/migrations/0002_prompt_eval_metrics.py @@ -6,43 +6,43 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('prompt_studio', '0001_initial'), + ("prompt_studio", "0001_initial"), ] operations = [ migrations.AddField( - model_name='toolstudioprompt', - name='eval_guidance_completeness', + model_name="toolstudioprompt", + name="eval_guidance_completeness", field=models.BooleanField(default=True), ), migrations.AddField( - model_name='toolstudioprompt', - name='eval_guidance_toxicity', + model_name="toolstudioprompt", + name="eval_guidance_toxicity", field=models.BooleanField(default=True), ), migrations.AddField( - model_name='toolstudioprompt', - name='eval_quality_correctness', + model_name="toolstudioprompt", + name="eval_quality_correctness", field=models.BooleanField(default=True), ), migrations.AddField( - model_name='toolstudioprompt', - name='eval_quality_faithfulness', + model_name="toolstudioprompt", + name="eval_quality_faithfulness", field=models.BooleanField(default=True), ), migrations.AddField( - model_name='toolstudioprompt', - name='eval_quality_relevance', + model_name="toolstudioprompt", + name="eval_quality_relevance", field=models.BooleanField(default=True), ), migrations.AddField( - model_name='toolstudioprompt', - name='eval_security_pii', + model_name="toolstudioprompt", + name="eval_security_pii", field=models.BooleanField(default=True), ), migrations.AddField( - model_name='toolstudioprompt', - name='evaluate', + model_name="toolstudioprompt", + name="evaluate", field=models.BooleanField(default=True), ), ] diff --git a/backend/prompt_studio/prompt_studio/serializers.py b/backend/prompt_studio/prompt_studio/serializers.py index 18fe6d70..56d7f0fa 100644 --- a/backend/prompt_studio/prompt_studio/serializers.py +++ b/backend/prompt_studio/prompt_studio/serializers.py @@ -1,6 +1,7 @@ -from backend.serializers import AuditSerializer from rest_framework import serializers +from backend.serializers import AuditSerializer + from .models import ToolStudioPrompt diff --git a/backend/prompt_studio/prompt_studio_core/migrations/0006_alter_customtool_summarize_as_source_and_more.py b/backend/prompt_studio/prompt_studio_core/migrations/0006_alter_customtool_summarize_as_source_and_more.py index 566fee5b..27f52948 100644 --- a/backend/prompt_studio/prompt_studio_core/migrations/0006_alter_customtool_summarize_as_source_and_more.py +++ b/backend/prompt_studio/prompt_studio_core/migrations/0006_alter_customtool_summarize_as_source_and_more.py @@ -5,7 +5,10 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ("prompt_studio_core", "0005_alter_customtool_default_profile_and_more"), + ( + "prompt_studio_core", + "0005_alter_customtool_default_profile_and_more", + ), ] operations = [ @@ -13,7 +16,8 @@ class Migration(migrations.Migration): model_name="customtool", name="summarize_as_source", field=models.BooleanField( - db_comment="Flag to use summarized content as source", default=False + db_comment="Flag to use summarized content as source", + default=False, ), ), migrations.AlterField( diff --git a/backend/prompt_studio/prompt_studio_core/models.py b/backend/prompt_studio/prompt_studio_core/models.py index 960c0723..4878f0d9 100644 --- a/backend/prompt_studio/prompt_studio_core/models.py +++ b/backend/prompt_studio/prompt_studio_core/models.py @@ -48,12 +48,12 @@ class CustomTool(BaseModel): preamble = models.TextField( blank=True, db_comment="Preamble to the prompts", - default=DefaultPrompts.PREAMBLE + default=DefaultPrompts.PREAMBLE, ) postamble = models.TextField( blank=True, db_comment="Appended as postable to prompts.", - default=DefaultPrompts.POSTAMBLE + default=DefaultPrompts.POSTAMBLE, ) prompt_grammer = models.JSONField( null=True, blank=True, db_comment="Synonymous words used in prompt" diff --git a/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py b/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py index 958590b2..57be7f5c 100644 --- a/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py +++ b/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py @@ -184,9 +184,9 @@ class PromptStudioHelper: Returns: List[ToolStudioPrompt]: List of instance of the model """ - prompt_instances: list[ - ToolStudioPrompt - ] = ToolStudioPrompt.objects.filter(tool_id=tool_id) + prompt_instances: list[ToolStudioPrompt] = ( + ToolStudioPrompt.objects.filter(tool_id=tool_id) + ) return prompt_instances @staticmethod @@ -509,9 +509,9 @@ class PromptStudioHelper: ) output: dict[str, Any] = {} - output[ - TSPKeys.ASSERTION_FAILURE_PROMPT - ] = prompt.assertion_failure_prompt + output[TSPKeys.ASSERTION_FAILURE_PROMPT] = ( + prompt.assertion_failure_prompt + ) output[TSPKeys.ASSERT_PROMPT] = prompt.assert_prompt output[TSPKeys.IS_ASSERT] = prompt.is_assert output[TSPKeys.PROMPT] = prompt.prompt @@ -526,12 +526,12 @@ class PromptStudioHelper: output[TSPKeys.GRAMMAR] = grammar_list output[TSPKeys.TYPE] = prompt.enforce_type output[TSPKeys.NAME] = prompt.prompt_key - output[ - TSPKeys.RETRIEVAL_STRATEGY - ] = prompt.profile_manager.retrieval_strategy - output[ - TSPKeys.SIMILARITY_TOP_K - ] = prompt.profile_manager.similarity_top_k + output[TSPKeys.RETRIEVAL_STRATEGY] = ( + prompt.profile_manager.retrieval_strategy + ) + output[TSPKeys.SIMILARITY_TOP_K] = ( + prompt.profile_manager.similarity_top_k + ) output[TSPKeys.SECTION] = prompt.profile_manager.section output[TSPKeys.X2TEXT_ADAPTER] = x2text # Eval settings for the prompt @@ -547,9 +547,9 @@ class PromptStudioHelper: ] = tool.exclude_failed output[TSPKeys.ENABLE_CHALLENGE] = tool.enable_challenge output[TSPKeys.CHALLENGE_LLM] = challenge_llm - output[ - TSPKeys.SINGLE_PASS_EXTRACTION_MODE - ] = tool.single_pass_extraction_mode + output[TSPKeys.SINGLE_PASS_EXTRACTION_MODE] = ( + tool.single_pass_extraction_mode + ) for attr in dir(prompt): if attr.startswith(TSPKeys.EVAL_METRIC_PREFIX): attr_val = getattr(prompt, attr) diff --git a/backend/prompt_studio/prompt_studio_core/views.py b/backend/prompt_studio/prompt_studio_core/views.py index c5ee0d0d..27e1a8f2 100644 --- a/backend/prompt_studio/prompt_studio_core/views.py +++ b/backend/prompt_studio/prompt_studio_core/views.py @@ -131,9 +131,9 @@ class PromptStudioCoreView(viewsets.ModelViewSet): Response: Reponse of dropdown dict """ try: - select_choices: dict[ - str, Any - ] = PromptStudioHelper.get_select_fields() + select_choices: dict[str, Any] = ( + PromptStudioHelper.get_select_fields() + ) return Response(select_choices, status=status.HTTP_200_OK) except Exception as e: logger.error(f"Error occured while fetching select fields {e}") diff --git a/backend/prompt_studio/prompt_studio_document_manager/apps.py b/backend/prompt_studio/prompt_studio_document_manager/apps.py index 90bd8928..b89cdd4a 100644 --- a/backend/prompt_studio/prompt_studio_document_manager/apps.py +++ b/backend/prompt_studio/prompt_studio_document_manager/apps.py @@ -2,4 +2,4 @@ from django.apps import AppConfig class PromptStudioDocumentManagerConfig(AppConfig): - name = 'prompt_studio.prompt_studio_document_manager' + name = "prompt_studio.prompt_studio_document_manager" diff --git a/backend/prompt_studio/prompt_studio_document_manager/migrations/0001_initial.py b/backend/prompt_studio/prompt_studio_document_manager/migrations/0001_initial.py index f926cdca..4132bf29 100644 --- a/backend/prompt_studio/prompt_studio_document_manager/migrations/0001_initial.py +++ b/backend/prompt_studio/prompt_studio_document_manager/migrations/0001_initial.py @@ -12,7 +12,10 @@ class Migration(migrations.Migration): dependencies = [ migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ("prompt_studio_core", "0007_remove_customtool_default_profile_and_more"), + ( + "prompt_studio_core", + "0007_remove_customtool_default_profile_and_more", + ), ] operations = [ @@ -33,7 +36,8 @@ class Migration(migrations.Migration): ( "document_name", models.CharField( - db_comment="Field to store the document name", editable=False + db_comment="Field to store the document name", + editable=False, ), ), ( @@ -71,7 +75,8 @@ class Migration(migrations.Migration): migrations.AddConstraint( model_name="documentmanager", constraint=models.UniqueConstraint( - fields=("document_name", "tool"), name="unique_document_name_tool" + fields=("document_name", "tool"), + name="unique_document_name_tool", ), ), ] diff --git a/backend/prompt_studio/prompt_studio_index_manager/constants.py b/backend/prompt_studio/prompt_studio_index_manager/constants.py index 2a0588cb..6cf3f5e5 100644 --- a/backend/prompt_studio/prompt_studio_index_manager/constants.py +++ b/backend/prompt_studio/prompt_studio_index_manager/constants.py @@ -1,3 +1,3 @@ class IndexManagerKeys: PROFILE_MANAGER = "profile_manager" - DOCUMENT_MANAGER = "document_manager" \ No newline at end of file + DOCUMENT_MANAGER = "document_manager" diff --git a/backend/prompt_studio/prompt_studio_index_manager/migrations/0001_initial.py b/backend/prompt_studio/prompt_studio_index_manager/migrations/0001_initial.py index 3869e2e7..48841d07 100644 --- a/backend/prompt_studio/prompt_studio_index_manager/migrations/0001_initial.py +++ b/backend/prompt_studio/prompt_studio_index_manager/migrations/0001_initial.py @@ -52,7 +52,8 @@ class Migration(migrations.Migration): ( "index_ids_history", models.JSONField( - db_comment="List of index ids", default=list), + db_comment="List of index ids", default=list + ), ), ( "created_by", diff --git a/backend/prompt_studio/prompt_studio_output_manager/exceptions.py b/backend/prompt_studio/prompt_studio_output_manager/exceptions.py index 63551469..f1153091 100644 --- a/backend/prompt_studio/prompt_studio_output_manager/exceptions.py +++ b/backend/prompt_studio/prompt_studio_output_manager/exceptions.py @@ -4,4 +4,3 @@ from rest_framework.exceptions import APIException class InternalError(APIException): status_code = 400 default_detail = "Internal service error." - diff --git a/backend/prompt_studio/prompt_studio_output_manager/migrations/0001_initial.py b/backend/prompt_studio/prompt_studio_output_manager/migrations/0001_initial.py index 726064b1..b0eaae9a 100644 --- a/backend/prompt_studio/prompt_studio_output_manager/migrations/0001_initial.py +++ b/backend/prompt_studio/prompt_studio_output_manager/migrations/0001_initial.py @@ -1,9 +1,10 @@ # Generated by Django 4.2.1 on 2024-02-07 11:20 +import uuid + +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion -import uuid class Migration(migrations.Migration): @@ -31,7 +32,10 @@ class Migration(migrations.Migration): serialize=False, ), ), - ("output", models.CharField(db_comment="Field to store output")), + ( + "output", + models.CharField(db_comment="Field to store output"), + ), ( "created_by", models.ForeignKey( diff --git a/backend/prompt_studio/prompt_studio_output_manager/migrations/0003_alter_promptstudiooutputmanager_doc_name.py b/backend/prompt_studio/prompt_studio_output_manager/migrations/0003_alter_promptstudiooutputmanager_doc_name.py index 515012c0..69f2c5a1 100644 --- a/backend/prompt_studio/prompt_studio_output_manager/migrations/0003_alter_promptstudiooutputmanager_doc_name.py +++ b/backend/prompt_studio/prompt_studio_output_manager/migrations/0003_alter_promptstudiooutputmanager_doc_name.py @@ -1,12 +1,15 @@ # Generated by Django 4.2.1 on 2024-02-07 19:53 -from django.db import migrations, models import django.utils.timezone +from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ("prompt_studio_output_manager", "0002_promptstudiooutputmanager_doc_name"), + ( + "prompt_studio_output_manager", + "0002_promptstudiooutputmanager_doc_name", + ), ] operations = [ diff --git a/backend/prompt_studio/prompt_studio_output_manager/migrations/0004_alter_promptstudiooutputmanager_doc_name.py b/backend/prompt_studio/prompt_studio_output_manager/migrations/0004_alter_promptstudiooutputmanager_doc_name.py index b51b8023..76f6081b 100644 --- a/backend/prompt_studio/prompt_studio_output_manager/migrations/0004_alter_promptstudiooutputmanager_doc_name.py +++ b/backend/prompt_studio/prompt_studio_output_manager/migrations/0004_alter_promptstudiooutputmanager_doc_name.py @@ -15,6 +15,8 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="promptstudiooutputmanager", name="doc_name", - field=models.CharField(db_comment="Field to store the document name"), + field=models.CharField( + db_comment="Field to store the document name" + ), ), ] diff --git a/backend/prompt_studio/prompt_studio_output_manager/migrations/0005_alter_promptstudiooutputmanager_profile_manager_and_more.py b/backend/prompt_studio/prompt_studio_output_manager/migrations/0005_alter_promptstudiooutputmanager_profile_manager_and_more.py index 61067a38..2274060e 100644 --- a/backend/prompt_studio/prompt_studio_output_manager/migrations/0005_alter_promptstudiooutputmanager_profile_manager_and_more.py +++ b/backend/prompt_studio/prompt_studio_output_manager/migrations/0005_alter_promptstudiooutputmanager_profile_manager_and_more.py @@ -1,7 +1,7 @@ # Generated by Django 4.2.1 on 2024-02-07 20:53 -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models class Migration(migrations.Migration): diff --git a/backend/prompt_studio/prompt_studio_output_manager/migrations/0010_delete_duplicate_rows.py b/backend/prompt_studio/prompt_studio_output_manager/migrations/0010_delete_duplicate_rows.py index 1c23c113..14a63fe8 100644 --- a/backend/prompt_studio/prompt_studio_output_manager/migrations/0010_delete_duplicate_rows.py +++ b/backend/prompt_studio/prompt_studio_output_manager/migrations/0010_delete_duplicate_rows.py @@ -5,34 +5,39 @@ from django.db import migrations, models def delete_duplicates_and_nulls(apps, schema_editor): prompt_studio_output_manager = apps.get_model( - "prompt_studio_output_manager", "PromptStudioOutputManager") + "prompt_studio_output_manager", "PromptStudioOutputManager" + ) # Delete rows where prompt_id, document_manager, profile_manager, or tool_id is NULL prompt_studio_output_manager.objects.filter( - models.Q(prompt_id=None) | - models.Q(document_manager=None) | - models.Q(profile_manager=None) | - models.Q(tool_id=None) + models.Q(prompt_id=None) + | models.Q(document_manager=None) + | models.Q(profile_manager=None) + | models.Q(tool_id=None) ).delete() # Find duplicate rows based on unique constraint fields and count their occurrences - duplicates = prompt_studio_output_manager.objects.values( - 'prompt_id', 'document_manager', 'profile_manager', 'tool_id' - ).annotate( - count=models.Count('prompt_output_id') - ).filter( - count__gt=1 # Filter to only get rows that have duplicates + duplicates = ( + prompt_studio_output_manager.objects.values( + "prompt_id", "document_manager", "profile_manager", "tool_id" + ) + .annotate(count=models.Count("prompt_output_id")) + .filter(count__gt=1) # Filter to only get rows that have duplicates ) # Iterate over each set of duplicates found for duplicate in duplicates: # Find all instances of duplicates for the current set - pks = prompt_studio_output_manager.objects.filter( - prompt_id=duplicate['prompt_id'], - document_manager=duplicate['document_manager'], - profile_manager=duplicate['profile_manager'], - tool_id=duplicate['tool_id'] - ).order_by('-created_at').values_list('pk')[1:] # Order by created_at descending and skip the first one (keep the latest) + pks = ( + prompt_studio_output_manager.objects.filter( + prompt_id=duplicate["prompt_id"], + document_manager=duplicate["document_manager"], + profile_manager=duplicate["profile_manager"], + tool_id=duplicate["tool_id"], + ) + .order_by("-created_at") + .values_list("pk")[1:] + ) # Order by created_at descending and skip the first one (keep the latest) # Delete the duplicate rows prompt_studio_output_manager.objects.filter(pk__in=pks).delete() @@ -47,6 +52,7 @@ class Migration(migrations.Migration): ] operations = [ - migrations.RunPython(delete_duplicates_and_nulls, - reverse_code=migrations.RunPython.noop), + migrations.RunPython( + delete_duplicates_and_nulls, reverse_code=migrations.RunPython.noop + ), ] diff --git a/backend/prompt_studio/prompt_studio_output_manager/migrations/0011_promptstudiooutputmanager_is_single_pass_extract_and_more.py b/backend/prompt_studio/prompt_studio_output_manager/migrations/0011_promptstudiooutputmanager_is_single_pass_extract_and_more.py index e81ee2ab..adae6726 100644 --- a/backend/prompt_studio/prompt_studio_output_manager/migrations/0011_promptstudiooutputmanager_is_single_pass_extract_and_more.py +++ b/backend/prompt_studio/prompt_studio_output_manager/migrations/0011_promptstudiooutputmanager_is_single_pass_extract_and_more.py @@ -6,8 +6,14 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ("prompt_studio_core", "0008_customtool_exclude_failed_customtool_monitor_llm"), - ("prompt_profile_manager", "0009_alter_profilemanager_prompt_studio_tool"), + ( + "prompt_studio_core", + "0008_customtool_exclude_failed_customtool_monitor_llm", + ), + ( + "prompt_profile_manager", + "0009_alter_profilemanager_prompt_studio_tool", + ), ("prompt_studio", "0006_alter_toolstudioprompt_prompt_key_and_more"), ("prompt_studio_output_manager", "0010_delete_duplicate_rows"), ] @@ -17,7 +23,8 @@ class Migration(migrations.Migration): model_name="promptstudiooutputmanager", name="is_single_pass_extract", field=models.BooleanField( - db_comment="Is the single pass extraction mode active", default=False + db_comment="Is the single pass extraction mode active", + default=False, ), ), migrations.AlterField( diff --git a/backend/prompt_studio/prompt_studio_output_manager/models.py b/backend/prompt_studio/prompt_studio_output_manager/models.py index f4380334..8e6acfbc 100644 --- a/backend/prompt_studio/prompt_studio_output_manager/models.py +++ b/backend/prompt_studio/prompt_studio_output_manager/models.py @@ -72,8 +72,13 @@ class PromptStudioOutputManager(BaseModel): class Meta: constraints = [ models.UniqueConstraint( - fields=["prompt_id", "document_manager", "profile_manager", - "tool_id", "is_single_pass_extract"], + fields=[ + "prompt_id", + "document_manager", + "profile_manager", + "tool_id", + "is_single_pass_extract", + ], name="unique_prompt_output", ), ] diff --git a/backend/prompt_studio/prompt_studio_output_manager/serializers.py b/backend/prompt_studio/prompt_studio_output_manager/serializers.py index c0a0af5b..57dadf07 100644 --- a/backend/prompt_studio/prompt_studio_output_manager/serializers.py +++ b/backend/prompt_studio/prompt_studio_output_manager/serializers.py @@ -7,4 +7,3 @@ class PromptStudioOutputSerializer(AuditSerializer): class Meta: model = PromptStudioOutputManager fields = "__all__" - diff --git a/backend/prompt_studio/prompt_studio_output_manager/views.py b/backend/prompt_studio/prompt_studio_output_manager/views.py index 6167f2b1..affcbd0d 100644 --- a/backend/prompt_studio/prompt_studio_output_manager/views.py +++ b/backend/prompt_studio/prompt_studio_output_manager/views.py @@ -43,9 +43,9 @@ class PromptStudioOutputView(viewsets.ModelViewSet): is_single_pass_extract_param ) - filter_args[ - PromptStudioOutputManagerKeys.IS_SINGLE_PASS_EXTRACT - ] = is_single_pass_extract + filter_args[PromptStudioOutputManagerKeys.IS_SINGLE_PASS_EXTRACT] = ( + is_single_pass_extract + ) if filter_args: queryset = PromptStudioOutputManager.objects.filter(**filter_args) diff --git a/backend/prompt_studio/prompt_studio_registry/prompt_studio_registry_helper.py b/backend/prompt_studio/prompt_studio_registry/prompt_studio_registry_helper.py index c9f328f7..991bf030 100644 --- a/backend/prompt_studio/prompt_studio_registry/prompt_studio_registry_helper.py +++ b/backend/prompt_studio/prompt_studio_registry/prompt_studio_registry_helper.py @@ -115,10 +115,10 @@ class PromptStudioRegistryHelper: PromptStudioRegistryHelper.frame_properties(tool=custom_tool) ) spec: Spec = PromptStudioRegistryHelper.frame_spec(tool=custom_tool) - prompts: list[ - ToolStudioPrompt - ] = PromptStudioHelper.fetch_prompt_from_tool( - tool_id=custom_tool.tool_id + prompts: list[ToolStudioPrompt] = ( + PromptStudioHelper.fetch_prompt_from_tool( + tool_id=custom_tool.tool_id + ) ) metadata = PromptStudioRegistryHelper.frame_export_json( tool=custom_tool, prompts=prompts @@ -195,9 +195,9 @@ class PromptStudioRegistryHelper: adapter_id = str(prompt.profile_manager.embedding_model.adapter_id) embedding_suffix = adapter_id.split("|")[0] - output[ - JsonSchemaKey.ASSERTION_FAILURE_PROMPT - ] = prompt.assertion_failure_prompt + output[JsonSchemaKey.ASSERTION_FAILURE_PROMPT] = ( + prompt.assertion_failure_prompt + ) output[JsonSchemaKey.ASSERT_PROMPT] = prompt.assert_prompt output[JsonSchemaKey.IS_ASSERT] = prompt.is_assert output[JsonSchemaKey.PROMPT] = prompt.prompt @@ -206,21 +206,21 @@ class PromptStudioRegistryHelper: output[JsonSchemaKey.VECTOR_DB] = vector_db output[JsonSchemaKey.EMBEDDING] = embedding_model output[JsonSchemaKey.X2TEXT_ADAPTER] = x2text - output[ - JsonSchemaKey.CHUNK_OVERLAP - ] = prompt.profile_manager.chunk_overlap + output[JsonSchemaKey.CHUNK_OVERLAP] = ( + prompt.profile_manager.chunk_overlap + ) output[JsonSchemaKey.LLM] = llm output[JsonSchemaKey.PREAMBLE] = tool.preamble output[JsonSchemaKey.POSTAMBLE] = tool.postamble output[JsonSchemaKey.GRAMMAR] = grammar_list output[JsonSchemaKey.TYPE] = prompt.enforce_type output[JsonSchemaKey.NAME] = prompt.prompt_key - output[ - JsonSchemaKey.RETRIEVAL_STRATEGY - ] = prompt.profile_manager.retrieval_strategy - output[ - JsonSchemaKey.SIMILARITY_TOP_K - ] = prompt.profile_manager.similarity_top_k + output[JsonSchemaKey.RETRIEVAL_STRATEGY] = ( + prompt.profile_manager.retrieval_strategy + ) + output[JsonSchemaKey.SIMILARITY_TOP_K] = ( + prompt.profile_manager.similarity_top_k + ) output[JsonSchemaKey.SECTION] = prompt.profile_manager.section output[JsonSchemaKey.REINDEX] = prompt.profile_manager.reindex output[JsonSchemaKey.EMBEDDING_SUFFIX] = embedding_suffix diff --git a/backend/prompt_studio/prompt_studio_registry/serializers.py b/backend/prompt_studio/prompt_studio_registry/serializers.py index 690a10fb..fe74848f 100644 --- a/backend/prompt_studio/prompt_studio_registry/serializers.py +++ b/backend/prompt_studio/prompt_studio_registry/serializers.py @@ -1,6 +1,7 @@ -from backend.serializers import AuditSerializer from rest_framework import serializers +from backend.serializers import AuditSerializer + from .models import PromptStudioRegistry diff --git a/backend/scheduler/serializer.py b/backend/scheduler/serializer.py index feadb34b..585cba98 100644 --- a/backend/scheduler/serializer.py +++ b/backend/scheduler/serializer.py @@ -1,12 +1,13 @@ import logging from typing import Any -from backend.constants import FieldLengthConstants as FieldLength from django.conf import settings from pipeline.manager import PipelineManager from rest_framework import serializers from scheduler.constants import SchedulerConstants as SC +from backend.constants import FieldLengthConstants as FieldLength + logger = logging.getLogger(__name__) JOB_NAME_LENGTH = 255 @@ -24,7 +25,9 @@ class JobKwargsSerializer(serializers.Serializer): class SchedulerKwargsSerializer(serializers.Serializer): coalesce = serializers.BooleanField() - misfire_grace_time = serializers.IntegerField(allow_null=True, required=False) + misfire_grace_time = serializers.IntegerField( + allow_null=True, required=False + ) max_instances = serializers.IntegerField() replace_existing = serializers.BooleanField() @@ -41,10 +44,10 @@ class AddJobSerializer(serializers.Serializer): def to_internal_value(self, data: dict[str, Any]) -> dict[str, Any]: if SC.NAME not in data: data[SC.NAME] = f"Job-{data[SC.ID]}" - data[ - SC.JOB_KWARGS - ] = PipelineManager.get_pipeline_execution_data_for_scheduled_run( - pipeline_id=data[SC.ID] + data[SC.JOB_KWARGS] = ( + PipelineManager.get_pipeline_execution_data_for_scheduled_run( + pipeline_id=data[SC.ID] + ) ) data[SC.SCHEDULER_KWARGS] = settings.SCHEDULER_KWARGS return super().to_internal_value(data) # type: ignore diff --git a/backend/tenant_account/invitation_views.py b/backend/tenant_account/invitation_views.py index 8d727eb7..de49bd84 100644 --- a/backend/tenant_account/invitation_views.py +++ b/backend/tenant_account/invitation_views.py @@ -15,8 +15,10 @@ class InvitationViewSet(viewsets.ViewSet): @action(detail=False, methods=["GET"]) def list_invitations(self, request: Request) -> Response: auth_controller = AuthenticationController() - invitations: list[MemberInvitation] = auth_controller.get_user_invitations( - organization_id=request.org_id, + invitations: list[MemberInvitation] = ( + auth_controller.get_user_invitations( + organization_id=request.org_id, + ) ) serialized_members = ListInvitationsResponseSerializer( invitations, many=True diff --git a/backend/tenant_account/models.py b/backend/tenant_account/models.py index f4b5b797..c7d7866d 100644 --- a/backend/tenant_account/models.py +++ b/backend/tenant_account/models.py @@ -5,7 +5,10 @@ from django.db import models class OrganizationMember(models.Model): member_id = models.BigAutoField(primary_key=True) user = models.OneToOneField( - User, on_delete=models.CASCADE, default=None, related_name="organization_member" + User, + on_delete=models.CASCADE, + default=None, + related_name="organization_member", ) role = models.CharField() diff --git a/backend/tenant_account/users_view.py b/backend/tenant_account/users_view.py index 65b7dcfb..2f885210 100644 --- a/backend/tenant_account/users_view.py +++ b/backend/tenant_account/users_view.py @@ -79,7 +79,8 @@ class OrganizationUserViewSet(viewsets.ViewSet): # z_code = request.COOKIES.get(Cookie.Z_CODE) user_info = auth_controller.get_user_info(request) role = auth_controller.get_organization_members_by_user( - request.user) + request.user + ) if not user_info: return Response( status=status.HTTP_404_NOT_FOUND, @@ -89,10 +90,10 @@ class OrganizationUserViewSet(viewsets.ViewSet): # Temporary fix for getting user role along with user info. # Proper implementation would be adding role field to UserInfo. serialized_user_info["is_admin"] = auth_controller.is_admin_by_role( - role.role) + role.role + ) return Response( - status=status.HTTP_200_OK, data={ - "user": serialized_user_info} + status=status.HTTP_200_OK, data={"user": serialized_user_info} ) except Exception as error: Logger.error(f"Error while get User : {error}") @@ -112,7 +113,6 @@ class OrganizationUserViewSet(viewsets.ViewSet): ) response_serializer = UserInviteResponseSerializer( - invite_response, many=True ) @@ -157,11 +157,12 @@ class OrganizationUserViewSet(viewsets.ViewSet): def get_organization_members(self, request: Request) -> Response: auth_controller = AuthenticationController() if request.org_id: - members: list[ - OrganizationMember - ] = auth_controller.get_organization_members_by_org_id() + members: list[OrganizationMember] = ( + auth_controller.get_organization_members_by_org_id() + ) serialized_members = OrganizationMemberSerializer( - members, many=True).data + members, many=True + ).data return Response( status=status.HTTP_200_OK, data={"message": "success", "members": serialized_members}, diff --git a/backend/tenant_account/views.py b/backend/tenant_account/views.py index 1f95d07a..598a72ce 100644 --- a/backend/tenant_account/views.py +++ b/backend/tenant_account/views.py @@ -37,7 +37,9 @@ def get_roles(request: Request) -> Response: @api_view(["POST"]) def reset_password(request: Request) -> Response: auth_controller = AuthenticationController() - data: ResetUserPasswordDto = auth_controller.reset_user_password(request.user) + data: ResetUserPasswordDto = auth_controller.reset_user_password( + request.user + ) if data.status: return Response( status=status.HTTP_200_OK, diff --git a/backend/tool_instance/tool_processor.py b/backend/tool_instance/tool_processor.py index 32313f10..acc0ab06 100644 --- a/backend/tool_instance/tool_processor.py +++ b/backend/tool_instance/tool_processor.py @@ -60,7 +60,7 @@ class ToolProcessor: ) schema_json: dict[str, Any] = schema.to_dict() return schema_json - + @staticmethod def update_schema_with_adapter_configurations( schema: Spec, user: User @@ -134,12 +134,12 @@ class ToolProcessor: def get_tool_list(user: User) -> list[dict[str, Any]]: """Function to get a list of tools.""" tool_registry = ToolRegistry() - prompt_studio_tools: list[ - dict[str, Any] - ] = PromptStudioRegistryHelper.fetch_json_for_registry(user) - tool_list: list[ - dict[str, Any] - ] = tool_registry.fetch_tools_descriptions() + prompt_studio_tools: list[dict[str, Any]] = ( + PromptStudioRegistryHelper.fetch_json_for_registry(user) + ) + tool_list: list[dict[str, Any]] = ( + tool_registry.fetch_tools_descriptions() + ) tool_list = tool_list + prompt_studio_tools return tool_list diff --git a/backend/tool_instance/views.py b/backend/tool_instance/views.py index 75d30320..98bf6c8a 100644 --- a/backend/tool_instance/views.py +++ b/backend/tool_instance/views.py @@ -3,7 +3,6 @@ import uuid from typing import Any from account.custom_exceptions import DuplicateData -from backend.constants import RequestKey from django.db import IntegrityError from django.db.models.query import QuerySet from rest_framework import serializers, status, viewsets @@ -14,7 +13,10 @@ from rest_framework.versioning import URLPathVersioning from tool_instance.constants import ToolInstanceErrors from tool_instance.constants import ToolInstanceKey as TIKey from tool_instance.constants import ToolKey -from tool_instance.exceptions import FetchToolListFailed, ToolFunctionIsMandatory +from tool_instance.exceptions import ( + FetchToolListFailed, + ToolFunctionIsMandatory, +) from tool_instance.models import ToolInstance from tool_instance.serializers import ( ToolInstanceReorderSerializer as TIReorderSerializer, @@ -25,6 +27,8 @@ from tool_instance.tool_processor import ToolProcessor from utils.filtering import FilterHelper from workflow_manager.workflow.constants import WorkflowKey +from backend.constants import RequestKey + logger = logging.getLogger(__name__) @@ -51,7 +55,8 @@ def get_tool_list(request: Request) -> Response: try: logger.info("Fetching tools from the tool registry...") return Response( - data=ToolProcessor.get_tool_list(request.user), status=status.HTTP_200_OK + data=ToolProcessor.get_tool_list(request.user), + status=status.HTTP_200_OK, ) except Exception as exc: logger.error(f"Failed to fetch tools: {exc}") @@ -117,10 +122,10 @@ class ToolInstanceViewSet(viewsets.ModelViewSet): instance (ToolInstance): Instance being deleted. """ lookup = {"step__gt": instance.step} - next_tool_instances: list[ - ToolInstance - ] = ToolInstanceHelper.get_tool_instances_by_workflow( - instance.workflow.id, TIKey.STEP, lookup=lookup + next_tool_instances: list[ToolInstance] = ( + ToolInstanceHelper.get_tool_instances_by_workflow( + instance.workflow.id, TIKey.STEP, lookup=lookup + ) ) super().perform_destroy(instance) diff --git a/backend/utils/local_context.py b/backend/utils/local_context.py index 534ec06f..4479fefd 100644 --- a/backend/utils/local_context.py +++ b/backend/utils/local_context.py @@ -1,7 +1,7 @@ import os import threading -from typing import Any from enum import Enum +from typing import Any class ConcurrencyMode(Enum): @@ -14,10 +14,8 @@ class Exceptions: class StateStore: - - mode = os.environ.get( - "CONCURRENCY_MODE", ConcurrencyMode.THREAD - ) + + mode = os.environ.get("CONCURRENCY_MODE", ConcurrencyMode.THREAD) # Thread-safe storage. thread_local = threading.local() diff --git a/backend/utils/request/request.py b/backend/utils/request/request.py index 9c05df15..e976e827 100644 --- a/backend/utils/request/request.py +++ b/backend/utils/request/request.py @@ -28,11 +28,15 @@ def make_http_request( if verb == HTTPMethod.GET: response = pyrequests.get(url, params=params, headers=headers) elif verb == HTTPMethod.POST: - response = pyrequests.post(url, json=data, params=params, headers=headers) + response = pyrequests.post( + url, json=data, params=params, headers=headers + ) elif verb == HTTPMethod.DELETE: response = pyrequests.delete(url, params=params, headers=headers) else: - raise ValueError("Invalid HTTP verb. Supported verbs: GET, POST, DELETE") + raise ValueError( + "Invalid HTTP verb. Supported verbs: GET, POST, DELETE" + ) response.raise_for_status() return_val: str = ( diff --git a/backend/workflow_manager/endpoint/base_connector.py b/backend/workflow_manager/endpoint/base_connector.py index f6e782fc..7d4d7e6a 100644 --- a/backend/workflow_manager/endpoint/base_connector.py +++ b/backend/workflow_manager/endpoint/base_connector.py @@ -4,15 +4,16 @@ from typing import Any from django.conf import settings from django.db import connection from fsspec import AbstractFileSystem -from unstract.connectors.filesystems import connectors -from unstract.connectors.filesystems.unstract_file_system import ( - UnstractFileSystem, -) from unstract.workflow_execution.execution_file_handler import ( ExecutionFileHandler, ) from utils.constants import Common +from unstract.connectors.filesystems import connectors +from unstract.connectors.filesystems.unstract_file_system import ( + UnstractFileSystem, +) + class BaseConnector(ExecutionFileHandler): """Base class for connectors providing common methods and utilities.""" diff --git a/backend/workflow_manager/endpoint/views.py b/backend/workflow_manager/endpoint/views.py index b84c6f9b..b1cc1c35 100644 --- a/backend/workflow_manager/endpoint/views.py +++ b/backend/workflow_manager/endpoint/views.py @@ -14,8 +14,8 @@ class WorkflowEndpointViewSet(viewsets.ModelViewSet): queryset = WorkflowEndpoint.objects.all() serializer_class = WorkflowEndpointSerializer - def get_queryset(self) -> QuerySet: - + def get_queryset(self) -> QuerySet: + queryset = ( WorkflowEndpoint.objects.all() .select_related("workflow") diff --git a/backend/workflow_manager/workflow/exceptions.py b/backend/workflow_manager/workflow/exceptions.py index 94422b07..4715a971 100644 --- a/backend/workflow_manager/workflow/exceptions.py +++ b/backend/workflow_manager/workflow/exceptions.py @@ -37,6 +37,7 @@ class InvalidRequest(APIException): status_code = 400 default_detail = "Invalid Request" + class MissingEnvException(APIException): status_code = 500 default_detail = "At least one active platform key should be available." @@ -73,4 +74,4 @@ class WorkflowExecutionBadRequestException(APIException): class ToolValidationError(APIException): status_code = 400 - default_detail = "Tool validation error" \ No newline at end of file + default_detail = "Tool validation error" diff --git a/backend/workflow_manager/workflow/generator.py b/backend/workflow_manager/workflow/generator.py index f611cb2f..f52a0a9c 100644 --- a/backend/workflow_manager/workflow/generator.py +++ b/backend/workflow_manager/workflow/generator.py @@ -6,13 +6,14 @@ from rest_framework.request import Request from tool_instance.constants import ToolInstanceKey as TIKey from tool_instance.exceptions import ToolInstantiationError from tool_instance.tool_processor import ToolProcessor -from unstract.core.llm_workflow_generator.llm_interface import LLMInterface from unstract.tool_registry.dto import Tool from workflow_manager.workflow.constants import WorkflowKey from workflow_manager.workflow.dto import ProvisionalWorkflow from workflow_manager.workflow.exceptions import WorkflowGenerationError from workflow_manager.workflow.models.workflow import Workflow as WorkflowModel +from unstract.core.llm_workflow_generator.llm_interface import LLMInterface + logger = logging.getLogger(__name__) @@ -69,8 +70,8 @@ class WorkflowGenerator: self._request = request def generate_workflow(self, tools: list[Tool]) -> None: - """Used to talk to the GPT model through core and obtain a - provisional workflow for the user to work with.""" + """Used to talk to the GPT model through core and obtain a provisional + workflow for the user to work with.""" self._provisional_wf = self._get_provisional_workflow(tools) @staticmethod diff --git a/backend/workflow_manager/workflow/serializers.py b/backend/workflow_manager/workflow/serializers.py index e0e4001c..b9ba92db 100644 --- a/backend/workflow_manager/workflow/serializers.py +++ b/backend/workflow_manager/workflow/serializers.py @@ -1,8 +1,6 @@ import logging from typing import Any, Optional, Union -from backend.constants import RequestKey -from backend.serializers import AuditSerializer from project.constants import ProjectKey from rest_framework.serializers import ( CharField, @@ -16,11 +14,17 @@ from rest_framework.serializers import ( from tool_instance.serializers import ToolInstanceSerializer from tool_instance.tool_instance_helper import ToolInstanceHelper from workflow_manager.endpoint.models import WorkflowEndpoint -from workflow_manager.workflow.constants import WorkflowExecutionKey, WorkflowKey +from workflow_manager.workflow.constants import ( + WorkflowExecutionKey, + WorkflowKey, +) from workflow_manager.workflow.exceptions import WorkflowGenerationError from workflow_manager.workflow.generator import WorkflowGenerator from workflow_manager.workflow.models.workflow import Workflow +from backend.constants import RequestKey +from backend.serializers import AuditSerializer + logger = logging.getLogger(__name__) diff --git a/backend/workflow_manager/workflow/workflow_helper.py b/backend/workflow_manager/workflow/workflow_helper.py index 0db2789e..d84aafb2 100644 --- a/backend/workflow_manager/workflow/workflow_helper.py +++ b/backend/workflow_manager/workflow/workflow_helper.py @@ -215,10 +215,10 @@ class WorkflowHelper: workflow_execution: Optional[WorkflowExecution] = None, execution_mode: Optional[tuple[str, str]] = None, ) -> ExecutionResponse: - tool_instances: list[ - ToolInstance - ] = ToolInstanceHelper.get_tool_instances_by_workflow( - workflow.id, ToolInstanceKey.STEP + tool_instances: list[ToolInstance] = ( + ToolInstanceHelper.get_tool_instances_by_workflow( + workflow.id, ToolInstanceKey.STEP + ) ) WorkflowHelper.validate_tool_instances_meta( diff --git a/document-service/src/unstract/document_service/main.py b/document-service/src/unstract/document_service/main.py index 764e9cf5..fe88c609 100644 --- a/document-service/src/unstract/document_service/main.py +++ b/document-service/src/unstract/document_service/main.py @@ -11,13 +11,18 @@ from odf import teletype, text from odf.opendocument import load logging.basicConfig( - level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s : %(message)s" + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s : %(message)s", ) UPLOAD_FOLDER = os.environ.get("UPLOAD_FOLDER", "/tmp/document_service/upload") -PROCESS_FOLDER = os.environ.get("PROCESS_FOLDER", "/tmp/document_service/process") +PROCESS_FOLDER = os.environ.get( + "PROCESS_FOLDER", "/tmp/document_service/process" +) LIBREOFFICE_PYTHON = os.environ.get("LIBREOFFICE_PYTHON", "/usr/bin/python3") -MAX_FILE_SIZE = int(os.environ.get("MAX_FILE_SIZE", 10485760)) # 10 * 1024 * 1024 +MAX_FILE_SIZE = int( + os.environ.get("MAX_FILE_SIZE", 10485760) +) # 10 * 1024 * 1024 SERVICE_API_TOKEN = os.environ.get("SERVICE_API_TOKEN", "") app = Flask("document_service") @@ -99,7 +104,9 @@ def upload_file(): redis_host = os.environ.get("REDIS_HOST") redis_port = int(os.environ.get("REDIS_PORT")) redis_password = os.environ.get("REDIS_PASSWORD") - r = redis.Redis(host=redis_host, port=redis_port, password=redis_password) + r = redis.Redis( + host=redis_host, port=redis_port, password=redis_password + ) # TODO: Create a file reaper process to look at uploaded time and delete redis_key = f"upload_time:{account_id}_{file_name}" current_timestamp = int(time.time()) @@ -123,7 +130,9 @@ def find_and_replace(): output_format = request.args.get("output_format").lower() find_and_replace_text = request.json - app.logger.info(f"Find and replace for file {file_name} for account {account_id}") + app.logger.info( + f"Find and replace for file {file_name} for account {account_id}" + ) app.logger.info(f"Output format: {output_format}") if output_format not in ["pdf"]: @@ -143,7 +152,9 @@ def find_and_replace(): try: command = f"{LIBREOFFICE_PYTHON} -m unoserver.converter --convert-to odt \ --filter writer8 {file_namex} {file_name_odt}" - result = subprocess.run(command, shell=True, capture_output=True, text=True) + result = subprocess.run( + command, shell=True, capture_output=True, text=True + ) app.logger.info(result) if result.returncode != 0: app.logger.error( @@ -155,7 +166,9 @@ def find_and_replace(): app.logger.info( f"File converted to ODT format successfully! {file_name_odt}" ) - app.logger.info(f"ODT convertion result: {result.stdout} | {result.stderr}") + app.logger.info( + f"ODT convertion result: {result.stdout} | {result.stderr}" + ) except Exception as e: app.logger.error(f"Error while converting file to ODT format: {e}") return "Error while converting file to ODT format!", 500 @@ -169,9 +182,13 @@ def find_and_replace(): replace_str = find_and_replace_text[find_str] for element in doc.getElementsByType(text.Span): if find_str in teletype.extractText(element): - app.logger.info(f"Found {find_str} in {teletype.extractText(element)}") + app.logger.info( + f"Found {find_str} in {teletype.extractText(element)}" + ) new_element = text.Span() - new_element.setAttribute("stylename", element.getAttribute("stylename")) + new_element.setAttribute( + "stylename", element.getAttribute("stylename") + ) t = teletype.extractText(element) t = t.replace(find_str, replace_str) new_element.addText(t) @@ -188,7 +205,9 @@ def find_and_replace(): f"{LIBREOFFICE_PYTHON} -m unoserver.converter --convert-to pdf " f"--filter writer_pdf_Export {file_name_odt} {file_name_output}" ) - result = subprocess.run(command, shell=True, capture_output=True, text=True) + result = subprocess.run( + command, shell=True, capture_output=True, text=True + ) if result.returncode != 0: app.logger.error( f"Failed to convert file to {output_format} format: " @@ -200,9 +219,13 @@ def find_and_replace(): f"File converted to {output_format} format successfully! " f"{file_name_output}" ) - app.logger.info(f"ODT convertion result: {result.stdout} | {result.stderr}") + app.logger.info( + f"ODT convertion result: {result.stdout} | {result.stderr}" + ) except Exception as e: - app.logger.error(f"Error while converting file to {output_format} format: {e}") + app.logger.error( + f"Error while converting file to {output_format} format: {e}" + ) return f"Error while converting file to {output_format} format!", 500 return send_file(file_name_output, as_attachment=True) diff --git a/frontend/public/icons/connector-icons/HTTP.svg b/frontend/public/icons/connector-icons/HTTP.svg index 7f4d49e4..764296d2 100644 --- a/frontend/public/icons/connector-icons/HTTP.svg +++ b/frontend/public/icons/connector-icons/HTTP.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/frontend/public/icons/connector-icons/google_bigquery-icon.svg b/frontend/public/icons/connector-icons/google_bigquery-icon.svg index 34e84de1..7f113442 100644 --- a/frontend/public/icons/connector-icons/google_bigquery-icon.svg +++ b/frontend/public/icons/connector-icons/google_bigquery-icon.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/frontend/src/assets/steps.svg b/frontend/src/assets/steps.svg index 5e8229c2..9477f61a 100644 --- a/frontend/src/assets/steps.svg +++ b/frontend/src/assets/steps.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/frontend/src/components/agency/actions/Actions.css b/frontend/src/components/agency/actions/Actions.css index a6b50499..1ec61eb9 100644 --- a/frontend/src/components/agency/actions/Actions.css +++ b/frontend/src/components/agency/actions/Actions.css @@ -18,4 +18,4 @@ .step-icon { opacity: 0.60; image-rendering: pixelated; -} \ No newline at end of file +} diff --git a/frontend/src/components/agency/cards-list/CardList.css b/frontend/src/components/agency/cards-list/CardList.css index 5197ffcb..62e73a17 100644 --- a/frontend/src/components/agency/cards-list/CardList.css +++ b/frontend/src/components/agency/cards-list/CardList.css @@ -27,4 +27,4 @@ } .tool-dragging { background-color: #DAE3EC; -} \ No newline at end of file +} diff --git a/frontend/src/components/agency/steps/Steps.css b/frontend/src/components/agency/steps/Steps.css index 9acf8304..a76aee99 100644 --- a/frontend/src/components/agency/steps/Steps.css +++ b/frontend/src/components/agency/steps/Steps.css @@ -62,4 +62,4 @@ .ds-set-card-select { width: 120px; -} \ No newline at end of file +} diff --git a/frontend/src/components/agency/tool-icon/ToolIcon.css b/frontend/src/components/agency/tool-icon/ToolIcon.css index f2c7732b..db7834b7 100644 --- a/frontend/src/components/agency/tool-icon/ToolIcon.css +++ b/frontend/src/components/agency/tool-icon/ToolIcon.css @@ -4,4 +4,4 @@ border-radius: 5px; border: 1px solid #FFD2DB; padding: 5px; -} \ No newline at end of file +} diff --git a/frontend/src/components/agency/tool-settings/ToolSettings.css b/frontend/src/components/agency/tool-settings/ToolSettings.css index 74ab755e..2cdec0fe 100644 --- a/frontend/src/components/agency/tool-settings/ToolSettings.css +++ b/frontend/src/components/agency/tool-settings/ToolSettings.css @@ -6,4 +6,4 @@ .tool-settings-submit-btn { padding: 10px 0px; -} \ No newline at end of file +} diff --git a/frontend/src/components/custom-tools/custom-synonyms/CustomSynonyms.css b/frontend/src/components/custom-tools/custom-synonyms/CustomSynonyms.css index d8c4727a..a62ceab7 100644 --- a/frontend/src/components/custom-tools/custom-synonyms/CustomSynonyms.css +++ b/frontend/src/components/custom-tools/custom-synonyms/CustomSynonyms.css @@ -7,4 +7,4 @@ .cus-syn-del { font-size: 10px; -} \ No newline at end of file +} diff --git a/frontend/src/components/custom-tools/document-parser/DocumentParser.css b/frontend/src/components/custom-tools/document-parser/DocumentParser.css index 86cf33e8..293932ce 100644 --- a/frontend/src/components/custom-tools/document-parser/DocumentParser.css +++ b/frontend/src/components/custom-tools/document-parser/DocumentParser.css @@ -10,4 +10,4 @@ .doc-parser-pad-bottom { padding-bottom: 6px; -} \ No newline at end of file +} diff --git a/frontend/src/components/custom-tools/edit-tool-info/EditToolInfo.css b/frontend/src/components/custom-tools/edit-tool-info/EditToolInfo.css index 70f1df2d..076ee82e 100644 --- a/frontend/src/components/custom-tools/edit-tool-info/EditToolInfo.css +++ b/frontend/src/components/custom-tools/edit-tool-info/EditToolInfo.css @@ -6,4 +6,4 @@ .edit-tool-info-helper-text { font-size: 12px; -} \ No newline at end of file +} diff --git a/frontend/src/components/custom-tools/footer-layout/FooterLayout.css b/frontend/src/components/custom-tools/footer-layout/FooterLayout.css index 8266e99b..c0303827 100644 --- a/frontend/src/components/custom-tools/footer-layout/FooterLayout.css +++ b/frontend/src/components/custom-tools/footer-layout/FooterLayout.css @@ -2,4 +2,4 @@ .tool-ide-main-footer-layout { padding: 8px 14px; -} \ No newline at end of file +} diff --git a/frontend/src/components/custom-tools/footer/Footer.css b/frontend/src/components/custom-tools/footer/Footer.css index 2a8b58d6..abf17411 100644 --- a/frontend/src/components/custom-tools/footer/Footer.css +++ b/frontend/src/components/custom-tools/footer/Footer.css @@ -4,4 +4,4 @@ display: flex; justify-content: space-between; margin-left: auto; -} \ No newline at end of file +} diff --git a/frontend/src/components/custom-tools/generate-index/GenerateIndex.css b/frontend/src/components/custom-tools/generate-index/GenerateIndex.css index 05d65dba..b75d612c 100644 --- a/frontend/src/components/custom-tools/generate-index/GenerateIndex.css +++ b/frontend/src/components/custom-tools/generate-index/GenerateIndex.css @@ -18,4 +18,4 @@ .gen-index-icon { font-size: 24px; -} \ No newline at end of file +} diff --git a/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.css b/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.css index 0cfe905d..87ec096d 100644 --- a/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.css +++ b/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.css @@ -17,4 +17,4 @@ .manage-docs-div { margin: 0; -} \ No newline at end of file +} diff --git a/frontend/src/components/custom-tools/manage-llm-profiles/ManageLlmProfiles.css b/frontend/src/components/custom-tools/manage-llm-profiles/ManageLlmProfiles.css index f4bce61c..fcd61731 100644 --- a/frontend/src/components/custom-tools/manage-llm-profiles/ManageLlmProfiles.css +++ b/frontend/src/components/custom-tools/manage-llm-profiles/ManageLlmProfiles.css @@ -2,4 +2,4 @@ .manage-llm-pro-icon { font-size: 10px; -} \ No newline at end of file +} diff --git a/frontend/src/components/custom-tools/output-analyzer-list/OutputAnalyzerList.css b/frontend/src/components/custom-tools/output-analyzer-list/OutputAnalyzerList.css index 358230c0..05d207d6 100644 --- a/frontend/src/components/custom-tools/output-analyzer-list/OutputAnalyzerList.css +++ b/frontend/src/components/custom-tools/output-analyzer-list/OutputAnalyzerList.css @@ -68,4 +68,4 @@ .output-analyzer-card-gap { margin-bottom: 12px; -} \ No newline at end of file +} diff --git a/frontend/src/components/custom-tools/output-for-doc-modal/OutputForDocModal.css b/frontend/src/components/custom-tools/output-for-doc-modal/OutputForDocModal.css index 3513d093..98214a0c 100644 --- a/frontend/src/components/custom-tools/output-for-doc-modal/OutputForDocModal.css +++ b/frontend/src/components/custom-tools/output-for-doc-modal/OutputForDocModal.css @@ -15,4 +15,3 @@ .output-doc-gap { margin-bottom: 8px; } - diff --git a/frontend/src/components/custom-tools/tools-main/ToolsMain.css b/frontend/src/components/custom-tools/tools-main/ToolsMain.css index a0f13564..ee47a367 100644 --- a/frontend/src/components/custom-tools/tools-main/ToolsMain.css +++ b/frontend/src/components/custom-tools/tools-main/ToolsMain.css @@ -23,4 +23,4 @@ .tools-main-footer { border-top: 1px #D9D9D9 solid; -} \ No newline at end of file +} diff --git a/frontend/src/components/error/UnAuthorized/Unauthorized.css b/frontend/src/components/error/UnAuthorized/Unauthorized.css index bb1b7cda..d7abe844 100644 --- a/frontend/src/components/error/UnAuthorized/Unauthorized.css +++ b/frontend/src/components/error/UnAuthorized/Unauthorized.css @@ -6,4 +6,4 @@ .unauth-text{ color:#666; font-size: 24px; -} \ No newline at end of file +} diff --git a/frontend/src/components/input-output/configure-ds/ConfigureDs.css b/frontend/src/components/input-output/configure-ds/ConfigureDs.css index 8af01af9..8e8bbc77 100644 --- a/frontend/src/components/input-output/configure-ds/ConfigureDs.css +++ b/frontend/src/components/input-output/configure-ds/ConfigureDs.css @@ -20,4 +20,3 @@ .config-tc-btn { background-color: #4BB543 !important; } - diff --git a/frontend/src/components/input-output/edit-ds-modal/EditDsModal.css b/frontend/src/components/input-output/edit-ds-modal/EditDsModal.css index 220d9afa..ba721b4b 100644 --- a/frontend/src/components/input-output/edit-ds-modal/EditDsModal.css +++ b/frontend/src/components/input-output/edit-ds-modal/EditDsModal.css @@ -8,4 +8,4 @@ .edit-ds-modal { width: 25% !important; -} \ No newline at end of file +} diff --git a/frontend/src/components/input-output/file-system/FileSystem.css b/frontend/src/components/input-output/file-system/FileSystem.css index 40872f4f..462593c4 100644 --- a/frontend/src/components/input-output/file-system/FileSystem.css +++ b/frontend/src/components/input-output/file-system/FileSystem.css @@ -49,4 +49,4 @@ .ant-tree-treenode-selected .ant-typography{ color: #fff !important; -} \ No newline at end of file +} diff --git a/frontend/src/components/input-output/list-of-sources/ListOfSources.css b/frontend/src/components/input-output/list-of-sources/ListOfSources.css index 45a6aae1..031ee02c 100644 --- a/frontend/src/components/input-output/list-of-sources/ListOfSources.css +++ b/frontend/src/components/input-output/list-of-sources/ListOfSources.css @@ -12,4 +12,4 @@ .list-of-srcs > .list { padding-top: 20px; -} \ No newline at end of file +} diff --git a/frontend/src/components/log-in/Login.css b/frontend/src/components/log-in/Login.css index 08bf177f..ced8c095 100644 --- a/frontend/src/components/log-in/Login.css +++ b/frontend/src/components/log-in/Login.css @@ -14,7 +14,7 @@ .login-right-section { width: 50%; - background-color: #ECEFF3; + background-color: #ECEFF3; } .right-section-text-wrapper{ display: flex; @@ -80,4 +80,4 @@ } .button-margin { margin-top: 60px; -} \ No newline at end of file +} diff --git a/frontend/src/components/oauth-ds/google/GoogleOAuthButton.css b/frontend/src/components/oauth-ds/google/GoogleOAuthButton.css index b54b2ead..11b85560 100644 --- a/frontend/src/components/oauth-ds/google/GoogleOAuthButton.css +++ b/frontend/src/components/oauth-ds/google/GoogleOAuthButton.css @@ -2,4 +2,4 @@ .google-oauth-layout { margin-bottom: 20px; -} \ No newline at end of file +} diff --git a/frontend/src/components/oauth-ds/oauth-status/OAuthStatus.css b/frontend/src/components/oauth-ds/oauth-status/OAuthStatus.css index 3187753d..500542b9 100644 --- a/frontend/src/components/oauth-ds/oauth-status/OAuthStatus.css +++ b/frontend/src/components/oauth-ds/oauth-status/OAuthStatus.css @@ -8,4 +8,4 @@ margin-top: 25px; font-size: 24px; font-weight: bold; -} \ No newline at end of file +} diff --git a/frontend/src/components/pipelines-or-deployments/header/Header.css b/frontend/src/components/pipelines-or-deployments/header/Header.css index d1217d6b..13abd5ed 100644 --- a/frontend/src/components/pipelines-or-deployments/header/Header.css +++ b/frontend/src/components/pipelines-or-deployments/header/Header.css @@ -19,4 +19,4 @@ .header-name { padding: 0px 8px; -} \ No newline at end of file +} diff --git a/frontend/src/components/pipelines-or-deployments/pipelines/Pipelines.css b/frontend/src/components/pipelines-or-deployments/pipelines/Pipelines.css index 95970cac..0509db89 100644 --- a/frontend/src/components/pipelines-or-deployments/pipelines/Pipelines.css +++ b/frontend/src/components/pipelines-or-deployments/pipelines/Pipelines.css @@ -3,4 +3,4 @@ height: 100%; display: flex; flex-direction: column; -} \ No newline at end of file +} diff --git a/frontend/src/components/profile/Profile.css b/frontend/src/components/profile/Profile.css index efb24bbb..1859feff 100644 --- a/frontend/src/components/profile/Profile.css +++ b/frontend/src/components/profile/Profile.css @@ -16,4 +16,4 @@ .header-text .typo-text { font-size: 18px; -} \ No newline at end of file +} diff --git a/frontend/src/components/rjsf-custom-widgets/array-field/ArrayField.css b/frontend/src/components/rjsf-custom-widgets/array-field/ArrayField.css index 22e61160..27aba5bf 100644 --- a/frontend/src/components/rjsf-custom-widgets/array-field/ArrayField.css +++ b/frontend/src/components/rjsf-custom-widgets/array-field/ArrayField.css @@ -12,4 +12,4 @@ .array-field-select { width: 100%; -} \ No newline at end of file +} diff --git a/frontend/src/components/settings/invite/InviteEditUser.css b/frontend/src/components/settings/invite/InviteEditUser.css index fb61e135..27c4c95e 100644 --- a/frontend/src/components/settings/invite/InviteEditUser.css +++ b/frontend/src/components/settings/invite/InviteEditUser.css @@ -31,4 +31,4 @@ .form-select{ margin: 0; -} \ No newline at end of file +} diff --git a/frontend/src/components/settings/settings/Settings.css b/frontend/src/components/settings/settings/Settings.css index 87849d76..51f78ced 100644 --- a/frontend/src/components/settings/settings/Settings.css +++ b/frontend/src/components/settings/settings/Settings.css @@ -20,4 +20,4 @@ .settings-plt-typo { font-size: 14px; -} \ No newline at end of file +} diff --git a/frontend/src/components/settings/users/Users.css b/frontend/src/components/settings/users/Users.css index d5f3cc3a..f0890f55 100644 --- a/frontend/src/components/settings/users/Users.css +++ b/frontend/src/components/settings/users/Users.css @@ -14,4 +14,4 @@ .delete-user-modal{ width: 400px; -} \ No newline at end of file +} diff --git a/frontend/src/components/tool-settings/list-of-items/ListOfItems.css b/frontend/src/components/tool-settings/list-of-items/ListOfItems.css index 08a3ead2..fbde1161 100644 --- a/frontend/src/components/tool-settings/list-of-items/ListOfItems.css +++ b/frontend/src/components/tool-settings/list-of-items/ListOfItems.css @@ -11,4 +11,4 @@ object-fit: cover; width: 100%; height: auto; -} \ No newline at end of file +} diff --git a/frontend/src/components/widgets/custom-button/CustomButton.css b/frontend/src/components/widgets/custom-button/CustomButton.css index 7896718e..da3fc49c 100644 --- a/frontend/src/components/widgets/custom-button/CustomButton.css +++ b/frontend/src/components/widgets/custom-button/CustomButton.css @@ -10,4 +10,4 @@ background-color: #0e4274 !important; border-color: #0e4274; color: #FFFFFF; -} \ No newline at end of file +} diff --git a/frontend/src/components/widgets/grid-view/GridView.css b/frontend/src/components/widgets/grid-view/GridView.css index 0d31ba74..8ae88330 100644 --- a/frontend/src/components/widgets/grid-view/GridView.css +++ b/frontend/src/components/widgets/grid-view/GridView.css @@ -18,4 +18,4 @@ min-width: 210px; display: grid; grid-template-rows: 1fr auto; -} \ No newline at end of file +} diff --git a/frontend/src/components/widgets/spinner-loader/SpinnerLoader.css b/frontend/src/components/widgets/spinner-loader/SpinnerLoader.css index f76b656a..11337206 100644 --- a/frontend/src/components/widgets/spinner-loader/SpinnerLoader.css +++ b/frontend/src/components/widgets/spinner-loader/SpinnerLoader.css @@ -20,4 +20,4 @@ .spinner-loader-layout .ant-spin-dot-item:nth-child(4) { background-color: #FF4D6D !important; /* Change the color of the fourth dot */ -} \ No newline at end of file +} diff --git a/frontend/src/components/widgets/top-bar/TopBar.css b/frontend/src/components/widgets/top-bar/TopBar.css index 319eab2b..52a380d6 100644 --- a/frontend/src/components/widgets/top-bar/TopBar.css +++ b/frontend/src/components/widgets/top-bar/TopBar.css @@ -18,4 +18,4 @@ display: inline; line-height: 24px; padding-left: 10px; -} \ No newline at end of file +} diff --git a/frontend/src/layouts/content-center-layout/ContentCenterLayout.css b/frontend/src/layouts/content-center-layout/ContentCenterLayout.css index bc6f5133..603fdb1d 100644 --- a/frontend/src/layouts/content-center-layout/ContentCenterLayout.css +++ b/frontend/src/layouts/content-center-layout/ContentCenterLayout.css @@ -28,4 +28,3 @@ .content-center-body-3 { text-align: center; } - diff --git a/frontend/src/layouts/island-layout/IslandLayout.css b/frontend/src/layouts/island-layout/IslandLayout.css index 111f3227..a5314b80 100644 --- a/frontend/src/layouts/island-layout/IslandLayout.css +++ b/frontend/src/layouts/island-layout/IslandLayout.css @@ -8,4 +8,4 @@ .island-layout > div { background-color: var(--white); height: 100%; -} \ No newline at end of file +} diff --git a/prompt-service/.gitignore b/prompt-service/.gitignore index e68e813a..df775c6b 100644 --- a/prompt-service/.gitignore +++ b/prompt-service/.gitignore @@ -161,4 +161,4 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -plugins \ No newline at end of file +plugins diff --git a/tools/classifier/src/main.py b/tools/classifier/src/main.py index 71dc677f..feab3cfa 100644 --- a/tools/classifier/src/main.py +++ b/tools/classifier/src/main.py @@ -29,7 +29,9 @@ class UnstractClassifier(BaseTool): elif len(bins) < 2: self.stream_error_and_exit("At least two bins are required") if not llm_adapter_instance_id: - self.stream_error_and_exit("Choose an LLM to process the classifier") + self.stream_error_and_exit( + "Choose an LLM to process the classifier" + ) if not text_extraction_adapter_id: self.stream_error_and_exit("Choose an LLM to extract the documents") diff --git a/tools/text_extractor/src/main.py b/tools/text_extractor/src/main.py index a4cba3f7..390da2a0 100644 --- a/tools/text_extractor/src/main.py +++ b/tools/text_extractor/src/main.py @@ -46,19 +46,21 @@ class TextExtractor(BaseTool): text_extraction_adapter_id = settings["extractorId"] source_name = self.get_exec_metadata.get(MetadataKey.SOURCE_NAME) - self.stream_log(f"Extractor ID: {text_extraction_adapter_id} " - "has been retrieved from settings.") - - input_log = ( - f"Processing file: \n\n`{source_name}`" + self.stream_log( + f"Extractor ID: {text_extraction_adapter_id} " + "has been retrieved from settings." ) + + input_log = f"Processing file: \n\n`{source_name}`" self.stream_update(input_log, state=LogState.INPUT_UPDATE) tool_extraction = X2Text(tool=self) text_extraction_adapter = tool_extraction.get_x2text( - adapter_instance_id=text_extraction_adapter_id) + adapter_instance_id=text_extraction_adapter_id + ) self.stream_log( - "Text extraction adapter has been created successfully.") + "Text extraction adapter has been created successfully." + ) extracted_text = text_extraction_adapter.process( input_file_path=input_file ) @@ -66,7 +68,7 @@ class TextExtractor(BaseTool): self.stream_log("Text has been extracted successfully.") - first_5_lines = '\n\n'.join(extracted_text.split('\n')[:5]) + first_5_lines = "\n\n".join(extracted_text.split("\n")[:5]) output_log = ( f"### Text\n\n```text\n{first_5_lines}\n```\n\n...(truncated)" ) @@ -75,8 +77,7 @@ class TextExtractor(BaseTool): try: self.stream_log("Preparing to write the extracted text.") if source_name: - output_path = ( - Path(output_dir) / f"{Path(source_name).stem}.txt") + output_path = Path(output_dir) / f"{Path(source_name).stem}.txt" with open(output_path, "w", encoding="utf-8") as file: file.write(extracted_text) @@ -93,11 +94,11 @@ class TextExtractor(BaseTool): def convert_to_actual_string(self, text: Any) -> str: if isinstance(text, bytes): - return text.decode('utf-8') + return text.decode("utf-8") elif isinstance(text, str): if text.startswith("b'") and text.endswith("'"): bytes_text: bytes = ast.literal_eval(text) - return bytes_text.decode('utf-8') + return bytes_text.decode("utf-8") else: return text else: diff --git a/unstract/connectors/src/unstract/connectors/databases/bigquery/bigquery.py b/unstract/connectors/src/unstract/connectors/databases/bigquery/bigquery.py index ca92f8da..74183682 100644 --- a/unstract/connectors/src/unstract/connectors/databases/bigquery/bigquery.py +++ b/unstract/connectors/src/unstract/connectors/databases/bigquery/bigquery.py @@ -4,6 +4,7 @@ from typing import Any from google.cloud import bigquery from google.cloud.bigquery import Client + from unstract.connectors.databases.unstract_db import UnstractDB from unstract.connectors.exceptions import ConnectorError @@ -29,9 +30,7 @@ class BigQuery(UnstractDB): @staticmethod def get_icon() -> str: - return ( - "/icons/connector-icons/Bigquery.png" - ) + return "/icons/connector-icons/Bigquery.png" @staticmethod def get_json_schema() -> str: diff --git a/unstract/connectors/src/unstract/connectors/databases/mariadb/mariadb.py b/unstract/connectors/src/unstract/connectors/databases/mariadb/mariadb.py index abba5d34..cd1997ea 100644 --- a/unstract/connectors/src/unstract/connectors/databases/mariadb/mariadb.py +++ b/unstract/connectors/src/unstract/connectors/databases/mariadb/mariadb.py @@ -3,6 +3,7 @@ from typing import Any import pymysql from pymysql.connections import Connection + from unstract.connectors.databases.unstract_db import UnstractDB @@ -30,9 +31,7 @@ class MariaDB(UnstractDB): @staticmethod def get_icon() -> str: - return ( - "/icons/connector-icons/MariaDB.png" - ) + return "/icons/connector-icons/MariaDB.png" @staticmethod def get_json_schema() -> str: diff --git a/unstract/connectors/src/unstract/connectors/databases/mssql/mssql.py b/unstract/connectors/src/unstract/connectors/databases/mssql/mssql.py index bd7a0078..3f134238 100644 --- a/unstract/connectors/src/unstract/connectors/databases/mssql/mssql.py +++ b/unstract/connectors/src/unstract/connectors/databases/mssql/mssql.py @@ -3,6 +3,7 @@ from typing import Any import pymssql from pymssql import Connection + from unstract.connectors.databases.unstract_db import UnstractDB diff --git a/unstract/connectors/src/unstract/connectors/databases/mysql/mysql.py b/unstract/connectors/src/unstract/connectors/databases/mysql/mysql.py index 2c7091d7..f2fb369d 100644 --- a/unstract/connectors/src/unstract/connectors/databases/mysql/mysql.py +++ b/unstract/connectors/src/unstract/connectors/databases/mysql/mysql.py @@ -3,6 +3,7 @@ from typing import Any import pymysql from pymysql.connections import Connection + from unstract.connectors.databases.unstract_db import UnstractDB diff --git a/unstract/connectors/src/unstract/connectors/databases/postgresql/postgresql.py b/unstract/connectors/src/unstract/connectors/databases/postgresql/postgresql.py index 7e97175d..7b2ddf05 100644 --- a/unstract/connectors/src/unstract/connectors/databases/postgresql/postgresql.py +++ b/unstract/connectors/src/unstract/connectors/databases/postgresql/postgresql.py @@ -3,6 +3,7 @@ from typing import Any import psycopg2 from psycopg2.extensions import connection + from unstract.connectors.databases.unstract_db import UnstractDB @@ -20,7 +21,11 @@ class PostgreSQL(UnstractDB): if not self.schema: self.schema = "public" if not self.connection_url and not ( - self.user and self.password and self.host and self.port and self.database + self.user + and self.password + and self.host + and self.port + and self.database ): raise ValueError( "Either ConnectionURL or connection parameters must be provided." @@ -40,9 +45,7 @@ class PostgreSQL(UnstractDB): @staticmethod def get_icon() -> str: - return ( - "/icons/connector-icons/Postgresql.png" - ) + return "/icons/connector-icons/Postgresql.png" @staticmethod def get_json_schema() -> str: diff --git a/unstract/connectors/src/unstract/connectors/databases/redshift/redshift.py b/unstract/connectors/src/unstract/connectors/databases/redshift/redshift.py index 015da7b5..ca7c3564 100644 --- a/unstract/connectors/src/unstract/connectors/databases/redshift/redshift.py +++ b/unstract/connectors/src/unstract/connectors/databases/redshift/redshift.py @@ -3,6 +3,7 @@ from typing import Any import psycopg2 from psycopg2.extensions import connection + from unstract.connectors.databases.unstract_db import UnstractDB @@ -33,9 +34,7 @@ class Redshift(UnstractDB): @staticmethod def get_icon() -> str: - return ( - "/icons/connector-icons/Redshift.png" - ) + return "/icons/connector-icons/Redshift.png" @staticmethod def get_json_schema() -> str: diff --git a/unstract/connectors/src/unstract/connectors/databases/snowflake/snowflake.py b/unstract/connectors/src/unstract/connectors/databases/snowflake/snowflake.py index a9f6b005..8c8feb47 100644 --- a/unstract/connectors/src/unstract/connectors/databases/snowflake/snowflake.py +++ b/unstract/connectors/src/unstract/connectors/databases/snowflake/snowflake.py @@ -3,6 +3,7 @@ from typing import Any import snowflake.connector from snowflake.connector.connection import SnowflakeConnection + from unstract.connectors.databases.unstract_db import UnstractDB @@ -32,9 +33,7 @@ class SnowflakeDB(UnstractDB): @staticmethod def get_icon() -> str: - return ( - "/icons/connector-icons/Snowflake.png" - ) + return "/icons/connector-icons/Snowflake.png" @staticmethod def get_json_schema() -> str: diff --git a/unstract/connectors/src/unstract/connectors/filesystems/azure_cloud_storage/azure_cloud_storage.py b/unstract/connectors/src/unstract/connectors/filesystems/azure_cloud_storage/azure_cloud_storage.py index 2835fffd..91389f34 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/azure_cloud_storage/azure_cloud_storage.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/azure_cloud_storage/azure_cloud_storage.py @@ -36,9 +36,7 @@ class AzureCloudStorageFS(UnstractFileSystem): @staticmethod def get_icon() -> str: - return ( - "/icons/connector-icons/azure_blob_storage.png" - ) + return "/icons/connector-icons/azure_blob_storage.png" @staticmethod def get_json_schema() -> str: diff --git a/unstract/connectors/src/unstract/connectors/filesystems/box/box.py b/unstract/connectors/src/unstract/connectors/filesystems/box/box.py index 2d10a180..7229413c 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/box/box.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/box/box.py @@ -7,8 +7,11 @@ from typing import Any from boxfs import BoxFileSystem from boxsdk import JWTAuth from boxsdk.exception import BoxOAuthException + from unstract.connectors.exceptions import ConnectorError -from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem +from unstract.connectors.filesystems.unstract_file_system import ( + UnstractFileSystem, +) logger = logging.getLogger(__name__) logging.getLogger("boxsdk").setLevel(logging.ERROR) @@ -77,7 +80,9 @@ class BoxFS(UnstractFileSystem): @staticmethod def get_description() -> str: - return "Fetch and store data to and from the Box content management system" + return ( + "Fetch and store data to and from the Box content management system" + ) @staticmethod def get_icon() -> str: diff --git a/unstract/connectors/src/unstract/connectors/filesystems/google_cloud_storage/google_cloud_storage.py b/unstract/connectors/src/unstract/connectors/filesystems/google_cloud_storage/google_cloud_storage.py index 902aec48..ba1f32fb 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/google_cloud_storage/google_cloud_storage.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/google_cloud_storage/google_cloud_storage.py @@ -35,9 +35,7 @@ class GoogleCloudStorageFS(UnstractFileSystem): @staticmethod def get_icon() -> str: - return ( - "/icons/connector-icons/google_cloud_storage.png" - ) + return "/icons/connector-icons/google_cloud_storage.png" @staticmethod def get_json_schema() -> str: diff --git a/unstract/connectors/src/unstract/connectors/filesystems/google_drive/google_drive.py b/unstract/connectors/src/unstract/connectors/filesystems/google_drive/google_drive.py index e91275b4..7fe9f1e6 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/google_drive/google_drive.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/google_drive/google_drive.py @@ -6,9 +6,14 @@ from typing import Any from oauth2client.client import OAuth2Credentials from pydrive2.auth import GoogleAuth from pydrive2.fs import GDriveFileSystem + from unstract.connectors.exceptions import ConnectorError -from unstract.connectors.filesystems.google_drive.constants import GDriveConstants -from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem +from unstract.connectors.filesystems.google_drive.constants import ( + GDriveConstants, +) +from unstract.connectors.filesystems.unstract_file_system import ( + UnstractFileSystem, +) from unstract.connectors.gcs_helper import GCSHelper logger = logging.getLogger(__name__) @@ -34,7 +39,9 @@ class GoogleDriveFS(UnstractFileSystem): "invalid": False, "access_token": settings["access_token"], "refresh_token": settings["refresh_token"], - GDriveConstants.TOKEN_EXPIRY: settings[GDriveConstants.TOKEN_EXPIRY], + GDriveConstants.TOKEN_EXPIRY: settings[ + GDriveConstants.TOKEN_EXPIRY + ], } gauth = GoogleAuth( settings_file=f"{os.path.dirname(__file__)}/static/settings.yaml", @@ -59,7 +66,7 @@ class GoogleDriveFS(UnstractFileSystem): @staticmethod def get_icon() -> str: - return "/icons/connector-icons/Google%20Drive.png" # noqa + return "/icons/connector-icons/Google%20Drive.png" @staticmethod def get_json_schema() -> str: diff --git a/unstract/connectors/src/unstract/connectors/filesystems/http/http.py b/unstract/connectors/src/unstract/connectors/filesystems/http/http.py index 3a5f6a36..9169086a 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/http/http.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/http/http.py @@ -4,8 +4,11 @@ from typing import Any import aiohttp from fsspec.implementations.http import HTTPFileSystem + from unstract.connectors.exceptions import ConnectorError -from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem +from unstract.connectors.filesystems.unstract_file_system import ( + UnstractFileSystem, +) logger = logging.getLogger(__name__) @@ -22,7 +25,9 @@ class HttpFS(UnstractFileSystem): "base_url": settings["base_url"], } if all(settings.get(key) for key in ("username", "password")): - basic_auth = aiohttp.BasicAuth(settings["username"], settings["password"]) + basic_auth = aiohttp.BasicAuth( + settings["username"], settings["password"] + ) client_kwargs.update({"auth": basic_auth}) self.http_fs = HTTPFileSystem(client_kwargs=client_kwargs) diff --git a/unstract/connectors/src/unstract/connectors/filesystems/local_storage/local_storage.py b/unstract/connectors/src/unstract/connectors/filesystems/local_storage/local_storage.py index 38f73734..b51ab2f9 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/local_storage/local_storage.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/local_storage/local_storage.py @@ -3,6 +3,7 @@ import os from typing import Any, Optional from fsspec.implementations.local import LocalFileSystem + from unstract.connectors.filesystems.unstract_file_system import ( UnstractFileSystem, ) diff --git a/unstract/connectors/src/unstract/connectors/filesystems/minio/minio.py b/unstract/connectors/src/unstract/connectors/filesystems/minio/minio.py index fe3dc762..f22bf4c2 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/minio/minio.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/minio/minio.py @@ -3,6 +3,7 @@ import os from typing import Any from s3fs.core import S3FileSystem + from unstract.connectors.exceptions import ConnectorError from unstract.connectors.filesystems.unstract_file_system import ( UnstractFileSystem, @@ -47,9 +48,7 @@ class MinioFS(UnstractFileSystem): @staticmethod def get_icon() -> str: - return ( - "/icons/connector-icons/S3.png" - ) + return "/icons/connector-icons/S3.png" @staticmethod def get_json_schema() -> str: diff --git a/unstract/connectors/src/unstract/connectors/filesystems/register.py b/unstract/connectors/src/unstract/connectors/filesystems/register.py index d820e5f6..5bcc637c 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/register.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/register.py @@ -4,7 +4,9 @@ from importlib import import_module from typing import Any from unstract.connectors.constants import Common -from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem +from unstract.connectors.filesystems.unstract_file_system import ( + UnstractFileSystem, +) logger = logging.getLogger(__name__) @@ -22,7 +24,9 @@ def register_connectors(connectors: dict[str, Any]) -> None: module = import_module(full_module_path) metadata = getattr(module, Common.METADATA, {}) if metadata.get("is_active", False): - connector_class: UnstractFileSystem = metadata[Common.CONNECTOR] + connector_class: UnstractFileSystem = metadata[ + Common.CONNECTOR + ] connector_id = connector_class.get_id() if not connector_id or (connector_id in connectors): logger.warning(f"Duplicate Id : {connector_id}") diff --git a/unstract/connectors/src/unstract/connectors/filesystems/ucs/ucs.py b/unstract/connectors/src/unstract/connectors/filesystems/ucs/ucs.py index eb95b2f7..e0cd0a62 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/ucs/ucs.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/ucs/ucs.py @@ -23,7 +23,7 @@ class UnstractCloudStorage(MinioFS): @staticmethod def get_icon() -> str: - return "/icons/connector-icons/Pandora%20Storage.png" # noqa + return "/icons/connector-icons/Pandora%20Storage.png" @staticmethod def get_json_schema() -> str: diff --git a/unstract/connectors/src/unstract/connectors/filesystems/unstract_file_system.py b/unstract/connectors/src/unstract/connectors/filesystems/unstract_file_system.py index b7a931c5..304488e0 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/unstract_file_system.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/unstract_file_system.py @@ -2,6 +2,7 @@ import logging from abc import ABC, abstractmethod from fsspec import AbstractFileSystem + from unstract.connectors.base import UnstractConnector from unstract.connectors.enums import ConnectorMode diff --git a/unstract/connectors/src/unstract/connectors/filesystems/zs_dropbox/exceptions.py b/unstract/connectors/src/unstract/connectors/filesystems/zs_dropbox/exceptions.py index 0c0d9cf2..1453a8f2 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/zs_dropbox/exceptions.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/zs_dropbox/exceptions.py @@ -2,6 +2,7 @@ from dropbox.auth import AuthError from dropbox.exceptions import ApiError from dropbox.exceptions import AuthError as ExcAuthError from dropbox.exceptions import DropboxException + from unstract.connectors.exceptions import ConnectorError diff --git a/unstract/connectors/src/unstract/connectors/filesystems/zs_dropbox/zs_dropbox.py b/unstract/connectors/src/unstract/connectors/filesystems/zs_dropbox/zs_dropbox.py index 1effa9c8..2ac4bb0e 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/zs_dropbox/zs_dropbox.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/zs_dropbox/zs_dropbox.py @@ -4,6 +4,7 @@ from typing import Any from dropbox.exceptions import DropboxException from dropboxdrivefs import DropboxDriveFileSystem + from unstract.connectors.exceptions import ConnectorError from unstract.connectors.filesystems.unstract_file_system import ( UnstractFileSystem, @@ -35,9 +36,7 @@ class DropboxFS(UnstractFileSystem): @staticmethod def get_icon() -> str: # TODO: Add an icon to GCS and serve it - return ( - "/icons/connector-icons/Dropbox.png" - ) + return "/icons/connector-icons/Dropbox.png" @staticmethod def get_json_schema() -> str: diff --git a/unstract/connectors/src/unstract/connectors/gcs_helper.py b/unstract/connectors/src/unstract/connectors/gcs_helper.py index 2e606adb..82c351c1 100644 --- a/unstract/connectors/src/unstract/connectors/gcs_helper.py +++ b/unstract/connectors/src/unstract/connectors/gcs_helper.py @@ -31,8 +31,10 @@ class GCSHelper: "GOOGLE_PROJECT_ID environment variable not set" ) - self.google_credentials = service_account.Credentials.from_service_account_info( - json.loads(self.google_service_json) + self.google_credentials = ( + service_account.Credentials.from_service_account_info( + json.loads(self.google_service_json) + ) ) def get_google_credentials(self) -> Credentials: @@ -58,22 +60,30 @@ class GCSHelper: md5_hash_bytes = base64.b64decode(blob.md5_hash) md5_hash_hex = md5_hash_bytes.hex() except Exception: - logger.error(f"Could not get blob {object_name} from bucket {bucket_name}") + logger.error( + f"Could not get blob {object_name} from bucket {bucket_name}" + ) return md5_hash_hex - def upload_file(self, bucket_name: str, object_name: str, file_path: str) -> None: + def upload_file( + self, bucket_name: str, object_name: str, file_path: str + ) -> None: client = Client(credentials=self.google_credentials) bucket = client.bucket(bucket_name) blob = bucket.blob(object_name) blob.upload_from_filename(file_path) - def upload_text(self, bucket_name: str, object_name: str, text: str) -> None: + def upload_text( + self, bucket_name: str, object_name: str, text: str + ) -> None: client = Client(credentials=self.google_credentials) bucket = client.bucket(bucket_name) blob = bucket.blob(object_name) blob.upload_from_string(text) - def upload_object(self, bucket_name: str, object_name: str, object: Any) -> None: + def upload_object( + self, bucket_name: str, object_name: str, object: Any + ) -> None: client = Client(credentials=self.google_credentials) bucket = client.bucket(bucket_name) blob = bucket.blob(object_name) @@ -91,4 +101,6 @@ class GCSHelper: logger.info(f"Reading file {object_name} from bucket {bucket_name}") return obj except Exception: - logger.error(f"Could not get blob {object_name} from bucket {bucket_name}") + logger.error( + f"Could not get blob {object_name} from bucket {bucket_name}" + ) diff --git a/unstract/connectors/tests/filesystems/test_google_drive_fs.py b/unstract/connectors/tests/filesystems/test_google_drive_fs.py index bbbf8669..bfeccbac 100644 --- a/unstract/connectors/tests/filesystems/test_google_drive_fs.py +++ b/unstract/connectors/tests/filesystems/test_google_drive_fs.py @@ -1,8 +1,8 @@ -# flake8: noqa - import unittest -from unstract.connectors.filesystems.google_drive.google_drive import GoogleDriveFS +from unstract.connectors.filesystems.google_drive.google_drive import ( + GoogleDriveFS, +) class TestGoogleDriveFS(unittest.TestCase): diff --git a/unstract/connectors/tests/filesystems/test_miniofs.py b/unstract/connectors/tests/filesystems/test_miniofs.py index 0443bbde..1ab09fcc 100644 --- a/unstract/connectors/tests/filesystems/test_miniofs.py +++ b/unstract/connectors/tests/filesystems/test_miniofs.py @@ -31,7 +31,9 @@ class TestMinoFS(unittest.TestCase): access_key = os.environ.get("MINIO_ACCESS_KEY_ID") secret_key = os.environ.get("MINIO_SECRET_ACCESS_KEY") print(access_key, secret_key) - bucket_name = os.environ.get("FREE_STORAGE_AWS_BUCKET_NAME", "minio-test") + bucket_name = os.environ.get( + "FREE_STORAGE_AWS_BUCKET_NAME", "minio-test" + ) s3 = MinioFS( { "key": access_key, diff --git a/unstract/core/src/unstract/core/llm_helper/config.py b/unstract/core/src/unstract/core/llm_helper/config.py index c9239f03..2e711300 100644 --- a/unstract/core/src/unstract/core/llm_helper/config.py +++ b/unstract/core/src/unstract/core/llm_helper/config.py @@ -33,22 +33,29 @@ class AzureOpenAIConfig: @classmethod def from_env(cls) -> "AzureOpenAIConfig": kwargs = { - "model": UnstractUtils.get_env(OpenAIKeys.OPENAI_API_MODEL, raise_err=True), + "model": UnstractUtils.get_env( + OpenAIKeys.OPENAI_API_MODEL, raise_err=True + ), "deployment_name": UnstractUtils.get_env( OpenAIKeys.OPENAI_API_ENGINE, raise_err=True ), "engine": UnstractUtils.get_env( OpenAIKeys.OPENAI_API_ENGINE, raise_err=True ), - "api_key": UnstractUtils.get_env(OpenAIKeys.OPENAI_API_KEY, raise_err=True), + "api_key": UnstractUtils.get_env( + OpenAIKeys.OPENAI_API_KEY, raise_err=True + ), "api_version": UnstractUtils.get_env( - OpenAIKeys.OPENAI_API_VERSION, default=OpenAIDefaults.OPENAI_API_VERSION + OpenAIKeys.OPENAI_API_VERSION, + default=OpenAIDefaults.OPENAI_API_VERSION, ), "azure_endpoint": UnstractUtils.get_env( - OpenAIKeys.OPENAI_API_BASE, default=OpenAIDefaults.OPENAI_API_BASE + OpenAIKeys.OPENAI_API_BASE, + default=OpenAIDefaults.OPENAI_API_BASE, ), "api_type": UnstractUtils.get_env( - OpenAIKeys.OPENAI_API_TYPE, default=OpenAIDefaults.OPENAI_API_TYPE + OpenAIKeys.OPENAI_API_TYPE, + default=OpenAIDefaults.OPENAI_API_TYPE, ), } return cls(**kwargs) diff --git a/unstract/core/src/unstract/core/llm_helper/llm_cache.py b/unstract/core/src/unstract/core/llm_helper/llm_cache.py index 5af5b67d..580c51e5 100644 --- a/unstract/core/src/unstract/core/llm_helper/llm_cache.py +++ b/unstract/core/src/unstract/core/llm_helper/llm_cache.py @@ -12,7 +12,9 @@ class LLMCache: redis_host = os.environ.get("REDIS_HOST") redis_port = os.environ.get("REDIS_PORT") if redis_host is None or redis_port is None: - raise RuntimeError("REDIS_HOST or REDIS_PORT environment variable not set") + raise RuntimeError( + "REDIS_HOST or REDIS_PORT environment variable not set" + ) redis_password = os.environ.get("REDIS_PASSWORD", None) if redis_password and ( redis_password == "" or redis_password.lower() == "none" @@ -105,7 +107,10 @@ class LLMCache: """ logger.info(f"Clearing cache with prefix: {self.cache_key_prefix}") keys_to_delete = [ - key for key in self.llm_cache.scan_iter(match=self.cache_key_prefix + "*") + key + for key in self.llm_cache.scan_iter( + match=self.cache_key_prefix + "*" + ) ] if keys_to_delete: return self.delete(*keys_to_delete) diff --git a/unstract/core/src/unstract/core/llm_helper/llm_helper.py b/unstract/core/src/unstract/core/llm_helper/llm_helper.py index 4874a44d..bdecc6ab 100644 --- a/unstract/core/src/unstract/core/llm_helper/llm_helper.py +++ b/unstract/core/src/unstract/core/llm_helper/llm_helper.py @@ -4,6 +4,7 @@ import time from typing import Optional from llama_index.llms import AzureOpenAI + from unstract.core.llm_helper.config import AzureOpenAIConfig from unstract.core.llm_helper.enums import LLMResult, PromptContext from unstract.core.llm_helper.llm_cache import LLMCache @@ -46,7 +47,9 @@ class LLMHelper: prompt_for_model = self.prompt if self.prompt_context == PromptContext.GENERATE_CRON_STRING: - prompt_for_model = prompt_for_model.replace("{$user_prompt}", user_prompt) + prompt_for_model = prompt_for_model.replace( + "{$user_prompt}", user_prompt + ) return prompt_for_model @@ -67,7 +70,9 @@ class LLMHelper: if ai_service == "azure-open-ai": logger.info("Using Azure OpenAI") if use_cache: - response = self.llm_cache.get_for_prompt(prompt=prompt_for_model) + response = self.llm_cache.get_for_prompt( + prompt=prompt_for_model + ) if response: return LLMResponse( result=LLMResult.OK, output=response, cost_type="cache" @@ -102,8 +107,12 @@ class LLMHelper: logger.info(f"OpenAI Response: {resp}") time_taken = end_time - start_time - self.llm_cache.set_for_prompt(prompt=prompt_for_model, response=resp) - return LLMResponse(output=resp, cost_type=ai_service, time_taken=time_taken) + self.llm_cache.set_for_prompt( + prompt=prompt_for_model, response=resp + ) + return LLMResponse( + output=resp, cost_type=ai_service, time_taken=time_taken + ) else: logger.error(f"AI service '{ai_service}' not found") return LLMResponse( diff --git a/unstract/core/src/unstract/core/llm_workflow_generator/llm_interface.py b/unstract/core/src/unstract/core/llm_workflow_generator/llm_interface.py index c8289883..d4efb2c4 100644 --- a/unstract/core/src/unstract/core/llm_workflow_generator/llm_interface.py +++ b/unstract/core/src/unstract/core/llm_workflow_generator/llm_interface.py @@ -7,9 +7,10 @@ import uuid import redis from llama_index.llms import AzureOpenAI -from unstract.core.llm_helper.config import AzureOpenAIConfig from unstract.tool_registry.dto import Properties, Tool +from unstract.core.llm_helper.config import AzureOpenAIConfig + # Refactor dated: 19/12/2023 ( Removal of Appkit removal) diff --git a/unstract/core/src/unstract/core/utilities.py b/unstract/core/src/unstract/core/utilities.py index 0364d28d..8f45dfed 100644 --- a/unstract/core/src/unstract/core/utilities.py +++ b/unstract/core/src/unstract/core/utilities.py @@ -4,7 +4,9 @@ from typing import Optional class UnstractUtils: @staticmethod - def get_env(env_key: str, default: Optional[str] = None, raise_err=False) -> str: + def get_env( + env_key: str, default: Optional[str] = None, raise_err=False + ) -> str: """Returns the value of an env variable. If its empty or None, raises an error diff --git a/unstract/core/tests/llm_helper/test_llm_cache.py b/unstract/core/tests/llm_helper/test_llm_cache.py index 3f638924..d94f7440 100644 --- a/unstract/core/tests/llm_helper/test_llm_cache.py +++ b/unstract/core/tests/llm_helper/test_llm_cache.py @@ -18,7 +18,9 @@ class LLMCacheTests(unittest.TestCase): cache.set_for_prompt("prompt1", "response1") cache.set_for_prompt("prompt2", "response2") cache.clear_by_prefix() - self.assertEqual(cache.get_for_prompt("prompt1"), "", "Cache is not cleared") + self.assertEqual( + cache.get_for_prompt("prompt1"), "", "Cache is not cleared" + ) if __name__ == "__main__": diff --git a/unstract/core/tests/test_pubsub_helper.py b/unstract/core/tests/test_pubsub_helper.py index ae3e1026..33a70cd5 100644 --- a/unstract/core/tests/test_pubsub_helper.py +++ b/unstract/core/tests/test_pubsub_helper.py @@ -11,7 +11,9 @@ class PubSubHelperTestCase(unittest.TestCase): ) ps2 = Log.publish( project_guid="test", - message=Log.log(level="ERROR", stage="COMPILE", message="Compile failed"), + message=Log.log( + level="ERROR", stage="COMPILE", message="Compile failed" + ), ) self.assertEqual(ps1, True) self.assertEqual(ps2, True) diff --git a/unstract/flags/src/unstract/flags/client.py b/unstract/flags/src/unstract/flags/client.py index 6dae709a..75c87f6f 100644 --- a/unstract/flags/src/unstract/flags/client.py +++ b/unstract/flags/src/unstract/flags/client.py @@ -2,6 +2,7 @@ import os from typing import Optional import grpc + from unstract.flags import evaluation_pb2, evaluation_pb2_grpc diff --git a/unstract/flags/src/unstract/flags/tests/test_client.py b/unstract/flags/src/unstract/flags/tests/test_client.py index 16884820..7b8ad4eb 100644 --- a/unstract/flags/src/unstract/flags/tests/test_client.py +++ b/unstract/flags/src/unstract/flags/tests/test_client.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock, patch import grpc + from unstract.flags.client import EvaluationClient diff --git a/unstract/tool-registry/src/unstract/tool_registry/helper.py b/unstract/tool-registry/src/unstract/tool_registry/helper.py index 057a147d..1dff798c 100644 --- a/unstract/tool-registry/src/unstract/tool_registry/helper.py +++ b/unstract/tool-registry/src/unstract/tool_registry/helper.py @@ -104,9 +104,9 @@ class ToolRegistryHelper: image_tag=tool_meta.tag, ) - tool_properties: Optional[ - dict[str, Any] - ] = tool_sandbox.get_properties() + tool_properties: Optional[dict[str, Any]] = ( + tool_sandbox.get_properties() + ) if not tool_properties: return {} return tool_properties diff --git a/unstract/tool-registry/src/unstract/tool_registry/schema_validator.py b/unstract/tool-registry/src/unstract/tool_registry/schema_validator.py index 9ededb2f..ef37a891 100644 --- a/unstract/tool-registry/src/unstract/tool_registry/schema_validator.py +++ b/unstract/tool-registry/src/unstract/tool_registry/schema_validator.py @@ -27,7 +27,9 @@ class JsonSchemaValidator: logger.error(f"Validation error: {e}") raise InvalidSchemaInput - def validate_and_filter(self, data: dict[str, Any]) -> Optional[dict[str, Any]]: + def validate_and_filter( + self, data: dict[str, Any] + ) -> Optional[dict[str, Any]]: """Validates the input data against the schema and filters the data based on the schema's properties. diff --git a/unstract/tool-registry/src/unstract/tool_registry/tool_registry.py b/unstract/tool-registry/src/unstract/tool_registry/tool_registry.py index dcb6d233..73d54e3c 100644 --- a/unstract/tool-registry/src/unstract/tool_registry/tool_registry.py +++ b/unstract/tool-registry/src/unstract/tool_registry/tool_registry.py @@ -211,9 +211,9 @@ class ToolRegistry: ) if not tool_meta: continue - properties: Optional[ - dict[str, Any] - ] = self.helper.get_tool_properties(tool_meta=tool_meta) + properties: Optional[dict[str, Any]] = ( + self.helper.get_tool_properties(tool_meta=tool_meta) + ) spec: Optional[dict[str, Any]] = self.helper.get_tool_spec( tool_meta=tool_meta ) @@ -235,9 +235,9 @@ class ToolRegistry: "icon": icon, } else: - tools: dict[ - str, dict[str, Any] - ] = self.helper.get_all_tools_from_disk() + tools: dict[str, dict[str, Any]] = ( + self.helper.get_all_tools_from_disk() + ) for tool, configuration in tools.items(): properties = configuration.get("properties") spec = configuration.get("spec") diff --git a/unstract/tool-registry/tests/test_tool_registry.py b/unstract/tool-registry/tests/test_tool_registry.py index 7fff83b0..4a09a459 100644 --- a/unstract/tool-registry/tests/test_tool_registry.py +++ b/unstract/tool-registry/tests/test_tool_registry.py @@ -25,9 +25,13 @@ class TestToolRegistry(unittest.TestCase): "tools": [self.TEST_IMAGE_URL], } directory = os.path.dirname(os.path.abspath(__file__)) - registry_file_path = os.path.join(directory, TestToolRegistry.REGISTRY_FILE) + registry_file_path = os.path.join( + directory, TestToolRegistry.REGISTRY_FILE + ) with open(registry_file_path, "w") as yaml_file: - yaml.dump(test_registry_content, yaml_file, default_flow_style=False) + yaml.dump( + test_registry_content, yaml_file, default_flow_style=False + ) # Apply the patch to 'run_tool_and_get_logs' before each test method self.mock_run_tool_and_get_logs = patch.object( @@ -87,7 +91,9 @@ class TestToolRegistry(unittest.TestCase): def test_get_tool_properties_by_tool_id(self) -> None: tool_id = "document_indexer" - properties = self.registry.get_tool_properties_by_tool_id(tool_id=tool_id) + properties = self.registry.get_tool_properties_by_tool_id( + tool_id=tool_id + ) self.assertIsNotNone(properties) self.assertEqual(properties.get("function_name"), tool_id) diff --git a/unstract/workflow-execution/src/unstract/workflow_execution/tools_utils.py b/unstract/workflow-execution/src/unstract/workflow_execution/tools_utils.py index ab1edaea..b0f8b6f6 100644 --- a/unstract/workflow-execution/src/unstract/workflow_execution/tools_utils.py +++ b/unstract/workflow-execution/src/unstract/workflow_execution/tools_utils.py @@ -3,8 +3,6 @@ import os from typing import Any, Optional from redis import Redis - -from unstract.core.pubsub_helper import LogPublisher from unstract.tool_registry import ToolRegistry from unstract.tool_sandbox import ToolSandbox from unstract.workflow_execution.constants import ToolExecution @@ -17,6 +15,8 @@ from unstract.workflow_execution.exceptions import ( ToolNotFoundException, ) +from unstract.core.pubsub_helper import LogPublisher + logger = logging.getLogger(__name__) @@ -79,9 +79,9 @@ class ToolsUtils: dict[str, dict[str, Any]]: tools """ tool_uids = [tool_instance.tool_id for tool_instance in tool_instances] - tools: dict[ - str, dict[str, Any] - ] = self.tool_registry.get_available_tools(tool_uids) + tools: dict[str, dict[str, Any]] = ( + self.tool_registry.get_available_tools(tool_uids) + ) if not ( all(tool_uid in tools for tool_uid in tool_uids) and len(tool_uids) == len(tools) diff --git a/unstract/workflow-execution/src/unstract/workflow_execution/workflow_execution.py b/unstract/workflow-execution/src/unstract/workflow_execution/workflow_execution.py index 4703bbe8..2b3f1daa 100644 --- a/unstract/workflow-execution/src/unstract/workflow_execution/workflow_execution.py +++ b/unstract/workflow-execution/src/unstract/workflow_execution/workflow_execution.py @@ -4,8 +4,6 @@ import time from typing import Any, Optional, Union import redis - -from unstract.core.pubsub_helper import LogPublisher from unstract.tool_sandbox import ToolSandbox from unstract.workflow_execution.constants import StepExecution, ToolExecution from unstract.workflow_execution.dto import ToolInstance, WorkflowDto @@ -27,6 +25,8 @@ from unstract.workflow_execution.execution_file_handler import ( ) from unstract.workflow_execution.tools_utils import ToolsUtils +from unstract.core.pubsub_helper import LogPublisher + logger = logging.getLogger(__name__) diff --git a/unstract/workflow-execution/tests/workflow_test.py b/unstract/workflow-execution/tests/workflow_test.py index 906d6485..a10922a6 100644 --- a/unstract/workflow-execution/tests/workflow_test.py +++ b/unstract/workflow-execution/tests/workflow_test.py @@ -18,10 +18,16 @@ def get_mock_tool_instances() -> list[ToolInstance]: input=item["input"], output=item["output"], metadata=item["metadata"], - input_file_connector=ConnectorInstance(**item["input_file_connector"]), - output_file_connector=ConnectorInstance(**item["output_file_connector"]), + input_file_connector=ConnectorInstance( + **item["input_file_connector"] + ), + output_file_connector=ConnectorInstance( + **item["output_file_connector"] + ), input_db_connector=ConnectorInstance(**item["input_db_connector"]), - output_db_connector=ConnectorInstance(**item["output_db_connector"]), + output_db_connector=ConnectorInstance( + **item["output_db_connector"] + ), tool_settings=ToolSettings(**item["tool_settings"]), ) for item in tool_instance_data @@ -85,8 +91,8 @@ class TestWorkflow(unittest.TestCase): mock_tool_instances = get_mock_tool_instances() mock_tool_utils = Mock() - mock_tool_utils.validate_tool_instance_with_tools.side_effect = Exception( - "Test error message" + mock_tool_utils.validate_tool_instance_with_tools.side_effect = ( + Exception("Test error message") ) workflow = Workflow( diff --git a/worker/src/unstract/worker/worker.py b/worker/src/unstract/worker/worker.py index a28ddfc5..acd8a1d0 100644 --- a/worker/src/unstract/worker/worker.py +++ b/worker/src/unstract/worker/worker.py @@ -6,12 +6,11 @@ import uuid from typing import Any, Optional from dotenv import load_dotenv - -from unstract.core.pubsub_helper import LogPublisher from unstract.worker.constants import Env, LogType, ToolKey import docker from docker import DockerClient # type: ignore[attr-defined] +from unstract.core.pubsub_helper import LogPublisher load_dotenv()