Python 3.9 to 3.12 (#1231)
* python version updated from 3.9 into 3.12 * x2text-service updated with uv and python version 3.12 * x2text-service docker file updated * Unstract packages updated with uv * Runner updated with uv * Promptservice updated with uv * Platform service updated with uv * backend service updated with uv * root pyproject.toml file updated * sdk version updated in services * unstract package modules updated based on sdk version: * docker file update * pdm lock workflow modified to support uv * Docs updated based on uv support * lock automation updated * snowflake module version updated into 3.14.0 * tox updated to support UV * tox updated to support UV * tox updated with pytest * tox updated with pytest-md-report * tox updated with module requirements * python migration from 3.9 to 3.12 * tox updated with module requirements * runner updated * Commit uv.lock changes * runner updated * Commit uv.lock changes * pytest.ini added * x2text-service docker file updated * pytest.ini removed * environment updated to test * docformatter commented on pre-commit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * some pre-commit issues ignored * some pre-commit issues ignored * some pre-commit issues ignored * some pre-commit issues ignored * some pre-commit issues ignored * pre-commit updates * un used import removed from platfrom service controller * tox issue fixed * tox issue fixed * docker files updated * backend dockerfile updated * open installation issue fixed * Tools docker file updated with base python version 3.12 * python version updated into min 3.12 in pyproject.toml * linting issue fixed * uv version upgraded into 0.6.14 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * migrations excluded from ruff * added PoethePoet task runner * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: Added poe tasks for services (#1248) * Added poe tasks for services * reverted FE change made by mistake * updated tool-sidecar to uv and python to 3.12.9 * minor updates in pyproject descreption * feat: platform-service logging improvements (#1255) feat: Used flask util from core to improve logging in platform-service, added core as a dependency to platform-service: * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: Platform-service build issue and numpy issue with Python 3.12 (#1258) * fix: Platform-service build and numpy issue with Py 3.12 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: Removed backend dockerfile install statements for numpy --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * minor: Handled scenario when cost is not calculated due to no usage * minor: Corrected content shown for workflow input * fix: Minor fixes, used gthread for prompt-service, runner * Commit uv.lock changes * Removed unused line in tool dockerfile --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Chandrasekharan M <chandrasekharan@zipstack.com> Co-authored-by: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com> Co-authored-by: ali-zipstack <muhammad.ali@zipstack.com>
This commit is contained in:
74
.github/workflows/ci-test.yaml
vendored
74
.github/workflows/ci-test.yaml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Run tox tests
|
||||
name: Run tox tests with UV
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -14,44 +14,46 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.9'
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
# Install a specific version of uv.
|
||||
version: "0.6.14"
|
||||
python-version: 3.12.9
|
||||
|
||||
- name: Cache tox environments
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: .tox/
|
||||
key: ${{ runner.os }}-tox-${{ hashFiles('**/pyproject.toml', '**/tox.ini') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-tox-
|
||||
- name: Cache tox environments
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: .tox/
|
||||
key: ${{ runner.os }}-tox-uv-${{ hashFiles('**/pyproject.toml', '**/tox.ini') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-tox-uv-
|
||||
|
||||
- name: Install tox
|
||||
run: pip install tox
|
||||
- name: Install tox with UV
|
||||
run: uv pip install tox tox-uv
|
||||
|
||||
- name: Run tox
|
||||
id: tox
|
||||
run: |
|
||||
tox
|
||||
- name: Run tox
|
||||
id: tox
|
||||
run: |
|
||||
tox
|
||||
|
||||
- name: Render the report to the PR
|
||||
uses: marocchino/sticky-pull-request-comment@v2
|
||||
with:
|
||||
header: runner-test-report
|
||||
recreate: true
|
||||
path: runner-report.md
|
||||
- name: Render the report to the PR
|
||||
uses: marocchino/sticky-pull-request-comment@v2
|
||||
with:
|
||||
header: runner-test-report
|
||||
recreate: true
|
||||
path: runner-report.md
|
||||
|
||||
- name: Output reports to the job summary when tests fail
|
||||
shell: bash
|
||||
run: |
|
||||
if [ -f "runner-report.md" ]; then
|
||||
echo "<details><summary>Runner Test Report</summary>" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
cat "runner-report.md" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "</details>" >> $GITHUB_STEP_SUMMARY
|
||||
fi
|
||||
- name: Output reports to the job summary when tests fail
|
||||
shell: bash
|
||||
run: |
|
||||
if [ -f "runner-report.md" ]; then
|
||||
echo "<details><summary>Runner Test Report</summary>" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
cat "runner-report.md" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "</details>" >> $GITHUB_STEP_SUMMARY
|
||||
fi
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
name: Automate pdm.lock
|
||||
name: Automate uv.lock
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review, review_requested]
|
||||
branches: [main]
|
||||
paths:
|
||||
- '**/pyproject.toml'
|
||||
- "**/pyproject.toml"
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
directories:
|
||||
description: 'Comma-separated list of directories to update'
|
||||
description: "Comma-separated list of directories to update"
|
||||
required: false
|
||||
default: '' # Run for all dirs specified in docker/scripts/pdm-lock-gen/pdm-lock.sh
|
||||
default: "" # Run for all dirs specified in docker/scripts/uv-lock-gen/uv-lock.sh
|
||||
|
||||
jobs:
|
||||
update_pdm_lock:
|
||||
name: Update PDM lock in all directories
|
||||
update_uv_lock:
|
||||
name: Update UV lock in all directories
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
@@ -29,19 +29,20 @@ jobs:
|
||||
with:
|
||||
ref: ${{ github.head_ref }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: '3.9'
|
||||
# Install a specific version of uv.
|
||||
version: "0.6.14"
|
||||
python-version: 3.12.9
|
||||
|
||||
- name: Install PDM
|
||||
run: python -m pip install pdm==2.16.1
|
||||
- run: uv pip install --python=3.12.9 pip
|
||||
|
||||
- name: Generate PDM lockfiles
|
||||
- name: Generate UV lockfiles
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
chmod +x ./docker/scripts/pdm-lock-gen/pdm-lock.sh
|
||||
chmod +x ./docker/scripts/uv-lock-gen/uv-lock.sh
|
||||
|
||||
# Get the input from the workflow or use the default value
|
||||
dirs="${{ github.event.inputs.directories }}"
|
||||
@@ -49,9 +50,9 @@ jobs:
|
||||
# Check if directories input is empty
|
||||
if [[ -z "$dirs" ]]; then
|
||||
# No directories input given, run the script without arguments (process all directories)
|
||||
echo "No directories specified, running on all dirs listed in docker/scripts/pdm-lock-gen/pdm-lock.sh"
|
||||
echo "No directories specified, running on all dirs listed in docker/scripts/uv-lock-gen/uv-lock.sh"
|
||||
|
||||
./docker/scripts/pdm-lock-gen/pdm-lock.sh
|
||||
./docker/scripts/uv-lock-gen/uv-lock.sh
|
||||
else
|
||||
# Convert comma-separated list into an array of directories
|
||||
IFS=',' read -r -a dir_array <<< "$dirs"
|
||||
@@ -60,13 +61,13 @@ jobs:
|
||||
echo "Processing specified directories: ${dir_array[*]}"
|
||||
|
||||
# Pass directories as command-line arguments to the script
|
||||
./docker/scripts/pdm-lock-gen/pdm-lock.sh "${dir_array[@]}"
|
||||
./docker/scripts/uv-lock-gen/uv-lock.sh "${dir_array[@]}"
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Commit pdm.lock changes
|
||||
- name: Commit uv.lock changes
|
||||
uses: stefanzweifel/git-auto-commit-action@v5
|
||||
with:
|
||||
commit_message: Commit pdm.lock changes
|
||||
commit_user_name: pdm-lock-automation[bot]
|
||||
commit_user_email: pdm-lock-automation-bot@unstract.com
|
||||
commit_message: Commit uv.lock changes
|
||||
commit_user_name: uv-lock-automation[bot]
|
||||
commit_user_email: uv-lock-automation-bot@unstract.com
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -630,6 +630,7 @@ backend/pluggable_apps/*
|
||||
|
||||
# FE Plugins
|
||||
frontend/src/plugins/*
|
||||
frontend/public/llm-whisperer/
|
||||
|
||||
|
||||
# TODO: Ensure its made generic to abstract subfolder and file names
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
---
|
||||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
|
||||
default_language_version:
|
||||
python: python3.9
|
||||
python: python3.12
|
||||
default_stages:
|
||||
- pre-commit
|
||||
|
||||
ci:
|
||||
skip:
|
||||
- hadolint-docker # Fails in pre-commit CI
|
||||
|
||||
- hadolint-docker # Fails in pre-commit CI
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
@@ -39,70 +34,66 @@ repos:
|
||||
- id: forbid-new-submodules
|
||||
- id: mixed-line-ending
|
||||
- id: no-commit-to-branch
|
||||
- repo: https://github.com/adrienverge/yamllint
|
||||
rev: v1.35.1
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.3.4
|
||||
hooks:
|
||||
- id: yamllint
|
||||
args: ["-d", "relaxed"]
|
||||
# language: system
|
||||
# - repo: https://github.com/rhysd/actionlint
|
||||
# rev: v1.6.27
|
||||
# hooks:
|
||||
# - id: actionlint-docker
|
||||
# args: [-ignore, 'label ".+" is unknown']
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.4.2
|
||||
hooks:
|
||||
- id: black
|
||||
args: [--config=pyproject.toml, -l 88]
|
||||
# language: system
|
||||
exclude: |
|
||||
(?x)^(
|
||||
unstract/flags/src/unstract/flags/evaluation_.*\.py|
|
||||
)$
|
||||
- repo: https://github.com/pycqa/flake8
|
||||
rev: 7.1.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: [--max-line-length=88]
|
||||
exclude: |
|
||||
(?x)^(
|
||||
.*migrations/.*\.py|
|
||||
core/tests/.*|
|
||||
unstract/flags/src/unstract/flags/evaluation_.*\.py|
|
||||
)$
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
files: "\\.(py)$"
|
||||
args:
|
||||
[
|
||||
"--profile",
|
||||
"black",
|
||||
"--filter-files",
|
||||
--settings-path=pyproject.toml,
|
||||
]
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
|
||||
- repo: https://github.com/hadialqattan/pycln
|
||||
rev: v2.5.0
|
||||
hooks:
|
||||
- id: pycln
|
||||
args: [--config=pyproject.toml]
|
||||
|
||||
# - repo: https://github.com/pycqa/docformatter
|
||||
# rev: v1.7.5
|
||||
# hooks:
|
||||
# - id: docformatter
|
||||
# language: python
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.16.0
|
||||
rev: v3.17.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
entry: pyupgrade --py39-plus --keep-runtime-typing
|
||||
types:
|
||||
- python
|
||||
|
||||
# - repo: https://github.com/astral-sh/uv-pre-commit
|
||||
# rev: 0.6.11
|
||||
# hooks:
|
||||
# - id: uv-lock
|
||||
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: v1.11.2
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
# language: system
|
||||
# entry: mypy .
|
||||
# pass_filenames: false
|
||||
# additional_dependencies: []
|
||||
|
||||
- repo: https://github.com/igorshubovych/markdownlint-cli
|
||||
rev: v0.42.0
|
||||
hooks:
|
||||
- id: markdownlint
|
||||
args: [--disable, MD013]
|
||||
- id: markdownlint-fix
|
||||
args: [--disable, MD013]
|
||||
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.18.4
|
||||
rev: v8.18.2
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
- repo: https://github.com/Lucas-C/pre-commit-hooks-nodejs
|
||||
rev: v1.1.2
|
||||
hooks:
|
||||
- id: htmlhint
|
||||
|
||||
- repo: https://github.com/hadolint/hadolint
|
||||
rev: v2.12.1-beta
|
||||
hooks:
|
||||
@@ -114,49 +105,3 @@ repos:
|
||||
- --ignore=DL3018
|
||||
- --ignore=SC1091
|
||||
files: Dockerfile$
|
||||
- repo: https://github.com/asottile/yesqa
|
||||
rev: v1.5.0
|
||||
hooks:
|
||||
- id: yesqa
|
||||
# - repo: https://github.com/pre-commit/mirrors-eslint
|
||||
# rev: "v9.0.0-beta.2" # Use the sha / tag you want to point at
|
||||
# hooks:
|
||||
# - id: eslint
|
||||
# args: [--config=frontend/.eslintrc.json]
|
||||
# files: \.[jt]sx?$ # *.js, *.jsx, *.ts and *.tsx
|
||||
# types: [file]
|
||||
# additional_dependencies:
|
||||
# - eslint@8.41.0
|
||||
# - eslint-config-google@0.14.0
|
||||
# - eslint-config-prettier@8.8.0
|
||||
# - eslint-plugin-prettier@4.2.1
|
||||
# - eslint-plugin-react@7.32.2
|
||||
# - eslint-plugin-import@2.25.2
|
||||
- repo: https://github.com/Lucas-C/pre-commit-hooks-nodejs
|
||||
rev: v1.1.2
|
||||
hooks:
|
||||
- id: htmlhint
|
||||
- repo: https://github.com/igorshubovych/markdownlint-cli
|
||||
rev: v0.41.0
|
||||
hooks:
|
||||
- id: markdownlint
|
||||
args: [--disable, MD013]
|
||||
- id: markdownlint-fix
|
||||
args: [--disable, MD013]
|
||||
- repo: https://github.com/pdm-project/pdm
|
||||
rev: 2.16.1
|
||||
hooks:
|
||||
- id: pdm-lock-check
|
||||
# - repo: local
|
||||
# hooks:
|
||||
# - id: run-mypy
|
||||
# name: Run mypy
|
||||
# entry: sh -c 'pdm run mypy .'
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
# - id: check-django-migrations
|
||||
# name: Check django migrations
|
||||
# entry: sh -c 'pdm run docker/scripts/check_django_migrations.sh'
|
||||
# language: system
|
||||
# types: [python] # hook only runs if a python file is staged
|
||||
# pass_filenames: false
|
||||
|
||||
@@ -1 +1 @@
|
||||
3.9.6
|
||||
3.12.9
|
||||
|
||||
12
README.md
12
README.md
@@ -5,9 +5,9 @@
|
||||
|
||||
## No-code LLM Platform to launch APIs and ETL Pipelines to structure unstructured documents
|
||||
|
||||
##
|
||||
##
|
||||
|
||||
[](https://pdm-project.org)
|
||||
[](https://github.com/astral-sh/uv)
|
||||
[](https://cla-assistant.io/Zipstack/unstract)
|
||||
[](https://results.pre-commit.ci/latest/github/Zipstack/unstract/main)
|
||||
[](https://sonarcloud.io/summary/new_code?id=Zipstack_unstract)
|
||||
@@ -57,8 +57,8 @@ Next, either download a release or clone this repo and do the following:
|
||||
|
||||
That's all there is to it!
|
||||
|
||||
Follow [these steps](backend/README.md#authentication) to change the default username and password.
|
||||
See [user guide](https://docs.unstract.com/unstract/unstract_platform/user_guides/run_platform) for more details on managing the platform.
|
||||
Follow [these steps](backend/README.md#authentication) to change the default username and password.
|
||||
See [user guide](https://docs.unstract.com/unstract/unstract_platform/user_guides/run_platform) for more details on managing the platform.
|
||||
|
||||
Another really quick way to experience Unstract is by signing up for our [hosted version](https://us-central.unstract.com/). It comes with a 14 day free trial!
|
||||
|
||||
@@ -154,9 +154,9 @@ Contributions are welcome! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for fur
|
||||
|
||||
## 🚨 Backup encryption key
|
||||
|
||||
Do copy the value of `ENCRYPTION_KEY` config in either `backend/.env` or `platform-service/.env` file to a secure location.
|
||||
Do copy the value of `ENCRYPTION_KEY` config in either `backend/.env` or `platform-service/.env` file to a secure location.
|
||||
|
||||
Adapter credentials are encrypted by the platform using this key. Its loss or change will make all existing adapters inaccessible!
|
||||
Adapter credentials are encrypted by the platform using this key. Its loss or change will make all existing adapters inaccessible!
|
||||
|
||||
## 📊 A note on analytics
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
3.9.6
|
||||
3.12.9
|
||||
|
||||
@@ -15,13 +15,13 @@ Contains the backend services for Unstract written with Django and DRF.
|
||||
|
||||
All commands assumes that you have activated your `venv`.
|
||||
|
||||
```bash
|
||||
# Create venv
|
||||
pdm venv create -w virtualenv --with-pip
|
||||
eval "$(pdm venv activate in-project)"
|
||||
Install UV: https://docs.astral.sh/uv/getting-started/installation/
|
||||
|
||||
# Remove venv
|
||||
pdm venv remove in-project
|
||||
```bash
|
||||
# Create venv and install dependencies
|
||||
uv venv
|
||||
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
||||
uv sync
|
||||
```
|
||||
|
||||
#### Installing dependencies
|
||||
@@ -30,28 +30,20 @@ Go to service dir and install dependencies listed in corresponding `pyproject.to
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pdm install
|
||||
uv sync
|
||||
|
||||
# Install specific dev dependency group
|
||||
pdm install --dev -G lint
|
||||
|
||||
uv sync --group dev
|
||||
# Install production dependencies only
|
||||
pdm install --prod --no-editable
|
||||
uv sync --group deploy
|
||||
```
|
||||
|
||||
#### Running scripts
|
||||
|
||||
PDM allows you to run scripts applicable within the service dir.
|
||||
UV allows you to run python scripts applicable within the service dir.
|
||||
|
||||
```bash
|
||||
# List the possible scripts that can be executed
|
||||
pdm run -l
|
||||
```
|
||||
|
||||
For example to run the backend (dev mode is recommended to take advantage of gunicorn's `reload` feature)
|
||||
|
||||
```bash
|
||||
pdm run backend --dev
|
||||
uv run sample_script.py
|
||||
```
|
||||
|
||||
#### Running commands
|
||||
@@ -68,16 +60,16 @@ DB_NAME='unstract_db'
|
||||
DB_PORT=5432
|
||||
```
|
||||
|
||||
- If you've made changes to the model, run `python manage.py makemigrations`, else ignore this step
|
||||
- If you've made changes to the model, run `uv run manage.py makemigrations`, else ignore this step
|
||||
- Run the following to apply any migrations to the DB and start the server
|
||||
|
||||
```bash
|
||||
python manage.py migrate
|
||||
python manage.py runserver localhost:8000
|
||||
uv run manage.py migrate
|
||||
uv run manage.py runserver localhost:8000
|
||||
```
|
||||
|
||||
- Server will start and run at port 8000. (<http://localhost:8000>)
|
||||
|
||||
|
||||
## Authentication
|
||||
|
||||
The default username is `unstract` and the default password is `unstract`.
|
||||
|
||||
@@ -6,7 +6,6 @@ from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = []
|
||||
|
||||
@@ -4,7 +4,6 @@ from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("account_usage", "0001_initial"),
|
||||
]
|
||||
|
||||
@@ -1,5 +1,21 @@
|
||||
import logging
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import logout as django_logout
|
||||
from django.db.utils import IntegrityError
|
||||
from django.middleware import csrf
|
||||
from django.shortcuts import redirect
|
||||
from logs_helper.log_service import LogService
|
||||
from rest_framework import status
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from tenant_account_v2.models import OrganizationMember as OrganizationMember
|
||||
from tenant_account_v2.organization_member_service import OrganizationMemberService
|
||||
from utils.cache_service import CacheService
|
||||
from utils.local_context import StateStore
|
||||
from utils.user_context import UserContext
|
||||
from utils.user_session import UserSessionUtils
|
||||
|
||||
from account_v2.authentication_helper import AuthenticationHelper
|
||||
from account_v2.authentication_plugin_registry import AuthenticationPluginRegistry
|
||||
@@ -33,32 +49,19 @@ from account_v2.serializer import (
|
||||
SetOrganizationsResponseSerializer,
|
||||
)
|
||||
from account_v2.user import UserService
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import logout as django_logout
|
||||
from django.db.utils import IntegrityError
|
||||
from django.middleware import csrf
|
||||
from django.shortcuts import redirect
|
||||
from logs_helper.log_service import LogService
|
||||
from rest_framework import status
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from tenant_account_v2.models import OrganizationMember as OrganizationMember
|
||||
from tenant_account_v2.organization_member_service import OrganizationMemberService
|
||||
from utils.cache_service import CacheService
|
||||
from utils.local_context import StateStore
|
||||
from utils.user_context import UserContext
|
||||
from utils.user_session import UserSessionUtils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthenticationController:
|
||||
"""Authentication Controller This controller class manages user
|
||||
authentication processes."""
|
||||
authentication processes.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""This method initializes the controller by selecting the appropriate
|
||||
authentication plugin based on availability."""
|
||||
authentication plugin based on availability.
|
||||
"""
|
||||
self.authentication_helper = AuthenticationHelper()
|
||||
if AuthenticationPluginRegistry.is_plugin_available():
|
||||
self.auth_service: AuthenticationService = (
|
||||
@@ -110,7 +113,6 @@ class AuthenticationController:
|
||||
Returns:
|
||||
list[OrganizationData]: _description_
|
||||
"""
|
||||
|
||||
try:
|
||||
organizations = self.auth_service.user_organizations(request)
|
||||
except Exception as ex:
|
||||
@@ -165,9 +167,7 @@ class AuthenticationController:
|
||||
if organization_id and organization_id in organization_ids:
|
||||
# Set organization in user context
|
||||
UserContext.set_organization_identifier(organization_id)
|
||||
organization = OrganizationService.get_organization_by_org_id(
|
||||
organization_id
|
||||
)
|
||||
organization = OrganizationService.get_organization_by_org_id(organization_id)
|
||||
if not organization:
|
||||
try:
|
||||
organization_data: OrganizationData = (
|
||||
@@ -206,7 +206,7 @@ class AuthenticationController:
|
||||
f"New organization created with Id {organization_id}",
|
||||
)
|
||||
|
||||
user_info: Optional[UserInfo] = self.get_user_info(request)
|
||||
user_info: UserInfo | None = self.get_user_info(request)
|
||||
serialized_user_info = SetOrganizationsResponseSerializer(user_info).data
|
||||
organization_info = OrganizationSerializer(organization).data
|
||||
response: Response = Response(
|
||||
@@ -232,7 +232,7 @@ class AuthenticationController:
|
||||
return response
|
||||
return Response(status=status.HTTP_403_FORBIDDEN)
|
||||
|
||||
def get_user_info(self, request: Request) -> Optional[UserInfo]:
|
||||
def get_user_info(self, request: Request) -> UserInfo | None:
|
||||
return self.auth_service.get_user_info(request)
|
||||
|
||||
def is_admin_by_role(self, role: str) -> bool:
|
||||
@@ -247,7 +247,7 @@ class AuthenticationController:
|
||||
"""
|
||||
return self.auth_service.is_admin_by_role(role=role)
|
||||
|
||||
def get_organization_info(self, org_id: str) -> Optional[Organization]:
|
||||
def get_organization_info(self, org_id: str) -> Organization | None:
|
||||
organization = OrganizationService.get_organization_by_org_id(org_id=org_id)
|
||||
return organization
|
||||
|
||||
@@ -255,9 +255,9 @@ class AuthenticationController:
|
||||
self,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
organization_name: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
) -> Optional[OrganizationData]:
|
||||
organization_name: str | None = None,
|
||||
display_name: str | None = None,
|
||||
) -> OrganizationData | None:
|
||||
return self.auth_service.make_organization_and_add_member(
|
||||
user_id, user_name, organization_name, display_name
|
||||
)
|
||||
@@ -282,7 +282,7 @@ class AuthenticationController:
|
||||
return response
|
||||
|
||||
def get_organization_members_by_org_id(
|
||||
self, organization_id: Optional[str] = None
|
||||
self, organization_id: str | None = None
|
||||
) -> list[OrganizationMember]:
|
||||
members: list[OrganizationMember] = OrganizationMemberService.get_members()
|
||||
return members
|
||||
@@ -297,9 +297,7 @@ class AuthenticationController:
|
||||
Returns:
|
||||
OrganizationMember: OrganizationMemberEntity
|
||||
"""
|
||||
member: OrganizationMember = OrganizationMemberService.get_user_by_id(
|
||||
id=user.id
|
||||
)
|
||||
member: OrganizationMember = OrganizationMemberService.get_user_by_id(id=user.id)
|
||||
return member
|
||||
|
||||
def get_user_roles(self) -> list[UserRoleData]:
|
||||
@@ -320,7 +318,7 @@ class AuthenticationController:
|
||||
self,
|
||||
admin: User,
|
||||
org_id: str,
|
||||
user_list: list[dict[str, Union[str, None]]],
|
||||
user_list: list[dict[str, str | None]],
|
||||
) -> list[UserInviteResponse]:
|
||||
"""Invites users to join an organization.
|
||||
|
||||
@@ -329,6 +327,7 @@ class AuthenticationController:
|
||||
org_id (str): ID of the organization to which users are invited.
|
||||
user_list (list[dict[str, Union[str, None]]]):
|
||||
List of user details for invitation.
|
||||
|
||||
Returns:
|
||||
list[UserInviteResponse]: List of responses for each
|
||||
user invitation.
|
||||
@@ -397,7 +396,7 @@ class AuthenticationController:
|
||||
|
||||
def add_user_role(
|
||||
self, admin: User, org_id: str, email: str, role: str
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
admin_user = OrganizationMemberService.get_user_by_id(id=admin.id)
|
||||
user = OrganizationMemberService.get_user_by_email(email=email)
|
||||
if user:
|
||||
@@ -414,7 +413,7 @@ class AuthenticationController:
|
||||
|
||||
def remove_user_role(
|
||||
self, admin: User, org_id: str, email: str, role: str
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
admin_user = OrganizationMemberService.get_user_by_id(id=admin.id)
|
||||
organization_member = OrganizationMemberService.get_user_by_email(email=email)
|
||||
if organization_member:
|
||||
@@ -471,12 +470,10 @@ class AuthenticationController:
|
||||
except IntegrityError:
|
||||
logger.warning(f"Account already exists for {user.email}")
|
||||
|
||||
def get_or_create_user(
|
||||
self, user: User
|
||||
) -> Optional[Union[User, OrganizationMember]]:
|
||||
def get_or_create_user(self, user: User) -> User | OrganizationMember | None:
|
||||
user_service = UserService()
|
||||
if user.id:
|
||||
account_user: Optional[User] = user_service.get_user_by_id(user.id)
|
||||
account_user: User | None = user_service.get_user_by_id(user.id)
|
||||
if account_user:
|
||||
return account_user
|
||||
elif user.email:
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from platform_settings_v2.platform_auth_service import PlatformAuthenticationService
|
||||
from tenant_account_v2.organization_member_service import OrganizationMemberService
|
||||
|
||||
from account_v2.dto import MemberData
|
||||
from account_v2.models import Organization, User
|
||||
from account_v2.user import UserService
|
||||
from platform_settings_v2.platform_auth_service import PlatformAuthenticationService
|
||||
from tenant_account_v2.organization_member_service import OrganizationMemberService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -14,9 +15,7 @@ class AuthenticationHelper:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def list_of_members_from_user_model(
|
||||
self, model_data: list[Any]
|
||||
) -> list[MemberData]:
|
||||
def list_of_members_from_user_model(self, model_data: list[Any]) -> list[MemberData]:
|
||||
members: list[MemberData] = []
|
||||
for data in model_data:
|
||||
user_id = data.user_id
|
||||
@@ -49,9 +48,7 @@ class AuthenticationHelper:
|
||||
user = user_service.create_user(email, user_id)
|
||||
return user
|
||||
|
||||
def create_initial_platform_key(
|
||||
self, user: User, organization: Organization
|
||||
) -> None:
|
||||
def create_initial_platform_key(self, user: User, organization: Organization) -> None:
|
||||
"""Create an initial platform key for the given user and organization.
|
||||
|
||||
This method generates a new platform key with the specified parameters
|
||||
@@ -109,7 +106,6 @@ class AuthenticationHelper:
|
||||
Parameters:
|
||||
user_id (str): The user_id of the users to remove.
|
||||
"""
|
||||
|
||||
organization_user = OrganizationMemberService.get_user_by_user_id(user_id)
|
||||
if not organization_user:
|
||||
logger.warning(
|
||||
|
||||
@@ -3,15 +3,17 @@ import os
|
||||
from importlib import import_module
|
||||
from typing import Any
|
||||
|
||||
from account_v2.constants import PluginConfig
|
||||
from django.apps import apps
|
||||
|
||||
from account_v2.constants import PluginConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load_plugins() -> dict[str, dict[str, Any]]:
|
||||
"""Iterating through the Authentication plugins and register their
|
||||
metadata."""
|
||||
metadata.
|
||||
"""
|
||||
auth_app = apps.get_app_config(PluginConfig.PLUGINS_APP)
|
||||
auth_package_path = auth_app.module.__package__
|
||||
auth_dir = os.path.join(auth_app.path, PluginConfig.AUTH_PLUGIN_DIR)
|
||||
|
||||
@@ -1,6 +1,17 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import authenticate, login, logout
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.http import HttpRequest
|
||||
from django.shortcuts import redirect, render
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from tenant_account_v2.models import OrganizationMember as OrganizationMember
|
||||
from tenant_account_v2.organization_member_service import OrganizationMemberService
|
||||
from utils.user_context import UserContext
|
||||
|
||||
from account_v2.authentication_helper import AuthenticationHelper
|
||||
from account_v2.constants import DefaultOrg, ErrorMessage, UserLoginTemplate
|
||||
@@ -18,16 +29,6 @@ from account_v2.enums import UserRole
|
||||
from account_v2.models import Organization, User
|
||||
from account_v2.organization import OrganizationService
|
||||
from account_v2.serializer import LoginRequestSerializer
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import authenticate, login, logout
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.http import HttpRequest
|
||||
from django.shortcuts import redirect, render
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from tenant_account_v2.models import OrganizationMember as OrganizationMember
|
||||
from tenant_account_v2.organization_member_service import OrganizationMemberService
|
||||
from utils.user_context import UserContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -92,10 +93,7 @@ class AuthenticationService:
|
||||
False otherwise.
|
||||
"""
|
||||
# Validation of user credentials
|
||||
if (
|
||||
username != DefaultOrg.MOCK_USER
|
||||
or password != DefaultOrg.MOCK_USER_PASSWORD
|
||||
):
|
||||
if username != DefaultOrg.MOCK_USER or password != DefaultOrg.MOCK_USER_PASSWORD:
|
||||
return False
|
||||
|
||||
user = authenticate(request, username=username, password=password)
|
||||
@@ -193,7 +191,7 @@ class AuthenticationService:
|
||||
self,
|
||||
request: Request,
|
||||
user: User,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> MemberData:
|
||||
member_data: MemberData = MemberData(
|
||||
user_id=user.user_id,
|
||||
@@ -324,7 +322,7 @@ class AuthenticationService:
|
||||
logger.error(f"Failed to set default user: {str(e)}")
|
||||
return False
|
||||
|
||||
def _get_or_create_user(self, organization: Optional[Organization]) -> User:
|
||||
def _get_or_create_user(self, organization: Organization | None) -> User:
|
||||
"""Get existing user or create a new one based on organization context.
|
||||
|
||||
Args:
|
||||
@@ -365,7 +363,7 @@ class AuthenticationService:
|
||||
|
||||
return self._create_mock_user()
|
||||
|
||||
def _get_admin_user(self) -> Optional[User]:
|
||||
def _get_admin_user(self) -> User | None:
|
||||
"""Get the first admin user from the organization.
|
||||
|
||||
Returns:
|
||||
@@ -376,7 +374,7 @@ class AuthenticationService:
|
||||
)
|
||||
return admin_members[0].user if admin_members else None
|
||||
|
||||
def _promote_first_member_to_admin(self) -> Optional[OrganizationMember]:
|
||||
def _promote_first_member_to_admin(self) -> OrganizationMember | None:
|
||||
"""Promote the first organization member to admin role.
|
||||
|
||||
Returns:
|
||||
@@ -406,7 +404,7 @@ class AuthenticationService:
|
||||
user.save()
|
||||
logger.info(f"Updated user {user} with username {DefaultOrg.MOCK_USER}")
|
||||
|
||||
def get_user_info(self, request: Request) -> Optional[UserInfo]:
|
||||
def get_user_info(self, request: Request) -> UserInfo | None:
|
||||
user: User = request.user
|
||||
if user:
|
||||
return UserInfo(
|
||||
@@ -419,16 +417,16 @@ class AuthenticationService:
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_organization_info(self, org_id: str) -> Optional[Organization]:
|
||||
def get_organization_info(self, org_id: str) -> Organization | None:
|
||||
return OrganizationService.get_organization_by_org_id(org_id=org_id)
|
||||
|
||||
def make_organization_and_add_member(
|
||||
self,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
organization_name: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
) -> Optional[OrganizationData]:
|
||||
organization_name: str | None = None,
|
||||
display_name: str | None = None,
|
||||
) -> OrganizationData | None:
|
||||
organization: OrganizationData = OrganizationData(
|
||||
id=str(uuid.uuid4()),
|
||||
display_name=DefaultOrg.MOCK_ORG,
|
||||
@@ -469,6 +467,6 @@ class AuthenticationService:
|
||||
admin: OrganizationMember,
|
||||
org_id: str,
|
||||
email: str,
|
||||
role: Optional[str] = None,
|
||||
role: str | None = None,
|
||||
) -> bool:
|
||||
raise MethodNotImplemented()
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from account_v2.authentication_plugin_registry import AuthenticationPluginRegistry
|
||||
from account_v2.authentication_service import AuthenticationService
|
||||
from account_v2.constants import Common
|
||||
from django.conf import settings
|
||||
from django.http import HttpRequest, HttpResponse, JsonResponse
|
||||
from utils.constants import Account
|
||||
from utils.local_context import StateStore
|
||||
from utils.user_session import UserSessionUtils
|
||||
|
||||
from account_v2.authentication_plugin_registry import AuthenticationPluginRegistry
|
||||
from account_v2.authentication_service import AuthenticationService
|
||||
from account_v2.constants import Common
|
||||
from backend.constants import RequestHeader
|
||||
|
||||
|
||||
@@ -42,9 +42,7 @@ class CustomAuthMiddleware:
|
||||
if is_authenticated:
|
||||
organization_id = UserSessionUtils.get_organization_id(request=request)
|
||||
if request.organization_id and not organization_id:
|
||||
return JsonResponse(
|
||||
{"message": "Organization access denied"}, status=403
|
||||
)
|
||||
return JsonResponse({"message": "Organization access denied"}, status=403)
|
||||
StateStore.set(Common.LOG_EVENTS_ID, request.session.session_key)
|
||||
StateStore.set(Account.ORGANIZATION_ID, organization_id)
|
||||
response = self.get_response(request)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from rest_framework.exceptions import APIException
|
||||
|
||||
|
||||
@@ -18,7 +16,7 @@ class DuplicateData(APIException):
|
||||
status_code = 400
|
||||
default_detail = "Duplicate Data"
|
||||
|
||||
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
|
||||
def __init__(self, detail: str | None = None, code: int | None = None):
|
||||
if detail is not None:
|
||||
self.detail = detail
|
||||
if code is not None:
|
||||
@@ -30,7 +28,7 @@ class TableNotExistError(APIException):
|
||||
status_code = 400
|
||||
default_detail = "Unknown Table"
|
||||
|
||||
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
|
||||
def __init__(self, detail: str | None = None, code: int | None = None):
|
||||
if detail is not None:
|
||||
self.detail = detail
|
||||
if code is not None:
|
||||
@@ -42,7 +40,7 @@ class UserNotExistError(APIException):
|
||||
status_code = 400
|
||||
default_detail = "Unknown User"
|
||||
|
||||
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
|
||||
def __init__(self, detail: str | None = None, code: int | None = None):
|
||||
if detail is not None:
|
||||
self.detail = detail
|
||||
if code is not None:
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemberData:
|
||||
user_id: str
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
picture: Optional[str] = None
|
||||
role: Optional[list[str]] = None
|
||||
organization_id: Optional[str] = None
|
||||
email: str | None = None
|
||||
name: str | None = None
|
||||
picture: str | None = None
|
||||
role: list[str] | None = None
|
||||
organization_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -45,11 +45,11 @@ class OrganizationSignupResponse:
|
||||
class UserInfo:
|
||||
email: str
|
||||
user_id: str
|
||||
id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
display_name: Optional[str] = None
|
||||
family_name: Optional[str] = None
|
||||
picture: Optional[str] = None
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
display_name: str | None = None
|
||||
family_name: str | None = None
|
||||
picture: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -97,14 +97,14 @@ class ResetUserPasswordDto:
|
||||
class UserInviteResponse:
|
||||
email: str
|
||||
status: str
|
||||
message: Optional[str] = None
|
||||
message: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserRoleData:
|
||||
name: str
|
||||
id: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
id: str | None = None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -123,8 +123,8 @@ class MemberInvitation:
|
||||
id: str
|
||||
email: str
|
||||
roles: list[str]
|
||||
created_at: Optional[str] = None
|
||||
expires_at: Optional[str] = None
|
||||
created_at: str | None = None
|
||||
expires_at: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -11,7 +11,6 @@ from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
|
||||
@@ -18,9 +18,7 @@ class Organization(models.Model):
|
||||
|
||||
name = models.CharField(max_length=NAME_SIZE)
|
||||
display_name = models.CharField(max_length=NAME_SIZE)
|
||||
organization_id = models.CharField(
|
||||
max_length=FieldLength.ORG_NAME_SIZE, unique=True
|
||||
)
|
||||
organization_id = models.CharField(max_length=FieldLength.ORG_NAME_SIZE, unique=True)
|
||||
created_by = models.ForeignKey(
|
||||
"User",
|
||||
on_delete=models.SET_NULL,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from django.db import IntegrityError
|
||||
|
||||
from account_v2.models import Organization
|
||||
from django.db import IntegrityError
|
||||
|
||||
Logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -12,7 +12,7 @@ class OrganizationService:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_organization_by_org_id(org_id: str) -> Optional[Organization]:
|
||||
def get_organization_by_org_id(org_id: str) -> Organization | None:
|
||||
try:
|
||||
return Organization.objects.get(organization_id=org_id) # type: ignore
|
||||
except Organization.DoesNotExist:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from account_v2.models import Organization, User
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
class OrganizationSignupSerializer(serializers.Serializer):
|
||||
@@ -92,18 +92,20 @@ class LoginRequestSerializer(serializers.Serializer):
|
||||
username = serializers.CharField(required=True)
|
||||
password = serializers.CharField(required=True)
|
||||
|
||||
def validate_username(self, value: Optional[str]) -> str:
|
||||
def validate_username(self, value: str | None) -> str:
|
||||
"""Check that the username is not empty and has at least 3
|
||||
characters."""
|
||||
characters.
|
||||
"""
|
||||
if not value or len(value) < 3:
|
||||
raise serializers.ValidationError(
|
||||
"Username must be at least 3 characters long."
|
||||
)
|
||||
return value
|
||||
|
||||
def validate_password(self, value: Optional[str]) -> str:
|
||||
def validate_password(self, value: str | None) -> str:
|
||||
"""Check that the password is not empty and has at least 3
|
||||
characters."""
|
||||
characters.
|
||||
"""
|
||||
if not value or len(value) < 3:
|
||||
raise serializers.ValidationError(
|
||||
"Password must be at least 3 characters long."
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from django.urls import path
|
||||
|
||||
from account_v2.views import (
|
||||
callback,
|
||||
create_organization,
|
||||
@@ -8,7 +10,6 @@ from account_v2.views import (
|
||||
set_organization,
|
||||
signup,
|
||||
)
|
||||
from django.urls import path
|
||||
|
||||
urlpatterns = [
|
||||
path("login", login, name="login"),
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from django.db import IntegrityError
|
||||
|
||||
from account_v2.models import User
|
||||
from django.db import IntegrityError
|
||||
|
||||
Logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,7 +28,7 @@ class UserService:
|
||||
user.save()
|
||||
return user
|
||||
|
||||
def get_user_by_email(self, email: str) -> Optional[User]:
|
||||
def get_user_by_email(self, email: str) -> User | None:
|
||||
try:
|
||||
user: User = User.objects.get(email=email)
|
||||
return user
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import api_view
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from utils.user_session import UserSessionUtils
|
||||
|
||||
from account_v2.authentication_controller import AuthenticationController
|
||||
from account_v2.dto import (
|
||||
OrganizationSignupRequestBody,
|
||||
@@ -14,11 +20,6 @@ from account_v2.serializer import (
|
||||
OrganizationSignupSerializer,
|
||||
UserSessionResponseSerializer,
|
||||
)
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import api_view
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from utils.user_session import UserSessionUtils
|
||||
|
||||
Logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -76,6 +77,7 @@ def get_organizations(request: Request) -> Response:
|
||||
"""get_organizations.
|
||||
|
||||
Retrieve the list of organizations to which the user belongs.
|
||||
|
||||
Args:
|
||||
request (HttpRequest): _description_
|
||||
|
||||
@@ -91,6 +93,7 @@ def set_organization(request: Request, id: str) -> Response:
|
||||
"""set_organization.
|
||||
|
||||
Set the current organization to use.
|
||||
|
||||
Args:
|
||||
request (HttpRequest): _description_
|
||||
id (String): organization Id
|
||||
@@ -98,7 +101,6 @@ def set_organization(request: Request, id: str) -> Response:
|
||||
Returns:
|
||||
Response: Contains the User and Current organization details.
|
||||
"""
|
||||
|
||||
auth_controller = AuthenticationController()
|
||||
return auth_controller.set_user_organization(request, id)
|
||||
|
||||
@@ -108,6 +110,7 @@ def get_session_data(request: Request) -> Response:
|
||||
"""get_session_data.
|
||||
|
||||
Retrieve the current session data.
|
||||
|
||||
Args:
|
||||
request (HttpRequest): _description_
|
||||
|
||||
@@ -128,6 +131,7 @@ def make_session_response(
|
||||
"""make_session_response.
|
||||
|
||||
Make the current session data.
|
||||
|
||||
Args:
|
||||
request (HttpRequest): _description_
|
||||
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from account_v2.models import User
|
||||
from cryptography.fernet import Fernet
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from platform_settings_v2.platform_auth_service import PlatformAuthenticationService
|
||||
from tenant_account_v2.organization_member_service import OrganizationMemberService
|
||||
|
||||
from adapter_processor_v2.constants import AdapterKeys, AllowedDomains
|
||||
from adapter_processor_v2.exceptions import (
|
||||
InternalServiceError,
|
||||
InValidAdapterId,
|
||||
TestAdapterError,
|
||||
)
|
||||
from cryptography.fernet import Fernet
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from platform_settings_v2.platform_auth_service import PlatformAuthenticationService
|
||||
from tenant_account_v2.organization_member_service import OrganizationMemberService
|
||||
from unstract.sdk.adapters.adapterkit import Adapterkit
|
||||
from unstract.sdk.adapters.base import Adapter
|
||||
from unstract.sdk.adapters.enums import AdapterTypes
|
||||
@@ -43,9 +44,7 @@ class AdapterProcessor:
|
||||
updated_adapters[0].get(AdapterKeys.JSON_SCHEMA)
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Invalid adapter Id : {adapter_id} while fetching JSON Schema"
|
||||
)
|
||||
logger.error(f"Invalid adapter Id : {adapter_id} while fetching JSON Schema")
|
||||
raise InValidAdapterId()
|
||||
return schema_details
|
||||
|
||||
@@ -72,9 +71,7 @@ class AdapterProcessor:
|
||||
AdapterKeys.NAME: each_adapter.get(AdapterKeys.NAME),
|
||||
AdapterKeys.DESCRIPTION: each_adapter.get(AdapterKeys.DESCRIPTION),
|
||||
AdapterKeys.ICON: each_adapter.get(AdapterKeys.ICON),
|
||||
AdapterKeys.ADAPTER_TYPE: each_adapter.get(
|
||||
AdapterKeys.ADAPTER_TYPE
|
||||
),
|
||||
AdapterKeys.ADAPTER_TYPE: each_adapter.get(AdapterKeys.ADAPTER_TYPE),
|
||||
}
|
||||
)
|
||||
return supported_adapters
|
||||
@@ -97,7 +94,6 @@ class AdapterProcessor:
|
||||
adapter_class = Adapterkit().get_adapter_class_by_adapter_id(adapter_id)
|
||||
|
||||
if adapter_metadata.pop(AdapterKeys.ADAPTER_TYPE) == AdapterKeys.X2TEXT:
|
||||
|
||||
if (
|
||||
adapter_metadata.get(AdapterKeys.PLATFORM_PROVIDED_UNSTRACT_KEY)
|
||||
and add_unstract_key
|
||||
@@ -136,7 +132,8 @@ class AdapterProcessor:
|
||||
@staticmethod
|
||||
def __fetch_adapters_by_key_value(key: str, value: Any) -> Adapter:
|
||||
"""Fetches a list of adapters that have an attribute matching key and
|
||||
value."""
|
||||
value.
|
||||
"""
|
||||
logger.info(f"Fetching adapter list for {key} with {value}")
|
||||
adapter_kit = Adapterkit()
|
||||
adapters = adapter_kit.get_adapters_list()
|
||||
@@ -172,10 +169,8 @@ class AdapterProcessor:
|
||||
)
|
||||
|
||||
if default_triad.get(AdapterKeys.X2TEXT_DEFAULT, None):
|
||||
user_default_adapter.default_x2text_adapter = (
|
||||
AdapterInstance.objects.get(
|
||||
pk=default_triad[AdapterKeys.X2TEXT_DEFAULT]
|
||||
)
|
||||
user_default_adapter.default_x2text_adapter = AdapterInstance.objects.get(
|
||||
pk=default_triad[AdapterKeys.X2TEXT_DEFAULT]
|
||||
)
|
||||
|
||||
user_default_adapter.save()
|
||||
@@ -223,7 +218,6 @@ class AdapterProcessor:
|
||||
- list[AdapterInstance]: A list of AdapterInstance objects that match
|
||||
the specified adapter type.
|
||||
"""
|
||||
|
||||
adapters: list[AdapterInstance] = AdapterInstance.objects.for_user(user).filter(
|
||||
adapter_type=adapter_type.value,
|
||||
)
|
||||
@@ -232,8 +226,8 @@ class AdapterProcessor:
|
||||
@staticmethod
|
||||
def get_adapter_by_name_and_type(
|
||||
adapter_type: AdapterTypes,
|
||||
adapter_name: Optional[str] = None,
|
||||
) -> Optional[AdapterInstance]:
|
||||
adapter_name: str | None = None,
|
||||
) -> AdapterInstance | None:
|
||||
"""Get the adapter instance by its name and type.
|
||||
|
||||
Parameters:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Optional
|
||||
from rest_framework.exceptions import APIException
|
||||
|
||||
from adapter_processor_v2.constants import AdapterKeys
|
||||
from rest_framework.exceptions import APIException
|
||||
from unstract.sdk.exceptions import SdkError
|
||||
|
||||
|
||||
@@ -39,9 +38,9 @@ class DuplicateAdapterNameError(APIException):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
detail: Optional[str] = None,
|
||||
code: Optional[str] = None,
|
||||
name: str | None = None,
|
||||
detail: str | None = None,
|
||||
code: str | None = None,
|
||||
) -> None:
|
||||
if name:
|
||||
detail = self.default_detail.replace("this name", f"name '{name}'")
|
||||
@@ -55,9 +54,9 @@ class TestAdapterError(APIException):
|
||||
def __init__(
|
||||
self,
|
||||
sdk_err: SdkError,
|
||||
detail: Optional[str] = None,
|
||||
code: Optional[str] = None,
|
||||
adapter_name: Optional[str] = None,
|
||||
detail: str | None = None,
|
||||
code: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
):
|
||||
if sdk_err.status_code:
|
||||
self.status_code = sdk_err.status_code
|
||||
@@ -77,8 +76,8 @@ class DeleteAdapterInUseError(APIException):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detail: Optional[str] = None,
|
||||
code: Optional[str] = None,
|
||||
detail: str | None = None,
|
||||
code: str | None = None,
|
||||
adapter_name: str = "adapter",
|
||||
):
|
||||
if detail is None:
|
||||
|
||||
@@ -8,7 +8,6 @@ from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
@@ -92,9 +91,7 @@ class Migration(migrations.Migration):
|
||||
),
|
||||
(
|
||||
"is_usable",
|
||||
models.BooleanField(
|
||||
db_comment="Is the Adpater Usable", default=True
|
||||
),
|
||||
models.BooleanField(db_comment="Is the Adpater Usable", default=True),
|
||||
),
|
||||
("description", models.TextField(blank=True, default=None, null=True)),
|
||||
(
|
||||
|
||||
@@ -9,9 +9,6 @@ from django.conf import settings
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
from tenant_account_v2.models import OrganizationMember
|
||||
from unstract.sdk.adapters.adapterkit import Adapterkit
|
||||
from unstract.sdk.adapters.enums import AdapterTypes
|
||||
from unstract.sdk.adapters.exceptions import AdapterError
|
||||
from utils.exceptions import InvalidEncryptionKey
|
||||
from utils.models.base_model import BaseModel
|
||||
from utils.models.organization_mixin import (
|
||||
@@ -19,6 +16,10 @@ from utils.models.organization_mixin import (
|
||||
DefaultOrganizationMixin,
|
||||
)
|
||||
|
||||
from unstract.sdk.adapters.adapterkit import Adapterkit
|
||||
from unstract.sdk.adapters.enums import AdapterTypes
|
||||
from unstract.sdk.adapters.exceptions import AdapterError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ADAPTER_NAME_SIZE = 128
|
||||
@@ -132,7 +133,6 @@ class AdapterInstance(DefaultOrganizationMixin, BaseModel):
|
||||
]
|
||||
|
||||
def create_adapter(self) -> None:
|
||||
|
||||
encryption_secret: str = settings.ENCRYPTION_KEY
|
||||
f: Fernet = Fernet(encryption_secret.encode("utf-8"))
|
||||
|
||||
|
||||
@@ -2,17 +2,17 @@ import json
|
||||
from typing import Any
|
||||
|
||||
from account_v2.serializer import UserSerializer
|
||||
from adapter_processor_v2.adapter_processor import AdapterProcessor
|
||||
from adapter_processor_v2.constants import AdapterKeys
|
||||
from cryptography.fernet import Fernet
|
||||
from django.conf import settings
|
||||
from rest_framework import serializers
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from unstract.sdk.adapters.constants import Common as common
|
||||
from unstract.sdk.adapters.enums import AdapterTypes
|
||||
|
||||
from adapter_processor_v2.adapter_processor import AdapterProcessor
|
||||
from adapter_processor_v2.constants import AdapterKeys
|
||||
from backend.constants import FieldLengthConstants as FLC
|
||||
from backend.serializers import AuditSerializer
|
||||
from unstract.sdk.adapters.constants import Common as common
|
||||
from unstract.sdk.adapters.enums import AdapterTypes
|
||||
|
||||
from .models import AdapterInstance, UserDefaultAdapter
|
||||
|
||||
@@ -31,12 +31,8 @@ class BaseAdapterSerializer(AuditSerializer):
|
||||
|
||||
class DefaultAdapterSerializer(serializers.Serializer):
|
||||
llm_default = serializers.CharField(max_length=FLC.UUID_LENGTH, required=False)
|
||||
embedding_default = serializers.CharField(
|
||||
max_length=FLC.UUID_LENGTH, required=False
|
||||
)
|
||||
vector_db_default = serializers.CharField(
|
||||
max_length=FLC.UUID_LENGTH, required=False
|
||||
)
|
||||
embedding_default = serializers.CharField(max_length=FLC.UUID_LENGTH, required=False)
|
||||
vector_db_default = serializers.CharField(max_length=FLC.UUID_LENGTH, required=False)
|
||||
|
||||
|
||||
class AdapterInstanceSerializer(BaseAdapterSerializer):
|
||||
@@ -51,9 +47,7 @@ class AdapterInstanceSerializer(BaseAdapterSerializer):
|
||||
f: Fernet = Fernet(encryption_secret.encode("utf-8"))
|
||||
json_string: str = json.dumps(data.pop(AdapterKeys.ADAPTER_METADATA))
|
||||
|
||||
data[AdapterKeys.ADAPTER_METADATA_B] = f.encrypt(
|
||||
json_string.encode("utf-8")
|
||||
)
|
||||
data[AdapterKeys.ADAPTER_METADATA_B] = f.encrypt(json_string.encode("utf-8"))
|
||||
|
||||
return data
|
||||
|
||||
@@ -79,7 +73,6 @@ class AdapterInstanceSerializer(BaseAdapterSerializer):
|
||||
|
||||
|
||||
class AdapterInfoSerializer(BaseAdapterSerializer):
|
||||
|
||||
context_window_size = serializers.SerializerMethodField()
|
||||
|
||||
class Meta(BaseAdapterSerializer.Meta):
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from django.urls import path
|
||||
from rest_framework.urlpatterns import format_suffix_patterns
|
||||
|
||||
from adapter_processor_v2.views import (
|
||||
AdapterInstanceViewSet,
|
||||
AdapterViewSet,
|
||||
DefaultAdapterViewSet,
|
||||
)
|
||||
from django.urls import path
|
||||
from rest_framework.urlpatterns import format_suffix_patterns
|
||||
|
||||
default_triad = DefaultAdapterViewSet.as_view(
|
||||
{"post": "configure_default_triad", "get": "get_default_triad"}
|
||||
|
||||
@@ -1,25 +1,7 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from adapter_processor_v2.adapter_processor import AdapterProcessor
|
||||
from adapter_processor_v2.constants import AdapterKeys
|
||||
from adapter_processor_v2.exceptions import (
|
||||
CannotDeleteDefaultAdapter,
|
||||
DeleteAdapterInUseError,
|
||||
DuplicateAdapterNameError,
|
||||
IdIsMandatory,
|
||||
InValidType,
|
||||
)
|
||||
from adapter_processor_v2.serializers import (
|
||||
AdapterInfoSerializer,
|
||||
AdapterInstanceSerializer,
|
||||
AdapterListSerializer,
|
||||
DefaultAdapterSerializer,
|
||||
SharedUserListSerializer,
|
||||
TestAdapterSerializer,
|
||||
UserDefaultAdapterSerializer,
|
||||
)
|
||||
from django.db import IntegrityError
|
||||
from django.db.models import ProtectedError, QuerySet
|
||||
from django.http import HttpRequest
|
||||
@@ -40,6 +22,25 @@ from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
||||
from tenant_account_v2.organization_member_service import OrganizationMemberService
|
||||
from utils.filtering import FilterHelper
|
||||
|
||||
from adapter_processor_v2.adapter_processor import AdapterProcessor
|
||||
from adapter_processor_v2.constants import AdapterKeys
|
||||
from adapter_processor_v2.exceptions import (
|
||||
CannotDeleteDefaultAdapter,
|
||||
DeleteAdapterInUseError,
|
||||
DuplicateAdapterNameError,
|
||||
IdIsMandatory,
|
||||
InValidType,
|
||||
)
|
||||
from adapter_processor_v2.serializers import (
|
||||
AdapterInfoSerializer,
|
||||
AdapterInstanceSerializer,
|
||||
AdapterListSerializer,
|
||||
DefaultAdapterSerializer,
|
||||
SharedUserListSerializer,
|
||||
TestAdapterSerializer,
|
||||
UserDefaultAdapterSerializer,
|
||||
)
|
||||
|
||||
from .constants import AdapterKeys as constant
|
||||
from .models import AdapterInstance, UserDefaultAdapter
|
||||
|
||||
@@ -130,11 +131,9 @@ class AdapterViewSet(GenericViewSet):
|
||||
|
||||
|
||||
class AdapterInstanceViewSet(ModelViewSet):
|
||||
|
||||
serializer_class = AdapterInstanceSerializer
|
||||
|
||||
def get_permissions(self) -> list[Any]:
|
||||
|
||||
if self.action in ["update", "retrieve"]:
|
||||
return [IsFrictionLessAdapter()]
|
||||
|
||||
@@ -148,7 +147,7 @@ class AdapterInstanceViewSet(ModelViewSet):
|
||||
# User cant view/update metadata but can delete/share etc
|
||||
return [IsOwner()]
|
||||
|
||||
def get_queryset(self) -> Optional[QuerySet]:
|
||||
def get_queryset(self) -> QuerySet | None:
|
||||
if filter_args := FilterHelper.build_filter_args(
|
||||
self.request,
|
||||
constant.ADAPTER_TYPE,
|
||||
@@ -237,9 +236,7 @@ class AdapterInstanceViewSet(ModelViewSet):
|
||||
name=serializer.validated_data.get(AdapterKeys.ADAPTER_NAME)
|
||||
)
|
||||
headers = self.get_success_headers(serializer.data)
|
||||
return Response(
|
||||
serializer.data, status=status.HTTP_201_CREATED, headers=headers
|
||||
)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
|
||||
|
||||
def destroy(
|
||||
self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any]
|
||||
@@ -261,13 +258,11 @@ class AdapterInstanceViewSet(ModelViewSet):
|
||||
)
|
||||
or (
|
||||
adapter_type == AdapterKeys.EMBEDDING
|
||||
and adapter_instance
|
||||
== user_default_adapter.default_embedding_adapter
|
||||
and adapter_instance == user_default_adapter.default_embedding_adapter
|
||||
)
|
||||
or (
|
||||
adapter_type == AdapterKeys.VECTOR_DB
|
||||
and adapter_instance
|
||||
== user_default_adapter.default_vector_db_adapter
|
||||
and adapter_instance == user_default_adapter.default_vector_db_adapter
|
||||
)
|
||||
or (
|
||||
adapter_type == AdapterKeys.X2TEXT
|
||||
@@ -337,7 +332,6 @@ class AdapterInstanceViewSet(ModelViewSet):
|
||||
|
||||
@action(detail=True, methods=["get"])
|
||||
def list_of_shared_users(self, request: HttpRequest, pk: Any = None) -> Response:
|
||||
|
||||
adapter = self.get_object()
|
||||
|
||||
serialized_instances = SharedUserListSerializer(adapter).data
|
||||
@@ -346,7 +340,6 @@ class AdapterInstanceViewSet(ModelViewSet):
|
||||
|
||||
@action(detail=True, methods=["get"])
|
||||
def adapter_info(self, request: HttpRequest, pk: uuid) -> Response:
|
||||
|
||||
adapter = self.get_object()
|
||||
|
||||
serialized_instances = AdapterInfoSerializer(adapter).data
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from plugins.api.dto import metadata
|
||||
|
||||
from api_v2.postman_collection.dto import PostmanCollection
|
||||
from plugins.api.dto import metadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApiDeploymentDTORegistry:
|
||||
_dto_class: Optional[Any] = None # Store the selected DTO class (cached)
|
||||
_dto_class: Any | None = None # Store the selected DTO class (cached)
|
||||
|
||||
@classmethod
|
||||
def load_dto(cls) -> Optional[Any]:
|
||||
def load_dto(cls) -> Any | None:
|
||||
class_name = PostmanCollection.__name__
|
||||
if metadata.get(class_name):
|
||||
return metadata[class_name].class_name
|
||||
return PostmanCollection # Return as soon as we find a valid DTO
|
||||
|
||||
@classmethod
|
||||
def get_dto(cls) -> Optional[type]:
|
||||
def get_dto(cls) -> type | None:
|
||||
"""Returns the first available DTO class, or None if unavailable."""
|
||||
return cls.load_dto()
|
||||
|
||||
@@ -1,6 +1,18 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.models import QuerySet
|
||||
from django.http import HttpResponse
|
||||
from permissions.permission import IsOwner
|
||||
from rest_framework import serializers, status, views, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import Serializer
|
||||
from utils.enums import CeleryTaskState
|
||||
from workflow_manager.workflow_v2.dto import ExecutionResponse
|
||||
|
||||
from api_v2.api_deployment_dto_registry import ApiDeploymentDTORegistry
|
||||
from api_v2.constants import ApiExecution
|
||||
@@ -14,25 +26,12 @@ from api_v2.serializers import (
|
||||
ExecutionQuerySerializer,
|
||||
ExecutionRequestSerializer,
|
||||
)
|
||||
from django.conf import settings
|
||||
from django.db.models import QuerySet
|
||||
from django.http import HttpResponse
|
||||
from permissions.permission import IsOwner
|
||||
from rest_framework import serializers, status, views, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import Serializer
|
||||
from utils.enums import CeleryTaskState
|
||||
from workflow_manager.workflow_v2.dto import ExecutionResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeploymentExecution(views.APIView):
|
||||
def initialize_request(
|
||||
self, request: Request, *args: Any, **kwargs: Any
|
||||
) -> Request:
|
||||
def initialize_request(self, request: Request, *args: Any, **kwargs: Any) -> Request:
|
||||
"""To remove csrf request for public API.
|
||||
|
||||
Args:
|
||||
@@ -41,7 +40,7 @@ class DeploymentExecution(views.APIView):
|
||||
Returns:
|
||||
Request: _description_
|
||||
"""
|
||||
setattr(request, "csrf_processing_done", True)
|
||||
request.csrf_processing_done = True
|
||||
return super().initialize_request(request, *args, **kwargs)
|
||||
|
||||
@DeploymentHelper.validate_api_key
|
||||
@@ -85,9 +84,7 @@ class DeploymentExecution(views.APIView):
|
||||
include_metrics = serializer.validated_data.get(ApiExecution.INCLUDE_METRICS)
|
||||
|
||||
# Fetch execution status
|
||||
response: ExecutionResponse = DeploymentHelper.get_execution_status(
|
||||
execution_id
|
||||
)
|
||||
response: ExecutionResponse = DeploymentHelper.get_execution_status(execution_id)
|
||||
# Determine response status
|
||||
response_status = status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
if response.execution_status == CeleryTaskState.COMPLETED.value:
|
||||
@@ -113,7 +110,7 @@ class DeploymentExecution(views.APIView):
|
||||
class APIDeploymentViewSet(viewsets.ModelViewSet):
|
||||
permission_classes = [IsOwner]
|
||||
|
||||
def get_queryset(self) -> Optional[QuerySet]:
|
||||
def get_queryset(self) -> QuerySet | None:
|
||||
return APIDeployment.objects.filter(created_by=self.request.user)
|
||||
|
||||
def get_serializer_class(self) -> serializers.Serializer:
|
||||
@@ -122,7 +119,7 @@ class APIDeploymentViewSet(viewsets.ModelViewSet):
|
||||
return APIDeploymentSerializer
|
||||
|
||||
@action(detail=True, methods=["get"])
|
||||
def fetch_one(self, request: Request, pk: Optional[str] = None) -> Response:
|
||||
def fetch_one(self, request: Request, pk: str | None = None) -> Response:
|
||||
"""Custom action to fetch a single instance."""
|
||||
instance = self.get_object()
|
||||
serializer = self.get_serializer(instance)
|
||||
@@ -134,9 +131,7 @@ class APIDeploymentViewSet(viewsets.ModelViewSet):
|
||||
serializer: Serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_create(serializer)
|
||||
api_key = DeploymentHelper.create_api_key(
|
||||
serializer=serializer, request=request
|
||||
)
|
||||
api_key = DeploymentHelper.create_api_key(serializer=serializer, request=request)
|
||||
response_serializer = DeploymentResponseSerializer(
|
||||
{"api_key": api_key.api_key, **serializer.data}
|
||||
)
|
||||
@@ -150,7 +145,7 @@ class APIDeploymentViewSet(viewsets.ModelViewSet):
|
||||
|
||||
@action(detail=True, methods=["get"])
|
||||
def download_postman_collection(
|
||||
self, request: Request, pk: Optional[str] = None
|
||||
self, request: Request, pk: str | None = None
|
||||
) -> Response:
|
||||
"""Downloads a Postman Collection of the API deployment instance."""
|
||||
instance = self.get_object()
|
||||
|
||||
@@ -2,9 +2,10 @@ import logging
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from api_v2.exceptions import Forbidden
|
||||
from rest_framework.request import Request
|
||||
|
||||
from api_v2.exceptions import Forbidden
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -56,5 +57,6 @@ class BaseAPIKeyValidator:
|
||||
self: Any, request: Request, func: Any, api_key: str, *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Process and validate API key with specific logic required by
|
||||
subclasses."""
|
||||
subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
from typing import Optional
|
||||
from pipeline_v2.exceptions import PipelineNotFound
|
||||
from pipeline_v2.pipeline_processor import PipelineProcessor
|
||||
from rest_framework import serializers, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
from api_v2.deployment_helper import DeploymentHelper
|
||||
from api_v2.exceptions import APINotFound, PathVariablesNotFound
|
||||
@@ -6,12 +11,6 @@ from api_v2.key_helper import KeyHelper
|
||||
from api_v2.models import APIKey
|
||||
from api_v2.permission import IsOwnerOrOrganizationMember
|
||||
from api_v2.serializers import APIKeyListSerializer, APIKeySerializer
|
||||
from pipeline_v2.exceptions import PipelineNotFound
|
||||
from pipeline_v2.pipeline_processor import PipelineProcessor
|
||||
from rest_framework import serializers, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
|
||||
class APIKeyViewSet(viewsets.ModelViewSet):
|
||||
@@ -27,8 +26,8 @@ class APIKeyViewSet(viewsets.ModelViewSet):
|
||||
def api_keys(
|
||||
self,
|
||||
request: Request,
|
||||
api_id: Optional[str] = None,
|
||||
pipeline_id: Optional[str] = None,
|
||||
api_id: str | None = None,
|
||||
pipeline_id: str | None = None,
|
||||
) -> Response:
|
||||
"""Custom action to fetch api keys of an api deployment."""
|
||||
if api_id:
|
||||
|
||||
@@ -1,18 +1,7 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from api_v2.api_key_validator import BaseAPIKeyValidator
|
||||
from api_v2.exceptions import (
|
||||
ApiKeyCreateException,
|
||||
APINotFound,
|
||||
InactiveAPI,
|
||||
InvalidAPIRequest,
|
||||
)
|
||||
from api_v2.key_helper import KeyHelper
|
||||
from api_v2.models import APIDeployment, APIKey
|
||||
from api_v2.serializers import APIExecutionResponseSerializer
|
||||
from api_v2.utils import APIDeploymentUtils
|
||||
from django.conf import settings
|
||||
from django.core.files.uploadedfile import UploadedFile
|
||||
from rest_framework.request import Request
|
||||
@@ -29,6 +18,18 @@ from workflow_manager.workflow_v2.execution import WorkflowExecutionServiceHelpe
|
||||
from workflow_manager.workflow_v2.models import Workflow, WorkflowExecution
|
||||
from workflow_manager.workflow_v2.workflow_helper import WorkflowHelper
|
||||
|
||||
from api_v2.api_key_validator import BaseAPIKeyValidator
|
||||
from api_v2.exceptions import (
|
||||
ApiKeyCreateException,
|
||||
APINotFound,
|
||||
InactiveAPI,
|
||||
InvalidAPIRequest,
|
||||
)
|
||||
from api_v2.key_helper import KeyHelper
|
||||
from api_v2.models import APIDeployment, APIKey
|
||||
from api_v2.serializers import APIExecutionResponseSerializer
|
||||
from api_v2.utils import APIDeploymentUtils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -55,7 +56,7 @@ class DeploymentHelper(BaseAPIKeyValidator):
|
||||
return func(self, request, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def validate_api(api_deployment: Optional[APIDeployment], api_key: str) -> None:
|
||||
def validate_api(api_deployment: APIDeployment | None, api_key: str) -> None:
|
||||
"""Validating API and API key.
|
||||
|
||||
Args:
|
||||
@@ -75,11 +76,12 @@ class DeploymentHelper(BaseAPIKeyValidator):
|
||||
@staticmethod
|
||||
def validate_and_get_workflow(workflow_id: str) -> Workflow:
|
||||
"""Validate that the specified workflow_id exists in the Workflow
|
||||
model."""
|
||||
model.
|
||||
"""
|
||||
return WorkflowHelper.get_workflow_by_id(workflow_id)
|
||||
|
||||
@staticmethod
|
||||
def get_api_by_id(api_id: str) -> Optional[APIDeployment]:
|
||||
def get_api_by_id(api_id: str) -> APIDeployment | None:
|
||||
return APIDeploymentUtils.get_api_by_id(api_id=api_id)
|
||||
|
||||
@staticmethod
|
||||
@@ -102,7 +104,7 @@ class DeploymentHelper(BaseAPIKeyValidator):
|
||||
@staticmethod
|
||||
def get_deployment_by_api_name(
|
||||
api_name: str,
|
||||
) -> Optional[APIDeployment]:
|
||||
) -> APIDeployment | None:
|
||||
"""Get and return the APIDeployment object by api_name."""
|
||||
try:
|
||||
api: APIDeployment = APIDeployment.objects.get(api_name=api_name)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from rest_framework.exceptions import APIException
|
||||
|
||||
|
||||
@@ -24,9 +22,7 @@ class ApiKeyCreateException(APIException):
|
||||
|
||||
class Forbidden(APIException):
|
||||
status_code = 403
|
||||
default_detail = (
|
||||
"User is forbidden from performing this action. Please contact admin"
|
||||
)
|
||||
default_detail = "User is forbidden from performing this action. Please contact admin"
|
||||
|
||||
|
||||
class APINotFound(NotFoundException):
|
||||
@@ -53,8 +49,8 @@ class NoActiveAPIKeyError(APIException):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detail: Optional[str] = None,
|
||||
code: Optional[str] = None,
|
||||
detail: str | None = None,
|
||||
code: str | None = None,
|
||||
deployment_name: str = "this deployment",
|
||||
):
|
||||
if detail is None:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from api_v2.api_deployment_views import DeploymentExecution
|
||||
from django.urls import re_path
|
||||
from rest_framework.urlpatterns import format_suffix_patterns
|
||||
|
||||
from api_v2.api_deployment_views import DeploymentExecution
|
||||
|
||||
execute = DeploymentExecution.as_view()
|
||||
|
||||
|
||||
|
||||
@@ -1,22 +1,20 @@
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
from api_v2.exceptions import UnauthorizedKey
|
||||
from api_v2.models import APIDeployment, APIKey
|
||||
from api_v2.serializers import APIKeySerializer
|
||||
from django.core.exceptions import ValidationError
|
||||
from pipeline_v2.models import Pipeline
|
||||
from rest_framework.request import Request
|
||||
from workflow_manager.workflow_v2.workflow_helper import WorkflowHelper
|
||||
|
||||
from api_v2.exceptions import UnauthorizedKey
|
||||
from api_v2.models import APIDeployment, APIKey
|
||||
from api_v2.serializers import APIKeySerializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KeyHelper:
|
||||
@staticmethod
|
||||
def validate_api_key(
|
||||
api_key: str, instance: Union[APIDeployment, Pipeline]
|
||||
) -> None:
|
||||
def validate_api_key(api_key: str, instance: APIDeployment | Pipeline) -> None:
|
||||
"""Validate api key.
|
||||
|
||||
Args:
|
||||
@@ -44,7 +42,7 @@ class KeyHelper:
|
||||
return api_keys
|
||||
|
||||
@staticmethod
|
||||
def has_access(api_key: APIKey, instance: Union[APIDeployment, Pipeline]) -> bool:
|
||||
def has_access(api_key: APIKey, instance: APIDeployment | Pipeline) -> bool:
|
||||
"""Check if the provided API key has access to the specified API
|
||||
instance.
|
||||
|
||||
@@ -66,15 +64,15 @@ class KeyHelper:
|
||||
@staticmethod
|
||||
def validate_workflow_exists(workflow_id: str) -> None:
|
||||
"""Validate that the specified workflow_id exists in the Workflow
|
||||
model."""
|
||||
model.
|
||||
"""
|
||||
WorkflowHelper.get_workflow_by_id(workflow_id)
|
||||
|
||||
@staticmethod
|
||||
def create_api_key(
|
||||
deployment: Union[APIDeployment, Pipeline], request: Request
|
||||
) -> APIKey:
|
||||
def create_api_key(deployment: APIDeployment | Pipeline, request: Request) -> APIKey:
|
||||
"""Create an APIKey entity using the data from the provided
|
||||
APIDeployment or Pipeline instance."""
|
||||
APIDeployment or Pipeline instance.
|
||||
"""
|
||||
api_key_serializer = APIKeySerializer(
|
||||
data=deployment.api_key_data,
|
||||
context={"deployment": deployment, "request": request},
|
||||
|
||||
@@ -8,7 +8,6 @@ from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
|
||||
@@ -2,7 +2,6 @@ import uuid
|
||||
from typing import Any
|
||||
|
||||
from account_v2.models import User
|
||||
from api_v2.constants import ApiExecution
|
||||
from django.db import models
|
||||
from pipeline_v2.models import Pipeline
|
||||
from utils.models.base_model import BaseModel
|
||||
@@ -13,6 +12,8 @@ from utils.models.organization_mixin import (
|
||||
from utils.user_context import UserContext
|
||||
from workflow_manager.workflow_v2.models.workflow import Workflow
|
||||
|
||||
from api_v2.constants import ApiExecution
|
||||
|
||||
API_NAME_MAX_LENGTH = 30
|
||||
DESCRIPTION_MAX_LENGTH = 255
|
||||
API_ENDPOINT_MAX_LENGTH = 255
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
import logging
|
||||
|
||||
from api_v2.models import APIDeployment
|
||||
from notification_v2.helper import NotificationHelper
|
||||
from notification_v2.models import Notification
|
||||
from pipeline_v2.dto import PipelineStatusPayload
|
||||
from workflow_manager.workflow_v2.models.execution import WorkflowExecution
|
||||
|
||||
from api_v2.models import APIDeployment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APINotification:
|
||||
def __init__(
|
||||
self, api: APIDeployment, workflow_execution: WorkflowExecution
|
||||
) -> None:
|
||||
def __init__(self, api: APIDeployment, workflow_execution: WorkflowExecution) -> None:
|
||||
self.notifications = Notification.objects.filter(api=api, is_active=True)
|
||||
self.api = api
|
||||
self.workflow_execution = workflow_execution
|
||||
|
||||
@@ -4,7 +4,8 @@ from utils.user_context import UserContext
|
||||
|
||||
class IsOwnerOrOrganizationMember(IsOwner):
|
||||
"""Permission that grants access if the user is the owner or belongs to the
|
||||
same organization."""
|
||||
same organization.
|
||||
"""
|
||||
|
||||
def has_object_permission(self, request, view, obj):
|
||||
# Check if the user is the owner via base class logic
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
class CollectionKey:
|
||||
POSTMAN_COLLECTION_V210 = "https://schema.getpostman.com/json/collection/v2.1.0/collection.json" # noqa: E501
|
||||
POSTMAN_COLLECTION_V210 = (
|
||||
"https://schema.getpostman.com/json/collection/v2.1.0/collection.json" # noqa: E501
|
||||
)
|
||||
EXECUTE_API_KEY = "Process document"
|
||||
EXECUTE_PIPELINE_API_KEY = "Process pipeline"
|
||||
STATUS_API_KEY = "Execution status"
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode, urljoin
|
||||
|
||||
from django.conf import settings
|
||||
from pipeline_v2.models import Pipeline
|
||||
from utils.request import HTTPMethod
|
||||
|
||||
from api_v2.constants import ApiExecution
|
||||
from api_v2.models import APIDeployment
|
||||
from api_v2.postman_collection.constants import CollectionKey
|
||||
from django.conf import settings
|
||||
from pipeline_v2.models import Pipeline
|
||||
from utils.request import HTTPMethod
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -21,8 +22,8 @@ class HeaderItem:
|
||||
class FormDataItem:
|
||||
key: str
|
||||
type: str
|
||||
src: Optional[str] = None
|
||||
value: Optional[str] = None
|
||||
src: str | None = None
|
||||
value: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.type == "file":
|
||||
@@ -46,7 +47,7 @@ class RequestItem:
|
||||
method: HTTPMethod
|
||||
url: str
|
||||
header: list[HeaderItem]
|
||||
body: Optional[BodyItem] = None
|
||||
body: BodyItem | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -63,7 +64,6 @@ class PostmanInfo:
|
||||
|
||||
|
||||
class APIBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_form_data_items(self) -> list[FormDataItem]:
|
||||
pass
|
||||
@@ -189,7 +189,7 @@ class PostmanCollection:
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
instance: Union[APIDeployment, Pipeline],
|
||||
instance: APIDeployment | Pipeline,
|
||||
api_key: str = CollectionKey.AUTH_QUERY_PARAM_DEFAULT,
|
||||
) -> "PostmanCollection":
|
||||
"""Creates a PostmanCollection instance.
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from api_v2.constants import ApiExecution
|
||||
from api_v2.models import APIDeployment, APIKey
|
||||
from django.core.validators import RegexValidator
|
||||
from pipeline_v2.models import Pipeline
|
||||
from rest_framework.serializers import (
|
||||
@@ -22,6 +20,8 @@ from utils.serializer.integrity_error_mixin import IntegrityErrorMixin
|
||||
from workflow_manager.workflow_v2.exceptions import ExecutionDoesNotExistError
|
||||
from workflow_manager.workflow_v2.models.execution import WorkflowExecution
|
||||
|
||||
from api_v2.constants import ApiExecution
|
||||
from api_v2.models import APIDeployment, APIKey
|
||||
from backend.serializers import AuditSerializer
|
||||
|
||||
|
||||
@@ -76,8 +76,9 @@ class APIKeySerializer(AuditSerializer):
|
||||
|
||||
def to_representation(self, instance: APIKey) -> OrderedDict[str, Any]:
|
||||
"""Override the to_representation method to include additional
|
||||
context."""
|
||||
deployment: Union[APIDeployment, Pipeline] = self.context.get("deployment")
|
||||
context.
|
||||
"""
|
||||
deployment: APIDeployment | Pipeline = self.context.get("deployment")
|
||||
representation: OrderedDict[str, Any] = super().to_representation(instance)
|
||||
|
||||
if deployment:
|
||||
@@ -89,9 +90,7 @@ class APIKeySerializer(AuditSerializer):
|
||||
elif isinstance(deployment, Pipeline):
|
||||
representation["api"] = None
|
||||
representation["pipeline"] = deployment.id
|
||||
representation["description"] = (
|
||||
f"API Key for {deployment.pipeline_name}"
|
||||
)
|
||||
representation["description"] = f"API Key for {deployment.pipeline_name}"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Context must be an instance of APIDeployment or Pipeline"
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from api_v2.api_deployment_views import APIDeploymentViewSet, DeploymentExecution
|
||||
from api_v2.api_key_views import APIKeyViewSet
|
||||
from django.urls import path
|
||||
from rest_framework.urlpatterns import format_suffix_patterns
|
||||
|
||||
from api_v2.api_deployment_views import APIDeploymentViewSet, DeploymentExecution
|
||||
from api_v2.api_key_views import APIKeyViewSet
|
||||
|
||||
deployment = APIDeploymentViewSet.as_view(
|
||||
{
|
||||
"get": APIDeploymentViewSet.list.__name__,
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from typing import Optional
|
||||
from workflow_manager.workflow_v2.models.execution import WorkflowExecution
|
||||
|
||||
from api_v2.models import APIDeployment
|
||||
from api_v2.notification import APINotification
|
||||
from workflow_manager.workflow_v2.models.execution import WorkflowExecution
|
||||
|
||||
|
||||
class APIDeploymentUtils:
|
||||
@staticmethod
|
||||
def get_api_by_id(api_id: str) -> Optional[APIDeployment]:
|
||||
def get_api_by_id(api_id: str) -> APIDeployment | None:
|
||||
"""Retrieves an APIDeployment instance by its unique ID.
|
||||
|
||||
Args:
|
||||
@@ -39,7 +38,5 @@ class APIDeploymentUtils:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
api_notification = APINotification(
|
||||
api=api, workflow_execution=workflow_execution
|
||||
)
|
||||
api_notification = APINotification(api=api, workflow_execution=workflow_execution)
|
||||
api_notification.send()
|
||||
|
||||
@@ -32,9 +32,7 @@ TaskRegistry()
|
||||
app.config_from_object("backend.celery_config.CeleryConfig")
|
||||
app.autodiscover_tasks()
|
||||
|
||||
logger.debug(
|
||||
f"Celery Configuration:\n" f"{pformat(app.conf.table(with_defaults=True))}"
|
||||
)
|
||||
logger.debug(f"Celery Configuration:\n {pformat(app.conf.table(with_defaults=True))}")
|
||||
|
||||
# Define the queues to purge when the Celery broker is restarted.
|
||||
queues_to_purge = [ExecutionLogConstants.CELERY_QUEUE_NAME]
|
||||
|
||||
@@ -10,7 +10,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class DatabaseWrapper(PostgresDatabaseWrapper):
|
||||
"""Custom DatabaseWrapper to manage PostgreSQL connections and set the
|
||||
search path."""
|
||||
search path.
|
||||
"""
|
||||
|
||||
def get_new_connection(self, conn_params):
|
||||
"""Establish a new database connection or reuse an existing one, and
|
||||
@@ -45,8 +46,6 @@ class DatabaseWrapper(PostgresDatabaseWrapper):
|
||||
)
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"SET search_path TO {settings.DB_SCHEMA}")
|
||||
logger.debug(
|
||||
f"Successfully set search_path for DB connection ID {conn_id}."
|
||||
)
|
||||
logger.debug(f"Successfully set search_path for DB connection ID {conn_id}.")
|
||||
finally:
|
||||
connection.autocommit = original_autocommit
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from rest_framework.exceptions import APIException
|
||||
from rest_framework.response import Response
|
||||
@@ -12,8 +12,8 @@ class UnstractBaseException(APIException):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detail: Optional[str] = None,
|
||||
core_err: Optional[ConnectorBaseException] = None,
|
||||
detail: str | None = None,
|
||||
core_err: ConnectorBaseException | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if detail is None:
|
||||
@@ -21,9 +21,7 @@ class UnstractBaseException(APIException):
|
||||
if core_err and core_err.user_message:
|
||||
detail = core_err.user_message
|
||||
if detail and "Name or service not known" in str(detail):
|
||||
detail = (
|
||||
"Failed to establish a new connection: " "Name or service not known"
|
||||
)
|
||||
detail = "Failed to establish a new connection: " "Name or service not known"
|
||||
super().__init__(detail=detail, **kwargs)
|
||||
self._core_err = core_err
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
The `urlpatterns` list routes URLs to views. For more information please see:
|
||||
https://docs.djangoproject.com/en/4.2/topics/http/urls/
|
||||
|
||||
Examples:
|
||||
Function views
|
||||
1. Add an import: from my_app import views
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
The `urlpatterns` list routes URLs to views. For more information please see:
|
||||
https://docs.djangoproject.com/en/4.2/topics/http/urls/
|
||||
|
||||
Examples:
|
||||
Function views
|
||||
1. Add an import: from my_app import views
|
||||
|
||||
@@ -11,7 +11,6 @@ https://docs.djangoproject.com/en/4.2/ref/settings/
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
from utils.common_utils import CommonUtils
|
||||
@@ -19,9 +18,7 @@ from utils.common_utils import CommonUtils
|
||||
missing_settings = []
|
||||
|
||||
|
||||
def get_required_setting(
|
||||
setting_key: str, default: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
def get_required_setting(setting_key: str, default: str | None = None) -> str | None:
|
||||
"""Get the value of an environment variable specified by the given key. Add
|
||||
missing keys to `missing_settings` so that exception can be raised at the
|
||||
end.
|
||||
@@ -64,9 +61,7 @@ LOGIN_NEXT_URL = os.environ.get("LOGIN_NEXT_URL", "http://localhost:3000/org")
|
||||
LANDING_URL = os.environ.get("LANDING_URL", "http://localhost:3000/landing")
|
||||
ERROR_URL = os.environ.get("ERROR_URL", "http://localhost:3000/error")
|
||||
|
||||
DJANGO_APP_BACKEND_URL = os.environ.get(
|
||||
"DJANGO_APP_BACKEND_URL", "http://localhost:8000"
|
||||
)
|
||||
DJANGO_APP_BACKEND_URL = os.environ.get("DJANGO_APP_BACKEND_URL", "http://localhost:8000")
|
||||
INTERNAL_SERVICE_API_KEY = os.environ.get("INTERNAL_SERVICE_API_KEY")
|
||||
|
||||
GOOGLE_STORAGE_ACCESS_KEY_ID = os.environ.get("GOOGLE_STORAGE_ACCESS_KEY_ID")
|
||||
@@ -481,7 +476,7 @@ for key in [
|
||||
"GOOGLE_OAUTH2_KEY",
|
||||
"GOOGLE_OAUTH2_SECRET",
|
||||
]:
|
||||
exec("SOCIAL_AUTH_{key} = os.environ.get('{key}')".format(key=key))
|
||||
exec(f"SOCIAL_AUTH_{key} = os.environ.get('{key}')")
|
||||
|
||||
SOCIAL_AUTH_PIPELINE = (
|
||||
# Checks if user is authenticated
|
||||
@@ -520,6 +515,4 @@ if missing_settings:
|
||||
)
|
||||
raise ValueError(ERROR_MESSAGE)
|
||||
|
||||
ENABLE_HIGHLIGHT_API_DEPLOYMENT = os.environ.get(
|
||||
"ENABLE_HIGHLIGHT_API_DEPLOYMENT", False
|
||||
)
|
||||
ENABLE_HIGHLIGHT_API_DEPLOYMENT = os.environ.get("ENABLE_HIGHLIGHT_API_DEPLOYMENT", False)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
The `urlpatterns` list routes URLs to views. For more information please see:
|
||||
https://docs.djangoproject.com/en/4.2/topics/http/urls/
|
||||
|
||||
Examples:
|
||||
Function views
|
||||
1. Add an import: from my_app import views
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
The `urlpatterns` list routes URLs to views. For more information please see:
|
||||
https://docs.djangoproject.com/en/4.2/topics/http/urls/
|
||||
|
||||
Examples:
|
||||
Function views
|
||||
1. Add an import: from my_app import views
|
||||
|
||||
@@ -16,9 +16,7 @@ class Command(BaseCommand):
|
||||
parser.add_argument(
|
||||
"--schema",
|
||||
type=str,
|
||||
help=(
|
||||
"Optional schema name to drop. Overrides env 'DB_SCHEMA' if specified"
|
||||
),
|
||||
help=("Optional schema name to drop. Overrides env 'DB_SCHEMA' if specified"),
|
||||
)
|
||||
|
||||
def handle(self, *args, **kwargs):
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from rest_framework.exceptions import APIException
|
||||
|
||||
|
||||
@@ -19,8 +17,8 @@ class MissingParamException(APIException):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
code: Optional[str] = None,
|
||||
param: Optional[str] = None,
|
||||
code: str | None = None,
|
||||
param: str | None = None,
|
||||
) -> None:
|
||||
detail = f"Bad request, missing parameter: {param}"
|
||||
super().__init__(detail, code)
|
||||
|
||||
@@ -9,7 +9,6 @@ from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
|
||||
@@ -3,8 +3,6 @@ import uuid
|
||||
from typing import Any
|
||||
|
||||
from account_v2.models import User
|
||||
from connector_auth_v2.constants import SocialAuthConstants
|
||||
from connector_auth_v2.pipeline.google import GoogleAuthHelper
|
||||
from django.db import models
|
||||
from django.db.models.query import QuerySet
|
||||
from rest_framework.request import Request
|
||||
@@ -12,6 +10,9 @@ from social_django.fields import JSONField
|
||||
from social_django.models import AbstractUserSocialAuth, DjangoStorage
|
||||
from social_django.strategy import DjangoStrategy
|
||||
|
||||
from connector_auth_v2.constants import SocialAuthConstants
|
||||
from connector_auth_v2.pipeline.google import GoogleAuthHelper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -63,7 +64,8 @@ class ConnectorAuth(AbstractUserSocialAuth):
|
||||
|
||||
def refresh_token(self, strategy, *args, **kwargs): # type: ignore
|
||||
"""Override of Python Social Auth (PSA)'s refresh_token functionality
|
||||
to store uid, provider."""
|
||||
to store uid, provider.
|
||||
"""
|
||||
token = self.extra_data.get("refresh_token") or self.extra_data.get(
|
||||
"access_token"
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from account_v2.models import User
|
||||
from connector_auth_v2.constants import ConnectorAuthKey, SocialAuthConstants
|
||||
@@ -68,7 +68,7 @@ class ConnectorAuthHelper:
|
||||
@staticmethod
|
||||
def get_oauth_creds_from_cache(
|
||||
cache_key: str, delete_key: bool = True
|
||||
) -> Optional[dict[str, str]]:
|
||||
) -> dict[str, str] | None:
|
||||
"""Retrieves oauth credentials from the cache.
|
||||
|
||||
Args:
|
||||
@@ -84,7 +84,8 @@ class ConnectorAuthHelper:
|
||||
|
||||
@staticmethod
|
||||
def get_or_create_connector_auth(
|
||||
oauth_credentials: dict[str, str], user: User = None # type: ignore
|
||||
oauth_credentials: dict[str, str],
|
||||
user: User = None, # type: ignore
|
||||
) -> ConnectorAuth:
|
||||
"""Gets or creates a ConnectorAuth object.
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from connector_auth_v2.constants import SocialAuthConstants
|
||||
from connector_auth_v2.exceptions import KeyNotConfigured
|
||||
from django.conf import settings
|
||||
from rest_framework import status, viewsets
|
||||
from rest_framework.request import Request
|
||||
@@ -10,6 +8,9 @@ from rest_framework.response import Response
|
||||
from rest_framework.versioning import URLPathVersioning
|
||||
from utils.user_session import UserSessionUtils
|
||||
|
||||
from connector_auth_v2.constants import SocialAuthConstants
|
||||
from connector_auth_v2.exceptions import KeyNotConfigured
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
# mypy: ignore-errors
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from connector_auth_v2.constants import ConnectorAuthKey
|
||||
from connector_auth_v2.pipeline.common import ConnectorAuthHelper
|
||||
from connector_v2.constants import ConnectorInstanceKey as CIKey
|
||||
|
||||
from backend.exceptions import UnstractFSException
|
||||
from connector_processor.constants import ConnectorKeys
|
||||
from connector_processor.exceptions import (
|
||||
InValidConnectorId,
|
||||
@@ -12,9 +15,6 @@ from connector_processor.exceptions import (
|
||||
OAuthTimeOut,
|
||||
TestConnectorInputError,
|
||||
)
|
||||
from connector_v2.constants import ConnectorInstanceKey as CIKey
|
||||
|
||||
from backend.exceptions import UnstractFSException
|
||||
from unstract.connectors.base import UnstractConnector
|
||||
from unstract.connectors.connectorkit import Connectorkit
|
||||
from unstract.connectors.enums import ConnectorMode
|
||||
@@ -25,10 +25,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fetch_connectors_by_key_value(
|
||||
key: str, value: Any, connector_mode: Optional[ConnectorMode] = None
|
||||
key: str, value: Any, connector_mode: ConnectorMode | None = None
|
||||
) -> list[UnstractConnector]:
|
||||
"""Fetches a list of connectors that have an attribute matching key and
|
||||
value."""
|
||||
value.
|
||||
"""
|
||||
logger.info(f"Fetching connector list for {key} with {value}")
|
||||
connector_kit = Connectorkit()
|
||||
connectors = connector_kit.get_connectors_list(mode=connector_mode)
|
||||
@@ -42,9 +43,7 @@ class ConnectorProcessor:
|
||||
schema_details: dict = {}
|
||||
if connector_id == UnstractCloudStorage.get_id():
|
||||
return schema_details
|
||||
updated_connectors = fetch_connectors_by_key_value(
|
||||
ConnectorKeys.ID, connector_id
|
||||
)
|
||||
updated_connectors = fetch_connectors_by_key_value(ConnectorKeys.ID, connector_id)
|
||||
if len(updated_connectors) == 0:
|
||||
logger.error(
|
||||
f"Invalid connector Id : {connector_id} "
|
||||
@@ -70,7 +69,7 @@ class ConnectorProcessor:
|
||||
|
||||
@staticmethod
|
||||
def get_all_supported_connectors(
|
||||
type: str, connector_mode: Optional[ConnectorMode] = None
|
||||
type: str, connector_mode: ConnectorMode | None = None
|
||||
) -> list[dict]:
|
||||
"""Function to return list of all supported connectors except PCS."""
|
||||
supported_connectors = []
|
||||
@@ -119,9 +118,7 @@ class ConnectorProcessor:
|
||||
raise OAuthTimeOut()
|
||||
|
||||
try:
|
||||
connector_impl = Connectorkit().get_connector_by_id(
|
||||
connector_id, credentials
|
||||
)
|
||||
connector_impl = Connectorkit().get_connector_by_id(connector_id, credentials)
|
||||
test_result = connector_impl.test_credentials()
|
||||
logger.info(f"{connector_id} test result: {test_result}")
|
||||
return test_result
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from connector_processor.views import ConnectorViewSet
|
||||
from django.urls import path
|
||||
|
||||
from connector_processor.views import ConnectorViewSet
|
||||
|
||||
from . import views
|
||||
|
||||
connector_test = ConnectorViewSet.as_view({"post": "test"})
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
from connector_processor.connector_processor import ConnectorProcessor
|
||||
from connector_processor.constants import ConnectorKeys
|
||||
from connector_processor.exceptions import IdIsMandatory, InValidType
|
||||
from connector_processor.serializers import TestConnectorSerializer
|
||||
from connector_v2.constants import ConnectorInstanceKey as CIKey
|
||||
from django.http.request import HttpRequest
|
||||
from django.http.response import HttpResponse
|
||||
@@ -13,6 +9,11 @@ from rest_framework.serializers import Serializer
|
||||
from rest_framework.versioning import URLPathVersioning
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from connector_processor.connector_processor import ConnectorProcessor
|
||||
from connector_processor.constants import ConnectorKeys
|
||||
from connector_processor.exceptions import IdIsMandatory, InValidType
|
||||
from connector_processor.serializers import TestConnectorSerializer
|
||||
|
||||
|
||||
@api_view(("GET",))
|
||||
def get_connector_schema(request: HttpRequest) -> HttpResponse:
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from account_v2.models import User
|
||||
from connector_v2.constants import ConnectorInstanceConstant
|
||||
from connector_v2.models import ConnectorInstance
|
||||
from connector_v2.unstract_account import UnstractAccount
|
||||
from django.conf import settings
|
||||
from utils.user_context import UserContext
|
||||
from workflow_manager.workflow_v2.models.workflow import Workflow
|
||||
|
||||
from connector_v2.constants import ConnectorInstanceConstant
|
||||
from connector_v2.models import ConnectorInstance
|
||||
from connector_v2.unstract_account import UnstractAccount
|
||||
from unstract.connectors.filesystems.ucs import UnstractCloudStorage
|
||||
from unstract.connectors.filesystems.ucs.constants import UCSKey
|
||||
|
||||
@@ -78,9 +78,9 @@ class ConnectorInstanceHelper:
|
||||
def get_connector_instances_by_workflow(
|
||||
workflow_id: str,
|
||||
connector_type: tuple[str, str],
|
||||
connector_mode: Optional[tuple[int, str]] = None,
|
||||
values: Optional[list[str]] = None,
|
||||
connector_name: Optional[str] = None,
|
||||
connector_mode: tuple[int, str] | None = None,
|
||||
values: list[str] | None = None,
|
||||
connector_name: str | None = None,
|
||||
) -> list[ConnectorInstance]:
|
||||
"""Method to get connector instances by workflow.
|
||||
|
||||
@@ -121,9 +121,9 @@ class ConnectorInstanceHelper:
|
||||
def get_connector_instance_by_workflow(
|
||||
workflow_id: str,
|
||||
connector_type: tuple[str, str],
|
||||
connector_mode: Optional[tuple[int, str]] = None,
|
||||
connector_name: Optional[str] = None,
|
||||
) -> Optional[ConnectorInstance]:
|
||||
connector_mode: tuple[int, str] | None = None,
|
||||
connector_name: str | None = None,
|
||||
) -> ConnectorInstance | None:
|
||||
"""Get one connector instance.
|
||||
|
||||
Use this method if the connector instance is unique for \
|
||||
@@ -162,7 +162,7 @@ class ConnectorInstanceHelper:
|
||||
def get_input_connector_instance_by_name_for_workflow(
|
||||
workflow_id: str,
|
||||
connector_name: str,
|
||||
) -> Optional[ConnectorInstance]:
|
||||
) -> ConnectorInstance | None:
|
||||
"""Method to get Input connector instance name from the workflow.
|
||||
|
||||
Args:
|
||||
@@ -182,7 +182,7 @@ class ConnectorInstanceHelper:
|
||||
def get_output_connector_instance_by_name_for_workflow(
|
||||
workflow_id: str,
|
||||
connector_name: str,
|
||||
) -> Optional[ConnectorInstance]:
|
||||
) -> ConnectorInstance | None:
|
||||
"""Method to get output connector name by Workflow.
|
||||
|
||||
Args:
|
||||
@@ -232,7 +232,7 @@ class ConnectorInstanceHelper:
|
||||
|
||||
@staticmethod
|
||||
def get_file_system_input_connector_instances_by_workflow(
|
||||
workflow_id: str, values: Optional[list[str]] = None
|
||||
workflow_id: str, values: list[str] | None = None
|
||||
) -> list[ConnectorInstance]:
|
||||
"""Method to fetch file system connector by workflow.
|
||||
|
||||
@@ -252,7 +252,7 @@ class ConnectorInstanceHelper:
|
||||
|
||||
@staticmethod
|
||||
def get_file_system_output_connector_instances_by_workflow(
|
||||
workflow_id: str, values: Optional[list[str]] = None
|
||||
workflow_id: str, values: list[str] | None = None
|
||||
) -> list[ConnectorInstance]:
|
||||
"""Method to get file system output connector by workflow.
|
||||
|
||||
@@ -272,7 +272,7 @@ class ConnectorInstanceHelper:
|
||||
|
||||
@staticmethod
|
||||
def get_database_input_connector_instances_by_workflow(
|
||||
workflow_id: str, values: Optional[list[str]] = None
|
||||
workflow_id: str, values: list[str] | None = None
|
||||
) -> list[ConnectorInstance]:
|
||||
"""Method to fetch input database connectors by workflow.
|
||||
|
||||
@@ -292,7 +292,7 @@ class ConnectorInstanceHelper:
|
||||
|
||||
@staticmethod
|
||||
def get_database_output_connector_instances_by_workflow(
|
||||
workflow_id: str, values: Optional[list[str]] = None
|
||||
workflow_id: str, values: list[str] | None = None
|
||||
) -> list[ConnectorInstance]:
|
||||
"""Method to fetch output database connectors by workflow.
|
||||
|
||||
|
||||
@@ -6,9 +6,7 @@ class ConnectorInstanceKey:
|
||||
CONNECTOR_VERSION = "connector_version"
|
||||
CONNECTOR_AUTH = "connector_auth"
|
||||
CONNECTOR_METADATA = "connector_metadata"
|
||||
CONNECTOR_EXISTS = (
|
||||
"Connector with this configuration already exists in this project."
|
||||
)
|
||||
CONNECTOR_EXISTS = "Connector with this configuration already exists in this project."
|
||||
DUPLICATE_API = "It appears that a duplicate call may have been made."
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
@@ -39,9 +38,7 @@ class Migration(migrations.Migration):
|
||||
("connector_version", models.CharField(default="", max_length=64)),
|
||||
(
|
||||
"connector_type",
|
||||
models.CharField(
|
||||
choices=[("INPUT", "Input"), ("OUTPUT", "Output")]
|
||||
),
|
||||
models.CharField(choices=[("INPUT", "Input"), ("OUTPUT", "Output")]),
|
||||
),
|
||||
(
|
||||
"connector_mode",
|
||||
|
||||
@@ -87,7 +87,8 @@ class ConnectorInstance(DefaultOrganizationMixin, BaseModel):
|
||||
# TODO: Remove if unused
|
||||
def get_connector_metadata(self) -> dict[str, str]:
|
||||
"""Gets connector metadata and refreshes the tokens if needed in case
|
||||
of OAuth."""
|
||||
of OAuth.
|
||||
"""
|
||||
tokens_refreshed = False
|
||||
if self.connector_auth:
|
||||
(
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import json
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from connector_auth_v2.models import ConnectorAuth
|
||||
from connector_auth_v2.pipeline.common import ConnectorAuthHelper
|
||||
from connector_processor.connector_processor import ConnectorProcessor
|
||||
from connector_processor.constants import ConnectorKeys
|
||||
from connector_processor.exceptions import OAuthTimeOut
|
||||
from connector_v2.constants import ConnectorInstanceKey as CIKey
|
||||
from cryptography.fernet import Fernet
|
||||
from django.conf import settings
|
||||
from utils.serializer_utils import SerializerUtils
|
||||
|
||||
from backend.serializers import AuditSerializer
|
||||
from connector_v2.constants import ConnectorInstanceKey as CIKey
|
||||
from unstract.connectors.filesystems.ucs import UnstractCloudStorage
|
||||
|
||||
from .models import ConnectorInstance
|
||||
@@ -41,7 +41,7 @@ class ConnectorInstanceSerializer(AuditSerializer):
|
||||
def save(self, **kwargs): # type: ignore
|
||||
user = self.context.get("request").user or None
|
||||
connector_id: str = kwargs[CIKey.CONNECTOR_ID]
|
||||
connector_oauth: Optional[ConnectorAuth] = None
|
||||
connector_oauth: ConnectorAuth | None = None
|
||||
if (
|
||||
ConnectorInstance.supportsOAuth(connector_id=connector_id)
|
||||
and CIKey.CONNECTOR_METADATA in kwargs
|
||||
|
||||
@@ -12,7 +12,6 @@ pytestmark = pytest.mark.django_db
|
||||
class TestConnector(APITestCase):
|
||||
def test_connector_list(self) -> None:
|
||||
"""Tests to List the connectors."""
|
||||
|
||||
url = reverse("connectors_v1-list")
|
||||
response = self.client.get(url)
|
||||
|
||||
@@ -20,7 +19,6 @@ class TestConnector(APITestCase):
|
||||
|
||||
def test_connectors_detail(self) -> None:
|
||||
"""Tests to fetch a connector with given pk."""
|
||||
|
||||
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
|
||||
response = self.client.get(url)
|
||||
|
||||
@@ -28,7 +26,6 @@ class TestConnector(APITestCase):
|
||||
|
||||
def test_connectors_detail_not_found(self) -> None:
|
||||
"""Tests for negative case to fetch non exiting key."""
|
||||
|
||||
url = reverse("connectors_v1-detail", kwargs={"pk": 768})
|
||||
response = self.client.get(url)
|
||||
|
||||
@@ -36,7 +33,6 @@ class TestConnector(APITestCase):
|
||||
|
||||
def test_connectors_create(self) -> None:
|
||||
"""Tests to create a new ConnectorInstance."""
|
||||
|
||||
url = reverse("connectors_v1-list")
|
||||
data = {
|
||||
"org": 1,
|
||||
@@ -57,8 +53,8 @@ class TestConnector(APITestCase):
|
||||
|
||||
def test_connectors_create_with_json_list(self) -> None:
|
||||
"""Tests to create a new connector with list included in the json
|
||||
field."""
|
||||
|
||||
field.
|
||||
"""
|
||||
url = reverse("connectors_v1-list")
|
||||
data = {
|
||||
"org": 1,
|
||||
@@ -80,7 +76,6 @@ class TestConnector(APITestCase):
|
||||
|
||||
def test_connectors_create_with_nested_json(self) -> None:
|
||||
"""Tests to create a new connector with json field as nested json."""
|
||||
|
||||
url = reverse("connectors_v1-list")
|
||||
data = {
|
||||
"org": 1,
|
||||
@@ -102,7 +97,6 @@ class TestConnector(APITestCase):
|
||||
|
||||
def test_connectors_create_bad_request(self) -> None:
|
||||
"""Tests for negative case to throw error on a wrong access."""
|
||||
|
||||
url = reverse("connectors_v1-list")
|
||||
data = {
|
||||
"org": 5,
|
||||
@@ -123,7 +117,6 @@ class TestConnector(APITestCase):
|
||||
|
||||
def test_connectors_update_json_field(self) -> None:
|
||||
"""Tests to update connector with json field update."""
|
||||
|
||||
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
|
||||
data = {
|
||||
"org": 1,
|
||||
@@ -144,7 +137,6 @@ class TestConnector(APITestCase):
|
||||
|
||||
def test_connectors_update(self) -> None:
|
||||
"""Tests to update connector update single field."""
|
||||
|
||||
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
|
||||
data = {
|
||||
"org": 1,
|
||||
@@ -166,7 +158,6 @@ class TestConnector(APITestCase):
|
||||
|
||||
def test_connectors_update_pk(self) -> None:
|
||||
"""Tests the PUT method for 400 error."""
|
||||
|
||||
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
|
||||
data = {
|
||||
"org": 2,
|
||||
@@ -187,7 +178,6 @@ class TestConnector(APITestCase):
|
||||
|
||||
def test_connectors_update_json_fields(self) -> None:
|
||||
"""Tests to update ConnectorInstance."""
|
||||
|
||||
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
|
||||
data = {
|
||||
"org": 1,
|
||||
@@ -203,16 +193,13 @@ class TestConnector(APITestCase):
|
||||
},
|
||||
}
|
||||
response = self.client.put(url, data, format="json")
|
||||
nested_value = response.data["connector_metadata"]["sample_metadata_json"][
|
||||
"key1"
|
||||
]
|
||||
nested_value = response.data["connector_metadata"]["sample_metadata_json"]["key1"]
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(nested_value, "value1")
|
||||
|
||||
def test_connectors_update_json_list_fields(self) -> None:
|
||||
"""Tests to update connector to the third second level of json."""
|
||||
|
||||
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
|
||||
data = {
|
||||
"org": 1,
|
||||
@@ -229,9 +216,7 @@ class TestConnector(APITestCase):
|
||||
},
|
||||
}
|
||||
response = self.client.put(url, data, format="json")
|
||||
nested_value = response.data["connector_metadata"]["sample_metadata_json"][
|
||||
"key1"
|
||||
]
|
||||
nested_value = response.data["connector_metadata"]["sample_metadata_json"]["key1"]
|
||||
nested_list = response.data["connector_metadata"]["file_list"]
|
||||
last_val = nested_list.pop()
|
||||
|
||||
@@ -287,7 +272,6 @@ class TestConnector(APITestCase):
|
||||
|
||||
def test_connectors_update_field(self) -> None:
|
||||
"""Tests the PATCH method."""
|
||||
|
||||
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
|
||||
data = {"connector_id": "e3a4512m-efgb-48d5-98a9-3983ntest"}
|
||||
response = self.client.patch(url, data, format="json")
|
||||
@@ -301,7 +285,6 @@ class TestConnector(APITestCase):
|
||||
|
||||
def test_connectors_update_json_field_patch(self) -> None:
|
||||
"""Tests the PATCH method."""
|
||||
|
||||
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
|
||||
data = {
|
||||
"connector_metadata": {
|
||||
@@ -322,7 +305,6 @@ class TestConnector(APITestCase):
|
||||
|
||||
def test_connectors_delete(self) -> None:
|
||||
"""Tests the DELETE method."""
|
||||
|
||||
url = reverse("connectors_v1-detail", kwargs={"pk": 1})
|
||||
response = self.client.delete(url, format="json")
|
||||
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from account_v2.custom_exceptions import DuplicateData
|
||||
from connector_auth_v2.constants import ConnectorAuthKey
|
||||
from connector_auth_v2.exceptions import CacheMissException, MissingParamException
|
||||
from connector_auth_v2.pipeline.common import ConnectorAuthHelper
|
||||
from connector_processor.exceptions import OAuthTimeOut
|
||||
from connector_v2.constants import ConnectorInstanceKey as CIKey
|
||||
from django.db import IntegrityError
|
||||
from django.db.models import QuerySet
|
||||
from rest_framework import status, viewsets
|
||||
@@ -15,6 +14,7 @@ from rest_framework.versioning import URLPathVersioning
|
||||
from utils.filtering import FilterHelper
|
||||
|
||||
from backend.constants import RequestKey
|
||||
from connector_v2.constants import ConnectorInstanceKey as CIKey
|
||||
|
||||
from .models import ConnectorInstance
|
||||
from .serializers import ConnectorInstanceSerializer
|
||||
@@ -26,7 +26,7 @@ class ConnectorInstanceViewSet(viewsets.ModelViewSet):
|
||||
versioning_class = URLPathVersioning
|
||||
serializer_class = ConnectorInstanceSerializer
|
||||
|
||||
def get_queryset(self) -> Optional[QuerySet]:
|
||||
def get_queryset(self) -> QuerySet | None:
|
||||
filter_args = FilterHelper.build_filter_args(
|
||||
self.request,
|
||||
RequestKey.WORKFLOW,
|
||||
@@ -40,7 +40,7 @@ class ConnectorInstanceViewSet(viewsets.ModelViewSet):
|
||||
queryset = ConnectorInstance.objects.all()
|
||||
return queryset
|
||||
|
||||
def _get_connector_metadata(self, connector_id: str) -> Optional[dict[str, str]]:
|
||||
def _get_connector_metadata(self, connector_id: str) -> dict[str, str] | None:
|
||||
"""Gets connector metadata for the ConnectorInstance.
|
||||
|
||||
For non oauth based - obtains from request
|
||||
@@ -118,6 +118,4 @@ class ConnectorInstanceViewSet(viewsets.ModelViewSet):
|
||||
{CIKey.DUPLICATE_API}"
|
||||
)
|
||||
headers = self.get_success_headers(serializer.data)
|
||||
return Response(
|
||||
serializer.data, status=status.HTTP_201_CREATED, headers=headers
|
||||
)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
|
||||
|
||||
@@ -4,9 +4,10 @@ This module defines the URL patterns for the feature_flags app.
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
from feature_flag.views import FeatureFlagViewSet
|
||||
from rest_framework.urlpatterns import format_suffix_patterns
|
||||
|
||||
from feature_flag.views import FeatureFlagViewSet
|
||||
|
||||
feature_flags_list = FeatureFlagViewSet.as_view(
|
||||
{
|
||||
"post": "evaluate",
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
"""
|
||||
Feature Flag view file
|
||||
"""Feature Flag view file
|
||||
Returns:
|
||||
evaluate response
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from feature_flag.helper import FeatureFlagHelper
|
||||
from rest_framework import status, viewsets
|
||||
from rest_framework.response import Response
|
||||
|
||||
from feature_flag.helper import FeatureFlagHelper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from rest_framework.exceptions import APIException
|
||||
|
||||
from backend.exceptions import UnstractBaseException
|
||||
@@ -69,7 +67,7 @@ class ValidationError(APIException):
|
||||
status_code = 400
|
||||
default_detail = "Validation Error"
|
||||
|
||||
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
|
||||
def __init__(self, detail: str | None = None, code: int | None = None):
|
||||
if detail is not None:
|
||||
self.detail = detail
|
||||
if code is not None:
|
||||
@@ -81,7 +79,7 @@ class FileDeletionFailed(APIException):
|
||||
status_code = 400
|
||||
default_detail = "Unable to delete file."
|
||||
|
||||
def __init__(self, detail: Optional[str] = None, code: Optional[int] = None):
|
||||
def __init__(self, detail: str | None = None, code: int | None = None):
|
||||
if detail is not None:
|
||||
self.detail = detail
|
||||
if code is not None:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from file_management.constants import FileInformationKey
|
||||
|
||||
@@ -10,12 +10,12 @@ from file_management.constants import FileInformationKey
|
||||
class FileInformation:
|
||||
name: str
|
||||
type: str
|
||||
modified_at: Optional[datetime]
|
||||
content_type: Optional[str]
|
||||
modified_at: datetime | None
|
||||
content_type: str | None
|
||||
size: int
|
||||
|
||||
def __init__(
|
||||
self, file_info: dict[str, Any], file_content_type: Optional[str] = None
|
||||
self, file_info: dict[str, Any], file_content_type: str | None = None
|
||||
) -> None:
|
||||
self.name = os.path.normpath(file_info[FileInformationKey.FILE_NAME])
|
||||
self.type = file_info[FileInformationKey.FILE_TYPE]
|
||||
@@ -27,7 +27,7 @@ class FileInformation:
|
||||
self.size = file_info[FileInformationKey.FILE_SIZE]
|
||||
|
||||
@staticmethod
|
||||
def parse_datetime(dt_string: Optional[Union[str, datetime]]) -> Optional[datetime]:
|
||||
def parse_datetime(dt_string: str | datetime | None) -> datetime | None:
|
||||
if isinstance(dt_string, str):
|
||||
return datetime.strptime(dt_string, "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
return dt_string
|
||||
|
||||
@@ -10,6 +10,9 @@ from connector_v2.models import ConnectorInstance
|
||||
from deprecated import deprecated
|
||||
from django.conf import settings
|
||||
from django.http import StreamingHttpResponse
|
||||
from fsspec import AbstractFileSystem
|
||||
from pydrive2.files import ApiRequestError
|
||||
|
||||
from file_management.exceptions import (
|
||||
ConnectorApiRequestError,
|
||||
ConnectorClassNotFound,
|
||||
@@ -22,9 +25,6 @@ from file_management.exceptions import (
|
||||
TenantDirCreationError,
|
||||
)
|
||||
from file_management.file_management_dto import FileInformation
|
||||
from fsspec import AbstractFileSystem
|
||||
from pydrive2.files import ApiRequestError
|
||||
|
||||
from unstract.connectors.filesystems import connectors as fs_connectors
|
||||
from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem
|
||||
|
||||
@@ -32,7 +32,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileManagerHelper:
|
||||
|
||||
@staticmethod
|
||||
def get_file_system(connector: ConnectorInstance) -> UnstractFileSystem:
|
||||
"""Creates the `UnstractFileSystem` for the corresponding connector."""
|
||||
@@ -70,7 +69,8 @@ class FileManagerHelper:
|
||||
@staticmethod
|
||||
def get_files(fs: AbstractFileSystem, file_path: str) -> list[FileInformation]:
|
||||
"""Iterate through the directories and make a list of
|
||||
FileInformation."""
|
||||
FileInformation.
|
||||
"""
|
||||
if not file_path.endswith("/"):
|
||||
file_path += "/"
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from file_management.constants import FileInformationKey
|
||||
from rest_framework import serializers
|
||||
from utils.FileValidator import FileValidator
|
||||
|
||||
from file_management.constants import FileInformationKey
|
||||
|
||||
|
||||
class FileInfoSerializer(serializers.Serializer):
|
||||
name = serializers.CharField()
|
||||
|
||||
@@ -3,6 +3,14 @@ from typing import Any
|
||||
|
||||
from connector_v2.models import ConnectorInstance
|
||||
from django.http import HttpRequest
|
||||
from oauth2client.client import HttpAccessTokenRefreshError
|
||||
from prompt_studio.prompt_studio_document_manager_v2.models import DocumentManager
|
||||
from rest_framework import serializers, status, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.versioning import URLPathVersioning
|
||||
from utils.user_session import UserSessionUtils
|
||||
|
||||
from file_management.exceptions import (
|
||||
ConnectorInstanceNotFound,
|
||||
ConnectorOAuthError,
|
||||
@@ -15,14 +23,6 @@ from file_management.serializer import (
|
||||
FileListRequestSerializer,
|
||||
FileUploadSerializer,
|
||||
)
|
||||
from oauth2client.client import HttpAccessTokenRefreshError
|
||||
from prompt_studio.prompt_studio_document_manager_v2.models import DocumentManager
|
||||
from rest_framework import serializers, status, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.versioning import URLPathVersioning
|
||||
from utils.user_session import UserSessionUtils
|
||||
|
||||
from unstract.connectors.exceptions import ConnectorError
|
||||
from unstract.connectors.filesystems.local_storage.local_storage import LocalStorageFS
|
||||
|
||||
@@ -96,9 +96,7 @@ class FileManagementViewSet(viewsets.ModelViewSet):
|
||||
|
||||
for uploaded_file in uploaded_files:
|
||||
file_name = uploaded_file.name
|
||||
logger.info(
|
||||
f"Uploading file: {file_name}" if file_name else "Uploading file"
|
||||
)
|
||||
logger.info(f"Uploading file: {file_name}" if file_name else "Uploading file")
|
||||
FileManagerHelper.upload_file(file_system, path, uploaded_file, file_name)
|
||||
return Response({"message": "Files are uploaded successfully!"})
|
||||
|
||||
|
||||
@@ -3,6 +3,4 @@ from rest_framework.urlpatterns import format_suffix_patterns
|
||||
|
||||
from .views import health_check
|
||||
|
||||
urlpatterns = format_suffix_patterns(
|
||||
[path("health", health_check, name="health-check")]
|
||||
)
|
||||
urlpatterns = format_suffix_patterns([path("health", health_check, name="health-check")])
|
||||
|
||||
@@ -4,7 +4,6 @@ from utils.cache_service import CacheService
|
||||
class LogService:
|
||||
@staticmethod
|
||||
def remove_logs_on_logout(session_id: str) -> None:
|
||||
|
||||
if session_id:
|
||||
key_pattern = f"{LogService.generate_redis_key(session_id=session_id)}*"
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from django.conf import settings
|
||||
from django.http import HttpRequest
|
||||
@@ -57,11 +57,9 @@ class LogsHelperViewSet(viewsets.ModelViewSet):
|
||||
# Extract the log message from the validated data
|
||||
log: str = serializer.validated_data.get("log")
|
||||
log_data = json.loads(log)
|
||||
timestamp = datetime.now(timezone.utc).timestamp()
|
||||
timestamp = datetime.now(UTC).timestamp()
|
||||
|
||||
redis_key = (
|
||||
f"{LogService.generate_redis_key(session_id=session_id)}:{timestamp}"
|
||||
)
|
||||
redis_key = f"{LogService.generate_redis_key(session_id=session_id)}:{timestamp}"
|
||||
|
||||
CacheService.set_key(redis_key, log_data, logs_expiry)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Django's command-line utility for administrative tasks."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Set via settings.REST_FRAMEWORK.EXCEPTION_HANDLER.
|
||||
def drf_logging_exc_handler(exc: Exception, context: Any) -> Optional[Response]:
|
||||
def drf_logging_exc_handler(exc: Exception, context: Any) -> Response | None:
|
||||
"""Custom exception handler for DRF.
|
||||
|
||||
DRF's exception handler takes care of Http404, PermissionDenied and
|
||||
@@ -31,7 +31,7 @@ def drf_logging_exc_handler(exc: Exception, context: Any) -> Optional[Response]:
|
||||
handled by another method in the middleware
|
||||
"""
|
||||
request = context.get("request")
|
||||
response: Optional[Response] = exception_handler(exc=exc, context=context)
|
||||
response: Response | None = exception_handler(exc=exc, context=context)
|
||||
ExceptionLoggingMiddleware.format_exc_and_log(
|
||||
request=request, response=response, exception=exc
|
||||
)
|
||||
@@ -54,7 +54,7 @@ class ExceptionLoggingMiddleware:
|
||||
|
||||
def process_exception(
|
||||
self, request: HttpRequest, exception: Exception
|
||||
) -> Optional[HttpResponse]:
|
||||
) -> HttpResponse | None:
|
||||
"""Django hook to handle exceptions by a middleware.
|
||||
|
||||
Args:
|
||||
@@ -78,7 +78,7 @@ class ExceptionLoggingMiddleware:
|
||||
|
||||
@staticmethod
|
||||
def format_exc_and_log(
|
||||
request: Request, exception: Exception, response: Optional[Response] = None
|
||||
request: Request, exception: Exception, response: Response | None = None
|
||||
) -> None:
|
||||
"""Format the exception to be logged and logs it.
|
||||
|
||||
@@ -90,18 +90,7 @@ class ExceptionLoggingMiddleware:
|
||||
if response:
|
||||
status_code = response.status_code
|
||||
if status_code >= 500:
|
||||
message = "{method} {url} {status}\n\n{error}\n\n````{tb}````".format(
|
||||
method=request.method,
|
||||
url=request.build_absolute_uri(),
|
||||
status=status_code,
|
||||
error=repr(exception),
|
||||
tb=traceback.format_exc(),
|
||||
)
|
||||
message = f"{request.method} {request.build_absolute_uri()} {status_code}\n\n{repr(exception)}\n\n````{traceback.format_exc()}````"
|
||||
else:
|
||||
message = "{method} {url} {status} {error}".format(
|
||||
method=request.method,
|
||||
url=request.build_absolute_uri(),
|
||||
status=status_code,
|
||||
error=repr(exception),
|
||||
)
|
||||
message = f"{request.method} {request.build_absolute_uri()} {status_code} {repr(exception)}"
|
||||
logger.error(message)
|
||||
|
||||
@@ -4,6 +4,5 @@ from log_request_id.middleware import RequestIDMiddleware
|
||||
|
||||
|
||||
class CustomRequestIDMiddleware(RequestIDMiddleware):
|
||||
|
||||
def _generate_id(self):
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
from django.conf import settings
|
||||
@@ -111,9 +111,7 @@ class DataMigrator:
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def _fetch_schema_names(
|
||||
self, schemas_to_migrate: list[str]
|
||||
) -> list[tuple[int, str]]:
|
||||
def _fetch_schema_names(self, schemas_to_migrate: list[str]) -> list[tuple[int, str]]:
|
||||
"""Fetches schema names and their IDs from the destination database
|
||||
based on the provided schema list. Supports fetching all schemas if
|
||||
'_ALL_' is specified.
|
||||
@@ -126,7 +124,6 @@ class DataMigrator:
|
||||
list[tuple[int, str]]: A list of tuples containing the ID and schema name.
|
||||
"""
|
||||
with self._db_connect_and_cursor(self.dest_db_config) as (conn, cur):
|
||||
|
||||
# Process schemas_to_migrate: trim spaces and remove empty entries
|
||||
schemas_to_migrate = [
|
||||
schema.strip() for schema in schemas_to_migrate if schema.strip()
|
||||
@@ -153,7 +150,7 @@ class DataMigrator:
|
||||
dest_cursor: cursor,
|
||||
column_names: list[str],
|
||||
column_transformations: dict[str, dict[str, Any]],
|
||||
) -> Optional[tuple[Any, ...]]:
|
||||
) -> tuple[Any, ...] | None:
|
||||
"""Prepares and migrates the relational keys of a single row from the
|
||||
source database to the destination database, updating specific column
|
||||
values based on provided transformations.
|
||||
@@ -209,7 +206,8 @@ class DataMigrator:
|
||||
column_names: list[str],
|
||||
) -> tuple[Any, ...]:
|
||||
"""Convert specified field in the row to JSON format if it is a
|
||||
list."""
|
||||
list.
|
||||
"""
|
||||
if key not in column_names:
|
||||
return row
|
||||
|
||||
@@ -328,10 +326,12 @@ class DataMigrator:
|
||||
|
||||
This method retrieves the maximum ID value from the specified table
|
||||
and sets the next auto-increment value for the table accordingly.
|
||||
|
||||
Args:
|
||||
dest_cursor (cursor): The cursor for the destination database.
|
||||
dest_table (str): The name of the table to adjust.
|
||||
dest_conn (connection): The connection to the destination database.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
@@ -367,8 +367,7 @@ class DataMigrator:
|
||||
|
||||
if max_id is None:
|
||||
logger.info(
|
||||
f"Table '{dest_table}' is empty. No need to adjust "
|
||||
"auto-increment."
|
||||
f"Table '{dest_table}' is empty. No need to adjust " "auto-increment."
|
||||
)
|
||||
return
|
||||
|
||||
@@ -386,7 +385,7 @@ class DataMigrator:
|
||||
raise
|
||||
|
||||
def migrate(
|
||||
self, migrations: list[dict[str, str]], organization_id: Optional[str] = None
|
||||
self, migrations: list[dict[str, str]], organization_id: str | None = None
|
||||
) -> None:
|
||||
self._create_tracking_table_if_not_exists()
|
||||
|
||||
@@ -397,7 +396,6 @@ class DataMigrator:
|
||||
dest_cursor,
|
||||
),
|
||||
):
|
||||
|
||||
for migration in migrations:
|
||||
migration_name = migration["name"]
|
||||
logger.info(f"Migration '{migration_name}' started")
|
||||
@@ -504,9 +502,7 @@ class Command(BaseCommand):
|
||||
schemas_to_migrate = schemas_to_migrate.split(",")
|
||||
|
||||
# Organization Data (Schema)
|
||||
schema_names = migrator._fetch_schema_names(
|
||||
schemas_to_migrate=schemas_to_migrate
|
||||
)
|
||||
schema_names = migrator._fetch_schema_names(schemas_to_migrate=schemas_to_migrate)
|
||||
|
||||
for organization_id, schema_name in schema_names:
|
||||
if schema_name == "public":
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
class MigrationQuery:
|
||||
"""This class contains methods to generate SQL queries for various
|
||||
migration operations."""
|
||||
migration operations.
|
||||
"""
|
||||
|
||||
def __init__(self, v2_schema) -> None:
|
||||
self.v2_schema = v2_schema
|
||||
@@ -253,8 +254,8 @@ class MigrationQuery:
|
||||
def get_organization_migrations(
|
||||
self, schema: str, organization_id: str
|
||||
) -> list[dict[str, str]]:
|
||||
"""
|
||||
Returns a list of dictionaries containing the organization migration details.
|
||||
"""Returns a list of dictionaries containing the organization migration details.
|
||||
|
||||
Args:
|
||||
schema (str): The name of the schema for the organization.
|
||||
organization_id (str): The ID of the organization.
|
||||
|
||||
@@ -26,7 +26,7 @@ class NotificationHelper:
|
||||
payload (Any): The data to be sent with the notification. This can be any
|
||||
format expected by the provider
|
||||
|
||||
Returns:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for notification in notifications:
|
||||
|
||||
@@ -7,7 +7,6 @@ from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from notification_v2.models import Notification
|
||||
|
||||
|
||||
@@ -15,7 +16,8 @@ class NotificationProvider(ABC):
|
||||
@abstractmethod
|
||||
def send(self):
|
||||
"""Method to be overridden in child classes for sending the
|
||||
notification."""
|
||||
notification.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses should implement this method.")
|
||||
|
||||
def validate(self):
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from celery import shared_task
|
||||
|
||||
from notification_v2.enums import AuthorizationType
|
||||
from notification_v2.provider.notification_provider import NotificationProvider
|
||||
|
||||
@@ -48,8 +49,8 @@ class Webhook(NotificationProvider):
|
||||
return super().validate()
|
||||
|
||||
def get_headers(self):
|
||||
"""
|
||||
Get the headers for the notification based on the authorization type and key.
|
||||
"""Get the headers for the notification based on the authorization type and key.
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
|
||||
@@ -95,9 +96,7 @@ class Webhook(NotificationProvider):
|
||||
# Check if custom header type has required details
|
||||
if authorization_type == AuthorizationType.CUSTOM_HEADER:
|
||||
if not authorization_header or not authorization_key:
|
||||
raise ValueError(
|
||||
"Custom header or key missing for custom authorization."
|
||||
)
|
||||
raise ValueError("Custom header or key missing for custom authorization.")
|
||||
return headers
|
||||
|
||||
|
||||
@@ -108,7 +107,7 @@ def send_webhook_notification(
|
||||
payload: Any,
|
||||
headers: Any = None,
|
||||
timeout: int = 10,
|
||||
max_retries: Optional[int] = None,
|
||||
max_retries: int | None = None,
|
||||
retry_delay: int = 10,
|
||||
):
|
||||
"""Celery task to send a webhook with retries and error handling.
|
||||
|
||||
@@ -107,7 +107,8 @@ class NotificationSerializer(serializers.ModelSerializer):
|
||||
|
||||
def validate_name(self, value):
|
||||
"""Check uniqueness of the name with respect to either 'api' or
|
||||
'pipeline'."""
|
||||
'pipeline'.
|
||||
"""
|
||||
api = self.initial_data.get("api", getattr(self.instance, "api", None))
|
||||
pipeline = self.initial_data.get(
|
||||
"pipeline", getattr(self.instance, "pipeline", None)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from api_v2.deployment_helper import DeploymentHelper
|
||||
from api_v2.exceptions import APINotFound
|
||||
from notification_v2.constants import NotificationUrlConstant
|
||||
from pipeline_v2.exceptions import PipelineNotFound
|
||||
from pipeline_v2.models import Pipeline
|
||||
from pipeline_v2.pipeline_processor import PipelineProcessor
|
||||
from rest_framework import viewsets
|
||||
|
||||
from notification_v2.constants import NotificationUrlConstant
|
||||
|
||||
from .models import Notification
|
||||
from .serializers import NotificationSerializer
|
||||
|
||||
|
||||
5736
backend/pdm.lock
generated
5736
backend/pdm.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -11,7 +11,6 @@ class IsOwner(permissions.BasePermission):
|
||||
"""Custom permission to only allow owners of an object."""
|
||||
|
||||
def has_object_permission(self, request: Request, view: APIView, obj: Any) -> bool:
|
||||
|
||||
return True if obj.created_by == request.user else False
|
||||
|
||||
|
||||
@@ -27,7 +26,6 @@ class IsOwnerOrSharedUser(permissions.BasePermission):
|
||||
"""Custom permission to only allow owners and shared users of an object."""
|
||||
|
||||
def has_object_permission(self, request: Request, view: APIView, obj: Any) -> bool:
|
||||
|
||||
return (
|
||||
True
|
||||
if (
|
||||
@@ -40,10 +38,10 @@ class IsOwnerOrSharedUser(permissions.BasePermission):
|
||||
|
||||
class IsOwnerOrSharedUserOrSharedToOrg(permissions.BasePermission):
|
||||
"""Custom permission to only allow owners and shared users of an object or
|
||||
if it is shared to org."""
|
||||
if it is shared to org.
|
||||
"""
|
||||
|
||||
def has_object_permission(self, request: Request, view: APIView, obj: Any) -> bool:
|
||||
|
||||
return (
|
||||
True
|
||||
if (
|
||||
@@ -57,12 +55,12 @@ class IsOwnerOrSharedUserOrSharedToOrg(permissions.BasePermission):
|
||||
|
||||
class IsFrictionLessAdapter(permissions.BasePermission):
|
||||
"""Hack for friction-less onboarding not allowing user to view or updating
|
||||
friction less adapter."""
|
||||
friction less adapter.
|
||||
"""
|
||||
|
||||
def has_object_permission(
|
||||
self, request: Request, view: APIView, obj: AdapterInstance
|
||||
) -> bool:
|
||||
|
||||
if obj.is_friction_less:
|
||||
return False
|
||||
|
||||
@@ -71,12 +69,12 @@ class IsFrictionLessAdapter(permissions.BasePermission):
|
||||
|
||||
class IsFrictionLessAdapterDelete(permissions.BasePermission):
|
||||
"""Hack for friction-less onboarding Allows frticon less adapter to rmoved
|
||||
by an org member."""
|
||||
by an org member.
|
||||
"""
|
||||
|
||||
def has_object_permission(
|
||||
self, request: Request, view: APIView, obj: AdapterInstance
|
||||
) -> bool:
|
||||
|
||||
if obj.is_friction_less:
|
||||
return True
|
||||
|
||||
|
||||
@@ -4,11 +4,12 @@ from typing import Any
|
||||
from api_v2.api_key_validator import BaseAPIKeyValidator
|
||||
from api_v2.exceptions import InvalidAPIRequest
|
||||
from api_v2.key_helper import KeyHelper
|
||||
from pipeline_v2.exceptions import PipelineNotFound
|
||||
from pipeline_v2.pipeline_processor import PipelineProcessor
|
||||
from rest_framework.request import Request
|
||||
from utils.user_context import UserContext
|
||||
|
||||
from pipeline_v2.exceptions import PipelineNotFound
|
||||
from pipeline_v2.pipeline_processor import PipelineProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class PipelineStatusPayload:
|
||||
@@ -8,8 +8,8 @@ class PipelineStatusPayload:
|
||||
pipeline_id: str,
|
||||
pipeline_name: str,
|
||||
status: str,
|
||||
execution_id: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
execution_id: str | None = None,
|
||||
error_message: str | None = None,
|
||||
):
|
||||
self.type = type
|
||||
self.pipeline_id = pipeline_id
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from rest_framework.exceptions import APIException
|
||||
|
||||
|
||||
@@ -24,14 +22,13 @@ class InactivePipelineError(APIException):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_name: Optional[str] = None,
|
||||
detail: Optional[str] = None,
|
||||
code: Optional[str] = None,
|
||||
pipeline_name: str | None = None,
|
||||
detail: str | None = None,
|
||||
code: str | None = None,
|
||||
):
|
||||
if pipeline_name:
|
||||
self.default_detail = (
|
||||
f"Pipeline '{pipeline_name}' is inactive, "
|
||||
"please activate the pipeline"
|
||||
f"Pipeline '{pipeline_name}' is inactive, " "please activate the pipeline"
|
||||
)
|
||||
super().__init__(detail, code)
|
||||
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.urls import reverse
|
||||
from pipeline_v2.constants import PipelineKey, PipelineURL
|
||||
from pipeline_v2.models import Pipeline
|
||||
from pipeline_v2.pipeline_processor import PipelineProcessor
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from utils.request.constants import RequestConstants
|
||||
@@ -13,6 +10,9 @@ from workflow_manager.workflow_v2.constants import WorkflowExecutionKey, Workflo
|
||||
from workflow_manager.workflow_v2.views import WorkflowViewSet
|
||||
|
||||
from backend.constants import RequestHeader
|
||||
from pipeline_v2.constants import PipelineKey, PipelineURL
|
||||
from pipeline_v2.models import Pipeline
|
||||
from pipeline_v2.pipeline_processor import PipelineProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,7 +24,7 @@ class PipelineManager:
|
||||
def execute_pipeline(
|
||||
request: Request,
|
||||
pipeline_id: str,
|
||||
execution_id: Optional[str] = None,
|
||||
execution_id: str | None = None,
|
||||
) -> Response:
|
||||
"""Used to execute a pipeline.
|
||||
|
||||
@@ -45,9 +45,10 @@ class PipelineManager:
|
||||
@staticmethod
|
||||
def get_pipeline_execution_data_for_scheduled_run(
|
||||
pipeline_id: str,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
) -> dict[str, Any] | None:
|
||||
"""Gets the required data to be passed while executing a pipeline Any
|
||||
changes to pipeline execution needs to be propagated here."""
|
||||
changes to pipeline execution needs to be propagated here.
|
||||
"""
|
||||
callback_url = settings.DJANGO_APP_BACKEND_URL + reverse(
|
||||
PipelineURL.EXECUTE_NAMESPACE
|
||||
)
|
||||
|
||||
@@ -8,7 +8,6 @@ from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user