Python 3.9 to 3.12 (#1231)

* python version updated from 3.9 into 3.12

* x2text-service updated with uv and python version 3.12

* x2text-service docker file updated

* Unstract packages updated with uv

* Runner updated with uv

* Promptservice updated with uv

* Platform service updated with uv

* backend service updated with uv

* root pyproject.toml file updated

* sdk version updated in services

* unstract package modules updated based on sdk version:

* docker file update

* pdm lock workflow modified to support uv

* Docs updated based on uv support

* lock automation updated

* snowflake module version updated into 3.14.0

* tox updated to support UV

* tox updated to support UV

* tox updated with pytest

* tox updated with pytest-md-report

* tox updated with module requirements

* python migration from 3.9 to 3.12

* tox updated with module requirements

* runner updated

* Commit uv.lock changes

* runner updated

* Commit uv.lock changes

* pytest.ini added

* x2text-service docker file updated

* pytest.ini removed

* environment updated to test

* docformatter commented on pre-commit

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* some pre-commit issues ignored

* some pre-commit issues ignored

* some pre-commit issues ignored

* some pre-commit issues ignored

* some pre-commit issues ignored

* pre-commit updates

* un used import removed from platfrom service controller

* tox issue fixed

* tox issue fixed

* docker files updated

* backend dockerfile updated

* open installation issue fixed

* Tools docker file updated with base python version 3.12

* python version updated into min 3.12 in pyproject.toml

* linting issue fixed

* uv version upgraded into 0.6.14

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* migrations excluded from ruff

* added PoethePoet task runner

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat: Added poe tasks for services (#1248)

* Added poe tasks for services

* reverted FE change made by mistake

* updated tool-sidecar to uv and python to 3.12.9

* minor updates in pyproject descreption

* feat: platform-service logging improvements (#1255)

feat: Used flask util from core to improve logging in platform-service, added core as a dependency to platform-service:

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: Platform-service build issue and numpy issue with Python 3.12 (#1258)

* fix: Platform-service build and numpy issue with Py 3.12

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: Removed backend dockerfile install statements for numpy

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* minor: Handled scenario when cost is not calculated due to no usage

* minor: Corrected content shown for workflow input

* fix: Minor fixes, used gthread for prompt-service, runner

* Commit uv.lock changes

* Removed unused line in tool dockerfile

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Chandrasekharan M <chandrasekharan@zipstack.com>
Co-authored-by: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com>
Co-authored-by: ali-zipstack <muhammad.ali@zipstack.com>
This commit is contained in:
Jaseem Jas
2025-04-24 16:07:02 +05:30
committed by GitHub
parent b381333c44
commit ba1df894d2
447 changed files with 32572 additions and 33358 deletions

View File

@@ -1,4 +1,4 @@
name: Run tox tests
name: Run tox tests with UV
on:
push:
@@ -14,44 +14,46 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.9'
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
# Install a specific version of uv.
version: "0.6.14"
python-version: 3.12.9
- name: Cache tox environments
uses: actions/cache@v4
with:
path: .tox/
key: ${{ runner.os }}-tox-${{ hashFiles('**/pyproject.toml', '**/tox.ini') }}
restore-keys: |
${{ runner.os }}-tox-
- name: Cache tox environments
uses: actions/cache@v4
with:
path: .tox/
key: ${{ runner.os }}-tox-uv-${{ hashFiles('**/pyproject.toml', '**/tox.ini') }}
restore-keys: |
${{ runner.os }}-tox-uv-
- name: Install tox
run: pip install tox
- name: Install tox with UV
run: uv pip install tox tox-uv
- name: Run tox
id: tox
run: |
tox
- name: Run tox
id: tox
run: |
tox
- name: Render the report to the PR
uses: marocchino/sticky-pull-request-comment@v2
with:
header: runner-test-report
recreate: true
path: runner-report.md
- name: Render the report to the PR
uses: marocchino/sticky-pull-request-comment@v2
with:
header: runner-test-report
recreate: true
path: runner-report.md
- name: Output reports to the job summary when tests fail
shell: bash
run: |
if [ -f "runner-report.md" ]; then
echo "<details><summary>Runner Test Report</summary>" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
cat "runner-report.md" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "</details>" >> $GITHUB_STEP_SUMMARY
fi
- name: Output reports to the job summary when tests fail
shell: bash
run: |
if [ -f "runner-report.md" ]; then
echo "<details><summary>Runner Test Report</summary>" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
cat "runner-report.md" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "</details>" >> $GITHUB_STEP_SUMMARY
fi

View File

@@ -1,21 +1,21 @@
name: Automate pdm.lock
name: Automate uv.lock
on:
pull_request:
types: [opened, synchronize, reopened, ready_for_review, review_requested]
branches: [main]
paths:
- '**/pyproject.toml'
- "**/pyproject.toml"
workflow_dispatch:
inputs:
directories:
description: 'Comma-separated list of directories to update'
description: "Comma-separated list of directories to update"
required: false
default: '' # Run for all dirs specified in docker/scripts/pdm-lock-gen/pdm-lock.sh
default: "" # Run for all dirs specified in docker/scripts/uv-lock-gen/uv-lock.sh
jobs:
update_pdm_lock:
name: Update PDM lock in all directories
update_uv_lock:
name: Update UV lock in all directories
runs-on: ubuntu-latest
permissions:
@@ -29,19 +29,20 @@ jobs:
with:
ref: ${{ github.head_ref }}
- name: Set up Python
uses: actions/setup-python@v5
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
python-version: '3.9'
# Install a specific version of uv.
version: "0.6.14"
python-version: 3.12.9
- name: Install PDM
run: python -m pip install pdm==2.16.1
- run: uv pip install --python=3.12.9 pip
- name: Generate PDM lockfiles
- name: Generate UV lockfiles
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
chmod +x ./docker/scripts/pdm-lock-gen/pdm-lock.sh
chmod +x ./docker/scripts/uv-lock-gen/uv-lock.sh
# Get the input from the workflow or use the default value
dirs="${{ github.event.inputs.directories }}"
@@ -49,9 +50,9 @@ jobs:
# Check if directories input is empty
if [[ -z "$dirs" ]]; then
# No directories input given, run the script without arguments (process all directories)
echo "No directories specified, running on all dirs listed in docker/scripts/pdm-lock-gen/pdm-lock.sh"
echo "No directories specified, running on all dirs listed in docker/scripts/uv-lock-gen/uv-lock.sh"
./docker/scripts/pdm-lock-gen/pdm-lock.sh
./docker/scripts/uv-lock-gen/uv-lock.sh
else
# Convert comma-separated list into an array of directories
IFS=',' read -r -a dir_array <<< "$dirs"
@@ -60,13 +61,13 @@ jobs:
echo "Processing specified directories: ${dir_array[*]}"
# Pass directories as command-line arguments to the script
./docker/scripts/pdm-lock-gen/pdm-lock.sh "${dir_array[@]}"
./docker/scripts/uv-lock-gen/uv-lock.sh "${dir_array[@]}"
fi
shell: bash
- name: Commit pdm.lock changes
- name: Commit uv.lock changes
uses: stefanzweifel/git-auto-commit-action@v5
with:
commit_message: Commit pdm.lock changes
commit_user_name: pdm-lock-automation[bot]
commit_user_email: pdm-lock-automation-bot@unstract.com
commit_message: Commit uv.lock changes
commit_user_name: uv-lock-automation[bot]
commit_user_email: uv-lock-automation-bot@unstract.com

1
.gitignore vendored
View File

@@ -630,6 +630,7 @@ backend/pluggable_apps/*
# FE Plugins
frontend/src/plugins/*
frontend/public/llm-whisperer/
# TODO: Ensure its made generic to abstract subfolder and file names

View File

@@ -1,16 +1,11 @@
---
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
default_language_version:
python: python3.9
python: python3.12
default_stages:
- pre-commit
ci:
skip:
- hadolint-docker # Fails in pre-commit CI
- hadolint-docker # Fails in pre-commit CI
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
@@ -39,70 +34,66 @@ repos:
- id: forbid-new-submodules
- id: mixed-line-ending
- id: no-commit-to-branch
- repo: https://github.com/adrienverge/yamllint
rev: v1.35.1
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.4
hooks:
- id: yamllint
args: ["-d", "relaxed"]
# language: system
# - repo: https://github.com/rhysd/actionlint
# rev: v1.6.27
# hooks:
# - id: actionlint-docker
# args: [-ignore, 'label ".+" is unknown']
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
args: [--config=pyproject.toml, -l 88]
# language: system
exclude: |
(?x)^(
unstract/flags/src/unstract/flags/evaluation_.*\.py|
)$
- repo: https://github.com/pycqa/flake8
rev: 7.1.0
hooks:
- id: flake8
args: [--max-line-length=88]
exclude: |
(?x)^(
.*migrations/.*\.py|
core/tests/.*|
unstract/flags/src/unstract/flags/evaluation_.*\.py|
)$
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
files: "\\.(py)$"
args:
[
"--profile",
"black",
"--filter-files",
--settings-path=pyproject.toml,
]
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/hadialqattan/pycln
rev: v2.5.0
hooks:
- id: pycln
args: [--config=pyproject.toml]
# - repo: https://github.com/pycqa/docformatter
# rev: v1.7.5
# hooks:
# - id: docformatter
# language: python
- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
rev: v3.17.0
hooks:
- id: pyupgrade
entry: pyupgrade --py39-plus --keep-runtime-typing
types:
- python
# - repo: https://github.com/astral-sh/uv-pre-commit
# rev: 0.6.11
# hooks:
# - id: uv-lock
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.11.2
# hooks:
# - id: mypy
# language: system
# entry: mypy .
# pass_filenames: false
# additional_dependencies: []
- repo: https://github.com/igorshubovych/markdownlint-cli
rev: v0.42.0
hooks:
- id: markdownlint
args: [--disable, MD013]
- id: markdownlint-fix
args: [--disable, MD013]
- repo: https://github.com/gitleaks/gitleaks
rev: v8.18.4
rev: v8.18.2
hooks:
- id: gitleaks
- repo: https://github.com/Lucas-C/pre-commit-hooks-nodejs
rev: v1.1.2
hooks:
- id: htmlhint
- repo: https://github.com/hadolint/hadolint
rev: v2.12.1-beta
hooks:
@@ -114,49 +105,3 @@ repos:
- --ignore=DL3018
- --ignore=SC1091
files: Dockerfile$
- repo: https://github.com/asottile/yesqa
rev: v1.5.0
hooks:
- id: yesqa
# - repo: https://github.com/pre-commit/mirrors-eslint
# rev: "v9.0.0-beta.2" # Use the sha / tag you want to point at
# hooks:
# - id: eslint
# args: [--config=frontend/.eslintrc.json]
# files: \.[jt]sx?$ # *.js, *.jsx, *.ts and *.tsx
# types: [file]
# additional_dependencies:
# - eslint@8.41.0
# - eslint-config-google@0.14.0
# - eslint-config-prettier@8.8.0
# - eslint-plugin-prettier@4.2.1
# - eslint-plugin-react@7.32.2
# - eslint-plugin-import@2.25.2
- repo: https://github.com/Lucas-C/pre-commit-hooks-nodejs
rev: v1.1.2
hooks:
- id: htmlhint
- repo: https://github.com/igorshubovych/markdownlint-cli
rev: v0.41.0
hooks:
- id: markdownlint
args: [--disable, MD013]
- id: markdownlint-fix
args: [--disable, MD013]
- repo: https://github.com/pdm-project/pdm
rev: 2.16.1
hooks:
- id: pdm-lock-check
# - repo: local
# hooks:
# - id: run-mypy
# name: Run mypy
# entry: sh -c 'pdm run mypy .'
# language: system
# pass_filenames: false
# - id: check-django-migrations
# name: Check django migrations
# entry: sh -c 'pdm run docker/scripts/check_django_migrations.sh'
# language: system
# types: [python] # hook only runs if a python file is staged
# pass_filenames: false

View File

@@ -1 +1 @@
3.9.6
3.12.9

View File

@@ -5,9 +5,9 @@
## No-code LLM Platform to launch APIs and ETL Pipelines to structure unstructured documents
##
##
[![pdm-managed](https://img.shields.io/badge/pdm-managed-blueviolet)](https://pdm-project.org)
[![uv](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json)](https://github.com/astral-sh/uv)
[![CLA assistant](https://cla-assistant.io/readme/badge/Zipstack/unstract)](https://cla-assistant.io/Zipstack/unstract)
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Zipstack/unstract/main.svg)](https://results.pre-commit.ci/latest/github/Zipstack/unstract/main)
[![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=Zipstack_unstract&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=Zipstack_unstract)
@@ -57,8 +57,8 @@ Next, either download a release or clone this repo and do the following:
That's all there is to it!
Follow [these steps](backend/README.md#authentication) to change the default username and password.
See [user guide](https://docs.unstract.com/unstract/unstract_platform/user_guides/run_platform) for more details on managing the platform.
Follow [these steps](backend/README.md#authentication) to change the default username and password.
See [user guide](https://docs.unstract.com/unstract/unstract_platform/user_guides/run_platform) for more details on managing the platform.
Another really quick way to experience Unstract is by signing up for our [hosted version](https://us-central.unstract.com/). It comes with a 14 day free trial!
@@ -154,9 +154,9 @@ Contributions are welcome! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for fur
## 🚨 Backup encryption key
Do copy the value of `ENCRYPTION_KEY` config in either `backend/.env` or `platform-service/.env` file to a secure location.
Do copy the value of `ENCRYPTION_KEY` config in either `backend/.env` or `platform-service/.env` file to a secure location.
Adapter credentials are encrypted by the platform using this key. Its loss or change will make all existing adapters inaccessible!
Adapter credentials are encrypted by the platform using this key. Its loss or change will make all existing adapters inaccessible!
## 📊 A note on analytics

View File

@@ -1 +1 @@
3.9.6
3.12.9

View File

@@ -15,13 +15,13 @@ Contains the backend services for Unstract written with Django and DRF.
All commands assumes that you have activated your `venv`.
```bash
# Create venv
pdm venv create -w virtualenv --with-pip
eval "$(pdm venv activate in-project)"
Install UV: https://docs.astral.sh/uv/getting-started/installation/
# Remove venv
pdm venv remove in-project
```bash
# Create venv and install dependencies
uv venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
uv sync
```
#### Installing dependencies
@@ -30,28 +30,20 @@ Go to service dir and install dependencies listed in corresponding `pyproject.to
```bash
# Install dependencies
pdm install
uv sync
# Install specific dev dependency group
pdm install --dev -G lint
uv sync --group dev
# Install production dependencies only
pdm install --prod --no-editable
uv sync --group deploy
```
#### Running scripts
PDM allows you to run scripts applicable within the service dir.
UV allows you to run python scripts applicable within the service dir.
```bash
# List the possible scripts that can be executed
pdm run -l
```
For example to run the backend (dev mode is recommended to take advantage of gunicorn's `reload` feature)
```bash
pdm run backend --dev
uv run sample_script.py
```
#### Running commands
@@ -68,16 +60,16 @@ DB_NAME='unstract_db'
DB_PORT=5432
```
- If you've made changes to the model, run `python manage.py makemigrations`, else ignore this step
- If you've made changes to the model, run `uv run manage.py makemigrations`, else ignore this step
- Run the following to apply any migrations to the DB and start the server
```bash
python manage.py migrate
python manage.py runserver localhost:8000
uv run manage.py migrate
uv run manage.py runserver localhost:8000
```
- Server will start and run at port 8000. (<http://localhost:8000>)
## Authentication
The default username is `unstract` and the default password is `unstract`.

View File

@@ -6,7 +6,6 @@ from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = []

View File

@@ -4,7 +4,6 @@ from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("account_usage", "0001_initial"),
]

View File

@@ -1,5 +1,21 @@
import logging
from typing import Any, Optional, Union
from typing import Any
from django.conf import settings
from django.contrib.auth import logout as django_logout
from django.db.utils import IntegrityError
from django.middleware import csrf
from django.shortcuts import redirect
from logs_helper.log_service import LogService
from rest_framework import status
from rest_framework.request import Request
from rest_framework.response import Response
from tenant_account_v2.models import OrganizationMember as OrganizationMember
from tenant_account_v2.organization_member_service import OrganizationMemberService
from utils.cache_service import CacheService
from utils.local_context import StateStore
from utils.user_context import UserContext
from utils.user_session import UserSessionUtils
from account_v2.authentication_helper import AuthenticationHelper
from account_v2.authentication_plugin_registry import AuthenticationPluginRegistry
@@ -33,32 +49,19 @@ from account_v2.serializer import (
SetOrganizationsResponseSerializer,
)
from account_v2.user import UserService
from django.conf import settings
from django.contrib.auth import logout as django_logout
from django.db.utils import IntegrityError
from django.middleware import csrf
from django.shortcuts import redirect
from logs_helper.log_service import LogService
from rest_framework import status
from rest_framework.request import Request
from rest_framework.response import Response
from tenant_account_v2.models import OrganizationMember as OrganizationMember
from tenant_account_v2.organization_member_service import OrganizationMemberService
from utils.cache_service import CacheService
from utils.local_context import StateStore
from utils.user_context import UserContext
from utils.user_session import UserSessionUtils
logger = logging.getLogger(__name__)
class AuthenticationController:
"""Authentication Controller This controller class manages user
authentication processes."""
authentication processes.
"""
def __init__(self) -> None:
"""This method initializes the controller by selecting the appropriate
authentication plugin based on availability."""
authentication plugin based on availability.
"""
self.authentication_helper = AuthenticationHelper()
if AuthenticationPluginRegistry.is_plugin_available():
self.auth_service: AuthenticationService = (
@@ -110,7 +113,6 @@ class AuthenticationController:
Returns:
list[OrganizationData]: _description_
"""
try:
organizations = self.auth_service.user_organizations(request)
except Exception as ex:
@@ -165,9 +167,7 @@ class AuthenticationController:
if organization_id and organization_id in organization_ids:
# Set organization in user context
UserContext.set_organization_identifier(organization_id)
organization = OrganizationService.get_organization_by_org_id(
organization_id
)
organization = OrganizationService.get_organization_by_org_id(organization_id)
if not organization:
try:
organization_data: OrganizationData = (
@@ -206,7 +206,7 @@ class AuthenticationController:
f"New organization created with Id {organization_id}",
)
user_info: Optional[UserInfo] = self.get_user_info(request)
user_info: UserInfo | None = self.get_user_info(request)
serialized_user_info = SetOrganizationsResponseSerializer(user_info).data
organization_info = OrganizationSerializer(organization).data
response: Response = Response(
@@ -232,7 +232,7 @@ class AuthenticationController:
return response
return Response(status=status.HTTP_403_FORBIDDEN)
def get_user_info(self, request: Request) -> Optional[UserInfo]:
def get_user_info(self, request: Request) -> UserInfo | None:
return self.auth_service.get_user_info(request)
def is_admin_by_role(self, role: str) -> bool:
@@ -247,7 +247,7 @@ class AuthenticationController:
"""
return self.auth_service.is_admin_by_role(role=role)
def get_organization_info(self, org_id: str) -> Optional[Organization]:
def get_organization_info(self, org_id: str) -> Organization | None:
organization = OrganizationService.get_organization_by_org_id(org_id=org_id)
return organization
@@ -255,9 +255,9 @@ class AuthenticationController:
self,
user_id: str,
user_name: str,
organization_name: Optional[str] = None,
display_name: Optional[str] = None,
) -> Optional[OrganizationData]:
organization_name: str | None = None,
display_name: str | None = None,
) -> OrganizationData | None:
return self.auth_service.make_organization_and_add_member(
user_id, user_name, organization_name, display_name
)
@@ -282,7 +282,7 @@ class AuthenticationController:
return response
def get_organization_members_by_org_id(
self, organization_id: Optional[str] = None
self, organization_id: str | None = None
) -> list[OrganizationMember]:
members: list[OrganizationMember] = OrganizationMemberService.get_members()
return members
@@ -297,9 +297,7 @@ class AuthenticationController:
Returns:
OrganizationMember: OrganizationMemberEntity
"""
member: OrganizationMember = OrganizationMemberService.get_user_by_id(
id=user.id
)
member: OrganizationMember = OrganizationMemberService.get_user_by_id(id=user.id)
return member
def get_user_roles(self) -> list[UserRoleData]:
@@ -320,7 +318,7 @@ class AuthenticationController:
self,
admin: User,
org_id: str,
user_list: list[dict[str, Union[str, None]]],
user_list: list[dict[str, str | None]],
) -> list[UserInviteResponse]:
"""Invites users to join an organization.
@@ -329,6 +327,7 @@ class AuthenticationController:
org_id (str): ID of the organization to which users are invited.
user_list (list[dict[str, Union[str, None]]]):
List of user details for invitation.
Returns:
list[UserInviteResponse]: List of responses for each
user invitation.
@@ -397,7 +396,7 @@ class AuthenticationController:
def add_user_role(
self, admin: User, org_id: str, email: str, role: str
) -> Optional[str]:
) -> str | None:
admin_user = OrganizationMemberService.get_user_by_id(id=admin.id)
user = OrganizationMemberService.get_user_by_email(email=email)
if user:
@@ -414,7 +413,7 @@ class AuthenticationController:
def remove_user_role(
self, admin: User, org_id: str, email: str, role: str
) -> Optional[str]:
) -> str | None:
admin_user = OrganizationMemberService.get_user_by_id(id=admin.id)
organization_member = OrganizationMemberService.get_user_by_email(email=email)
if organization_member:
@@ -471,12 +470,10 @@ class AuthenticationController:
except IntegrityError:
logger.warning(f"Account already exists for {user.email}")
def get_or_create_user(
self, user: User
) -> Optional[Union[User, OrganizationMember]]:
def get_or_create_user(self, user: User) -> User | OrganizationMember | None:
user_service = UserService()
if user.id:
account_user: Optional[User] = user_service.get_user_by_id(user.id)
account_user: User | None = user_service.get_user_by_id(user.id)
if account_user:
return account_user
elif user.email:

View File

@@ -1,11 +1,12 @@
import logging
from typing import Any
from platform_settings_v2.platform_auth_service import PlatformAuthenticationService
from tenant_account_v2.organization_member_service import OrganizationMemberService
from account_v2.dto import MemberData
from account_v2.models import Organization, User
from account_v2.user import UserService
from platform_settings_v2.platform_auth_service import PlatformAuthenticationService
from tenant_account_v2.organization_member_service import OrganizationMemberService
logger = logging.getLogger(__name__)
@@ -14,9 +15,7 @@ class AuthenticationHelper:
def __init__(self) -> None:
pass
def list_of_members_from_user_model(
self, model_data: list[Any]
) -> list[MemberData]:
def list_of_members_from_user_model(self, model_data: list[Any]) -> list[MemberData]:
members: list[MemberData] = []
for data in model_data:
user_id = data.user_id
@@ -49,9 +48,7 @@ class AuthenticationHelper:
user = user_service.create_user(email, user_id)
return user
def create_initial_platform_key(
self, user: User, organization: Organization
) -> None:
def create_initial_platform_key(self, user: User, organization: Organization) -> None:
"""Create an initial platform key for the given user and organization.
This method generates a new platform key with the specified parameters
@@ -109,7 +106,6 @@ class AuthenticationHelper:
Parameters:
user_id (str): The user_id of the users to remove.
"""
organization_user = OrganizationMemberService.get_user_by_user_id(user_id)
if not organization_user:
logger.warning(

View File

@@ -3,15 +3,17 @@ import os
from importlib import import_module
from typing import Any
from account_v2.constants import PluginConfig
from django.apps import apps
from account_v2.constants import PluginConfig
logger = logging.getLogger(__name__)
def _load_plugins() -> dict[str, dict[str, Any]]:
"""Iterating through the Authentication plugins and register their
metadata."""
metadata.
"""
auth_app = apps.get_app_config(PluginConfig.PLUGINS_APP)
auth_package_path = auth_app.module.__package__
auth_dir = os.path.join(auth_app.path, PluginConfig.AUTH_PLUGIN_DIR)

View File

@@ -1,6 +1,17 @@
import logging
import uuid
from typing import Any, Optional
from typing import Any
from django.conf import settings
from django.contrib.auth import authenticate, login, logout
from django.contrib.auth.hashers import make_password
from django.http import HttpRequest
from django.shortcuts import redirect, render
from rest_framework.request import Request
from rest_framework.response import Response
from tenant_account_v2.models import OrganizationMember as OrganizationMember
from tenant_account_v2.organization_member_service import OrganizationMemberService
from utils.user_context import UserContext
from account_v2.authentication_helper import AuthenticationHelper
from account_v2.constants import DefaultOrg, ErrorMessage, UserLoginTemplate
@@ -18,16 +29,6 @@ from account_v2.enums import UserRole
from account_v2.models import Organization, User
from account_v2.organization import OrganizationService
from account_v2.serializer import LoginRequestSerializer
from django.conf import settings
from django.contrib.auth import authenticate, login, logout
from django.contrib.auth.hashers import make_password
from django.http import HttpRequest
from django.shortcuts import redirect, render
from rest_framework.request import Request
from rest_framework.response import Response
from tenant_account_v2.models import OrganizationMember as OrganizationMember
from tenant_account_v2.organization_member_service import OrganizationMemberService
from utils.user_context import UserContext
logger = logging.getLogger(__name__)
@@ -92,10 +93,7 @@ class AuthenticationService:
False otherwise.
"""
# Validation of user credentials
if (
username != DefaultOrg.MOCK_USER
or password != DefaultOrg.MOCK_USER_PASSWORD
):
if username != DefaultOrg.MOCK_USER or password != DefaultOrg.MOCK_USER_PASSWORD:
return False
user = authenticate(request, username=username, password=password)
@@ -193,7 +191,7 @@ class AuthenticationService:
self,
request: Request,
user: User,
data: Optional[dict[str, Any]] = None,
data: dict[str, Any] | None = None,
) -> MemberData:
member_data: MemberData = MemberData(
user_id=user.user_id,
@@ -324,7 +322,7 @@ class AuthenticationService:
logger.error(f"Failed to set default user: {str(e)}")
return False
def _get_or_create_user(self, organization: Optional[Organization]) -> User:
def _get_or_create_user(self, organization: Organization | None) -> User:
"""Get existing user or create a new one based on organization context.
Args:
@@ -365,7 +363,7 @@ class AuthenticationService:
return self._create_mock_user()
def _get_admin_user(self) -> Optional[User]:
def _get_admin_user(self) -> User | None:
"""Get the first admin user from the organization.
Returns:
@@ -376,7 +374,7 @@ class AuthenticationService:
)
return admin_members[0].user if admin_members else None
def _promote_first_member_to_admin(self) -> Optional[OrganizationMember]:
def _promote_first_member_to_admin(self) -> OrganizationMember | None:
"""Promote the first organization member to admin role.
Returns:
@@ -406,7 +404,7 @@ class AuthenticationService:
user.save()
logger.info(f"Updated user {user} with username {DefaultOrg.MOCK_USER}")
def get_user_info(self, request: Request) -> Optional[UserInfo]:
def get_user_info(self, request: Request) -> UserInfo | None:
user: User = request.user
if user:
return UserInfo(
@@ -419,16 +417,16 @@ class AuthenticationService:
else:
return None
def get_organization_info(self, org_id: str) -> Optional[Organization]:
def get_organization_info(self, org_id: str) -> Organization | None:
return OrganizationService.get_organization_by_org_id(org_id=org_id)
def make_organization_and_add_member(
self,
user_id: str,
user_name: str,
organization_name: Optional[str] = None,
display_name: Optional[str] = None,
) -> Optional[OrganizationData]:
organization_name: str | None = None,
display_name: str | None = None,
) -> OrganizationData | None:
organization: OrganizationData = OrganizationData(
id=str(uuid.uuid4()),
display_name=DefaultOrg.MOCK_ORG,
@@ -469,6 +467,6 @@ class AuthenticationService:
admin: OrganizationMember,
org_id: str,
email: str,
role: Optional[str] = None,
role: str | None = None,
) -> bool:
raise MethodNotImplemented()

View File

@@ -1,12 +1,12 @@
from account_v2.authentication_plugin_registry import AuthenticationPluginRegistry
from account_v2.authentication_service import AuthenticationService
from account_v2.constants import Common
from django.conf import settings
from django.http import HttpRequest, HttpResponse, JsonResponse
from utils.constants import Account
from utils.local_context import StateStore
from utils.user_session import UserSessionUtils
from account_v2.authentication_plugin_registry import AuthenticationPluginRegistry
from account_v2.authentication_service import AuthenticationService
from account_v2.constants import Common
from backend.constants import RequestHeader
@@ -42,9 +42,7 @@ class CustomAuthMiddleware:
if is_authenticated:
organization_id = UserSessionUtils.get_organization_id(request=request)
if request.organization_id and not organization_id:
return JsonResponse(
{"message": "Organization access denied"}, status=403
)
return JsonResponse({"message": "Organization access denied"}, status=403)
StateStore.set(Common.LOG_EVENTS_ID, request.session.session_key)
StateStore.set(Account.ORGANIZATION_ID, organization_id)
response = self.get_response(request)

View File

@@ -1,5 +1,3 @@
from typing import Optional
from rest_framework.exceptions import APIException
@@ -18,7 +16,7 @@ class DuplicateData(APIException):
status_code = 400
default_detail = "Duplicate Data"
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
def __init__(self, detail: str | None = None, code: int | None = None):
if detail is not None:
self.detail = detail
if code is not None:
@@ -30,7 +28,7 @@ class TableNotExistError(APIException):
status_code = 400
default_detail = "Unknown Table"
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
def __init__(self, detail: str | None = None, code: int | None = None):
if detail is not None:
self.detail = detail
if code is not None:
@@ -42,7 +40,7 @@ class UserNotExistError(APIException):
status_code = 400
default_detail = "Unknown User"
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
def __init__(self, detail: str | None = None, code: int | None = None):
if detail is not None:
self.detail = detail
if code is not None:

View File

@@ -1,15 +1,15 @@
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any
@dataclass
class MemberData:
user_id: str
email: Optional[str] = None
name: Optional[str] = None
picture: Optional[str] = None
role: Optional[list[str]] = None
organization_id: Optional[str] = None
email: str | None = None
name: str | None = None
picture: str | None = None
role: list[str] | None = None
organization_id: str | None = None
@dataclass
@@ -45,11 +45,11 @@ class OrganizationSignupResponse:
class UserInfo:
email: str
user_id: str
id: Optional[str] = None
name: Optional[str] = None
display_name: Optional[str] = None
family_name: Optional[str] = None
picture: Optional[str] = None
id: str | None = None
name: str | None = None
display_name: str | None = None
family_name: str | None = None
picture: str | None = None
@dataclass
@@ -97,14 +97,14 @@ class ResetUserPasswordDto:
class UserInviteResponse:
email: str
status: str
message: Optional[str] = None
message: str | None = None
@dataclass
class UserRoleData:
name: str
id: Optional[str] = None
description: Optional[str] = None
id: str | None = None
description: str | None = None
@dataclass
@@ -123,8 +123,8 @@ class MemberInvitation:
id: str
email: str
roles: list[str]
created_at: Optional[str] = None
expires_at: Optional[str] = None
created_at: str | None = None
expires_at: str | None = None
@dataclass

View File

@@ -11,7 +11,6 @@ from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [

View File

@@ -18,9 +18,7 @@ class Organization(models.Model):
name = models.CharField(max_length=NAME_SIZE)
display_name = models.CharField(max_length=NAME_SIZE)
organization_id = models.CharField(
max_length=FieldLength.ORG_NAME_SIZE, unique=True
)
organization_id = models.CharField(max_length=FieldLength.ORG_NAME_SIZE, unique=True)
created_by = models.ForeignKey(
"User",
on_delete=models.SET_NULL,

View File

@@ -1,8 +1,8 @@
import logging
from typing import Optional
from django.db import IntegrityError
from account_v2.models import Organization
from django.db import IntegrityError
Logger = logging.getLogger(__name__)
@@ -12,7 +12,7 @@ class OrganizationService:
pass
@staticmethod
def get_organization_by_org_id(org_id: str) -> Optional[Organization]:
def get_organization_by_org_id(org_id: str) -> Organization | None:
try:
return Organization.objects.get(organization_id=org_id) # type: ignore
except Organization.DoesNotExist:

View File

@@ -1,8 +1,8 @@
import re
from typing import Optional
from rest_framework import serializers
from account_v2.models import Organization, User
from rest_framework import serializers
class OrganizationSignupSerializer(serializers.Serializer):
@@ -92,18 +92,20 @@ class LoginRequestSerializer(serializers.Serializer):
username = serializers.CharField(required=True)
password = serializers.CharField(required=True)
def validate_username(self, value: Optional[str]) -> str:
def validate_username(self, value: str | None) -> str:
"""Check that the username is not empty and has at least 3
characters."""
characters.
"""
if not value or len(value) < 3:
raise serializers.ValidationError(
"Username must be at least 3 characters long."
)
return value
def validate_password(self, value: Optional[str]) -> str:
def validate_password(self, value: str | None) -> str:
"""Check that the password is not empty and has at least 3
characters."""
characters.
"""
if not value or len(value) < 3:
raise serializers.ValidationError(
"Password must be at least 3 characters long."

View File

@@ -1,3 +1,5 @@
from django.urls import path
from account_v2.views import (
callback,
create_organization,
@@ -8,7 +10,6 @@ from account_v2.views import (
set_organization,
signup,
)
from django.urls import path
urlpatterns = [
path("login", login, name="login"),

View File

@@ -1,8 +1,9 @@
import logging
from typing import Any, Optional
from typing import Any
from django.db import IntegrityError
from account_v2.models import User
from django.db import IntegrityError
Logger = logging.getLogger(__name__)
@@ -27,7 +28,7 @@ class UserService:
user.save()
return user
def get_user_by_email(self, email: str) -> Optional[User]:
def get_user_by_email(self, email: str) -> User | None:
try:
user: User = User.objects.get(email=email)
return user

View File

@@ -1,6 +1,12 @@
import logging
from typing import Any
from rest_framework import status
from rest_framework.decorators import api_view
from rest_framework.request import Request
from rest_framework.response import Response
from utils.user_session import UserSessionUtils
from account_v2.authentication_controller import AuthenticationController
from account_v2.dto import (
OrganizationSignupRequestBody,
@@ -14,11 +20,6 @@ from account_v2.serializer import (
OrganizationSignupSerializer,
UserSessionResponseSerializer,
)
from rest_framework import status
from rest_framework.decorators import api_view
from rest_framework.request import Request
from rest_framework.response import Response
from utils.user_session import UserSessionUtils
Logger = logging.getLogger(__name__)
@@ -76,6 +77,7 @@ def get_organizations(request: Request) -> Response:
"""get_organizations.
Retrieve the list of organizations to which the user belongs.
Args:
request (HttpRequest): _description_
@@ -91,6 +93,7 @@ def set_organization(request: Request, id: str) -> Response:
"""set_organization.
Set the current organization to use.
Args:
request (HttpRequest): _description_
id (String): organization Id
@@ -98,7 +101,6 @@ def set_organization(request: Request, id: str) -> Response:
Returns:
Response: Contains the User and Current organization details.
"""
auth_controller = AuthenticationController()
return auth_controller.set_user_organization(request, id)
@@ -108,6 +110,7 @@ def get_session_data(request: Request) -> Response:
"""get_session_data.
Retrieve the current session data.
Args:
request (HttpRequest): _description_
@@ -128,6 +131,7 @@ def make_session_response(
"""make_session_response.
Make the current session data.
Args:
request (HttpRequest): _description_

View File

@@ -1,19 +1,20 @@
import json
import logging
from typing import Any, Optional
from typing import Any
from account_v2.models import User
from cryptography.fernet import Fernet
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from platform_settings_v2.platform_auth_service import PlatformAuthenticationService
from tenant_account_v2.organization_member_service import OrganizationMemberService
from adapter_processor_v2.constants import AdapterKeys, AllowedDomains
from adapter_processor_v2.exceptions import (
InternalServiceError,
InValidAdapterId,
TestAdapterError,
)
from cryptography.fernet import Fernet
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from platform_settings_v2.platform_auth_service import PlatformAuthenticationService
from tenant_account_v2.organization_member_service import OrganizationMemberService
from unstract.sdk.adapters.adapterkit import Adapterkit
from unstract.sdk.adapters.base import Adapter
from unstract.sdk.adapters.enums import AdapterTypes
@@ -43,9 +44,7 @@ class AdapterProcessor:
updated_adapters[0].get(AdapterKeys.JSON_SCHEMA)
)
else:
logger.error(
f"Invalid adapter Id : {adapter_id} while fetching JSON Schema"
)
logger.error(f"Invalid adapter Id : {adapter_id} while fetching JSON Schema")
raise InValidAdapterId()
return schema_details
@@ -72,9 +71,7 @@ class AdapterProcessor:
AdapterKeys.NAME: each_adapter.get(AdapterKeys.NAME),
AdapterKeys.DESCRIPTION: each_adapter.get(AdapterKeys.DESCRIPTION),
AdapterKeys.ICON: each_adapter.get(AdapterKeys.ICON),
AdapterKeys.ADAPTER_TYPE: each_adapter.get(
AdapterKeys.ADAPTER_TYPE
),
AdapterKeys.ADAPTER_TYPE: each_adapter.get(AdapterKeys.ADAPTER_TYPE),
}
)
return supported_adapters
@@ -97,7 +94,6 @@ class AdapterProcessor:
adapter_class = Adapterkit().get_adapter_class_by_adapter_id(adapter_id)
if adapter_metadata.pop(AdapterKeys.ADAPTER_TYPE) == AdapterKeys.X2TEXT:
if (
adapter_metadata.get(AdapterKeys.PLATFORM_PROVIDED_UNSTRACT_KEY)
and add_unstract_key
@@ -136,7 +132,8 @@ class AdapterProcessor:
@staticmethod
def __fetch_adapters_by_key_value(key: str, value: Any) -> Adapter:
"""Fetches a list of adapters that have an attribute matching key and
value."""
value.
"""
logger.info(f"Fetching adapter list for {key} with {value}")
adapter_kit = Adapterkit()
adapters = adapter_kit.get_adapters_list()
@@ -172,10 +169,8 @@ class AdapterProcessor:
)
if default_triad.get(AdapterKeys.X2TEXT_DEFAULT, None):
user_default_adapter.default_x2text_adapter = (
AdapterInstance.objects.get(
pk=default_triad[AdapterKeys.X2TEXT_DEFAULT]
)
user_default_adapter.default_x2text_adapter = AdapterInstance.objects.get(
pk=default_triad[AdapterKeys.X2TEXT_DEFAULT]
)
user_default_adapter.save()
@@ -223,7 +218,6 @@ class AdapterProcessor:
- list[AdapterInstance]: A list of AdapterInstance objects that match
the specified adapter type.
"""
adapters: list[AdapterInstance] = AdapterInstance.objects.for_user(user).filter(
adapter_type=adapter_type.value,
)
@@ -232,8 +226,8 @@ class AdapterProcessor:
@staticmethod
def get_adapter_by_name_and_type(
adapter_type: AdapterTypes,
adapter_name: Optional[str] = None,
) -> Optional[AdapterInstance]:
adapter_name: str | None = None,
) -> AdapterInstance | None:
"""Get the adapter instance by its name and type.
Parameters:

View File

@@ -1,7 +1,6 @@
from typing import Optional
from rest_framework.exceptions import APIException
from adapter_processor_v2.constants import AdapterKeys
from rest_framework.exceptions import APIException
from unstract.sdk.exceptions import SdkError
@@ -39,9 +38,9 @@ class DuplicateAdapterNameError(APIException):
def __init__(
self,
name: Optional[str] = None,
detail: Optional[str] = None,
code: Optional[str] = None,
name: str | None = None,
detail: str | None = None,
code: str | None = None,
) -> None:
if name:
detail = self.default_detail.replace("this name", f"name '{name}'")
@@ -55,9 +54,9 @@ class TestAdapterError(APIException):
def __init__(
self,
sdk_err: SdkError,
detail: Optional[str] = None,
code: Optional[str] = None,
adapter_name: Optional[str] = None,
detail: str | None = None,
code: str | None = None,
adapter_name: str | None = None,
):
if sdk_err.status_code:
self.status_code = sdk_err.status_code
@@ -77,8 +76,8 @@ class DeleteAdapterInUseError(APIException):
def __init__(
self,
detail: Optional[str] = None,
code: Optional[str] = None,
detail: str | None = None,
code: str | None = None,
adapter_name: str = "adapter",
):
if detail is None:

View File

@@ -8,7 +8,6 @@ from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
@@ -92,9 +91,7 @@ class Migration(migrations.Migration):
),
(
"is_usable",
models.BooleanField(
db_comment="Is the Adpater Usable", default=True
),
models.BooleanField(db_comment="Is the Adpater Usable", default=True),
),
("description", models.TextField(blank=True, default=None, null=True)),
(

View File

@@ -9,9 +9,6 @@ from django.conf import settings
from django.db import models
from django.db.models import QuerySet
from tenant_account_v2.models import OrganizationMember
from unstract.sdk.adapters.adapterkit import Adapterkit
from unstract.sdk.adapters.enums import AdapterTypes
from unstract.sdk.adapters.exceptions import AdapterError
from utils.exceptions import InvalidEncryptionKey
from utils.models.base_model import BaseModel
from utils.models.organization_mixin import (
@@ -19,6 +16,10 @@ from utils.models.organization_mixin import (
DefaultOrganizationMixin,
)
from unstract.sdk.adapters.adapterkit import Adapterkit
from unstract.sdk.adapters.enums import AdapterTypes
from unstract.sdk.adapters.exceptions import AdapterError
logger = logging.getLogger(__name__)
ADAPTER_NAME_SIZE = 128
@@ -132,7 +133,6 @@ class AdapterInstance(DefaultOrganizationMixin, BaseModel):
]
def create_adapter(self) -> None:
encryption_secret: str = settings.ENCRYPTION_KEY
f: Fernet = Fernet(encryption_secret.encode("utf-8"))

View File

@@ -2,17 +2,17 @@ import json
from typing import Any
from account_v2.serializer import UserSerializer
from adapter_processor_v2.adapter_processor import AdapterProcessor
from adapter_processor_v2.constants import AdapterKeys
from cryptography.fernet import Fernet
from django.conf import settings
from rest_framework import serializers
from rest_framework.serializers import ModelSerializer
from unstract.sdk.adapters.constants import Common as common
from unstract.sdk.adapters.enums import AdapterTypes
from adapter_processor_v2.adapter_processor import AdapterProcessor
from adapter_processor_v2.constants import AdapterKeys
from backend.constants import FieldLengthConstants as FLC
from backend.serializers import AuditSerializer
from unstract.sdk.adapters.constants import Common as common
from unstract.sdk.adapters.enums import AdapterTypes
from .models import AdapterInstance, UserDefaultAdapter
@@ -31,12 +31,8 @@ class BaseAdapterSerializer(AuditSerializer):
class DefaultAdapterSerializer(serializers.Serializer):
llm_default = serializers.CharField(max_length=FLC.UUID_LENGTH, required=False)
embedding_default = serializers.CharField(
max_length=FLC.UUID_LENGTH, required=False
)
vector_db_default = serializers.CharField(
max_length=FLC.UUID_LENGTH, required=False
)
embedding_default = serializers.CharField(max_length=FLC.UUID_LENGTH, required=False)
vector_db_default = serializers.CharField(max_length=FLC.UUID_LENGTH, required=False)
class AdapterInstanceSerializer(BaseAdapterSerializer):
@@ -51,9 +47,7 @@ class AdapterInstanceSerializer(BaseAdapterSerializer):
f: Fernet = Fernet(encryption_secret.encode("utf-8"))
json_string: str = json.dumps(data.pop(AdapterKeys.ADAPTER_METADATA))
data[AdapterKeys.ADAPTER_METADATA_B] = f.encrypt(
json_string.encode("utf-8")
)
data[AdapterKeys.ADAPTER_METADATA_B] = f.encrypt(json_string.encode("utf-8"))
return data
@@ -79,7 +73,6 @@ class AdapterInstanceSerializer(BaseAdapterSerializer):
class AdapterInfoSerializer(BaseAdapterSerializer):
context_window_size = serializers.SerializerMethodField()
class Meta(BaseAdapterSerializer.Meta):

View File

@@ -1,10 +1,11 @@
from django.urls import path
from rest_framework.urlpatterns import format_suffix_patterns
from adapter_processor_v2.views import (
AdapterInstanceViewSet,
AdapterViewSet,
DefaultAdapterViewSet,
)
from django.urls import path
from rest_framework.urlpatterns import format_suffix_patterns
default_triad = DefaultAdapterViewSet.as_view(
{"post": "configure_default_triad", "get": "get_default_triad"}

View File

@@ -1,25 +1,7 @@
import logging
import uuid
from typing import Any, Optional
from typing import Any
from adapter_processor_v2.adapter_processor import AdapterProcessor
from adapter_processor_v2.constants import AdapterKeys
from adapter_processor_v2.exceptions import (
CannotDeleteDefaultAdapter,
DeleteAdapterInUseError,
DuplicateAdapterNameError,
IdIsMandatory,
InValidType,
)
from adapter_processor_v2.serializers import (
AdapterInfoSerializer,
AdapterInstanceSerializer,
AdapterListSerializer,
DefaultAdapterSerializer,
SharedUserListSerializer,
TestAdapterSerializer,
UserDefaultAdapterSerializer,
)
from django.db import IntegrityError
from django.db.models import ProtectedError, QuerySet
from django.http import HttpRequest
@@ -40,6 +22,25 @@ from rest_framework.viewsets import GenericViewSet, ModelViewSet
from tenant_account_v2.organization_member_service import OrganizationMemberService
from utils.filtering import FilterHelper
from adapter_processor_v2.adapter_processor import AdapterProcessor
from adapter_processor_v2.constants import AdapterKeys
from adapter_processor_v2.exceptions import (
CannotDeleteDefaultAdapter,
DeleteAdapterInUseError,
DuplicateAdapterNameError,
IdIsMandatory,
InValidType,
)
from adapter_processor_v2.serializers import (
AdapterInfoSerializer,
AdapterInstanceSerializer,
AdapterListSerializer,
DefaultAdapterSerializer,
SharedUserListSerializer,
TestAdapterSerializer,
UserDefaultAdapterSerializer,
)
from .constants import AdapterKeys as constant
from .models import AdapterInstance, UserDefaultAdapter
@@ -130,11 +131,9 @@ class AdapterViewSet(GenericViewSet):
class AdapterInstanceViewSet(ModelViewSet):
serializer_class = AdapterInstanceSerializer
def get_permissions(self) -> list[Any]:
if self.action in ["update", "retrieve"]:
return [IsFrictionLessAdapter()]
@@ -148,7 +147,7 @@ class AdapterInstanceViewSet(ModelViewSet):
# User cant view/update metadata but can delete/share etc
return [IsOwner()]
def get_queryset(self) -> Optional[QuerySet]:
def get_queryset(self) -> QuerySet | None:
if filter_args := FilterHelper.build_filter_args(
self.request,
constant.ADAPTER_TYPE,
@@ -237,9 +236,7 @@ class AdapterInstanceViewSet(ModelViewSet):
name=serializer.validated_data.get(AdapterKeys.ADAPTER_NAME)
)
headers = self.get_success_headers(serializer.data)
return Response(
serializer.data, status=status.HTTP_201_CREATED, headers=headers
)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
def destroy(
self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any]
@@ -261,13 +258,11 @@ class AdapterInstanceViewSet(ModelViewSet):
)
or (
adapter_type == AdapterKeys.EMBEDDING
and adapter_instance
== user_default_adapter.default_embedding_adapter
and adapter_instance == user_default_adapter.default_embedding_adapter
)
or (
adapter_type == AdapterKeys.VECTOR_DB
and adapter_instance
== user_default_adapter.default_vector_db_adapter
and adapter_instance == user_default_adapter.default_vector_db_adapter
)
or (
adapter_type == AdapterKeys.X2TEXT
@@ -337,7 +332,6 @@ class AdapterInstanceViewSet(ModelViewSet):
@action(detail=True, methods=["get"])
def list_of_shared_users(self, request: HttpRequest, pk: Any = None) -> Response:
adapter = self.get_object()
serialized_instances = SharedUserListSerializer(adapter).data
@@ -346,7 +340,6 @@ class AdapterInstanceViewSet(ModelViewSet):
@action(detail=True, methods=["get"])
def adapter_info(self, request: HttpRequest, pk: uuid) -> Response:
adapter = self.get_object()
serialized_instances = AdapterInfoSerializer(adapter).data

View File

@@ -1,23 +1,24 @@
import logging
from typing import Any, Optional
from typing import Any
from plugins.api.dto import metadata
from api_v2.postman_collection.dto import PostmanCollection
from plugins.api.dto import metadata
logger = logging.getLogger(__name__)
class ApiDeploymentDTORegistry:
_dto_class: Optional[Any] = None # Store the selected DTO class (cached)
_dto_class: Any | None = None # Store the selected DTO class (cached)
@classmethod
def load_dto(cls) -> Optional[Any]:
def load_dto(cls) -> Any | None:
class_name = PostmanCollection.__name__
if metadata.get(class_name):
return metadata[class_name].class_name
return PostmanCollection # Return as soon as we find a valid DTO
@classmethod
def get_dto(cls) -> Optional[type]:
def get_dto(cls) -> type | None:
"""Returns the first available DTO class, or None if unavailable."""
return cls.load_dto()

View File

@@ -1,6 +1,18 @@
import json
import logging
from typing import Any, Optional
from typing import Any
from django.conf import settings
from django.db.models import QuerySet
from django.http import HttpResponse
from permissions.permission import IsOwner
from rest_framework import serializers, status, views, viewsets
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import Serializer
from utils.enums import CeleryTaskState
from workflow_manager.workflow_v2.dto import ExecutionResponse
from api_v2.api_deployment_dto_registry import ApiDeploymentDTORegistry
from api_v2.constants import ApiExecution
@@ -14,25 +26,12 @@ from api_v2.serializers import (
ExecutionQuerySerializer,
ExecutionRequestSerializer,
)
from django.conf import settings
from django.db.models import QuerySet
from django.http import HttpResponse
from permissions.permission import IsOwner
from rest_framework import serializers, status, views, viewsets
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import Serializer
from utils.enums import CeleryTaskState
from workflow_manager.workflow_v2.dto import ExecutionResponse
logger = logging.getLogger(__name__)
class DeploymentExecution(views.APIView):
def initialize_request(
self, request: Request, *args: Any, **kwargs: Any
) -> Request:
def initialize_request(self, request: Request, *args: Any, **kwargs: Any) -> Request:
"""To remove csrf request for public API.
Args:
@@ -41,7 +40,7 @@ class DeploymentExecution(views.APIView):
Returns:
Request: _description_
"""
setattr(request, "csrf_processing_done", True)
request.csrf_processing_done = True
return super().initialize_request(request, *args, **kwargs)
@DeploymentHelper.validate_api_key
@@ -85,9 +84,7 @@ class DeploymentExecution(views.APIView):
include_metrics = serializer.validated_data.get(ApiExecution.INCLUDE_METRICS)
# Fetch execution status
response: ExecutionResponse = DeploymentHelper.get_execution_status(
execution_id
)
response: ExecutionResponse = DeploymentHelper.get_execution_status(execution_id)
# Determine response status
response_status = status.HTTP_422_UNPROCESSABLE_ENTITY
if response.execution_status == CeleryTaskState.COMPLETED.value:
@@ -113,7 +110,7 @@ class DeploymentExecution(views.APIView):
class APIDeploymentViewSet(viewsets.ModelViewSet):
permission_classes = [IsOwner]
def get_queryset(self) -> Optional[QuerySet]:
def get_queryset(self) -> QuerySet | None:
return APIDeployment.objects.filter(created_by=self.request.user)
def get_serializer_class(self) -> serializers.Serializer:
@@ -122,7 +119,7 @@ class APIDeploymentViewSet(viewsets.ModelViewSet):
return APIDeploymentSerializer
@action(detail=True, methods=["get"])
def fetch_one(self, request: Request, pk: Optional[str] = None) -> Response:
def fetch_one(self, request: Request, pk: str | None = None) -> Response:
"""Custom action to fetch a single instance."""
instance = self.get_object()
serializer = self.get_serializer(instance)
@@ -134,9 +131,7 @@ class APIDeploymentViewSet(viewsets.ModelViewSet):
serializer: Serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
api_key = DeploymentHelper.create_api_key(
serializer=serializer, request=request
)
api_key = DeploymentHelper.create_api_key(serializer=serializer, request=request)
response_serializer = DeploymentResponseSerializer(
{"api_key": api_key.api_key, **serializer.data}
)
@@ -150,7 +145,7 @@ class APIDeploymentViewSet(viewsets.ModelViewSet):
@action(detail=True, methods=["get"])
def download_postman_collection(
self, request: Request, pk: Optional[str] = None
self, request: Request, pk: str | None = None
) -> Response:
"""Downloads a Postman Collection of the API deployment instance."""
instance = self.get_object()

View File

@@ -2,9 +2,10 @@ import logging
from functools import wraps
from typing import Any
from api_v2.exceptions import Forbidden
from rest_framework.request import Request
from api_v2.exceptions import Forbidden
logger = logging.getLogger(__name__)
@@ -56,5 +57,6 @@ class BaseAPIKeyValidator:
self: Any, request: Request, func: Any, api_key: str, *args: Any, **kwargs: Any
) -> Any:
"""Process and validate API key with specific logic required by
subclasses."""
subclasses.
"""
pass

View File

@@ -1,4 +1,9 @@
from typing import Optional
from pipeline_v2.exceptions import PipelineNotFound
from pipeline_v2.pipeline_processor import PipelineProcessor
from rest_framework import serializers, viewsets
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response
from api_v2.deployment_helper import DeploymentHelper
from api_v2.exceptions import APINotFound, PathVariablesNotFound
@@ -6,12 +11,6 @@ from api_v2.key_helper import KeyHelper
from api_v2.models import APIKey
from api_v2.permission import IsOwnerOrOrganizationMember
from api_v2.serializers import APIKeyListSerializer, APIKeySerializer
from pipeline_v2.exceptions import PipelineNotFound
from pipeline_v2.pipeline_processor import PipelineProcessor
from rest_framework import serializers, viewsets
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response
class APIKeyViewSet(viewsets.ModelViewSet):
@@ -27,8 +26,8 @@ class APIKeyViewSet(viewsets.ModelViewSet):
def api_keys(
self,
request: Request,
api_id: Optional[str] = None,
pipeline_id: Optional[str] = None,
api_id: str | None = None,
pipeline_id: str | None = None,
) -> Response:
"""Custom action to fetch api keys of an api deployment."""
if api_id:

View File

@@ -1,18 +1,7 @@
import logging
from typing import Any, Optional
from typing import Any
from urllib.parse import urlencode
from api_v2.api_key_validator import BaseAPIKeyValidator
from api_v2.exceptions import (
ApiKeyCreateException,
APINotFound,
InactiveAPI,
InvalidAPIRequest,
)
from api_v2.key_helper import KeyHelper
from api_v2.models import APIDeployment, APIKey
from api_v2.serializers import APIExecutionResponseSerializer
from api_v2.utils import APIDeploymentUtils
from django.conf import settings
from django.core.files.uploadedfile import UploadedFile
from rest_framework.request import Request
@@ -29,6 +18,18 @@ from workflow_manager.workflow_v2.execution import WorkflowExecutionServiceHelpe
from workflow_manager.workflow_v2.models import Workflow, WorkflowExecution
from workflow_manager.workflow_v2.workflow_helper import WorkflowHelper
from api_v2.api_key_validator import BaseAPIKeyValidator
from api_v2.exceptions import (
ApiKeyCreateException,
APINotFound,
InactiveAPI,
InvalidAPIRequest,
)
from api_v2.key_helper import KeyHelper
from api_v2.models import APIDeployment, APIKey
from api_v2.serializers import APIExecutionResponseSerializer
from api_v2.utils import APIDeploymentUtils
logger = logging.getLogger(__name__)
@@ -55,7 +56,7 @@ class DeploymentHelper(BaseAPIKeyValidator):
return func(self, request, *args, **kwargs)
@staticmethod
def validate_api(api_deployment: Optional[APIDeployment], api_key: str) -> None:
def validate_api(api_deployment: APIDeployment | None, api_key: str) -> None:
"""Validating API and API key.
Args:
@@ -75,11 +76,12 @@ class DeploymentHelper(BaseAPIKeyValidator):
@staticmethod
def validate_and_get_workflow(workflow_id: str) -> Workflow:
"""Validate that the specified workflow_id exists in the Workflow
model."""
model.
"""
return WorkflowHelper.get_workflow_by_id(workflow_id)
@staticmethod
def get_api_by_id(api_id: str) -> Optional[APIDeployment]:
def get_api_by_id(api_id: str) -> APIDeployment | None:
return APIDeploymentUtils.get_api_by_id(api_id=api_id)
@staticmethod
@@ -102,7 +104,7 @@ class DeploymentHelper(BaseAPIKeyValidator):
@staticmethod
def get_deployment_by_api_name(
api_name: str,
) -> Optional[APIDeployment]:
) -> APIDeployment | None:
"""Get and return the APIDeployment object by api_name."""
try:
api: APIDeployment = APIDeployment.objects.get(api_name=api_name)

View File

@@ -1,5 +1,3 @@
from typing import Optional
from rest_framework.exceptions import APIException
@@ -24,9 +22,7 @@ class ApiKeyCreateException(APIException):
class Forbidden(APIException):
status_code = 403
default_detail = (
"User is forbidden from performing this action. Please contact admin"
)
default_detail = "User is forbidden from performing this action. Please contact admin"
class APINotFound(NotFoundException):
@@ -53,8 +49,8 @@ class NoActiveAPIKeyError(APIException):
def __init__(
self,
detail: Optional[str] = None,
code: Optional[str] = None,
detail: str | None = None,
code: str | None = None,
deployment_name: str = "this deployment",
):
if detail is None:

View File

@@ -1,7 +1,8 @@
from api_v2.api_deployment_views import DeploymentExecution
from django.urls import re_path
from rest_framework.urlpatterns import format_suffix_patterns
from api_v2.api_deployment_views import DeploymentExecution
execute = DeploymentExecution.as_view()

View File

@@ -1,22 +1,20 @@
import logging
from typing import Union
from api_v2.exceptions import UnauthorizedKey
from api_v2.models import APIDeployment, APIKey
from api_v2.serializers import APIKeySerializer
from django.core.exceptions import ValidationError
from pipeline_v2.models import Pipeline
from rest_framework.request import Request
from workflow_manager.workflow_v2.workflow_helper import WorkflowHelper
from api_v2.exceptions import UnauthorizedKey
from api_v2.models import APIDeployment, APIKey
from api_v2.serializers import APIKeySerializer
logger = logging.getLogger(__name__)
class KeyHelper:
@staticmethod
def validate_api_key(
api_key: str, instance: Union[APIDeployment, Pipeline]
) -> None:
def validate_api_key(api_key: str, instance: APIDeployment | Pipeline) -> None:
"""Validate api key.
Args:
@@ -44,7 +42,7 @@ class KeyHelper:
return api_keys
@staticmethod
def has_access(api_key: APIKey, instance: Union[APIDeployment, Pipeline]) -> bool:
def has_access(api_key: APIKey, instance: APIDeployment | Pipeline) -> bool:
"""Check if the provided API key has access to the specified API
instance.
@@ -66,15 +64,15 @@ class KeyHelper:
@staticmethod
def validate_workflow_exists(workflow_id: str) -> None:
"""Validate that the specified workflow_id exists in the Workflow
model."""
model.
"""
WorkflowHelper.get_workflow_by_id(workflow_id)
@staticmethod
def create_api_key(
deployment: Union[APIDeployment, Pipeline], request: Request
) -> APIKey:
def create_api_key(deployment: APIDeployment | Pipeline, request: Request) -> APIKey:
"""Create an APIKey entity using the data from the provided
APIDeployment or Pipeline instance."""
APIDeployment or Pipeline instance.
"""
api_key_serializer = APIKeySerializer(
data=deployment.api_key_data,
context={"deployment": deployment, "request": request},

View File

@@ -8,7 +8,6 @@ from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [

View File

@@ -2,7 +2,6 @@ import uuid
from typing import Any
from account_v2.models import User
from api_v2.constants import ApiExecution
from django.db import models
from pipeline_v2.models import Pipeline
from utils.models.base_model import BaseModel
@@ -13,6 +12,8 @@ from utils.models.organization_mixin import (
from utils.user_context import UserContext
from workflow_manager.workflow_v2.models.workflow import Workflow
from api_v2.constants import ApiExecution
API_NAME_MAX_LENGTH = 30
DESCRIPTION_MAX_LENGTH = 255
API_ENDPOINT_MAX_LENGTH = 255

View File

@@ -1,18 +1,17 @@
import logging
from api_v2.models import APIDeployment
from notification_v2.helper import NotificationHelper
from notification_v2.models import Notification
from pipeline_v2.dto import PipelineStatusPayload
from workflow_manager.workflow_v2.models.execution import WorkflowExecution
from api_v2.models import APIDeployment
logger = logging.getLogger(__name__)
class APINotification:
def __init__(
self, api: APIDeployment, workflow_execution: WorkflowExecution
) -> None:
def __init__(self, api: APIDeployment, workflow_execution: WorkflowExecution) -> None:
self.notifications = Notification.objects.filter(api=api, is_active=True)
self.api = api
self.workflow_execution = workflow_execution

View File

@@ -4,7 +4,8 @@ from utils.user_context import UserContext
class IsOwnerOrOrganizationMember(IsOwner):
"""Permission that grants access if the user is the owner or belongs to the
same organization."""
same organization.
"""
def has_object_permission(self, request, view, obj):
# Check if the user is the owner via base class logic

View File

@@ -1,5 +1,7 @@
class CollectionKey:
POSTMAN_COLLECTION_V210 = "https://schema.getpostman.com/json/collection/v2.1.0/collection.json" # noqa: E501
POSTMAN_COLLECTION_V210 = (
"https://schema.getpostman.com/json/collection/v2.1.0/collection.json" # noqa: E501
)
EXECUTE_API_KEY = "Process document"
EXECUTE_PIPELINE_API_KEY = "Process pipeline"
STATUS_API_KEY = "Execution status"

View File

@@ -1,14 +1,15 @@
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
from typing import Any, Optional, Union
from typing import Any
from urllib.parse import urlencode, urljoin
from django.conf import settings
from pipeline_v2.models import Pipeline
from utils.request import HTTPMethod
from api_v2.constants import ApiExecution
from api_v2.models import APIDeployment
from api_v2.postman_collection.constants import CollectionKey
from django.conf import settings
from pipeline_v2.models import Pipeline
from utils.request import HTTPMethod
@dataclass
@@ -21,8 +22,8 @@ class HeaderItem:
class FormDataItem:
key: str
type: str
src: Optional[str] = None
value: Optional[str] = None
src: str | None = None
value: str | None = None
def __post_init__(self) -> None:
if self.type == "file":
@@ -46,7 +47,7 @@ class RequestItem:
method: HTTPMethod
url: str
header: list[HeaderItem]
body: Optional[BodyItem] = None
body: BodyItem | None = None
@dataclass
@@ -63,7 +64,6 @@ class PostmanInfo:
class APIBase(ABC):
@abstractmethod
def get_form_data_items(self) -> list[FormDataItem]:
pass
@@ -189,7 +189,7 @@ class PostmanCollection:
@classmethod
def create(
cls,
instance: Union[APIDeployment, Pipeline],
instance: APIDeployment | Pipeline,
api_key: str = CollectionKey.AUTH_QUERY_PARAM_DEFAULT,
) -> "PostmanCollection":
"""Creates a PostmanCollection instance.

View File

@@ -1,9 +1,7 @@
import uuid
from collections import OrderedDict
from typing import Any, Union
from typing import Any
from api_v2.constants import ApiExecution
from api_v2.models import APIDeployment, APIKey
from django.core.validators import RegexValidator
from pipeline_v2.models import Pipeline
from rest_framework.serializers import (
@@ -22,6 +20,8 @@ from utils.serializer.integrity_error_mixin import IntegrityErrorMixin
from workflow_manager.workflow_v2.exceptions import ExecutionDoesNotExistError
from workflow_manager.workflow_v2.models.execution import WorkflowExecution
from api_v2.constants import ApiExecution
from api_v2.models import APIDeployment, APIKey
from backend.serializers import AuditSerializer
@@ -76,8 +76,9 @@ class APIKeySerializer(AuditSerializer):
def to_representation(self, instance: APIKey) -> OrderedDict[str, Any]:
"""Override the to_representation method to include additional
context."""
deployment: Union[APIDeployment, Pipeline] = self.context.get("deployment")
context.
"""
deployment: APIDeployment | Pipeline = self.context.get("deployment")
representation: OrderedDict[str, Any] = super().to_representation(instance)
if deployment:
@@ -89,9 +90,7 @@ class APIKeySerializer(AuditSerializer):
elif isinstance(deployment, Pipeline):
representation["api"] = None
representation["pipeline"] = deployment.id
representation["description"] = (
f"API Key for {deployment.pipeline_name}"
)
representation["description"] = f"API Key for {deployment.pipeline_name}"
else:
raise ValueError(
"Context must be an instance of APIDeployment or Pipeline"

View File

@@ -1,8 +1,9 @@
from api_v2.api_deployment_views import APIDeploymentViewSet, DeploymentExecution
from api_v2.api_key_views import APIKeyViewSet
from django.urls import path
from rest_framework.urlpatterns import format_suffix_patterns
from api_v2.api_deployment_views import APIDeploymentViewSet, DeploymentExecution
from api_v2.api_key_views import APIKeyViewSet
deployment = APIDeploymentViewSet.as_view(
{
"get": APIDeploymentViewSet.list.__name__,

View File

@@ -1,13 +1,12 @@
from typing import Optional
from workflow_manager.workflow_v2.models.execution import WorkflowExecution
from api_v2.models import APIDeployment
from api_v2.notification import APINotification
from workflow_manager.workflow_v2.models.execution import WorkflowExecution
class APIDeploymentUtils:
@staticmethod
def get_api_by_id(api_id: str) -> Optional[APIDeployment]:
def get_api_by_id(api_id: str) -> APIDeployment | None:
"""Retrieves an APIDeployment instance by its unique ID.
Args:
@@ -39,7 +38,5 @@ class APIDeploymentUtils:
Returns:
None
"""
api_notification = APINotification(
api=api, workflow_execution=workflow_execution
)
api_notification = APINotification(api=api, workflow_execution=workflow_execution)
api_notification.send()

View File

@@ -32,9 +32,7 @@ TaskRegistry()
app.config_from_object("backend.celery_config.CeleryConfig")
app.autodiscover_tasks()
logger.debug(
f"Celery Configuration:\n" f"{pformat(app.conf.table(with_defaults=True))}"
)
logger.debug(f"Celery Configuration:\n {pformat(app.conf.table(with_defaults=True))}")
# Define the queues to purge when the Celery broker is restarted.
queues_to_purge = [ExecutionLogConstants.CELERY_QUEUE_NAME]

View File

@@ -10,7 +10,8 @@ logger = logging.getLogger(__name__)
class DatabaseWrapper(PostgresDatabaseWrapper):
"""Custom DatabaseWrapper to manage PostgreSQL connections and set the
search path."""
search path.
"""
def get_new_connection(self, conn_params):
"""Establish a new database connection or reuse an existing one, and
@@ -45,8 +46,6 @@ class DatabaseWrapper(PostgresDatabaseWrapper):
)
with connection.cursor() as cursor:
cursor.execute(f"SET search_path TO {settings.DB_SCHEMA}")
logger.debug(
f"Successfully set search_path for DB connection ID {conn_id}."
)
logger.debug(f"Successfully set search_path for DB connection ID {conn_id}.")
finally:
connection.autocommit = original_autocommit

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any
from rest_framework.exceptions import APIException
from rest_framework.response import Response
@@ -12,8 +12,8 @@ class UnstractBaseException(APIException):
def __init__(
self,
detail: Optional[str] = None,
core_err: Optional[ConnectorBaseException] = None,
detail: str | None = None,
core_err: ConnectorBaseException | None = None,
**kwargs: Any,
) -> None:
if detail is None:
@@ -21,9 +21,7 @@ class UnstractBaseException(APIException):
if core_err and core_err.user_message:
detail = core_err.user_message
if detail and "Name or service not known" in str(detail):
detail = (
"Failed to establish a new connection: " "Name or service not known"
)
detail = "Failed to establish a new connection: " "Name or service not known"
super().__init__(detail=detail, **kwargs)
self._core_err = core_err

View File

@@ -2,6 +2,7 @@
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/4.2/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views

View File

@@ -2,6 +2,7 @@
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/4.2/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views

View File

@@ -11,7 +11,6 @@ https://docs.djangoproject.com/en/4.2/ref/settings/
import os
from pathlib import Path
from typing import Optional
from dotenv import find_dotenv, load_dotenv
from utils.common_utils import CommonUtils
@@ -19,9 +18,7 @@ from utils.common_utils import CommonUtils
missing_settings = []
def get_required_setting(
setting_key: str, default: Optional[str] = None
) -> Optional[str]:
def get_required_setting(setting_key: str, default: str | None = None) -> str | None:
"""Get the value of an environment variable specified by the given key. Add
missing keys to `missing_settings` so that exception can be raised at the
end.
@@ -64,9 +61,7 @@ LOGIN_NEXT_URL = os.environ.get("LOGIN_NEXT_URL", "http://localhost:3000/org")
LANDING_URL = os.environ.get("LANDING_URL", "http://localhost:3000/landing")
ERROR_URL = os.environ.get("ERROR_URL", "http://localhost:3000/error")
DJANGO_APP_BACKEND_URL = os.environ.get(
"DJANGO_APP_BACKEND_URL", "http://localhost:8000"
)
DJANGO_APP_BACKEND_URL = os.environ.get("DJANGO_APP_BACKEND_URL", "http://localhost:8000")
INTERNAL_SERVICE_API_KEY = os.environ.get("INTERNAL_SERVICE_API_KEY")
GOOGLE_STORAGE_ACCESS_KEY_ID = os.environ.get("GOOGLE_STORAGE_ACCESS_KEY_ID")
@@ -481,7 +476,7 @@ for key in [
"GOOGLE_OAUTH2_KEY",
"GOOGLE_OAUTH2_SECRET",
]:
exec("SOCIAL_AUTH_{key} = os.environ.get('{key}')".format(key=key))
exec(f"SOCIAL_AUTH_{key} = os.environ.get('{key}')")
SOCIAL_AUTH_PIPELINE = (
# Checks if user is authenticated
@@ -520,6 +515,4 @@ if missing_settings:
)
raise ValueError(ERROR_MESSAGE)
ENABLE_HIGHLIGHT_API_DEPLOYMENT = os.environ.get(
"ENABLE_HIGHLIGHT_API_DEPLOYMENT", False
)
ENABLE_HIGHLIGHT_API_DEPLOYMENT = os.environ.get("ENABLE_HIGHLIGHT_API_DEPLOYMENT", False)

View File

@@ -2,6 +2,7 @@
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/4.2/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views

View File

@@ -2,6 +2,7 @@
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/4.2/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views

View File

@@ -16,9 +16,7 @@ class Command(BaseCommand):
parser.add_argument(
"--schema",
type=str,
help=(
"Optional schema name to drop. Overrides env 'DB_SCHEMA' if specified"
),
help=("Optional schema name to drop. Overrides env 'DB_SCHEMA' if specified"),
)
def handle(self, *args, **kwargs):

View File

@@ -1,5 +1,3 @@
from typing import Optional
from rest_framework.exceptions import APIException
@@ -19,8 +17,8 @@ class MissingParamException(APIException):
def __init__(
self,
code: Optional[str] = None,
param: Optional[str] = None,
code: str | None = None,
param: str | None = None,
) -> None:
detail = f"Bad request, missing parameter: {param}"
super().__init__(detail, code)

View File

@@ -9,7 +9,6 @@ from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [

View File

@@ -3,8 +3,6 @@ import uuid
from typing import Any
from account_v2.models import User
from connector_auth_v2.constants import SocialAuthConstants
from connector_auth_v2.pipeline.google import GoogleAuthHelper
from django.db import models
from django.db.models.query import QuerySet
from rest_framework.request import Request
@@ -12,6 +10,9 @@ from social_django.fields import JSONField
from social_django.models import AbstractUserSocialAuth, DjangoStorage
from social_django.strategy import DjangoStrategy
from connector_auth_v2.constants import SocialAuthConstants
from connector_auth_v2.pipeline.google import GoogleAuthHelper
logger = logging.getLogger(__name__)
@@ -63,7 +64,8 @@ class ConnectorAuth(AbstractUserSocialAuth):
def refresh_token(self, strategy, *args, **kwargs): # type: ignore
"""Override of Python Social Auth (PSA)'s refresh_token functionality
to store uid, provider."""
to store uid, provider.
"""
token = self.extra_data.get("refresh_token") or self.extra_data.get(
"access_token"
)

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, Optional
from typing import Any
from account_v2.models import User
from connector_auth_v2.constants import ConnectorAuthKey, SocialAuthConstants
@@ -68,7 +68,7 @@ class ConnectorAuthHelper:
@staticmethod
def get_oauth_creds_from_cache(
cache_key: str, delete_key: bool = True
) -> Optional[dict[str, str]]:
) -> dict[str, str] | None:
"""Retrieves oauth credentials from the cache.
Args:
@@ -84,7 +84,8 @@ class ConnectorAuthHelper:
@staticmethod
def get_or_create_connector_auth(
oauth_credentials: dict[str, str], user: User = None # type: ignore
oauth_credentials: dict[str, str],
user: User = None, # type: ignore
) -> ConnectorAuth:
"""Gets or creates a ConnectorAuth object.

View File

@@ -1,8 +1,6 @@
import logging
import uuid
from connector_auth_v2.constants import SocialAuthConstants
from connector_auth_v2.exceptions import KeyNotConfigured
from django.conf import settings
from rest_framework import status, viewsets
from rest_framework.request import Request
@@ -10,6 +8,9 @@ from rest_framework.response import Response
from rest_framework.versioning import URLPathVersioning
from utils.user_session import UserSessionUtils
from connector_auth_v2.constants import SocialAuthConstants
from connector_auth_v2.exceptions import KeyNotConfigured
logger = logging.getLogger(__name__)

View File

@@ -1,10 +1,13 @@
# mypy: ignore-errors
import json
import logging
from typing import Any, Optional
from typing import Any
from connector_auth_v2.constants import ConnectorAuthKey
from connector_auth_v2.pipeline.common import ConnectorAuthHelper
from connector_v2.constants import ConnectorInstanceKey as CIKey
from backend.exceptions import UnstractFSException
from connector_processor.constants import ConnectorKeys
from connector_processor.exceptions import (
InValidConnectorId,
@@ -12,9 +15,6 @@ from connector_processor.exceptions import (
OAuthTimeOut,
TestConnectorInputError,
)
from connector_v2.constants import ConnectorInstanceKey as CIKey
from backend.exceptions import UnstractFSException
from unstract.connectors.base import UnstractConnector
from unstract.connectors.connectorkit import Connectorkit
from unstract.connectors.enums import ConnectorMode
@@ -25,10 +25,11 @@ logger = logging.getLogger(__name__)
def fetch_connectors_by_key_value(
key: str, value: Any, connector_mode: Optional[ConnectorMode] = None
key: str, value: Any, connector_mode: ConnectorMode | None = None
) -> list[UnstractConnector]:
"""Fetches a list of connectors that have an attribute matching key and
value."""
value.
"""
logger.info(f"Fetching connector list for {key} with {value}")
connector_kit = Connectorkit()
connectors = connector_kit.get_connectors_list(mode=connector_mode)
@@ -42,9 +43,7 @@ class ConnectorProcessor:
schema_details: dict = {}
if connector_id == UnstractCloudStorage.get_id():
return schema_details
updated_connectors = fetch_connectors_by_key_value(
ConnectorKeys.ID, connector_id
)
updated_connectors = fetch_connectors_by_key_value(ConnectorKeys.ID, connector_id)
if len(updated_connectors) == 0:
logger.error(
f"Invalid connector Id : {connector_id} "
@@ -70,7 +69,7 @@ class ConnectorProcessor:
@staticmethod
def get_all_supported_connectors(
type: str, connector_mode: Optional[ConnectorMode] = None
type: str, connector_mode: ConnectorMode | None = None
) -> list[dict]:
"""Function to return list of all supported connectors except PCS."""
supported_connectors = []
@@ -119,9 +118,7 @@ class ConnectorProcessor:
raise OAuthTimeOut()
try:
connector_impl = Connectorkit().get_connector_by_id(
connector_id, credentials
)
connector_impl = Connectorkit().get_connector_by_id(connector_id, credentials)
test_result = connector_impl.test_credentials()
logger.info(f"{connector_id} test result: {test_result}")
return test_result

View File

@@ -1,6 +1,7 @@
from connector_processor.views import ConnectorViewSet
from django.urls import path
from connector_processor.views import ConnectorViewSet
from . import views
connector_test = ConnectorViewSet.as_view({"post": "test"})

View File

@@ -1,7 +1,3 @@
from connector_processor.connector_processor import ConnectorProcessor
from connector_processor.constants import ConnectorKeys
from connector_processor.exceptions import IdIsMandatory, InValidType
from connector_processor.serializers import TestConnectorSerializer
from connector_v2.constants import ConnectorInstanceKey as CIKey
from django.http.request import HttpRequest
from django.http.response import HttpResponse
@@ -13,6 +9,11 @@ from rest_framework.serializers import Serializer
from rest_framework.versioning import URLPathVersioning
from rest_framework.viewsets import GenericViewSet
from connector_processor.connector_processor import ConnectorProcessor
from connector_processor.constants import ConnectorKeys
from connector_processor.exceptions import IdIsMandatory, InValidType
from connector_processor.serializers import TestConnectorSerializer
@api_view(("GET",))
def get_connector_schema(request: HttpRequest) -> HttpResponse:

View File

@@ -1,14 +1,14 @@
import logging
from typing import Any, Optional
from typing import Any
from account_v2.models import User
from connector_v2.constants import ConnectorInstanceConstant
from connector_v2.models import ConnectorInstance
from connector_v2.unstract_account import UnstractAccount
from django.conf import settings
from utils.user_context import UserContext
from workflow_manager.workflow_v2.models.workflow import Workflow
from connector_v2.constants import ConnectorInstanceConstant
from connector_v2.models import ConnectorInstance
from connector_v2.unstract_account import UnstractAccount
from unstract.connectors.filesystems.ucs import UnstractCloudStorage
from unstract.connectors.filesystems.ucs.constants import UCSKey
@@ -78,9 +78,9 @@ class ConnectorInstanceHelper:
def get_connector_instances_by_workflow(
workflow_id: str,
connector_type: tuple[str, str],
connector_mode: Optional[tuple[int, str]] = None,
values: Optional[list[str]] = None,
connector_name: Optional[str] = None,
connector_mode: tuple[int, str] | None = None,
values: list[str] | None = None,
connector_name: str | None = None,
) -> list[ConnectorInstance]:
"""Method to get connector instances by workflow.
@@ -121,9 +121,9 @@ class ConnectorInstanceHelper:
def get_connector_instance_by_workflow(
workflow_id: str,
connector_type: tuple[str, str],
connector_mode: Optional[tuple[int, str]] = None,
connector_name: Optional[str] = None,
) -> Optional[ConnectorInstance]:
connector_mode: tuple[int, str] | None = None,
connector_name: str | None = None,
) -> ConnectorInstance | None:
"""Get one connector instance.
Use this method if the connector instance is unique for \
@@ -162,7 +162,7 @@ class ConnectorInstanceHelper:
def get_input_connector_instance_by_name_for_workflow(
workflow_id: str,
connector_name: str,
) -> Optional[ConnectorInstance]:
) -> ConnectorInstance | None:
"""Method to get Input connector instance name from the workflow.
Args:
@@ -182,7 +182,7 @@ class ConnectorInstanceHelper:
def get_output_connector_instance_by_name_for_workflow(
workflow_id: str,
connector_name: str,
) -> Optional[ConnectorInstance]:
) -> ConnectorInstance | None:
"""Method to get output connector name by Workflow.
Args:
@@ -232,7 +232,7 @@ class ConnectorInstanceHelper:
@staticmethod
def get_file_system_input_connector_instances_by_workflow(
workflow_id: str, values: Optional[list[str]] = None
workflow_id: str, values: list[str] | None = None
) -> list[ConnectorInstance]:
"""Method to fetch file system connector by workflow.
@@ -252,7 +252,7 @@ class ConnectorInstanceHelper:
@staticmethod
def get_file_system_output_connector_instances_by_workflow(
workflow_id: str, values: Optional[list[str]] = None
workflow_id: str, values: list[str] | None = None
) -> list[ConnectorInstance]:
"""Method to get file system output connector by workflow.
@@ -272,7 +272,7 @@ class ConnectorInstanceHelper:
@staticmethod
def get_database_input_connector_instances_by_workflow(
workflow_id: str, values: Optional[list[str]] = None
workflow_id: str, values: list[str] | None = None
) -> list[ConnectorInstance]:
"""Method to fetch input database connectors by workflow.
@@ -292,7 +292,7 @@ class ConnectorInstanceHelper:
@staticmethod
def get_database_output_connector_instances_by_workflow(
workflow_id: str, values: Optional[list[str]] = None
workflow_id: str, values: list[str] | None = None
) -> list[ConnectorInstance]:
"""Method to fetch output database connectors by workflow.

View File

@@ -6,9 +6,7 @@ class ConnectorInstanceKey:
CONNECTOR_VERSION = "connector_version"
CONNECTOR_AUTH = "connector_auth"
CONNECTOR_METADATA = "connector_metadata"
CONNECTOR_EXISTS = (
"Connector with this configuration already exists in this project."
)
CONNECTOR_EXISTS = "Connector with this configuration already exists in this project."
DUPLICATE_API = "It appears that a duplicate call may have been made."

View File

@@ -8,7 +8,6 @@ from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
@@ -39,9 +38,7 @@ class Migration(migrations.Migration):
("connector_version", models.CharField(default="", max_length=64)),
(
"connector_type",
models.CharField(
choices=[("INPUT", "Input"), ("OUTPUT", "Output")]
),
models.CharField(choices=[("INPUT", "Input"), ("OUTPUT", "Output")]),
),
(
"connector_mode",

View File

@@ -87,7 +87,8 @@ class ConnectorInstance(DefaultOrganizationMixin, BaseModel):
# TODO: Remove if unused
def get_connector_metadata(self) -> dict[str, str]:
"""Gets connector metadata and refreshes the tokens if needed in case
of OAuth."""
of OAuth.
"""
tokens_refreshed = False
if self.connector_auth:
(

View File

@@ -1,19 +1,19 @@
import json
import logging
from collections import OrderedDict
from typing import Any, Optional
from typing import Any
from connector_auth_v2.models import ConnectorAuth
from connector_auth_v2.pipeline.common import ConnectorAuthHelper
from connector_processor.connector_processor import ConnectorProcessor
from connector_processor.constants import ConnectorKeys
from connector_processor.exceptions import OAuthTimeOut
from connector_v2.constants import ConnectorInstanceKey as CIKey
from cryptography.fernet import Fernet
from django.conf import settings
from utils.serializer_utils import SerializerUtils
from backend.serializers import AuditSerializer
from connector_v2.constants import ConnectorInstanceKey as CIKey
from unstract.connectors.filesystems.ucs import UnstractCloudStorage
from .models import ConnectorInstance
@@ -41,7 +41,7 @@ class ConnectorInstanceSerializer(AuditSerializer):
def save(self, **kwargs): # type: ignore
user = self.context.get("request").user or None
connector_id: str = kwargs[CIKey.CONNECTOR_ID]
connector_oauth: Optional[ConnectorAuth] = None
connector_oauth: ConnectorAuth | None = None
if (
ConnectorInstance.supportsOAuth(connector_id=connector_id)
and CIKey.CONNECTOR_METADATA in kwargs

View File

@@ -12,7 +12,6 @@ pytestmark = pytest.mark.django_db
class TestConnector(APITestCase):
def test_connector_list(self) -> None:
"""Tests to List the connectors."""
url = reverse("connectors_v1-list")
response = self.client.get(url)
@@ -20,7 +19,6 @@ class TestConnector(APITestCase):
def test_connectors_detail(self) -> None:
"""Tests to fetch a connector with given pk."""
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
response = self.client.get(url)
@@ -28,7 +26,6 @@ class TestConnector(APITestCase):
def test_connectors_detail_not_found(self) -> None:
"""Tests for negative case to fetch non exiting key."""
url = reverse("connectors_v1-detail", kwargs={"pk": 768})
response = self.client.get(url)
@@ -36,7 +33,6 @@ class TestConnector(APITestCase):
def test_connectors_create(self) -> None:
"""Tests to create a new ConnectorInstance."""
url = reverse("connectors_v1-list")
data = {
"org": 1,
@@ -57,8 +53,8 @@ class TestConnector(APITestCase):
def test_connectors_create_with_json_list(self) -> None:
"""Tests to create a new connector with list included in the json
field."""
field.
"""
url = reverse("connectors_v1-list")
data = {
"org": 1,
@@ -80,7 +76,6 @@ class TestConnector(APITestCase):
def test_connectors_create_with_nested_json(self) -> None:
"""Tests to create a new connector with json field as nested json."""
url = reverse("connectors_v1-list")
data = {
"org": 1,
@@ -102,7 +97,6 @@ class TestConnector(APITestCase):
def test_connectors_create_bad_request(self) -> None:
"""Tests for negative case to throw error on a wrong access."""
url = reverse("connectors_v1-list")
data = {
"org": 5,
@@ -123,7 +117,6 @@ class TestConnector(APITestCase):
def test_connectors_update_json_field(self) -> None:
"""Tests to update connector with json field update."""
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
data = {
"org": 1,
@@ -144,7 +137,6 @@ class TestConnector(APITestCase):
def test_connectors_update(self) -> None:
"""Tests to update connector update single field."""
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
data = {
"org": 1,
@@ -166,7 +158,6 @@ class TestConnector(APITestCase):
def test_connectors_update_pk(self) -> None:
"""Tests the PUT method for 400 error."""
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
data = {
"org": 2,
@@ -187,7 +178,6 @@ class TestConnector(APITestCase):
def test_connectors_update_json_fields(self) -> None:
"""Tests to update ConnectorInstance."""
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
data = {
"org": 1,
@@ -203,16 +193,13 @@ class TestConnector(APITestCase):
},
}
response = self.client.put(url, data, format="json")
nested_value = response.data["connector_metadata"]["sample_metadata_json"][
"key1"
]
nested_value = response.data["connector_metadata"]["sample_metadata_json"]["key1"]
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(nested_value, "value1")
def test_connectors_update_json_list_fields(self) -> None:
"""Tests to update connector to the third second level of json."""
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
data = {
"org": 1,
@@ -229,9 +216,7 @@ class TestConnector(APITestCase):
},
}
response = self.client.put(url, data, format="json")
nested_value = response.data["connector_metadata"]["sample_metadata_json"][
"key1"
]
nested_value = response.data["connector_metadata"]["sample_metadata_json"]["key1"]
nested_list = response.data["connector_metadata"]["file_list"]
last_val = nested_list.pop()
@@ -287,7 +272,6 @@ class TestConnector(APITestCase):
def test_connectors_update_field(self) -> None:
"""Tests the PATCH method."""
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
data = {"connector_id": "e3a4512m-efgb-48d5-98a9-3983ntest"}
response = self.client.patch(url, data, format="json")
@@ -301,7 +285,6 @@ class TestConnector(APITestCase):
def test_connectors_update_json_field_patch(self) -> None:
"""Tests the PATCH method."""
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
data = {
"connector_metadata": {
@@ -322,7 +305,6 @@ class TestConnector(APITestCase):
def test_connectors_delete(self) -> None:
"""Tests the DELETE method."""
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
response = self.client.delete(url, format="json")
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)

View File

@@ -1,12 +1,11 @@
import logging
from typing import Any, Optional
from typing import Any
from account_v2.custom_exceptions import DuplicateData
from connector_auth_v2.constants import ConnectorAuthKey
from connector_auth_v2.exceptions import CacheMissException, MissingParamException
from connector_auth_v2.pipeline.common import ConnectorAuthHelper
from connector_processor.exceptions import OAuthTimeOut
from connector_v2.constants import ConnectorInstanceKey as CIKey
from django.db import IntegrityError
from django.db.models import QuerySet
from rest_framework import status, viewsets
@@ -15,6 +14,7 @@ from rest_framework.versioning import URLPathVersioning
from utils.filtering import FilterHelper
from backend.constants import RequestKey
from connector_v2.constants import ConnectorInstanceKey as CIKey
from .models import ConnectorInstance
from .serializers import ConnectorInstanceSerializer
@@ -26,7 +26,7 @@ class ConnectorInstanceViewSet(viewsets.ModelViewSet):
versioning_class = URLPathVersioning
serializer_class = ConnectorInstanceSerializer
def get_queryset(self) -> Optional[QuerySet]:
def get_queryset(self) -> QuerySet | None:
filter_args = FilterHelper.build_filter_args(
self.request,
RequestKey.WORKFLOW,
@@ -40,7 +40,7 @@ class ConnectorInstanceViewSet(viewsets.ModelViewSet):
queryset = ConnectorInstance.objects.all()
return queryset
def _get_connector_metadata(self, connector_id: str) -> Optional[dict[str, str]]:
def _get_connector_metadata(self, connector_id: str) -> dict[str, str] | None:
"""Gets connector metadata for the ConnectorInstance.
For non oauth based - obtains from request
@@ -118,6 +118,4 @@ class ConnectorInstanceViewSet(viewsets.ModelViewSet):
{CIKey.DUPLICATE_API}"
)
headers = self.get_success_headers(serializer.data)
return Response(
serializer.data, status=status.HTTP_201_CREATED, headers=headers
)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)

View File

@@ -4,9 +4,10 @@ This module defines the URL patterns for the feature_flags app.
"""
from django.urls import path
from feature_flag.views import FeatureFlagViewSet
from rest_framework.urlpatterns import format_suffix_patterns
from feature_flag.views import FeatureFlagViewSet
feature_flags_list = FeatureFlagViewSet.as_view(
{
"post": "evaluate",

View File

@@ -1,15 +1,15 @@
"""
Feature Flag view file
"""Feature Flag view file
Returns:
evaluate response
"""
import logging
from feature_flag.helper import FeatureFlagHelper
from rest_framework import status, viewsets
from rest_framework.response import Response
from feature_flag.helper import FeatureFlagHelper
logger = logging.getLogger(__name__)

View File

@@ -1,5 +1,3 @@
from typing import Optional
from rest_framework.exceptions import APIException
from backend.exceptions import UnstractBaseException
@@ -69,7 +67,7 @@ class ValidationError(APIException):
status_code = 400
default_detail = "Validation Error"
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
def __init__(self, detail: str | None = None, code: int | None = None):
if detail is not None:
self.detail = detail
if code is not None:
@@ -81,7 +79,7 @@ class FileDeletionFailed(APIException):
status_code = 400
default_detail = "Unable to delete file."
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
def __init__(self, detail: str | None = None, code: int | None = None):
if detail is not None:
self.detail = detail
if code is not None:

View File

@@ -1,7 +1,7 @@
import os
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional, Union
from typing import Any
from file_management.constants import FileInformationKey
@@ -10,12 +10,12 @@ from file_management.constants import FileInformationKey
class FileInformation:
name: str
type: str
modified_at: Optional[datetime]
content_type: Optional[str]
modified_at: datetime | None
content_type: str | None
size: int
def __init__(
self, file_info: dict[str, Any], file_content_type: Optional[str] = None
self, file_info: dict[str, Any], file_content_type: str | None = None
) -> None:
self.name = os.path.normpath(file_info[FileInformationKey.FILE_NAME])
self.type = file_info[FileInformationKey.FILE_TYPE]
@@ -27,7 +27,7 @@ class FileInformation:
self.size = file_info[FileInformationKey.FILE_SIZE]
@staticmethod
def parse_datetime(dt_string: Optional[Union[str, datetime]]) -> Optional[datetime]:
def parse_datetime(dt_string: str | datetime | None) -> datetime | None:
if isinstance(dt_string, str):
return datetime.strptime(dt_string, "%Y-%m-%dT%H:%M:%S.%f%z")
return dt_string

View File

@@ -10,6 +10,9 @@ from connector_v2.models import ConnectorInstance
from deprecated import deprecated
from django.conf import settings
from django.http import StreamingHttpResponse
from fsspec import AbstractFileSystem
from pydrive2.files import ApiRequestError
from file_management.exceptions import (
ConnectorApiRequestError,
ConnectorClassNotFound,
@@ -22,9 +25,6 @@ from file_management.exceptions import (
TenantDirCreationError,
)
from file_management.file_management_dto import FileInformation
from fsspec import AbstractFileSystem
from pydrive2.files import ApiRequestError
from unstract.connectors.filesystems import connectors as fs_connectors
from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem
@@ -32,7 +32,6 @@ logger = logging.getLogger(__name__)
class FileManagerHelper:
@staticmethod
def get_file_system(connector: ConnectorInstance) -> UnstractFileSystem:
"""Creates the `UnstractFileSystem` for the corresponding connector."""
@@ -70,7 +69,8 @@ class FileManagerHelper:
@staticmethod
def get_files(fs: AbstractFileSystem, file_path: str) -> list[FileInformation]:
"""Iterate through the directories and make a list of
FileInformation."""
FileInformation.
"""
if not file_path.endswith("/"):
file_path += "/"

View File

@@ -1,7 +1,8 @@
from file_management.constants import FileInformationKey
from rest_framework import serializers
from utils.FileValidator import FileValidator
from file_management.constants import FileInformationKey
class FileInfoSerializer(serializers.Serializer):
name = serializers.CharField()

View File

@@ -3,6 +3,14 @@ from typing import Any
from connector_v2.models import ConnectorInstance
from django.http import HttpRequest
from oauth2client.client import HttpAccessTokenRefreshError
from prompt_studio.prompt_studio_document_manager_v2.models import DocumentManager
from rest_framework import serializers, status, viewsets
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.versioning import URLPathVersioning
from utils.user_session import UserSessionUtils
from file_management.exceptions import (
ConnectorInstanceNotFound,
ConnectorOAuthError,
@@ -15,14 +23,6 @@ from file_management.serializer import (
FileListRequestSerializer,
FileUploadSerializer,
)
from oauth2client.client import HttpAccessTokenRefreshError
from prompt_studio.prompt_studio_document_manager_v2.models import DocumentManager
from rest_framework import serializers, status, viewsets
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.versioning import URLPathVersioning
from utils.user_session import UserSessionUtils
from unstract.connectors.exceptions import ConnectorError
from unstract.connectors.filesystems.local_storage.local_storage import LocalStorageFS
@@ -96,9 +96,7 @@ class FileManagementViewSet(viewsets.ModelViewSet):
for uploaded_file in uploaded_files:
file_name = uploaded_file.name
logger.info(
f"Uploading file: {file_name}" if file_name else "Uploading file"
)
logger.info(f"Uploading file: {file_name}" if file_name else "Uploading file")
FileManagerHelper.upload_file(file_system, path, uploaded_file, file_name)
return Response({"message": "Files are uploaded successfully!"})

View File

@@ -3,6 +3,4 @@ from rest_framework.urlpatterns import format_suffix_patterns
from .views import health_check
urlpatterns = format_suffix_patterns(
[path("health", health_check, name="health-check")]
)
urlpatterns = format_suffix_patterns([path("health", health_check, name="health-check")])

View File

@@ -4,7 +4,6 @@ from utils.cache_service import CacheService
class LogService:
@staticmethod
def remove_logs_on_logout(session_id: str) -> None:
if session_id:
key_pattern = f"{LogService.generate_redis_key(session_id=session_id)}*"

View File

@@ -1,6 +1,6 @@
import json
import logging
from datetime import datetime, timezone
from datetime import UTC, datetime
from django.conf import settings
from django.http import HttpRequest
@@ -57,11 +57,9 @@ class LogsHelperViewSet(viewsets.ModelViewSet):
# Extract the log message from the validated data
log: str = serializer.validated_data.get("log")
log_data = json.loads(log)
timestamp = datetime.now(timezone.utc).timestamp()
timestamp = datetime.now(UTC).timestamp()
redis_key = (
f"{LogService.generate_redis_key(session_id=session_id)}:{timestamp}"
)
redis_key = f"{LogService.generate_redis_key(session_id=session_id)}:{timestamp}"
CacheService.set_key(redis_key, log_data, logs_expiry)

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python
"""Django's command-line utility for administrative tasks."""
import os
import sys

View File

@@ -1,7 +1,7 @@
import json
import logging
import traceback
from typing import Any, Optional
from typing import Any
from django.conf import settings
from django.http import HttpRequest, HttpResponse
@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
# Set via settings.REST_FRAMEWORK.EXCEPTION_HANDLER.
def drf_logging_exc_handler(exc: Exception, context: Any) -> Optional[Response]:
def drf_logging_exc_handler(exc: Exception, context: Any) -> Response | None:
"""Custom exception handler for DRF.
DRF's exception handler takes care of Http404, PermissionDenied and
@@ -31,7 +31,7 @@ def drf_logging_exc_handler(exc: Exception, context: Any) -> Optional[Response]:
handled by another method in the middleware
"""
request = context.get("request")
response: Optional[Response] = exception_handler(exc=exc, context=context)
response: Response | None = exception_handler(exc=exc, context=context)
ExceptionLoggingMiddleware.format_exc_and_log(
request=request, response=response, exception=exc
)
@@ -54,7 +54,7 @@ class ExceptionLoggingMiddleware:
def process_exception(
self, request: HttpRequest, exception: Exception
) -> Optional[HttpResponse]:
) -> HttpResponse | None:
"""Django hook to handle exceptions by a middleware.
Args:
@@ -78,7 +78,7 @@ class ExceptionLoggingMiddleware:
@staticmethod
def format_exc_and_log(
request: Request, exception: Exception, response: Optional[Response] = None
request: Request, exception: Exception, response: Response | None = None
) -> None:
"""Format the exception to be logged and logs it.
@@ -90,18 +90,7 @@ class ExceptionLoggingMiddleware:
if response:
status_code = response.status_code
if status_code >= 500:
message = "{method} {url} {status}\n\n{error}\n\n````{tb}````".format(
method=request.method,
url=request.build_absolute_uri(),
status=status_code,
error=repr(exception),
tb=traceback.format_exc(),
)
message = f"{request.method} {request.build_absolute_uri()} {status_code}\n\n{repr(exception)}\n\n````{traceback.format_exc()}````"
else:
message = "{method} {url} {status} {error}".format(
method=request.method,
url=request.build_absolute_uri(),
status=status_code,
error=repr(exception),
)
message = f"{request.method} {request.build_absolute_uri()} {status_code} {repr(exception)}"
logger.error(message)

View File

@@ -4,6 +4,5 @@ from log_request_id.middleware import RequestIDMiddleware
class CustomRequestIDMiddleware(RequestIDMiddleware):
def _generate_id(self):
return str(uuid.uuid4())

View File

@@ -4,7 +4,7 @@ import os
import time
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, Optional
from typing import Any
import psycopg2
from django.conf import settings
@@ -111,9 +111,7 @@ class DataMigrator:
)
conn.commit()
def _fetch_schema_names(
self, schemas_to_migrate: list[str]
) -> list[tuple[int, str]]:
def _fetch_schema_names(self, schemas_to_migrate: list[str]) -> list[tuple[int, str]]:
"""Fetches schema names and their IDs from the destination database
based on the provided schema list. Supports fetching all schemas if
'_ALL_' is specified.
@@ -126,7 +124,6 @@ class DataMigrator:
list[tuple[int, str]]: A list of tuples containing the ID and schema name.
"""
with self._db_connect_and_cursor(self.dest_db_config) as (conn, cur):
# Process schemas_to_migrate: trim spaces and remove empty entries
schemas_to_migrate = [
schema.strip() for schema in schemas_to_migrate if schema.strip()
@@ -153,7 +150,7 @@ class DataMigrator:
dest_cursor: cursor,
column_names: list[str],
column_transformations: dict[str, dict[str, Any]],
) -> Optional[tuple[Any, ...]]:
) -> tuple[Any, ...] | None:
"""Prepares and migrates the relational keys of a single row from the
source database to the destination database, updating specific column
values based on provided transformations.
@@ -209,7 +206,8 @@ class DataMigrator:
column_names: list[str],
) -> tuple[Any, ...]:
"""Convert specified field in the row to JSON format if it is a
list."""
list.
"""
if key not in column_names:
return row
@@ -328,10 +326,12 @@ class DataMigrator:
This method retrieves the maximum ID value from the specified table
and sets the next auto-increment value for the table accordingly.
Args:
dest_cursor (cursor): The cursor for the destination database.
dest_table (str): The name of the table to adjust.
dest_conn (connection): The connection to the destination database.
Returns:
None
"""
@@ -367,8 +367,7 @@ class DataMigrator:
if max_id is None:
logger.info(
f"Table '{dest_table}' is empty. No need to adjust "
"auto-increment."
f"Table '{dest_table}' is empty. No need to adjust " "auto-increment."
)
return
@@ -386,7 +385,7 @@ class DataMigrator:
raise
def migrate(
self, migrations: list[dict[str, str]], organization_id: Optional[str] = None
self, migrations: list[dict[str, str]], organization_id: str | None = None
) -> None:
self._create_tracking_table_if_not_exists()
@@ -397,7 +396,6 @@ class DataMigrator:
dest_cursor,
),
):
for migration in migrations:
migration_name = migration["name"]
logger.info(f"Migration '{migration_name}' started")
@@ -504,9 +502,7 @@ class Command(BaseCommand):
schemas_to_migrate = schemas_to_migrate.split(",")
# Organization Data (Schema)
schema_names = migrator._fetch_schema_names(
schemas_to_migrate=schemas_to_migrate
)
schema_names = migrator._fetch_schema_names(schemas_to_migrate=schemas_to_migrate)
for organization_id, schema_name in schema_names:
if schema_name == "public":

View File

@@ -1,6 +1,7 @@
class MigrationQuery:
"""This class contains methods to generate SQL queries for various
migration operations."""
migration operations.
"""
def __init__(self, v2_schema) -> None:
self.v2_schema = v2_schema
@@ -253,8 +254,8 @@ class MigrationQuery:
def get_organization_migrations(
self, schema: str, organization_id: str
) -> list[dict[str, str]]:
"""
Returns a list of dictionaries containing the organization migration details.
"""Returns a list of dictionaries containing the organization migration details.
Args:
schema (str): The name of the schema for the organization.
organization_id (str): The ID of the organization.

View File

@@ -26,7 +26,7 @@ class NotificationHelper:
payload (Any): The data to be sent with the notification. This can be any
format expected by the provider
Returns:
Returns:
None
"""
for notification in notifications:

View File

@@ -7,7 +7,6 @@ from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [

View File

@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from django.conf import settings
from notification_v2.models import Notification
@@ -15,7 +16,8 @@ class NotificationProvider(ABC):
@abstractmethod
def send(self):
"""Method to be overridden in child classes for sending the
notification."""
notification.
"""
raise NotImplementedError("Subclasses should implement this method.")
def validate(self):

View File

@@ -1,8 +1,9 @@
import logging
from typing import Any, Optional
from typing import Any
import requests
from celery import shared_task
from notification_v2.enums import AuthorizationType
from notification_v2.provider.notification_provider import NotificationProvider
@@ -48,8 +49,8 @@ class Webhook(NotificationProvider):
return super().validate()
def get_headers(self):
"""
Get the headers for the notification based on the authorization type and key.
"""Get the headers for the notification based on the authorization type and key.
Raises:
ValueError: _description_
@@ -95,9 +96,7 @@ class Webhook(NotificationProvider):
# Check if custom header type has required details
if authorization_type == AuthorizationType.CUSTOM_HEADER:
if not authorization_header or not authorization_key:
raise ValueError(
"Custom header or key missing for custom authorization."
)
raise ValueError("Custom header or key missing for custom authorization.")
return headers
@@ -108,7 +107,7 @@ def send_webhook_notification(
payload: Any,
headers: Any = None,
timeout: int = 10,
max_retries: Optional[int] = None,
max_retries: int | None = None,
retry_delay: int = 10,
):
"""Celery task to send a webhook with retries and error handling.

View File

@@ -107,7 +107,8 @@ class NotificationSerializer(serializers.ModelSerializer):
def validate_name(self, value):
"""Check uniqueness of the name with respect to either 'api' or
'pipeline'."""
'pipeline'.
"""
api = self.initial_data.get("api", getattr(self.instance, "api", None))
pipeline = self.initial_data.get(
"pipeline", getattr(self.instance, "pipeline", None)

View File

@@ -1,11 +1,12 @@
from api_v2.deployment_helper import DeploymentHelper
from api_v2.exceptions import APINotFound
from notification_v2.constants import NotificationUrlConstant
from pipeline_v2.exceptions import PipelineNotFound
from pipeline_v2.models import Pipeline
from pipeline_v2.pipeline_processor import PipelineProcessor
from rest_framework import viewsets
from notification_v2.constants import NotificationUrlConstant
from .models import Notification
from .serializers import NotificationSerializer

5736
backend/pdm.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,6 @@ class IsOwner(permissions.BasePermission):
"""Custom permission to only allow owners of an object."""
def has_object_permission(self, request: Request, view: APIView, obj: Any) -> bool:
return True if obj.created_by == request.user else False
@@ -27,7 +26,6 @@ class IsOwnerOrSharedUser(permissions.BasePermission):
"""Custom permission to only allow owners and shared users of an object."""
def has_object_permission(self, request: Request, view: APIView, obj: Any) -> bool:
return (
True
if (
@@ -40,10 +38,10 @@ class IsOwnerOrSharedUser(permissions.BasePermission):
class IsOwnerOrSharedUserOrSharedToOrg(permissions.BasePermission):
"""Custom permission to only allow owners and shared users of an object or
if it is shared to org."""
if it is shared to org.
"""
def has_object_permission(self, request: Request, view: APIView, obj: Any) -> bool:
return (
True
if (
@@ -57,12 +55,12 @@ class IsOwnerOrSharedUserOrSharedToOrg(permissions.BasePermission):
class IsFrictionLessAdapter(permissions.BasePermission):
"""Hack for friction-less onboarding not allowing user to view or updating
friction less adapter."""
friction less adapter.
"""
def has_object_permission(
self, request: Request, view: APIView, obj: AdapterInstance
) -> bool:
if obj.is_friction_less:
return False
@@ -71,12 +69,12 @@ class IsFrictionLessAdapter(permissions.BasePermission):
class IsFrictionLessAdapterDelete(permissions.BasePermission):
"""Hack for friction-less onboarding Allows frticon less adapter to rmoved
by an org member."""
by an org member.
"""
def has_object_permission(
self, request: Request, view: APIView, obj: AdapterInstance
) -> bool:
if obj.is_friction_less:
return True

View File

@@ -4,11 +4,12 @@ from typing import Any
from api_v2.api_key_validator import BaseAPIKeyValidator
from api_v2.exceptions import InvalidAPIRequest
from api_v2.key_helper import KeyHelper
from pipeline_v2.exceptions import PipelineNotFound
from pipeline_v2.pipeline_processor import PipelineProcessor
from rest_framework.request import Request
from utils.user_context import UserContext
from pipeline_v2.exceptions import PipelineNotFound
from pipeline_v2.pipeline_processor import PipelineProcessor
logger = logging.getLogger(__name__)

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any
class PipelineStatusPayload:
@@ -8,8 +8,8 @@ class PipelineStatusPayload:
pipeline_id: str,
pipeline_name: str,
status: str,
execution_id: Optional[str] = None,
error_message: Optional[str] = None,
execution_id: str | None = None,
error_message: str | None = None,
):
self.type = type
self.pipeline_id = pipeline_id

View File

@@ -1,5 +1,3 @@
from typing import Optional
from rest_framework.exceptions import APIException
@@ -24,14 +22,13 @@ class InactivePipelineError(APIException):
def __init__(
self,
pipeline_name: Optional[str] = None,
detail: Optional[str] = None,
code: Optional[str] = None,
pipeline_name: str | None = None,
detail: str | None = None,
code: str | None = None,
):
if pipeline_name:
self.default_detail = (
f"Pipeline '{pipeline_name}' is inactive, "
"please activate the pipeline"
f"Pipeline '{pipeline_name}' is inactive, " "please activate the pipeline"
)
super().__init__(detail, code)

View File

@@ -1,11 +1,8 @@
import logging
from typing import Any, Optional
from typing import Any
from django.conf import settings
from django.urls import reverse
from pipeline_v2.constants import PipelineKey, PipelineURL
from pipeline_v2.models import Pipeline
from pipeline_v2.pipeline_processor import PipelineProcessor
from rest_framework.request import Request
from rest_framework.response import Response
from utils.request.constants import RequestConstants
@@ -13,6 +10,9 @@ from workflow_manager.workflow_v2.constants import WorkflowExecutionKey, Workflo
from workflow_manager.workflow_v2.views import WorkflowViewSet
from backend.constants import RequestHeader
from pipeline_v2.constants import PipelineKey, PipelineURL
from pipeline_v2.models import Pipeline
from pipeline_v2.pipeline_processor import PipelineProcessor
logger = logging.getLogger(__name__)
@@ -24,7 +24,7 @@ class PipelineManager:
def execute_pipeline(
request: Request,
pipeline_id: str,
execution_id: Optional[str] = None,
execution_id: str | None = None,
) -> Response:
"""Used to execute a pipeline.
@@ -45,9 +45,10 @@ class PipelineManager:
@staticmethod
def get_pipeline_execution_data_for_scheduled_run(
pipeline_id: str,
) -> Optional[dict[str, Any]]:
) -> dict[str, Any] | None:
"""Gets the required data to be passed while executing a pipeline Any
changes to pipeline execution needs to be propagated here."""
changes to pipeline execution needs to be propagated here.
"""
callback_url = settings.DJANGO_APP_BACKEND_URL + reverse(
PipelineURL.EXECUTE_NAMESPACE
)

View File

@@ -8,7 +8,6 @@ from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [

Some files were not shown because too many files have changed in this diff Show More