Initial commit on Unstract

This commit is contained in:
jaseemjaskp
2024-02-25 16:19:36 +05:30
parent 4d0f0d26f1
commit 26ebb17d47
1048 changed files with 150540 additions and 0 deletions

35
.github/pull_request_template.md vendored Normal file
View File

@@ -0,0 +1,35 @@
## What
...
## Why
...
## How
...
## Relevant Docs
-
## Related Issues or PRs
-
## Dependencies Versions / Env Variables
-
## Notes on Testing
...
## Screenshots
...
## Checklist
I have read and understood the [Contribution Guidelines]().

View File

@@ -0,0 +1,86 @@
name: Container Image Build Test for PRs
env:
VERSION: ci-test # Used for docker tag
on:
push:
branches:
- main
- development
paths:
- 'backend/**'
- 'frontend/**'
- 'unstract/**'
- 'document-service/**'
- 'platform-service/**'
- 'x2text-service/**'
- 'worker/**'
- 'docker/dockerfiles/**'
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
branches:
- main
- development
paths:
- 'backend/**'
- 'frontend/**'
- 'unstract/**'
- 'document-service/**'
- 'platform-service/**'
- 'x2text-service/**'
- 'worker/**'
- 'docker/dockerfiles/**'
jobs:
build:
if: github.event.pull_request.draft == false
runs-on: custom-k8s-runner
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Login to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Container Build
working-directory: ./docker
run: |
docker compose -f docker-compose.build.yaml build
- name: Container Run
working-directory: ./docker
run: |
cp ../backend/sample.env ../backend/.env
cp ../document-service/sample.env ../document-service/.env
cp ../platform-service/sample.env ../platform-service/.env
cp ../prompt-service/sample.env ../prompt-service/.env
cp ../worker/sample.env ../worker/.env
cp ../x2text-service/sample.env ../x2text-service/.env
cp sample.essentials.env essentials.env
docker compose -f docker-compose.yaml up -d
sleep 10
docker compose -f docker-compose.yaml ps -a
# Get the names of exited containers
custom_format="{{.Name}}\t{{.Image}}\t{{.Service}}"
EXITED_CONTAINERS=$(docker compose -f docker-compose.yaml ps -a --filter status=exited --format "$custom_format")
line_count=$(echo "$EXITED_CONTAINERS" | wc -l)
if [ "$line_count" -gt 1 ]; then
echo "Exited Containers: $EXITED_CONTAINERS"
SERVICE=$(echo "$EXITED_CONTAINERS" | awk 'NR>0 {print $3}')
echo "Exited Services:"
echo "$SERVICE"
echo "There are exited containers."
# Print logs of exited containers
IFS=$'\n'
for SERVICE in $SERVICE; do
docker compose -f docker-compose.yaml logs "$SERVICE"
done
docker compose -f docker-compose.yaml down -v
exit 1
fi
docker compose -f docker-compose.yaml down -v

View File

@@ -0,0 +1,64 @@
name: Unstract Docker Image Build and Push (Development)
on:
workflow_dispatch:
inputs:
tag:
description: "Docker image tag"
required: true
default: "latest"
service_name:
description: "Service to build"
required: true
default: "backend" # Provide a default value
type: choice
options: # Define available options
- all-services
- frontend
- backend
- document-service
- platform-service
- worker
- prompt-service
- x2text-service
run-name: "[${{ inputs.service_name }}] Docker Image Build and Push (Development)"
jobs:
build-and-push:
runs-on: custom-k8s-runner
steps:
- name: Output Inputs
run: echo "${{ toJSON(github.event.inputs) }}"
- name: Checkout code
uses: actions/checkout@v2
- name: Login to Docker Hub
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# Build and push Docker image for the specified service
- name: Build and push image for ${{ github.event.inputs.service_name }}
working-directory: ./docker
if: github.event.inputs.service_name != 'all-services'
run: |
VERSION=${{ github.event.inputs.tag }} docker compose -f docker-compose.build.yaml build --no-cache ${{ github.event.inputs.service_name }}
docker push unstract/${{ github.event.inputs.service_name }}:${{ github.event.inputs.tag }}
# Build and push all service images
- name: Build and push all images
working-directory: ./docker
if: github.event.inputs.service_name == 'all-services'
run: |
VERSION=${{ github.event.inputs.tag }} docker-compose -f docker-compose.build.yaml build --no-cache
# Push all built images
docker push unstract/backend:${{ github.event.inputs.tag }}
docker push unstract/frontend:${{ github.event.inputs.tag }}
docker push unstract/document-service:${{ github.event.inputs.tag }}
docker push unstract/platform-service:${{ github.event.inputs.tag }}
docker push unstract/worker:${{ github.event.inputs.tag }}
docker push unstract/prompt-service:${{ github.event.inputs.tag }}
docker push unstract/x2text-service:${{ github.event.inputs.tag }}

View File

@@ -0,0 +1,65 @@
name: Unstract Tools Docker Image Build and Push (Development)
on:
workflow_dispatch:
inputs:
tag:
description: "Docker image tag"
required: true
default: "latest"
service_name:
description: "Tool to build"
required: true
default: "tool-classifier" # Provide a default value
type: choice
options: # Define available options
- tool-classifier
- tool-doc-pii-redactor
- tool-indexer
- tool-ocr
- tool-translate
- tool-structure
- tool-text-extractor
run-name: "[${{ inputs.service_name }}] Docker Image Build and Push (Development)"
jobs:
build-and-push:
runs-on: custom-k8s-runner
steps:
- name: Output Inputs
run: echo "${{ toJSON(github.event.inputs) }}"
- name: Checkout code
uses: actions/checkout@v2
- name: Login to Docker Hub
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build tool-classifier
if: github.event.inputs.service_name=='tool-classifier'
run: docker build -t unstract/${{github.event.inputs.service_name}}:${{ github.event.inputs.tag }} ./tools/classifier
- name: Build tool-doc-pii-redactor
if: github.event.inputs.service_name=='tool-doc-pii-redactor'
run: docker build -t unstract/${{github.event.inputs.service_name}}:${{ github.event.inputs.tag }} ./tools/doc_pii_redactor
- name: Build tool-indexer
if: github.event.inputs.service_name=='tool-indexer'
run: docker build -t unstract/${{github.event.inputs.service_name}}:${{ github.event.inputs.tag }} ./tools/indexer
- name: Build tool-ocr
if: github.event.inputs.service_name=='tool-ocr'
run: docker build -t unstract/${{github.event.inputs.service_name}}:${{ github.event.inputs.tag }} ./tools/ocr
- name: Build tool-translate
if: github.event.inputs.service_name=='tool-translate'
run: docker build -t unstract/${{github.event.inputs.service_name}}:${{ github.event.inputs.tag }} ./tools/translate
- name: Build tool-structure
if: github.event.inputs.service_name=='tool-structure'
run: docker build -t unstract/${{github.event.inputs.service_name}}:${{ github.event.inputs.tag }} ./tools/structure
- name: Build tool-text-extractor
if: github.event.inputs.service_name=='tool-text-extractor'
run: docker build -t unstract/${{github.event.inputs.service_name}}:${{ github.event.inputs.tag }} ./tools/text_extractor
- name: Push Docker image to Docker Hub
run: docker push unstract/${{ github.event.inputs.service_name }}:${{ github.event.inputs.tag }}

33
.github/workflows/production-build.yaml vendored Normal file
View File

@@ -0,0 +1,33 @@
name: Unstract Docker Image Build and Push (Production)
on:
release:
types:
- created
run-name: "[${{ github.event.release.tag_name }}] Docker Image Build and Push (Development)"
jobs:
build-and-push:
runs-on: custom-k8s-runner
strategy:
matrix:
service_name: [backend, frontend, document-service, platform-service, prompt-service, worker, x2text-service]
steps:
- name: Checkout code
uses: actions/checkout@v2
with:
ref: ${{ github.event.release.tag_name }}
- name: Login to Docker Hub
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and push image
working-directory: ./docker
run: |
VERSION=${{ github.event.release.tag_name }} docker-compose -f docker-compose.build.yaml build --no-cache ${{ matrix.service_name }}
docker push unstract/${{ matrix.service_name }}:${{ github.event.release.tag_name }}

631
.gitignore vendored Normal file
View File

@@ -0,0 +1,631 @@
# Created by https://www.toptal.com/developers/gitignore/api/windows,macos,linux,pycharm,pycharm+all,pycharm+iml,python,visualstudiocode,react,django
# Edit at https://www.toptal.com/developers/gitignore?templates=windows,macos,linux,pycharm,pycharm+all,pycharm+iml,python,visualstudiocode,react,django
### Django ###
*.log
*.pot
*.pyc
__pycache__/
local_settings.py
media
# If your build process includes running collectstatic, then you probably don't need or want to include staticfiles/
# in your Git repository. Update and uncomment the following line accordingly.
# <django-project-name>/staticfiles/
### Django.Python Stack ###
# Byte-compiled / optimized / DLL files
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
# Django stuff:
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
.pdm-build
.pdm-python
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.env.export
.venv*
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Linux ###
*~
# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*
# KDE directory preferences
.directory
# Linux trash folder which might appear on any partition or disk
.Trash-*
# .nfs files are created when an open file is removed but is still being accessed
.nfs*
### macOS ###
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
### macOS Patch ###
# iCloud generated files
*.icloud
### PyCharm ###
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf
# AWS User-specific
.idea/**/aws.xml
# Generated files
.idea/**/contentModel.xml
# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml
# Gradle
.idea/**/gradle.xml
.idea/**/libraries
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
cmake-build-*/
# Mongo Explorer plugin
.idea/**/mongoSettings.xml
# File-based project format
*.iws
# IntelliJ
out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml
# SonarLint plugin
.idea/sonarlint/
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Editor-based Rest Client
.idea/httpRequests
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
### PyCharm Patch ###
# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
# *.iml
# modules.xml
# .idea/misc.xml
# *.ipr
# Sonarlint plugin
# https://plugins.jetbrains.com/plugin/7973-sonarlint
.idea/**/sonarlint/
# SonarQube Plugin
# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
.idea/**/sonarIssues.xml
# Markdown Navigator plugin
# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
.idea/**/markdown-navigator.xml
.idea/**/markdown-navigator-enh.xml
.idea/**/markdown-navigator/
# Cache file creation bug
# See https://youtrack.jetbrains.com/issue/JBR-2257
.idea/$CACHE_FILE$
# CodeStream plugin
# https://plugins.jetbrains.com/plugin/12206-codestream
.idea/codestream.xml
# Azure Toolkit for IntelliJ plugin
# https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij
.idea/**/azureSettings.xml
### PyCharm+all ###
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
# AWS User-specific
# Generated files
# Sensitive or high-churn files
# Gradle
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
# Mongo Explorer plugin
# File-based project format
# IntelliJ
# mpeltonen/sbt-idea plugin
# JIRA plugin
# Cursive Clojure plugin
# SonarLint plugin
# Crashlytics plugin (for Android Studio and IntelliJ)
# Editor-based Rest Client
# Android studio 3.1+ serialized cache file
### PyCharm+all Patch ###
# Ignore everything but code style settings and run configurations
# that are supposed to be shared within teams.
.idea/*
!.idea/codeStyles
!.idea/runConfigurations
### PyCharm+iml ###
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
# AWS User-specific
# Generated files
# Sensitive or high-churn files
# Gradle
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
# Mongo Explorer plugin
# File-based project format
# IntelliJ
# mpeltonen/sbt-idea plugin
# JIRA plugin
# Cursive Clojure plugin
# SonarLint plugin
# Crashlytics plugin (for Android Studio and IntelliJ)
# Editor-based Rest Client
# Android studio 3.1+ serialized cache file
### PyCharm+iml Patch ###
# Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023
*.iml
modules.xml
.idea/misc.xml
*.ipr
### Python ###
# Byte-compiled / optimized / DLL files
# C extensions
# Distribution / packaging
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
# Installer logs
# Unit test / coverage reports
# Translations
# Django stuff:
# Flask stuff:
# Scrapy stuff:
# Sphinx documentation
# PyBuilder
# Jupyter Notebook
# IPython
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
# Celery stuff
# SageMath parsed files
# Environments
# Spyder project settings
# Rope project settings
# mkdocs documentation
# mypy
# Pyre type checker
# pytype static type analyzer
# Cython debug symbols
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
### react ###
.DS_*
logs
**/*.backup.*
**/*.back.*
node_modules
bower_components
*.sublime*
psd
thumb
sketch
### VisualStudioCode ###
.vscode/
**/.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
# !.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets
# Local History for Visual Studio Code
.history/
# Built Visual Studio Code Extensions
*.vsix
### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide
### Windows ###
# Windows thumbnail cache files
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db
# Dump file
*.stackdump
# Folder config file
[Dd]esktop.ini
# Recycle Bin used on file shares
$RECYCLE.BIN/
# Windows Installer files
*.cab
*.msi
*.msix
*.msm
*.msp
# Windows shortcuts
*.lnk
### Unstract ###
# Plugins
backend/plugins/authentication/*
!backend/plugins/authentication/auth_sample
# Tool registry
unstract/tool-registry/src/unstract/tool_registry/*.json
unstract/tool-registry/tests/*.yaml
!unstract/tool-registry/src/unstract/tool_registry/public_tools.json
unstract/tool-registry/src/unstract/tool_registry/config/registry.yaml
# Docker related
# End of https://www.toptal.com/developers/gitignore/api/windows,macos,linux,pycharm,pycharm+all,pycharm+iml,python,visualstudiocode,react,django
docker/temp/*
docker/init.sql/*
docker/*.env
!docker/sample.*.env
docker/public_tools.json
docker/proxy_overrides.yaml
# Tool development
tools/*/sdks/
tools/*/data_dir/
docker/workflow_data/

3
.jshintrc Normal file
View File

@@ -0,0 +1,3 @@
{
"esversion": 9
}

200
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,200 @@
---
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
# - Added unstract feature flag auto generated code to flake8 exclude list
# Force all unspecified python hooks to run python 3.10
default_language_version:
python: python3.9
default_stages:
- commit
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
exclude_types:
- "markdown"
- id: end-of-file-fixer
- id: check-yaml
args: [--unsafe]
- id: check-added-large-files
args: ["--maxkb=10240"]
- id: check-case-conflict
- id: check-docstring-first
- id: check-ast
- id: check-json
exclude: ".vscode/launch.json"
- id: check-executables-have-shebangs
- id: check-shebang-scripts-are-executable
- id: check-toml
- id: debug-statements
- id: detect-private-key
- id: check-merge-conflict
- id: check-symlinks
- id: destroyed-symlinks
- id: forbid-new-submodules
- id: mixed-line-ending
- id: no-commit-to-branch
- repo: https://github.com/adrienverge/yamllint
rev: v1.35.1
hooks:
- id: yamllint
args: ["-d", "relaxed"]
language: system
- repo: https://github.com/rhysd/actionlint
rev: v1.6.26
hooks:
- id: actionlint-docker
args: [-ignore, 'label ".+" is unknown']
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
args: [--config=pyproject.toml, -l 80]
language: system
exclude: |
(?x)^(
unstract/flags/src/unstract/flags/evaluation_.*\.py|
)$
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
args: [--max-line-length=80]
exclude: |
(?x)^(
.*migrations/.*\.py|
core/tests/.*|
unstract/flags/src/unstract/flags/evaluation_.*\.py|
)$
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
files: "\\.(py)$"
args:
[
"--profile",
"black",
"--filter-files",
--settings-path=pyproject.toml,
]
- repo: https://github.com/hadialqattan/pycln
rev: v2.4.0
hooks:
- id: pycln
args: [--config=pyproject.toml]
- repo: https://github.com/pycqa/docformatter
rev: v1.7.5
hooks:
- id: docformatter
# - repo: https://github.com/MarcoGorelli/absolufy-imports
# rev: v0.3.1
# hooks:
# - id: absolufy-imports
# files: ^backend/
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
hooks:
- id: pyupgrade
entry: pyupgrade --py39-plus --keep-runtime-typing
types:
- python
- repo: https://github.com/gitleaks/gitleaks
rev: v8.18.2
hooks:
- id: gitleaks
- repo: https://github.com/hadolint/hadolint
rev: v2.12.1-beta
hooks:
- id: hadolint-docker
args:
- --ignore=DL3003
- --ignore=DL3008
- --ignore=DL3013
- --ignore=DL3018
- --ignore=SC1091
files: Dockerfile$
- repo: https://github.com/asottile/yesqa
rev: v1.5.0
hooks:
- id: yesqa
- repo: https://github.com/pre-commit/mirrors-eslint
rev: "v9.0.0-beta.0" # Use the sha / tag you want to point at
hooks:
- id: eslint
args: [--config=frontend/.eslintrc.json]
files: \.[jt]sx?$ # *.js, *.jsx, *.ts and *.tsx
types: [file]
additional_dependencies:
- eslint@8.41.0
- eslint-config-google@0.14.0
- eslint-config-prettier@8.8.0
- eslint-plugin-prettier@4.2.1
- eslint-plugin-react@7.32.2
- eslint-plugin-import@2.25.2
- repo: https://github.com/Lucas-C/pre-commit-hooks-nodejs
rev: v1.1.2
hooks:
- id: htmlhint
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
hooks:
- id: pyupgrade
entry: pyupgrade --py38-plus --keep-runtime-typing
types:
- python
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
language: system
entry: mypy .
pass_filenames: false
# IMPORTANT!
# Keep args same as tool.mypy section in pyproject.toml
args:
[
--allow-subclassing-any,
--allow-untyped-decorators,
--check-untyped-defs,
--exclude, ".*migrations/.*.py",
--exclude, "backend/prompt/.*",
--exclude, "document-service/.*",
--exclude, "unstract/connectors/tests/.*",
--exclude, "unstract/core/.*",
--exclude, "unstract/flags/src/unstract/flags/.*",
--exclude, "__pypackages__/.*",
--follow-imports, "silent",
--ignore-missing-imports,
--implicit-reexport,
--pretty,
--python-version=3.9,
--show-column-numbers,
--show-error-codes,
--strict,
--warn-redundant-casts,
--warn-return-any,
--warn-unreachable,
--warn-unused-configs,
--warn-unused-ignores,
]
- repo: https://github.com/igorshubovych/markdownlint-cli
rev: v0.39.0
hooks:
- id: markdownlint
args: [--disable, MD013]
- id: markdownlint-fix
args: [--disable, MD013]
- repo: https://github.com/pdm-project/pdm
rev: 2.12.3
hooks:
- id: pdm-lock-check
- repo: local
hooks:
- id: check-django-migrations
name: Check django migrations
entry: sh -c 'pdm run `find . -name "manage.py" -not -path "*/.venv/*"` makemigrations --check --dry-run --no-input'
language: system
types: [python] # hook only runs if a python file is staged
pass_filenames: false

3
CONTRIBUTE.md Normal file
View File

@@ -0,0 +1,3 @@
Conventions
- Where ever you are adding yaml files, preferred extension is `.yaml`

199
README.md Normal file
View File

@@ -0,0 +1,199 @@
# Unstract
[![pdm-managed](https://img.shields.io/badge/pdm-managed-blueviolet)](https://pdm-project.org)
TODO: Write few lines about the project.
## System Requirements
- docker
- git
## Running with docker compose
- All services needed by the backend can be run with
```
cd docker/
VERSION=test docker compose -f docker-compose.build.yaml build
VERSION=test docker compose -f docker-compose.yaml up -d
```
Additional information on running with Docker can be found in [DOCKERISING.md](/DOCKERISING.md)
- Use the `-f` flag to run all dependencies necessary for development, this runs containers needed for testing as well such as Minio.
```
docker compose -f docker-compose-dev-essentials.yaml up
```
- It might take sometime on the first run to pull the images.
## Running locally
### Installation
- Install the below libraries which are needed to run Unstract
- Linux
```
sudo apt install build-essential pkg-config libpoppler-cpp-dev libmagic-dev python3-dev
```
- Mac
```
brew install pkg-config poppler freetds libmagic
```
### Create your virtual env
- In order to install dependencies and run a package, ensure that you've sourced a virtual environment within that package. All commands in this repository assumes that you have sourced your required venv.
```
cd <package_to_use>
python -m venv .venv
source ./venv/bin/activate
```
### Install dependencies with PDM
- This repository makes use of [PDM](https://github.com/pdm-project/pdm) for managing dependencies with the help of a virtual
environment.
- If you haven't installed PDM in your machine yet,
- Install it using the below command
```
curl -sSL https://pdm.fming.dev/install-pdm.py | python3 -
```
- Or install it from PyPI using `pip`
```
pip install pdm
```
Ensure you're running the PDM commands from the corresponding package root
- Install dependencies for running the package with
```
pdm install
```
This install dev dependencies as well by default
- For production, install the requirements with
```
pdm install --prod
```
- With PDM its possible to run some services from any directory within this
repository. To list the possible scripts that can be executed
```
pdm run -l
```
- Add a new dependency with (ensure you're running it from the correct project's root)
Perform an editable install with `-e` only for local development.
```
pdm add <package_from_PyPI>
pdm add -e <relative_path_to_local_package>
```
- List all dependencies with
```
pdm list
```
- After updating `pyproject.toml`s with a newly added dependency, the lock file can be updated with
```
pdm lock
```
- Refer [PDM's documentation](https://pdm.fming.dev/latest/reference/cli/) for further details.
### Configuring Postgres
- Create a Postgres user and DB for the BE and configure it like so
```
POSTGRES_USER: unstract_dev
POSTGRES_PASSWORD: unstract_pass
POSTGRES_DB: unstract_db
```
If you require a different config, make sure the necessary envs from [backend/sample.env](/backend/sample.env) are exported.
- Execute the script [backend/init.sql](/backend/init.sql) that adds roles and creates a DB and extension for ZS Document Indexer tool to work.
Make sure that [pgvector](https://github.com/pgvector/pgvector#installation) is installed.
### Pre-commit hooks
- We use pre-commit to run some hooks whenever code is pushed to perform linting and static code analysis among other checks.
- Ensure dev dependencies are installed and you're in the virtual env
- Install hooks with `pre-commit install` or `pdm run pre-commit install`
- Manually trigger pre-commit hooks in following ways:
```bash
#
# Using the tool directly
#
# Run all pre-commit hooks
pre-commit run
# Run specific pre-commit hook
pre-commit run flake8
# Run mypy pre-commit hook for selected folder
pre-commit run mypy --files prompt-service/**/*.py
# Run mypy for selected folder
mypy prompt-service/**/*.py
#
# Using pdm to run the scripts
#
# Run all pre-commit hooks
pdm run pre-commit run
# Run specific pre-commit hook
pdm run pre-commit run flake8
# Run mypy pre-commit hook for selected folder
pdm run pre-commit run mypy --files prompt-service/**/*.py
# Run mypy for selected folder
pdm run mypy prompt-service/**/*.py
```
### Backend
- Check [backend/README.md](/backend/README.md) for running the backend.
### Frontend
- Install dependencies with `npm install`
- Start the server with `npm start`
### Traefik Proxy Overrides
It is possible to simultaneously run few services directly on docker host while others are run as docker containers via docker compose.
This enables seamless development without worrying about deployment of other services which you are not concerned with.
We just need to override default Traefik proxy routing to allow this, that's all.
1. Copy `docker/sample.proxy_overrides.yaml` to `docker/proxy_overrides.yaml`.
Modify to update Traefik proxy routes for services running directly on docker host (`host.docker.internal:<port>`).
2. Update host name of dependency components in config of services running directly on docker host:
- Replace as `*.localhost` IF container port is exposed on docker host
- **OR** use container IPs obtained via `docker network inspect unstract-network`
- **OR** run `dockers/scripts/resolve_container_svc_from_host.sh` IF container port is NOT exposed on docker host or if you want to keep dependency host names unchanged
Run the services.
#### Conflicting Host Names
When same host name environment variables are used by both the service running locally and a service
running in a container (for example, running in from a tool), host name resolution conflicts can arise for the following:
- `localhost` -> Using this inside a container points to the container itself, and not the host.
- `host.docker.internal` -> Meant to be used inside containers only, to get host IP.
Does not make sense to use in services running locally.
*In such cases, use another host name and point the same to host IP in `/etc/hosts`.*
For example, the backend uses the PROMPT_HOST environment variable, which is also supplied
in the Tool configuration when spawning Tool containers. If the backend is running
locally and the Tools are in containers, we could set the value to
`prompt-service` and add it to `/etc/hosts` as shown below.
```
<host_local_ip> prompt-service
```

136
backend/README.md Normal file
View File

@@ -0,0 +1,136 @@
# Unstract Backend
Contains the backend services for Unstract written with Django and DRF.
## Dependencies
1. Postgres
1. Redis
## Getting started
**NOTE**: All commands are executed from `/backend` and require the venv to be active. Refer [these steps](/README.md#create-your-virtual-env) to create/activate your venv
### Install and run manually
- Ensure that you've sourced your virtual environment and installed dependencies mentioned [here](/README.md#create-your-virtual-env).
- If you plan to run the django server locally, make sure the dependent services are up (either locally or through docker compose)
- Copy `sample.env` into `.env` and update the necessary variables. For eg:
```
DJANGO_SETTINGS_MODULE='backend.settings.dev'
DB_HOST='localhost'
DB_USER='unstract_dev'
DB_PASSWORD='unstract_pass'
DB_NAME='unstract_db'
DB_PORT=5432
```
- If you've made changes to the model, run `python manage.py makemigrations`, else ignore this step
- Run the following to apply any migrations to the DB and start the server
```
python manage.py migrate
python manage.py runserver localhost:8000
```
- Server will start and run at port 8000. (<http://localhost:8000>)
## Asynchronous execution/pipeline execution
- Working with celery
- Each pipeline or shared tasks will added to the queue (Redis), And the worker will consume from the queue
### Run Execution Worker
Run the following command to start the worker:
```bash
celery -A backend worker --loglevel=info
```
### Worker Dashboard
- We have to ensure the package flower is installed in the current environment
- Run command
```bash
celery -A backend flower
```
This command will start Flower on the default port (5555) and can be accessed via a web browser. Flower provides a user-friendly interface for monitoring and managing Celery tasks
## Connecting to Postgres
Follow the below steps to connect to the postgres DB running with `docker compose`.
1. Exec into a shell within the postgres container
```
docker compose exec -it db bash
```
2. Connect to the db as the specified user
```
psql -d unstract_db -U unstract_dev
```
3. Execute PSQL commands within the shell.
## API Docs
While running the backend server locally, access the API documentation that's auto generated at
the backend endpoint `/api/v1/doc/`.
**NOTE:** There exists issues accessing this when the django server is run with gunicorn (in case of running with
a container)
- [Account](account/api_doc.md)
- [FileManagement](file_management/api_doc.md)
## Connectors
### Google Drive
The Google Drive connector makes use of [PyDrive2](https://pypi.org/project/PyDrive2/) library and supports only OAuth 2.0 authentication.
To set it up, follow the first step higlighted in [Google's docs](https://developers.google.com/identity/protocols/oauth2#1.-obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar.console_name-.) and set the client ID and client secret
as envs in `backend/.env`
```
GOOGLE_OAUTH2_KEY="<client-id>"
GOOGLE_OAUTH2_SECRET="<client-secret>"
```
# Archived - (EXPERIMENTAL)
## Accessing the admin site
- If its the first time, create a super user and follow the on-screen instructions
```
python manage.py createsuperuser
```
- Register your models in `<app>/admin.py`, for example
```
from django.contrib import admin
from .models import Prompt
admin.site.register(Prompt)
```
- Make sure the server is running and hit the `/admin` endpoint
## Running unit tests
Units tests are run with [pytest](https://docs.pytest.org/en/7.3.x/) and [pytest-django](https://pytest-django.readthedocs.io/en/latest/index.html)
```
pytest
pytest prompt # To run for an app named prompt
```
All tests are organized within an app, for eg: `prompt/tests/test_urls.py`
**NOTE:** The django server need not be up to run the tests, however the DB needs to be running.

26
backend/account/ReadMe.md Normal file
View File

@@ -0,0 +1,26 @@
# Basic WorkFlow
`We can Add Workflows Here`
## Login
### Step
1. Login
2. Get Organizations
3. Set Organization
4. Use organizational APIs /unstract/<org_id>/
## Switch organization
1. Get Organizations
2. Set Organization
3. Use organizational APIs /unstract/<org_id>/
## Get current user and Organization data
- Use Get User Profile and Get Organization Info APIs
## Signout
1.signout APi

View File

5
backend/account/admin.py Normal file
View File

@@ -0,0 +1,5 @@
from django.contrib import admin
from .models import Organization, User
admin.site.register([Organization, User])

202
backend/account/api_doc.md Normal file
View File

@@ -0,0 +1,202 @@
# Authentication APIs
### Unauthorized Attempt
**Status Code:** 401
**Response:**
```json
{
"message": "Unauthorized"
}
```
### Login
```http
GET /login
```
Sample URL:
> <base_url>/login
Response: Auth0 login page
### Logout
```http
GET /logout
```
Sample URL:
> <base_url>/logout
Response: OK (redirect to Auth0 page)
### Signup
```http
GET /signup
```
Sample URL:
> <base_url>/signup
Response: Auth0 signup page
### Get User Profile
```http
GET /profile
```
Sample URL:
> <base_url>/unstract/<org_id>/profile
Response:
```json
{
"user": {
"id": "6",
"email": "iamali003@gmail.com",
"name": "Ali",
"display_name": "Ali",
"family_name": null,
"picture": null
}
}
```
### Password Reset
```http
POST /reset_password
```
Sample URL:
> <base_url>/unstract/<org_id>/reset_password
Response:
```json
{
"status": "failed",
"message": "user doesn't have Username-Password-Authentication"
}
```
### Get Organizations of User
```http
GET /organization
```
Sample URL:
> <base_url>/organization
Response:
```json
{
"message": "success",
"organizations": [
{
"id": "org_Z12elHhCcPH5rPD7",
"display_name": "Personal Org",
"name": "personal"
},
{
"id": "org_CB46CBskR8BxFjVV",
"display_name": "ali Test",
"name": "ali1"
}
]
}
```
### Select an Organization to Use
```http
POST /set
```
Sample URL:
> <base_url>/organization/<org_id>/set
Response:
```json
{
"user": {
"id": "6",
"email": "iamali003@gmail.com",
"name": "Ali",
"display_name": "Ali",
"family_name": null,
"picture": null
},
"organization": {
"display_name": "ali Test",
"name": "ali1",
"organization_id": "org_CB46CBskR8BxFjVV"
}
}
```
### Get Organization Members
```http
GET /members
```
Sample URL:
> <base_url>/unstract/<org_id>/members
Response:
```json
{
"message": "success",
"members": [
{
"user_id": "google-oauth2|102763382532901780910",
"email": "iamali003@gmail.com",
"name": "",
"picture": null
}
]
}
```
### Get Organization Info
```http
GET /organization
```
Sample URL:
> <base_url>/unstract/<org_id>/organization
Response:
```json
{
"message": "success",
"organization": {
"name": "ali1",
"display_name": "ali Test",
"organization_id": "org_CB46CBskR8BxFjVV",
"created_at": "2023-06-26 04:42:40.905458+00:00"
}
}
```
## Ref
- Postman Collection : [postaman collection link](https://api.postman.com/collections/24537488-9380ea92-d1e0-45f4-827c-c5cc9d0370b8?access_key=PMAT-01H3VGHTM9SR01MHXA95G1RTWB)
- By testing Postman
- set cookies
- set X-CSRFToken in header for POST request

6
backend/account/apps.py Normal file
View File

@@ -0,0 +1,6 @@
from django.apps import AppConfig
class AccountConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "account"

View File

@@ -0,0 +1,508 @@
import logging
from typing import Any, Optional, Union
from urllib.parse import urlencode
from account.authentication_helper import AuthenticationHelper
from account.authentication_plugin_registry import AuthenticationPluginRegistry
from account.authentication_service import AuthenticationService
from account.cache_service import CacheService
from account.constants import (
AuthoErrorCode,
Cookie,
ErrorMessage,
OrganizationMemberModel,
)
from account.custom_exceptions import (
DuplicateData,
Forbidden,
UserNotExistError,
)
from account.dto import (
CallbackData,
MemberInvitation,
OrganizationData,
UserInfo,
UserInviteResponse,
UserRoleData,
)
from account.exceptions import OrganizationNotExist
from account.models import Organization, User
from account.organization import OrganizationService
from account.serializer import (
GetOrganizationsResponseSerializer,
OrganizationSerializer,
SetOrganizationsResponseSerializer,
)
from account.user import UserService
from django.conf import settings
from django.contrib.auth import login as django_login
from django.contrib.auth import logout as django_logout
from django.db.utils import IntegrityError
from django.middleware import csrf
from django.shortcuts import redirect
from django_tenants.utils import tenant_context
from psycopg2.errors import UndefinedTable
from rest_framework import status
from rest_framework.request import Request
from rest_framework.response import Response
from tenant_account.models import OrganizationMember as OrganizationMember
from tenant_account.organization_member_service import OrganizationMemberService
Logger = logging.getLogger(__name__)
class AuthenticationController:
"""Authentication Controller This controller class manages user
authentication processes."""
def __init__(self) -> None:
"""This method initializes the controller by selecting the appropriate
authentication plugin based on availability."""
self.authentication_helper = AuthenticationHelper()
self.organization_member_service = OrganizationMemberService()
if AuthenticationPluginRegistry.is_plugin_available():
self.auth_service: AuthenticationService = (
AuthenticationPluginRegistry.get_plugin()
)
else:
self.auth_service = AuthenticationService()
def user_login(
self,
request: Request,
) -> Any:
return self.auth_service.user_login(request)
def user_signup(self, request: Request) -> Any:
return self.auth_service.user_signup(request)
def authorization_callback(
self, request: Request, backend: str = settings.DEFAULT_MODEL_BACKEND
) -> Any:
"""Handle authorization callback.
This function processes the authorization callback from
an external service.
Args:
request (Request): Request instance
backend (str, optional): backend used to use login.
Defaults: settings.DEFAULT_MODEL_BACKEND.
Returns:
Any: Redirect response
"""
callback_data: CallbackData = self.auth_service.get_callback_data(
request=request
)
user: User = self.get_or_create_user_by_email(request, callback_data)
try:
member = self.auth_service.handle_invited_user_while_callback(
request=request, user=user
)
except Exception as ex:
"""Error code reference
frontend/src/components/error/GenericError/GenericError.jsx."""
if ex.code == AuthoErrorCode.IDM: # type: ignore
query_params = {"code": AuthoErrorCode.IDM}
return redirect(
f"{settings.ERROR_URL}?{urlencode(query_params)}"
)
elif ex.code == AuthoErrorCode.UMM: # type: ignore
query_params = {"code": AuthoErrorCode.UMM}
return redirect(
f"{settings.ERROR_URL}?{urlencode(query_params)}"
)
return redirect(f"{settings.ERROR_URL}")
if member.organization_id and member.role and len(member.role) > 0:
organization: Optional[Organization] = (
OrganizationService.get_organization_by_org_id(
member.organization_id
)
)
if organization:
try:
self.create_tenant_user(
organization=organization, user=user
)
except UndefinedTable:
pass
response = self.auth_service.handle_authorization_callback(
user=user,
data=callback_data,
redirect_url=request.GET.get("redirect_url"),
)
django_login(request, user, backend)
return response
def user_organizations(self, request: Request) -> Any:
"""List a user's organizations.
Args:
user (User): User instance
z_code (str): _description_
Returns:
list[OrganizationData]: _description_
"""
try:
organizations = self.auth_service.user_organizations(request)
except Exception as ex:
#
self.user_logout(request)
if ex.code == AuthoErrorCode.USF: # type: ignore
response = Response(
status=status.HTTP_412_PRECONDITION_FAILED,
data={"domain": ex.data.get("domain")}, # type: ignore
)
return response
user: User = request.user
org_ids = {org.id for org in organizations}
CacheService.set_user_organizations(user.user_id, list(org_ids))
serialized_organizations = GetOrganizationsResponseSerializer(
organizations, many=True
).data
response = Response(
status=status.HTTP_200_OK,
data={
"message": "success",
"organizations": serialized_organizations,
},
)
if Cookie.CSRFTOKEN not in request.COOKIES:
csrf_token = csrf.get_token(request)
response.set_cookie(Cookie.CSRFTOKEN, csrf_token)
return response
def set_user_organization(
self, request: Request, organization_id: str
) -> Response:
user: User = request.user
organization_ids = CacheService.get_user_organizations(user.user_id)
if not organization_ids:
z_organizations: list[OrganizationData] = (
self.auth_service.get_organizations_by_user_id(user.user_id)
)
organization_ids = {org.id for org in z_organizations}
if organization_id and organization_id in organization_ids:
organization = OrganizationService.get_organization_by_org_id(
organization_id
)
if not organization:
try:
organization_data: OrganizationData = (
self.auth_service.get_organization_by_org_id(
organization_id
)
)
except ValueError:
raise OrganizationNotExist()
try:
organization = OrganizationService.create_organization(
organization_data.name,
organization_data.display_name,
organization_data.id,
)
except IntegrityError:
raise DuplicateData(
f"{ErrorMessage.ORGANIZATION_EXIST}, \
{ErrorMessage.DUPLICATE_API}"
)
self.create_tenant_user(organization=organization, user=user)
user_info: Optional[UserInfo] = self.get_user_info(request)
serialized_user_info = SetOrganizationsResponseSerializer(
user_info
).data
organization_info = OrganizationSerializer(organization).data
response: Response = Response(
status=status.HTTP_200_OK,
data={
"user": serialized_user_info,
"organization": organization_info,
},
)
# Update user session data in redis
user_session_info: dict[str, Any] = (
CacheService.get_user_session_info(user.email)
)
user_session_info["current_org"] = organization_id
CacheService.set_user_session_info(user_session_info)
response.set_cookie(Cookie.ORG_ID, organization_id)
return response
return Response(status=status.HTTP_403_FORBIDDEN)
def get_user_info(self, request: Request) -> Optional[UserInfo]:
return self.auth_service.get_user_info(request)
def is_admin_by_role(self, role: str) -> bool:
"""Check the role is act as admin in the context of authentication
plugin.
Args:
role (str): role
Returns:
bool: _description_
"""
return self.auth_service.is_admin_by_role(role=role)
def get_organization_info(self, org_id: str) -> Optional[Organization]:
organization = self.auth_service.get_organization_info(org_id)
if not organization:
organization = OrganizationService.get_organization_by_org_id(
org_id=org_id
)
return organization
def make_organization_and_add_member(
self,
user_id: str,
user_name: str,
organization_name: Optional[str] = None,
display_name: Optional[str] = None,
) -> Optional[OrganizationData]:
return self.auth_service.make_organization_and_add_member(
user_id, user_name, organization_name, display_name
)
def make_user_organization_name(self) -> str:
return self.auth_service.make_user_organization_name()
def make_user_organization_display_name(self, user_name: str) -> str:
return self.auth_service.make_user_organization_display_name(user_name)
def user_logout(self, request: Request) -> Response:
response = self.auth_service.user_logout(request=request)
django_logout(request)
return response
def get_organization_members_by_org_id(
self, organization_id: Optional[str] = None
) -> list[OrganizationMember]:
members: list[OrganizationMember] = OrganizationMember.objects.all()
return members
def get_organization_members_by_user(
self, user: User
) -> OrganizationMember:
member: OrganizationMember = OrganizationMember.objects.filter(
user=user
).first()
return member
def get_user_roles(self) -> list[UserRoleData]:
return self.auth_service.get_roles()
def get_user_invitations(
self, organization_id: str
) -> list[MemberInvitation]:
return self.auth_service.get_invitations(
organization_id=organization_id
)
def delete_user_invitation(
self, organization_id: str, invitation_id: str
) -> bool:
return self.auth_service.delete_invitation(
organization_id=organization_id, invitation_id=invitation_id
)
def reset_user_password(self, user: User) -> Response:
return self.auth_service.reset_user_password(user)
def invite_user(
self,
admin: User,
org_id: str,
user_list: list[dict[str, Union[str, None]]],
) -> list[UserInviteResponse]:
"""Invites users to join an organization.
Args:
admin (User): Admin user initiating the invitation.
org_id (str): ID of the organization to which users are invited.
user_list (list[dict[str, Union[str, None]]]):
List of user details for invitation.
Returns:
list[UserInviteResponse]: List of responses for each
user invitation.
"""
admin_user = OrganizationMember.objects.get(user=admin.id)
if not self.auth_service.is_organization_admin(admin_user):
raise Forbidden()
response = []
for user_item in user_list:
email = user_item.get("email")
role = user_item.get("role")
if email:
user = self.organization_member_service.get_user_by_email(
email=email
)
user_response = {}
user_response["email"] = email
message: str = "User invitation successful"
if user:
# Already in organization
status = False
message = "User is already part of current organization."
else:
try:
self.auth_service.check_user_organization_association(
user_email=email
)
status = self.auth_service.invite_user(
admin_user, org_id, email, role=role
)
except Exception as exception:
status = False
message = exception.message # type: ignore
response.append(
UserInviteResponse(
email=email,
status="success" if status else "failed",
message=message,
)
)
return response
def remove_users_from_organization(
self, admin: User, organization_id: str, user_emails: list[str]
) -> bool:
admin_user = OrganizationMember.objects.get(user=admin.id)
user_ids = OrganizationMember.objects.filter(
user__email__in=user_emails
).values_list(
OrganizationMemberModel.USER_ID, OrganizationMemberModel.ID
)
user_ids_list: list[str] = []
ids_list: list[str] = []
for user in user_ids:
user_ids_list.append(user[0])
ids_list.append(user[1])
if len(user_ids_list) > 0:
is_removed = self.auth_service.remove_users_from_organization(
admin=admin_user,
organization_id=organization_id,
user_ids=user_ids_list,
)
else:
is_removed = False
if is_removed:
OrganizationMember.objects.filter(user__in=ids_list).delete()
return is_removed
def add_user_role(
self, admin: User, org_id: str, email: str, role: str
) -> Optional[str]:
admin_user = OrganizationMember.objects.get(user=admin.id)
user = self.organization_member_service.get_user_by_email(email=email)
if user:
current_roles = self.auth_service.add_organization_user_role(
admin_user, org_id, user.user.user_id, [role]
)
if current_roles:
self.save_orgnanization_user_role(
user_id=user.user.user_id, role=current_roles[0]
)
return current_roles[0]
else:
return None
def remove_user_role(
self, admin: User, org_id: str, email: str, role: str
) -> Optional[str]:
admin_user = OrganizationMember.objects.get(user=admin.id)
organization_member = (
self.organization_member_service.get_user_by_email(email=email)
)
if organization_member:
current_roles = self.auth_service.remove_organization_user_role(
admin_user, org_id, organization_member.user.user_id, [role]
)
if current_roles:
self.save_orgnanization_user_role(
user_id=organization_member.user.user_id,
role=current_roles[0],
)
return current_roles[0]
else:
return None
def save_orgnanization_user_role(self, user_id: str, role: str) -> None:
organization_user = (
self.organization_member_service.get_user_by_user_id(
user_id=user_id
)
)
if organization_user:
# consider single role
organization_user.role = role
organization_user.save()
def create_tenant_user(
self, organization: Organization, user: User
) -> None:
with tenant_context(organization):
existing_tenant_user = (
self.organization_member_service.get_user_by_id(id=user.id)
)
if existing_tenant_user:
Logger.info(f"{existing_tenant_user.user.email} Already exist")
else:
account_user = self.get_or_create_user(user=user)
if account_user:
user_roles = (
self.auth_service.get_organization_role_of_user(
user_id=account_user.user_id,
organization_id=organization.organization_id,
)
)
user_role = user_roles[0]
tenant_user: OrganizationMember = OrganizationMember(
user=user, role=user_role
)
tenant_user.save()
else:
raise UserNotExistError()
def get_or_create_user_by_email(
self, request: Request, callback_data: CallbackData
) -> Union[User, OrganizationMember]:
email = callback_data.email
user_service = UserService()
user = user_service.get_user_by_email(email)
if not user:
user_id = callback_data.user_id
user = user_service.create_user(email, user_id)
return user
def get_or_create_user(
self, user: User
) -> Optional[Union[User, OrganizationMember]]:
user_service = UserService()
if user.id:
account_user: Optional[User] = user_service.get_user_by_id(user.id)
if account_user:
return account_user
elif user.email:
account_user = user_service.get_user_by_email(email=user.email)
if account_user:
return account_user
if user.user_id:
user.save()
return user
elif user.email and user.user_id:
account_user = user_service.create_user(
email=user.email, user_id=user.user_id
)
return account_user
return None

View File

@@ -0,0 +1,38 @@
from typing import Any
from account.constants import DefaultOrg
from account.dto import CallbackData, MemberData, OrganizationData
from rest_framework.request import Request
class AuthenticationHelper:
def __init__(self) -> None:
pass
def get_organizations_by_user_id(self) -> list[OrganizationData]:
organizationData: OrganizationData = OrganizationData(
id=DefaultOrg.MOCK_ORG,
display_name=DefaultOrg.MOCK_ORG,
name=DefaultOrg.MOCK_ORG,
)
return [organizationData]
def get_authorize_token(rself, equest: Request) -> CallbackData:
return CallbackData(
user_id=DefaultOrg.MOCK_USER_ID,
email=DefaultOrg.MOCK_USER_EMAIL,
token="",
)
def list_of_members_from_user_model(
self, model_data: list[Any]
) -> list[MemberData]:
members: list[MemberData] = []
for data in model_data:
user_id = data.user_id
email = data.email
name = data.username
members.append(MemberData(user_id=user_id, email=email, name=name))
return members

View File

@@ -0,0 +1,98 @@
import logging
import os
from importlib import import_module
from typing import Any
from account.constants import PluginConfig
from django.apps import apps
Logger = logging.getLogger(__name__)
def _load_plugins() -> dict[str, dict[str, Any]]:
"""Iterating through the Authentication plugins and register their
metadata."""
auth_app = apps.get_app_config(PluginConfig.PLUGINS_APP)
auth_package_path = auth_app.module.__package__
auth_dir = os.path.join(auth_app.path, PluginConfig.AUTH_PLUGIN_DIR)
auth_package_path = f"{auth_package_path}.{PluginConfig.AUTH_PLUGIN_DIR}"
auth_modules = {}
for item in os.listdir(auth_dir):
# Loads a plugin only if name starts with `auth`.
if not item.startswith(PluginConfig.AUTH_MODULE_PREFIX):
continue
# Loads a plugin if it is in a directory.
if os.path.isdir(os.path.join(auth_dir, item)):
auth_module_name = item
# Loads a plugin if it is a shared library.
# Module name is extracted from shared library name.
# `auth.platform_architecture.so` will be file name and
# `auth` will be the module name.
elif item.endswith(".so"):
auth_module_name = item.split(".")[0]
else:
continue
try:
full_module_path = f"{auth_package_path}.{auth_module_name}"
module = import_module(full_module_path)
metadata = getattr(module, PluginConfig.AUTH_METADATA, {})
if metadata.get(PluginConfig.METADATA_IS_ACTIVE, False):
auth_modules[auth_module_name] = {
PluginConfig.AUTH_MODULE: module,
PluginConfig.AUTH_METADATA: module.metadata,
}
Logger.info(
"Loaded auth plugin: %s, is_active: %s",
module.metadata["name"],
module.metadata["is_active"],
)
else:
Logger.warning(
"Metadata is not active for %s authentication module.",
auth_module_name,
)
except ModuleNotFoundError as exception:
Logger.error(
"Error while importing authentication module : %s",
exception,
)
if len(auth_modules) > 1:
raise ValueError(
"Multiple authentication modules found."
"Only one authentication method is allowed."
)
elif len(auth_modules) == 0:
Logger.warning(
"No authentication modules found."
"Application will start without authentication module"
)
return auth_modules
class AuthenticationPluginRegistry:
auth_modules: dict[str, dict[str, Any]] = _load_plugins()
@classmethod
def is_plugin_available(cls) -> bool:
"""Check if any authentication plugin is available.
Returns:
bool: True if a plugin is available, False otherwise.
"""
return len(cls.auth_modules) > 0
@classmethod
def get_plugin(cls) -> Any:
"""Get the selected authentication plugin.
Returns:
AuthenticationService: Selected authentication plugin instance.
"""
chosen_auth_module = next(iter(cls.auth_modules.values()))
chosen_metadata = chosen_auth_module[PluginConfig.AUTH_METADATA]
service_class_name = chosen_metadata[
PluginConfig.METADATA_SERVICE_CLASS
]
return service_class_name()

View File

@@ -0,0 +1,326 @@
import logging
import uuid
from typing import Any, Optional
from account.authentication_helper import AuthenticationHelper
from account.cache_service import CacheService
from account.constants import Common, DefaultOrg
from account.custom_exceptions import Forbidden, MethodNotImplemented
from account.dto import (
CallbackData,
MemberData,
MemberInvitation,
OrganizationData,
ResetUserPasswordDto,
UserInfo,
UserRoleData,
UserSessionInfo,
)
from account.enums import UserRole
from account.models import Organization, User
from account.organization import OrganizationService
from django.http import HttpRequest
from rest_framework.request import Request
from rest_framework.response import Response
from tenant_account.models import OrganizationMember as OrganizationMember
Logger = logging.getLogger(__name__)
class AuthenticationService:
def __init__(self) -> None:
self.authentication_helper = AuthenticationHelper()
self.default_user: User = self.get_user()
self.default_organization: Organization = self.user_organization()
self.user_session_info = self.get_user_session_info()
def get_current_organization(self) -> Organization:
return self.default_organization
def get_current_user(self) -> User:
return self.default_user
def get_current_user_session(self) -> UserSessionInfo:
return self.user_session_info
def user_login(self, request: HttpRequest) -> Any:
raise MethodNotImplemented()
def user_signup(self, request: HttpRequest) -> Any:
raise MethodNotImplemented()
def is_admin_by_role(self, role: str) -> bool:
"""Check the role with actual admin Role.
Args:
role (str): input string
Returns:
bool: _description_
"""
try:
return UserRole(role.lower()) == UserRole.ADMIN
except ValueError:
return False
def get_callback_data(self, request: Request) -> CallbackData:
return CallbackData(
user_id=DefaultOrg.MOCK_USER_ID,
email=DefaultOrg.MOCK_USER_EMAIL,
token="",
)
def user_organization(self) -> Organization:
return Organization(
name=DefaultOrg.ORGANIZATION_NAME,
display_name=DefaultOrg.ORGANIZATION_NAME,
organization_id=DefaultOrg.ORGANIZATION_NAME,
schema_name=DefaultOrg.ORGANIZATION_NAME,
)
def handle_invited_user_while_callback(
self, request: Request, user: User
) -> MemberData:
member_data: MemberData = MemberData(
user_id=self.default_user.user_id,
organization_id=self.default_organization.organization_id,
role=[UserRole.ADMIN.value],
)
return member_data
def handle_authorization_callback(
self, user: User, data: CallbackData, redirect_url: str = ""
) -> Response:
return Response()
def add_to_organization(
self,
request: Request,
user: User,
data: Optional[dict[str, Any]] = None,
) -> MemberData:
member_data: MemberData = MemberData(
user_id=self.default_user.user_id,
organization_id=self.default_organization.organization_id,
)
return member_data
def remove_users_from_organization(
self,
admin: OrganizationMember,
organization_id: str,
user_ids: list[str],
) -> bool:
raise MethodNotImplemented()
def user_organizations(self, request: Request) -> list[OrganizationData]:
organizationData: OrganizationData = OrganizationData(
id=self.default_organization.organization_id,
display_name=self.default_organization.display_name,
name=self.default_organization.name,
)
return [organizationData]
def get_organizations_by_user_id(self, id: str) -> list[OrganizationData]:
organizationData: OrganizationData = OrganizationData(
id=self.default_organization.organization_id,
display_name=self.default_organization.display_name,
name=self.default_organization.name,
)
return [organizationData]
def get_organization_role_of_user(
self, user_id: str, organization_id: str
) -> list[str]:
return [UserRole.ADMIN.value]
def is_organization_admin(self, member: OrganizationMember) -> bool:
"""Check if the organization member has administrative privileges.
Args:
member (OrganizationMember): The organization member to check.
Returns:
bool: True if the user has administrative privileges,
False otherwise.
"""
try:
return UserRole(member.role) == UserRole.ADMIN
except ValueError:
return False
def check_user_organization_association(self, user_email: str) -> None:
"""Check if the user is already associated with any organizations.
Raises:
- UserAlreadyAssociatedException:
If the user is already associated with organizations.
"""
return None
def get_roles(self) -> list[UserRoleData]:
return [
UserRoleData(name=UserRole.ADMIN.value),
UserRoleData(name=UserRole.USER.value),
]
def get_invitations(self, organization_id: str) -> list[MemberInvitation]:
raise MethodNotImplemented()
def delete_invitation(
self, organization_id: str, invitation_id: str
) -> bool:
raise MethodNotImplemented()
def add_organization_user_role(
self,
admin: User,
organization_id: str,
user_id: str,
role_ids: list[str],
) -> list[str]:
if admin.role == UserRole.ADMIN.value:
return role_ids
raise Forbidden
def remove_organization_user_role(
self,
admin: User,
organization_id: str,
user_id: str,
role_ids: list[str],
) -> list[str]:
if admin.role == UserRole.ADMIN.value:
return role_ids
raise Forbidden
def get_organization_by_org_id(self, id: str) -> OrganizationData:
organizationData: OrganizationData = OrganizationData(
id=DefaultOrg.ORGANIZATION_NAME,
display_name=DefaultOrg.ORGANIZATION_NAME,
name=DefaultOrg.ORGANIZATION_NAME,
)
return organizationData
def get_user(self) -> User:
user = CacheService.get_user_session_info(DefaultOrg.MOCK_USER_EMAIL)
if not user:
try:
user = User.objects.get(email=DefaultOrg.MOCK_USER_EMAIL)
except User.DoesNotExist:
user = User(
username=DefaultOrg.MOCK_USER,
user_id=DefaultOrg.MOCK_USER_ID,
email=DefaultOrg.MOCK_USER_EMAIL,
)
user.save()
if isinstance(user, User):
id = user.id
user_id = user.user_id
email = user.email
else:
id = user[Common.ID]
user_id = user[Common.USER_ID]
email = user[Common.USER_EMAIL]
current_org = Common.PUBLIC_SCHEMA_NAME
user_session_info: UserSessionInfo = UserSessionInfo(
id=id,
user_id=user_id,
email=email,
current_org=current_org,
)
CacheService.set_user_session_info(user_session_info)
user_info = User(id=id, user_id=user_id, username=email, email=email)
return user_info
def get_user_info(self, request: Request) -> Optional[UserInfo]:
user: User = request.user
if user:
return UserInfo(
id=user.id,
user_id=user.user_id,
name=user.username,
display_name=user.username,
email=user.email,
)
else:
user = self.get_user()
return UserInfo(
id=user.id,
user_id=user.user_id,
name=user.username,
display_name=user.username,
email=user.email,
)
def get_user_session_info(self) -> UserSessionInfo:
user_session_info_dict = CacheService.get_user_session_info(
self.default_user.email
)
if not user_session_info_dict:
user_session_info: UserSessionInfo = UserSessionInfo(
id=self.default_user.id,
user_id=self.default_user.user_id,
email=self.default_user.email,
current_org=self.default_organization.organization_id,
)
CacheService.set_user_session_info(user_session_info)
else:
user_session_info = UserSessionInfo.from_dict(
user_session_info_dict
)
return user_session_info
def get_organization_info(self, org_id: str) -> Optional[Organization]:
return OrganizationService.get_organization_by_org_id(org_id=org_id)
def make_organization_and_add_member(
self,
user_id: str,
user_name: str,
organization_name: Optional[str] = None,
display_name: Optional[str] = None,
) -> Optional[OrganizationData]:
organization: OrganizationData = OrganizationData(
id=str(uuid.uuid4()),
display_name=DefaultOrg.MOCK_ORG,
name=DefaultOrg.MOCK_ORG,
)
return organization
def make_user_organization_name(self) -> str:
return str(uuid.uuid4())
def make_user_organization_display_name(self, user_name: str) -> str:
name = f"{user_name}'s" if user_name else "Your"
return f"{name} organization"
def user_logout(self, request: HttpRequest) -> Response:
raise MethodNotImplemented()
def get_user_id_from_token(
self, token: Optional[dict[str, Any]]
) -> Response:
return DefaultOrg.MOCK_USER_ID
def get_organization_members_by_org_id(
self, organization_id: str
) -> list[MemberData]:
users: list[OrganizationMember] = OrganizationMember.objects.all()
return self.authentication_helper.list_of_members_from_user_model(users)
def reset_user_password(self, user: User) -> ResetUserPasswordDto:
raise MethodNotImplemented()
def invite_user(
self,
admin: OrganizationMember,
org_id: str,
email: str,
role: Optional[str] = None,
) -> bool:
raise MethodNotImplemented()

View File

@@ -0,0 +1,134 @@
from typing import Any, Optional, Union
import redis
from account.custom_cache import CustomCache
from account.dto import UserSessionInfo
from django.conf import settings
from django.core.cache import cache
class CacheService:
def __init__(self) -> None:
self.cache = redis.Redis(
host=settings.REDIS_HOST,
port=int(settings.REDIS_PORT),
password=settings.REDIS_PASSWORD,
username=settings.REDIS_USER,
)
def get_a_key(self, key: str) -> Optional[Any]:
data = self.cache.get(str(key))
if data is not None:
return data.decode("utf-8")
return data
def set_a_key(self, key: str, value: Any) -> None:
self.cache.set(
str(key),
value,
int(settings.WORKFLOW_ACTION_EXPIRATION_TIME_IN_SECOND),
)
@staticmethod
def set_cookie(cookie: str, token: dict[str, Any]) -> None:
cache.set(cookie, token)
@staticmethod
def get_cookie(cookie: str) -> dict[str, Any]:
data: dict[str, Any] = cache.get(cookie)
return data
@staticmethod
def set_user_session_info(
user_session_info: Union[UserSessionInfo, dict[str, Any]]
) -> None:
if isinstance(user_session_info, UserSessionInfo):
email = user_session_info.email
user_session = user_session_info.to_dict()
else:
email = user_session_info["email"]
user_session = user_session_info
session_info_key: str = CacheService.get_user_session_info_key(email)
cache.set(
session_info_key,
user_session,
int(settings.SESSION_EXPIRATION_TIME_IN_SECOND),
)
@staticmethod
def get_user_session_info(email: str) -> dict[str, Any]:
session_info_key: str = CacheService.get_user_session_info_key(email)
data: dict[str, Any] = cache.get(session_info_key)
return data
@staticmethod
def get_user_session_info_key(email: str) -> str:
session_info_key: str = f"session:{email}"
return session_info_key
@staticmethod
def check_a_key_exist(key: str, version: Any = None) -> bool:
data: bool = cache.has_key(key, version)
return data
@staticmethod
def delete_a_key(key: str, version: Any = None) -> None:
cache.delete(key, version)
@staticmethod
def set_user_organizations(user_id: str, organizations: list[str]) -> None:
key: str = f"{user_id}|organizations"
cache.set(key, list(organizations))
@staticmethod
def get_user_organizations(user_id: str) -> Any:
key: str = f"{user_id}|organizations"
return cache.get(key)
@staticmethod
def remove_user_organizations(user_id: str) -> Any:
key: str = f"{user_id}|organizations"
return cache.delete(key)
@staticmethod
def add_cookie_id_to_user(user_id: str, cookie_id: str) -> None:
custom_cache = CustomCache()
key: str = f"{user_id}|cookies"
custom_cache.rpush(key, cookie_id)
@staticmethod
def remove_cookie_id_from_user(user_id: str, cookie_id: str) -> None:
custom_cache = CustomCache()
key: str = f"{user_id}|cookies"
custom_cache.lrem(key, cookie_id)
@staticmethod
def remove_all_session_keys(
user_id: Optional[str] = None,
cookie_id: Optional[str] = None,
key: Optional[str] = None,
) -> None:
if cookie_id is not None:
cache.delete(cookie_id)
if user_id is not None:
cache.delete(user_id)
CacheService.remove_user_organizations(user_id)
if key is not None:
cache.delete(key)
# @staticmethod
# def get_cookie_ids_for_user(user_id: str) -> list[str]:
# custom_cache = CustomCache()
# key: str = f"{user_id}|cookies"
# cookie_ids = custom_cache.lrange(key, 0, -1)
# return cookie_ids
# KEY_FUNCTION for cache settings
def custom_key_function(key: str, key_prefix: str, version: int) -> str:
if version > 1:
return f"{key_prefix}:{version}:{key}"
if key_prefix:
return f"{key_prefix}:{key}"
else:
return key

View File

@@ -0,0 +1,74 @@
class OAuthConstant:
TOKEN_USER_INFO_FEILD = "userinfo"
TOKEN_ORG_ID_FEILD = "org_id"
TOKEN_EMAIL_FEILD = "email"
TOKEN_Z_ID_FEILD = "sub"
TOKEN_USER_NAME_FEILD = "name"
TOKEN_PRIMARY_Z_ID_FEILD = "primary_sub"
class LoginConstant:
INVITATION = "invitation"
ORGANIZATION = "organization"
ORGANIZATION_NAME = "organization_name"
class Common:
NEXT_URL_VARIABLE = "next"
PUBLIC_SCHEMA_NAME = "public"
ID = "id"
USER_ID = "user_id"
USER_EMAIL = "email"
USER_EMAILS = "emails"
USER_IDS = "user_ids"
USER_ROLE = "role"
MAX_EMAIL_IN_REQUEST = 10
class UserModel:
USER_ID = "user_id"
ID = "id"
class OrganizationMemberModel:
USER_ID = "user__user_id"
ID = "user__id"
class Cookie:
ORG_ID = "org_id"
Z_CODE = "z_code"
CSRFTOKEN = "csrftoken"
class ErrorMessage:
ORGANIZATION_EXIST = "Organization already exists"
DUPLICATE_API = "It appears that a duplicate call may have been made."
class DefaultOrg:
ORGANIZATION_NAME = "mock_org"
MOCK_ORG = "mock_org"
MOCK_USER = "mock_user"
MOCK_USER_ID = "mock_user_id"
MOCK_USER_EMAIL = "email@mock.com"
class PluginConfig:
PLUGINS_APP = "plugins"
AUTH_MODULE_PREFIX = "auth"
AUTH_PLUGIN_DIR = "authentication"
AUTH_MODULE = "module"
AUTH_METADATA = "metadata"
METADATA_SERVICE_CLASS = "service_class"
METADATA_IS_ACTIVE = "is_active"
class AuthoErrorCode:
"""Error code reference
frontend/src/components/error/GenericError/GenericError.jsx."""
IDM = "IDM"
UMM = "UMM"
INF = "INF"
USF = "USF"

View File

@@ -0,0 +1,144 @@
from typing import Optional
from account.authentication_plugin_registry import AuthenticationPluginRegistry
from account.authentication_service import AuthenticationService
from account.cache_service import CacheService
from account.constants import Common, Cookie, DefaultOrg
from account.dto import UserSessionInfo
from account.models import User
from account.user import UserService
from backend.constants import RequestHeader
from django.conf import settings
from django.core.cache import cache
from django.db import connection
from django.http import HttpRequest, HttpResponse, JsonResponse
from tenant_account.organization_member_service import OrganizationMemberService
class CustomAuthMiddleware:
def __init__(self, get_response: HttpResponse):
self.get_response = get_response
# One-time configuration and initialization.
def __call__(self, request: HttpRequest) -> HttpResponse:
user = None
request.user = user
# Returns result without authenticated if added in whitelisted paths
if any(
request.path.startswith(path) for path in settings.WHITELISTED_PATHS
):
return self.get_response(request)
tenantAccessiblePublicPath = False
if any(
request.path.startswith(path)
for path in settings.TENANT_ACCESSIBLE_PUBLIC_PATHS
):
tenantAccessiblePublicPath = True
# Authenticating With API_KEY
x_api_key = request.headers.get(RequestHeader.X_API_KEY)
if (
settings.INTERNAL_SERVICE_API_KEY
and x_api_key == settings.INTERNAL_SERVICE_API_KEY
): # Should API Key be in settings or just env alone?
return self.get_response(request)
if not AuthenticationPluginRegistry.is_plugin_available():
self.without_authentication(request, user)
elif request.COOKIES:
self.authenticate_with_cookies(request, tenantAccessiblePublicPath)
if (
request.user # type: ignore
and request.session
and "user" in request.session
):
response = self.get_response(request) # type: ignore
return response
return JsonResponse({"message": "Unauthorized"}, status=401)
def without_authentication(
self, request: HttpRequest, user: Optional[User]
) -> None:
org_id = DefaultOrg.MOCK_ORG
user_id = DefaultOrg.MOCK_USER_ID
email = DefaultOrg.MOCK_USER_EMAIL
user_session_info = CacheService.get_user_session_info(email)
if user is None:
try:
user_service = UserService()
user = user_service.get_user_by_user_id(user_id)
if not user:
member = user_service.get_user_by_user_id(user_id)
if member:
user = member.user
except AttributeError:
pass
if user is None:
authentication_service = AuthenticationService()
user = authentication_service.get_current_user()
if user:
if not user_session_info:
user_info: UserSessionInfo = UserSessionInfo(
id=user.id,
user_id=user.user_id,
email=user.email,
current_org=org_id,
)
CacheService.set_user_session_info(user_info)
user_session_info = CacheService.get_user_session_info(email)
request.user = user
request.org_id = org_id
request.session["user"] = user_session_info
request.session.save()
def authenticate_with_cookies(
self,
request: HttpRequest,
tenantAccessiblePublicPath: bool,
) -> None:
z_code: str = request.COOKIES.get(Cookie.Z_CODE)
token = cache.get(z_code) if z_code else None
if not token:
return
user_email = token["userinfo"]["email"]
user_session_info = CacheService.get_user_session_info(user_email)
if not user_session_info:
return
current_org = user_session_info["current_org"]
if not current_org:
return
if (
current_org != connection.get_schema()
and not tenantAccessiblePublicPath
):
return
if (
current_org == Common.PUBLIC_SCHEMA_NAME
or tenantAccessiblePublicPath
):
user_service = UserService()
else:
organization_member_service = OrganizationMemberService()
member = organization_member_service.get_user_by_email(user_email)
if not member:
return
user_service = UserService()
user = user_service.get_user_by_email(user_email)
if not user:
return
request.user = user
request.org_id = current_org
request.session["user"] = token
request.session.save()

View File

@@ -0,0 +1,13 @@
from typing import Any
from django.http import HttpRequest
from rest_framework.exceptions import AuthenticationFailed
def api_login_required(view_func: Any) -> Any:
def wrapper(request: HttpRequest, *args: Any, **kwargs: Any) -> Any:
if request.user and request.session and "user" in request.session:
return view_func(request, *args, **kwargs)
raise AuthenticationFailed("Unauthorized")
return wrapper

View File

@@ -0,0 +1,12 @@
from django_redis import get_redis_connection
class CustomCache:
def __init__(self) -> None:
self.cache = get_redis_connection("default")
def rpush(self, key: str, value: str) -> None:
self.cache.rpush(key, value)
def lrem(self, key: str, value: str) -> None:
self.cache.lrem(key, value)

View File

@@ -0,0 +1,66 @@
from typing import Optional
from rest_framework.exceptions import APIException
class ConflictError(Exception):
def __init__(self, message: str) -> None:
self.message = message
super().__init__(self.message)
class MethodNotImplemented(APIException):
status_code = 501
default_detail = "Method Not Implemented"
class DuplicateData(APIException):
status_code = 400
default_detail = "Duplicate Data"
def __init__(
self, detail: Optional[str] = None, code: Optional[int] = None
):
if detail is not None:
self.detail = detail
if code is not None:
self.code = code
super().__init__(detail, code)
class TableNotExistError(APIException):
status_code = 400
default_detail = "Unknown Table"
def __init__(
self, detail: Optional[str] = None, code: Optional[int] = None
):
if detail is not None:
self.detail = detail
if code is not None:
self.code = code
super().__init__()
class UserNotExistError(APIException):
status_code = 400
default_detail = "Unknown User"
def __init__(
self, detail: Optional[str] = None, code: Optional[int] = None
):
if detail is not None:
self.detail = detail
if code is not None:
self.code = code
super().__init__()
class Forbidden(APIException):
status_code = 403
default_detail = "Do not have permission to perform this action."
class UserAlreadyAssociatedException(APIException):
status_code = 400
default_detail = "User is already associated with one organization."

130
backend/account/dto.py Normal file
View File

@@ -0,0 +1,130 @@
from dataclasses import dataclass
from typing import Any, Optional
@dataclass
class MemberData:
user_id: str
email: Optional[str] = None
name: Optional[str] = None
picture: Optional[str] = None
role: Optional[list[str]] = None
organization_id: Optional[str] = None
@dataclass
class OrganizationData:
id: str
display_name: str
name: str
@dataclass
class CallbackData:
user_id: str
email: str
token: Any
@dataclass
class OrganizationSignupRequestBody:
name: str
display_name: str
organization_id: str
@dataclass
class OrganizationSignupResponse:
name: str
display_name: str
organization_id: str
created_at: str
@dataclass
class UserInfo:
email: str
user_id: str
id: Optional[str] = None
name: Optional[str] = None
display_name: Optional[str] = None
family_name: Optional[str] = None
picture: Optional[str] = None
@dataclass
class UserSessionInfo:
id: str
user_id: str
email: str
current_org: str
@staticmethod
def from_dict(data: dict[str, Any]) -> "UserSessionInfo":
return UserSessionInfo(
id=data["id"],
user_id=data["user_id"],
email=data["email"],
current_org=data["current_org"],
)
def to_dict(self) -> Any:
return {
"id": self.id,
"user_id": self.user_id,
"email": self.email,
"current_org": self.current_org,
}
@dataclass
class GetUserReposne:
user: UserInfo
organizations: list[OrganizationData]
@dataclass
class ResetUserPasswordDto:
status: bool
message: str
@dataclass
class UserInviteResponse:
email: str
status: str
message: Optional[str] = None
@dataclass
class UserRoleData:
name: str
id: Optional[str] = None
description: Optional[str] = None
@dataclass
class MemberInvitation:
"""Represents an invitation to join an organization in Auth0.
Attributes:
id (str): The unique identifier for the invitation.
email (str): The user email.
roles (List[str]): The roles assigned to the invitee.
created_at (Optional[str]): The timestamp when the invitation
was created.
expires_at (Optional[str]): The timestamp when the invitation expires.
"""
id: str
email: str
roles: list[str]
created_at: Optional[str] = None
expires_at: Optional[str] = None
@dataclass
class UserOrganizationRole:
user_id: str
role: UserRoleData
organization_id: str

6
backend/account/enums.py Normal file
View File

@@ -0,0 +1,6 @@
from enum import Enum
class UserRole(Enum):
USER = "user"
ADMIN = "admin"

View File

@@ -0,0 +1,26 @@
from rest_framework.exceptions import APIException
class UserIdNotExist(APIException):
status_code = 404
default_detail = "User ID does not exist"
class UserAlreadyExistInOrganization(APIException):
status_code = 403
default_detail = "User allready exist in the organization"
class OrganizationNotExist(APIException):
status_code = 404
default_detail = "Organization does not exist"
class UnknownException(APIException):
status_code = 500
default_detail = "An unexpected error occurred"
class BadRequestException(APIException):
status_code = 400
default_detail = "Bad Request"

View File

@@ -0,0 +1,237 @@
# Generated by Django 4.2.1 on 2023-07-18 10:39
import django.contrib.auth.models
import django.contrib.auth.validators
import django.db.models.deletion
import django.utils.timezone
import django_tenants.postgresql_backend.base
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
("auth", "0012_alter_user_first_name_max_length"),
]
operations = [
migrations.CreateModel(
name="User",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("password", models.CharField(max_length=128, verbose_name="password")),
(
"last_login",
models.DateTimeField(
blank=True, null=True, verbose_name="last login"
),
),
(
"is_superuser",
models.BooleanField(
default=False,
help_text="Designates that this user has all permissions without explicitly assigning them.",
verbose_name="superuser status",
),
),
(
"username",
models.CharField(
error_messages={
"unique": "A user with that username already exists."
},
help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.",
max_length=150,
unique=True,
validators=[
django.contrib.auth.validators.UnicodeUsernameValidator()
],
verbose_name="username",
),
),
(
"first_name",
models.CharField(
blank=True, max_length=150, verbose_name="first name"
),
),
(
"last_name",
models.CharField(
blank=True, max_length=150, verbose_name="last name"
),
),
(
"email",
models.EmailField(
blank=True, max_length=254, verbose_name="email address"
),
),
(
"is_staff",
models.BooleanField(
default=False,
help_text="Designates whether the user can log into this admin site.",
verbose_name="staff status",
),
),
(
"is_active",
models.BooleanField(
default=True,
help_text="Designates whether this user should be treated as active. Unselect this instead of deleting accounts.",
verbose_name="active",
),
),
(
"date_joined",
models.DateTimeField(
default=django.utils.timezone.now, verbose_name="date joined"
),
),
("user_id", models.CharField()),
("project_storage_created", models.BooleanField(default=False)),
("modified_at", models.DateTimeField(auto_now=True)),
("created_at", models.DateTimeField(auto_now_add=True)),
(
"created_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="created_users",
to=settings.AUTH_USER_MODEL,
),
),
(
"groups",
models.ManyToManyField(
blank=True,
related_name="customuser_set",
related_query_name="customuser",
to="auth.group",
),
),
(
"modified_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="modified_users",
to=settings.AUTH_USER_MODEL,
),
),
(
"user_permissions",
models.ManyToManyField(
blank=True,
related_name="customuser_set",
related_query_name="customuser",
to="auth.permission",
),
),
],
options={
"verbose_name": "user",
"verbose_name_plural": "users",
"abstract": False,
},
managers=[
("objects", django.contrib.auth.models.UserManager()),
],
),
migrations.CreateModel(
name="Organization",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
(
"schema_name",
models.CharField(
db_index=True,
max_length=63,
unique=True,
validators=[
django_tenants.postgresql_backend.base._check_schema_name
],
),
),
("name", models.CharField(max_length=64)),
("display_name", models.CharField(max_length=64)),
("organization_id", models.CharField(max_length=64)),
("modified_at", models.DateTimeField(auto_now=True)),
("created_at", models.DateTimeField(auto_now=True)),
(
"created_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="created_orgs",
to=settings.AUTH_USER_MODEL,
),
),
(
"modified_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="modified_orgs",
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="Domain",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
(
"domain",
models.CharField(db_index=True, max_length=253, unique=True),
),
("is_primary", models.BooleanField(db_index=True, default=True)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="domains",
to="account.organization",
),
),
],
options={
"abstract": False,
},
),
]

View File

@@ -0,0 +1,39 @@
# mypy: ignore-errors
# Generated by Django 4.2.1 on 2023-07-18 10:40
from django.contrib.auth.hashers import make_password
from django.db import migrations
def create_public_tenant_and_domain(apps, schema_editor):
Organization = apps.get_model("account", "Organization")
# public tenant
tenant = Organization(
name="public",
display_name="public",
organization_id="public",
schema_name="public",
)
tenant.save()
User = apps.get_model("account", "User")
# public User admin
user = User(
username="admin",
email="admin@zipstack.com",
is_superuser=True,
is_staff=True,
password=make_password("ascon"),
)
user.save()
class Migration(migrations.Migration):
dependencies = [
("account", "0001_initial"),
]
operations = [
migrations.RunPython(create_public_tenant_and_domain),
]

View File

@@ -0,0 +1,65 @@
# Generated by Django 4.2.1 on 2023-11-02 05:22
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import uuid
class Migration(migrations.Migration):
dependencies = [
("account", "0002_auto_20230718_1040"),
]
operations = [
migrations.CreateModel(
name="PlatformKey",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("key", models.UUIDField(default=uuid.uuid4)),
(
"key_name",
models.CharField(blank=True, max_length=64, null=True, unique=True),
),
("is_active", models.BooleanField(default=False)),
(
"created_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="created_keys",
to=settings.AUTH_USER_MODEL,
),
),
(
"modified_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="modified_keys",
to=settings.AUTH_USER_MODEL,
),
),
(
"organization",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="related_org",
to="account.organization",
),
),
],
),
]

View File

@@ -0,0 +1,23 @@
# Generated by Django 4.2.1 on 2023-11-15 11:37
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("account", "0003_platformkey"),
]
operations = [
migrations.AlterField(
model_name="platformkey",
name="key_name",
field=models.CharField(blank=True, default="", max_length=64),
),
migrations.AddConstraint(
model_name="platformkey",
constraint=models.UniqueConstraint(
fields=("key_name", "organization"), name="unique_key_name"
),
),
]

View File

@@ -0,0 +1,39 @@
# Generated by Django 4.2.1 on 2024-02-13 11:52
from typing import Any
from account.models import EncryptionSecret
from cryptography.fernet import Fernet
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("account", "0004_alter_platformkey_key_name_and_more"),
]
def initialize_secret(apps: Any, schema_editor: Any) -> None:
EncryptionSecret.objects.create(
key=Fernet.generate_key().decode("utf-8")
)
operations = [
migrations.CreateModel(
name="EncryptionSecret",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("key", models.CharField(blank=True, max_length=64)),
],
),
migrations.RunPython(
initialize_secret, reverse_code=migrations.RunPython.noop
),
]

View File

141
backend/account/models.py Normal file
View File

@@ -0,0 +1,141 @@
import uuid
from backend.constants import FieldLengthConstants as FieldLength
from django.contrib.auth.models import AbstractUser, Group, Permission
from django.db import models
from django_tenants.models import DomainMixin, TenantMixin
NAME_SIZE = 64
KEY_SIZE = 64
class Organization(TenantMixin):
"""Stores data related to an organization.
The fields created_by and modified_by is updated after a
:model:`account.User` is created.
"""
name = models.CharField(max_length=NAME_SIZE)
display_name = models.CharField(max_length=NAME_SIZE)
organization_id = models.CharField(max_length=FieldLength.ORG_NAME_SIZE)
created_by = models.ForeignKey(
"User",
on_delete=models.SET_NULL,
related_name="created_orgs",
null=True,
blank=True,
)
modified_by = models.ForeignKey(
"User",
on_delete=models.SET_NULL,
related_name="modified_orgs",
null=True,
blank=True,
)
modified_at = models.DateTimeField(auto_now=True)
created_at = models.DateTimeField(auto_now=True)
auto_create_schema = True
class Domain(DomainMixin):
pass
class User(AbstractUser):
"""Stores data related to a user belonging to any organization.
Every org, user is assumed to be unique.
"""
# Third Party Authentication User ID
user_id = models.CharField()
project_storage_created = models.BooleanField(default=False)
created_by = models.ForeignKey(
"User",
on_delete=models.SET_NULL,
related_name="created_users",
null=True,
blank=True,
)
modified_by = models.ForeignKey(
"User",
on_delete=models.SET_NULL,
related_name="modified_users",
null=True,
blank=True,
)
modified_at = models.DateTimeField(auto_now=True)
created_at = models.DateTimeField(auto_now_add=True)
# Specify a unique related_name for the groups field
groups = models.ManyToManyField(
Group,
related_name="customuser_set",
related_query_name="customuser",
blank=True,
)
# Specify a unique related_name for the user_permissions field
user_permissions = models.ManyToManyField(
Permission,
related_name="customuser_set",
related_query_name="customuser",
blank=True,
)
def __str__(self): # type: ignore
return f"User({self.id}, email: {self.email}, userId: {self.user_id})"
class PlatformKey(models.Model):
"""Model to hold details of Platform keys.
Only users with admin role are allowed to perform any operation
related keys.
"""
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
key = models.UUIDField(default=uuid.uuid4)
key_name = models.CharField(
max_length=KEY_SIZE, null=False, blank=True, default=""
)
is_active = models.BooleanField(default=False)
organization = models.ForeignKey(
"Organization",
on_delete=models.SET_NULL,
related_name="related_org",
null=True,
blank=True,
)
created_by = models.ForeignKey(
"User",
on_delete=models.SET_NULL,
related_name="created_keys",
null=True,
blank=True,
)
modified_by = models.ForeignKey(
"User",
on_delete=models.SET_NULL,
related_name="modified_keys",
null=True,
blank=True,
)
class Meta:
constraints = [
models.UniqueConstraint(
fields=["key_name", "organization"],
name="unique_key_name",
),
]
class EncryptionSecret(models.Model):
key = models.CharField(
max_length=KEY_SIZE,
null=False,
blank=True,
)

View File

@@ -0,0 +1,42 @@
import logging
from typing import Optional
from account.models import Domain, Organization
from django.db import IntegrityError
Logger = logging.getLogger(__name__)
class OrganizationService:
def __init__(self): # type: ignore
pass
@staticmethod
def get_organization_by_org_id(org_id: str) -> Optional[Organization]:
try:
return Organization.objects.get(organization_id=org_id) # type: ignore
except Organization.DoesNotExist:
return None
@staticmethod
def create_organization(
name: str, display_name: str, organization_id: str
) -> Organization:
try:
organization: Organization = Organization(
name=name,
display_name=display_name,
organization_id=organization_id,
schema_name=organization_id,
)
organization.save()
except IntegrityError as error:
Logger.info(f"[Duplicate Id] Failed to create Organization Error: {error}")
raise error
# Add one or more domains for the tenant
domain = Domain()
domain.domain = organization_id
domain.tenant = organization
domain.is_primary = True
domain.save()
return organization

View File

@@ -0,0 +1,86 @@
import re
# from account.enums import Region
from account.models import Organization, User
from rest_framework import serializers
class OrganizationSignupSerializer(serializers.Serializer):
name = serializers.CharField(required=True, max_length=150)
display_name = serializers.CharField(required=True, max_length=150)
organization_id = serializers.CharField(required=True, max_length=30)
def validate_organization_id(self, value): # type: ignore
if not re.match(r"^[a-z0-9_-]+$", value):
raise serializers.ValidationError(
"organization_code should only contain alphanumeric characters,_ and -."
)
return value
class OrganizationCallbackSerializer(serializers.Serializer):
id = serializers.CharField(required=False)
class GetOrganizationsResponseSerializer(serializers.Serializer):
id = serializers.CharField()
display_name = serializers.CharField()
name = serializers.CharField()
# Add more fields as needed
def to_representation(self, instance): # type: ignore
data = super().to_representation(instance)
# Modify the representation if needed
return data
class GetOrganizationMembersResponseSerializer(serializers.Serializer):
user_id = serializers.CharField()
email = serializers.CharField()
name = serializers.CharField()
picture = serializers.CharField()
# Add more fields as needed
def to_representation(self, instance): # type: ignore
data = super().to_representation(instance)
# Modify the representation if needed
return data
class OrganizationSerializer(serializers.Serializer):
name = serializers.CharField()
organization_id = serializers.CharField()
class SetOrganizationsResponseSerializer(serializers.Serializer):
id = serializers.CharField()
email = serializers.CharField()
name = serializers.CharField()
display_name = serializers.CharField()
family_name = serializers.CharField()
picture = serializers.CharField()
# Add more fields as needed
def to_representation(self, instance): # type: ignore
data = super().to_representation(instance)
# Modify the representation if needed
return data
class ModelTenantSerializer(serializers.ModelSerializer):
class Meta:
model = Organization
fields = fields = ("name", "created_on")
class UserSerializer(serializers.ModelSerializer):
class Meta:
model = User
fields = ("id", "email")
class OrganizationSignupResponseSerializer(serializers.Serializer):
name = serializers.CharField()
display_name = serializers.CharField()
organization_id = serializers.CharField()
created_at = serializers.CharField()

View File

@@ -0,0 +1,11 @@
<!DOCTYPE html>
<html lang="en" xml:lang="en">
<head>
<meta charset="utf-8" />
<title>ZipstackID Django App Example</title>
</head>
<body>
<h1 id="profileDropDown">Welcome Guest</h1>
<p><a href="{% url 'login' %}" id="qsLoginBtn">Login</a></p>
</body>
</html>

1
backend/account/tests.py Normal file
View File

@@ -0,0 +1 @@
# Create your tests here.

20
backend/account/urls.py Normal file
View File

@@ -0,0 +1,20 @@
from account.views import (
callback,
create_organization,
get_organizations,
login,
logout,
set_organization,
signup,
)
from django.urls import path
urlpatterns = [
path("login", login, name="login"),
path("signup", signup, name="signup"),
path("logout", logout, name="logout"),
path("callback", callback, name="callback"),
path("organization", get_organizations, name="get_organizations"),
path("organization/<str:id>/set", set_organization, name="set_organization"),
path("organization/create", create_organization, name="create_organization"),
]

50
backend/account/user.py Normal file
View File

@@ -0,0 +1,50 @@
import logging
from typing import Any, Optional
from account.models import User
from django.db import IntegrityError
Logger = logging.getLogger(__name__)
class UserService:
def __init__(
self,
) -> None:
pass
def create_user(self, email: str, user_id: str) -> User:
try:
user: User = User(email=email, user_id=user_id, username=email)
user.save()
except IntegrityError as error:
Logger.info(f"[Duplicate Id] Failed to create User Error: {error}")
raise error
return user
def get_user_by_email(self, email: str) -> Optional[User]:
try:
user: User = User.objects.get(email=email)
return user
except User.DoesNotExist:
return None
def get_user_by_user_id(self, user_id: str) -> Any:
try:
return User.objects.get(user_id=user_id)
except User.DoesNotExist:
return None
def get_user_by_id(self, id: str) -> Any:
"""Retrieve a user by their ID, taking into account the schema context.
Args:
id (str): The ID of the user.
Returns:
Any: The user object if found, or None if not found.
"""
try:
return User.objects.get(id=id)
except User.DoesNotExist:
return None

125
backend/account/views.py Normal file
View File

@@ -0,0 +1,125 @@
import logging
from typing import Any
from account.authentication_controller import AuthenticationController
from account.dto import (
OrganizationSignupRequestBody,
OrganizationSignupResponse,
)
from account.models import Organization
from account.organization import OrganizationService
from account.serializer import (
OrganizationSignupResponseSerializer,
OrganizationSignupSerializer,
)
from rest_framework import status
from rest_framework.decorators import api_view
from rest_framework.request import Request
from rest_framework.response import Response
Logger = logging.getLogger(__name__)
@api_view(["POST"])
def create_organization(request: Request) -> Response:
serializer = OrganizationSignupSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
try:
requestBody: OrganizationSignupRequestBody = makeSignupRequestParams(
serializer
)
organization: Organization = OrganizationService.create_organization(
requestBody.name,
requestBody.display_name,
requestBody.organization_id,
)
response = makeSignupResponse(organization)
return Response(
status=status.HTTP_201_CREATED,
data={"message": "success", "tenant": response},
)
except Exception as error:
Logger.error(error)
return Response(
status=status.HTTP_500_INTERNAL_SERVER_ERROR, data="Unknown Error"
)
@api_view(["GET"])
def callback(request: Request) -> Response:
auth_controller = AuthenticationController()
return auth_controller.authorization_callback(request)
@api_view(["GET"])
def login(request: Request) -> Response:
auth_controller = AuthenticationController()
return auth_controller.user_login(request)
@api_view(["GET"])
def signup(request: Request) -> Response:
auth_controller = AuthenticationController()
return auth_controller.user_signup(request)
@api_view(["GET"])
def logout(request: Request) -> Response:
auth_controller = AuthenticationController()
return auth_controller.user_logout(request)
@api_view(["GET"])
def get_organizations(request: Request) -> Response:
"""get_organizations.
Retrieve the list of organizations to which the user belongs.
Args:
request (HttpRequest): _description_
Returns:
Response: A list of organizations with associated information.
"""
auth_controller = AuthenticationController()
return auth_controller.user_organizations(request)
@api_view(["POST"])
def set_organization(request: Request, id: str) -> Response:
"""set_organization.
Set the current organization to use.
Args:
request (HttpRequest): _description_
id (String): organization Id
Returns:
Response: Contains the User and Current organization details.
"""
auth_controller = AuthenticationController()
return auth_controller.set_user_organization(request, id)
def makeSignupRequestParams(
serializer: OrganizationSignupSerializer,
) -> OrganizationSignupRequestBody:
return OrganizationSignupRequestBody(
serializer.validated_data["name"],
serializer.validated_data["display_name"],
serializer.validated_data["organization_id"],
)
def makeSignupResponse(
organization: Organization,
) -> Any:
return OrganizationSignupResponseSerializer(
OrganizationSignupResponse(
organization.name,
organization.display_name,
organization.organization_id,
organization.created_at,
)
).data

View File

View File

@@ -0,0 +1,280 @@
import json
import logging
from typing import Any, Optional
import adapter_processor
from account.models import User
from adapter_processor.constants import AdapterKeys
from adapter_processor.exceptions import (
InternalServiceError,
InValidAdapterId,
TestAdapterException,
TestAdapterInputException,
)
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from platform_settings.exceptions import ActiveKeyNotFound
from platform_settings.platform_auth_service import (
PlatformAuthenticationService,
)
from unstract.adapters.adapterkit import Adapterkit
from unstract.adapters.base import Adapter
from unstract.adapters.enums import AdapterTypes
from unstract.adapters.exceptions import AdapterError
from unstract.adapters.x2text.constants import X2TextConstants
from .models import AdapterInstance
logger = logging.getLogger(__name__)
class AdapterProcessor:
@staticmethod
def get_json_schema(adapter_id: str) -> dict[str, Any]:
"""Function to return JSON Schema for Adapters."""
schema_details: dict[str, Any] = {}
updated_adapters = AdapterProcessor.__fetch_adapters_by_key_value(
AdapterKeys.ID, adapter_id
)
if len(updated_adapters) != 0:
try:
schema_details[AdapterKeys.JSON_SCHEMA] = json.loads(
updated_adapters[0].get(AdapterKeys.JSON_SCHEMA)
)
except Exception as exc:
logger.error(f"Error occured while parsing JSON Schema : {exc}")
raise InternalServiceError()
else:
logger.error(
f"Invalid adapter Id : {adapter_id} while fetching JSON Schema"
)
raise InValidAdapterId()
return schema_details
@staticmethod
def get_all_supported_adapters(type: str) -> list[dict[Any, Any]]:
"""Function to return list of all supported adapters."""
supported_adapters = []
updated_adapters = []
updated_adapters = AdapterProcessor.__fetch_adapters_by_key_value(
AdapterKeys.ADAPTER_TYPE, type
)
for each_adapter in updated_adapters:
supported_adapters.append(
{
AdapterKeys.ID: each_adapter.get(AdapterKeys.ID),
AdapterKeys.NAME: each_adapter.get(AdapterKeys.NAME),
AdapterKeys.DESCRIPTION: each_adapter.get(
AdapterKeys.DESCRIPTION
),
AdapterKeys.ICON: each_adapter.get(AdapterKeys.ICON),
AdapterKeys.ADAPTER_TYPE: each_adapter.get(
AdapterKeys.ADAPTER_TYPE
),
}
)
return supported_adapters
@staticmethod
def get_adapter_data_with_key(adapter_id: str, key_value: str) -> Any:
"""Generic Function to get adapter data with provided key."""
updated_adapters = AdapterProcessor.__fetch_adapters_by_key_value(
"id", adapter_id
)
if len(updated_adapters) == 0:
logger.error(
f"Invalid adapter ID {adapter_id} while invoking utility"
)
raise InValidAdapterId()
return AdapterProcessor.__fetch_adapters_by_key_value("id", adapter_id)[
0
].get(key_value)
@staticmethod
def test_adapter(adapter_id: str, adapter_metadata: dict[str, Any]) -> bool:
logger.info(f"Testing adapter: {adapter_id}")
try:
adapter_class = Adapterkit().get_adapter_class_by_adapter_id(
adapter_id
)
if (
adapter_metadata.pop(AdapterKeys.ADAPTER_TYPE)
== AdapterKeys.X2TEXT
):
adapter_metadata[
X2TextConstants.X2TEXT_HOST
] = settings.X2TEXT_HOST
adapter_metadata[
X2TextConstants.X2TEXT_PORT
] = settings.X2TEXT_PORT
platform_key = (
PlatformAuthenticationService.get_active_platform_key()
)
adapter_metadata[
X2TextConstants.PLATFORM_SERVICE_API_KEY
] = str(platform_key.key)
adapter_instance = adapter_class(adapter_metadata)
test_result: bool = adapter_instance.test_connection()
logger.info(f"{adapter_id} test result: {test_result}")
return test_result
except ActiveKeyNotFound:
raise
except Exception as e:
logger.error(f"Error while testing {adapter_id}: {e}")
if isinstance(e, AdapterError):
raise TestAdapterInputException(str(e.message))
elif isinstance(e, ActiveKeyNotFound):
raise e
else:
raise TestAdapterException()
@staticmethod
def __fetch_adapters_by_key_value(key: str, value: Any) -> Adapter:
"""Fetches a list of adapters that have an attribute matching key and
value."""
logger.info(f"Fetching adapter list for {key} with {value}")
adapter_kit = Adapterkit()
adapters = adapter_kit.get_adapters_list()
return [iterate for iterate in adapters if iterate[key] == value]
@staticmethod
def set_default_triad(default_triad: dict[str, str], user: User) -> None:
filter_params: dict[str, Any] = {}
try:
for key in default_triad:
filter_params.clear()
adapter_id = default_triad[key]
# Query rows where adapter_type=X and is_default=True
if key == AdapterKeys.LLM_DEFAULT:
adapter_type = AdapterTypes.LLM.name
elif key == AdapterKeys.EMBEDDING_DEFAULT:
adapter_type = AdapterTypes.EMBEDDING.name
elif key == AdapterKeys.VECTOR_DB_DEFAULT:
adapter_type = AdapterTypes.VECTOR_DB.name
filter_params["adapter_type"] = adapter_type
filter_params["is_default"] = True
filter_params["created_by"] = user
AdapterInstance.objects.filter(**filter_params).update(
is_default=False
)
# Update the adapter_id in the incoming
# list to set is_default=True
filter_params.clear()
try:
new_adapter_default: AdapterInstance = (
AdapterInstance.objects.get(pk=adapter_id)
)
new_adapter_default.is_default = True
new_adapter_default.save()
except (
adapter_processor.models.AdapterInstance.DoesNotExist
) as e:
logger.error(
f"Error while retrieving adapter: {adapter_id} "
f"reason: {e}"
)
raise InValidAdapterId()
logger.info("Changed defaults successfully")
except Exception as e:
logger.error(f"Unable to save defaults because: {e}")
if isinstance(e, InValidAdapterId):
raise e
else:
raise InternalServiceError()
@staticmethod
def get_adapter_instance_by_id(adapter_instance_id: str) -> Adapter:
"""Get the adapter instance by its ID.
Parameters:
- adapter_instance_id (str): The ID of the adapter instance.
Returns:
- Adapter: The adapter instance with the specified ID.
Raises:
- Exception: If there is an error while fetching the adapter instance.
"""
try:
adapter = AdapterInstance.objects.get(id=adapter_instance_id)
except Exception as e:
logger.error(f"Unable to fetch adapter: {e}")
if not adapter:
logger.error("Unable to fetch adapter")
return adapter.adapter_name
@staticmethod
def get_adapters_by_type(
adapter_type: AdapterTypes, user: User
) -> list[AdapterInstance]:
"""Get a list of adapters by their type.
Parameters:
- adapter_type (AdapterTypes): The type of adapters to retrieve.
Returns:
- list[AdapterInstance]: A list of AdapterInstance objects that match
the specified adapter type.
"""
adapters: list[AdapterInstance] = AdapterInstance.objects.filter(
adapter_type=adapter_type.value, created_by=user
)
return adapters
@staticmethod
def get_adapter_by_name_and_type(
adapter_type: AdapterTypes,
adapter_name: Optional[str] = None,
) -> Optional[AdapterInstance]:
"""Get the adapter instance by its name and type.
Parameters:
- adapter_name (str): The name of the adapter instance.
- adapter_type (AdapterTypes): The type of the adapter instance.
Returns:
- AdapterInstance: The adapter with the specified name and type.
"""
if adapter_name:
adapter: AdapterInstance = AdapterInstance.objects.get(
adapter_name=adapter_name, adapter_type=adapter_type.value
)
else:
try:
adapter = AdapterInstance.objects.get(
adapter_type=adapter_type.value, is_default=True
)
except AdapterInstance.DoesNotExist:
return None
return adapter
@staticmethod
def get_default_adapters(user: User) -> list[AdapterInstance]:
"""Retrieve a list of default adapter instances. This method queries
the database to fetch all adapter instances marked as default.
Raises:
InternalServiceError: If an unexpected error occurs during
the database query.
Returns:
list[AdapterInstance]: A list of AdapterInstance objects that are
marked as default.
"""
try:
adapters: list[AdapterInstance] = AdapterInstance.objects.filter(
is_default=True, created_by=user
)
return adapters
except ObjectDoesNotExist as e:
logger.error(f"No default adapters found: {e}")
raise InternalServiceError("No default adapters found")
except Exception as e:
logger.error(f"Error occurred while fetching default adapters: {e}")
raise InternalServiceError("Error fetching default adapters")

View File

@@ -0,0 +1,23 @@
class AdapterKeys:
JSON_SCHEMA = "json_schema"
ADAPTER_TYPE = "adapter_type"
IS_DEFAULT = "is_default"
LLM = "LLM"
X2TEXT = "X2TEXT"
VECTOR_DB = "VECTOR_DB"
EMBEDDING = "EMBEDDING"
NAME = "name"
DESCRIPTION = "description"
ICON = "icon"
ADAPTER_ID = "adapter_id"
ADAPTER_METADATA = "adapter_metadata"
ADAPTER_METADATA_B = "adapter_metadata_b"
ID = "id"
IS_VALID = "is_valid"
LLM_DEFAULT = "llm_default"
VECTOR_DB_DEFAULT = "vector_db_default"
EMBEDDING_DEFAULT = "embedding_default"
ADAPTER_NAME_EXISTS = (
"Configuration with this ID already exists. "
"Please try with a different ID"
)

View File

@@ -0,0 +1,55 @@
from backend.exceptions import UnstractBaseException
from rest_framework.exceptions import APIException
class IdIsMandatory(APIException):
status_code = 400
default_detail = "ID is Mandatory."
class InValidType(APIException):
status_code = 400
default_detail = "Type is not Valid."
class InValidAdapterId(APIException):
status_code = 400
default_detail = "Adapter ID is not Valid."
class JSONParseException(APIException):
status_code = 500
default_detail = "Exception occured while Parsing JSON Schema."
class InternalServiceError(APIException):
status_code = 500
default_detail = "Internal Service error"
class CannotDeleteDefaultAdapter(APIException):
status_code = 500
default_detail = (
"This is configured as default and cannot be deleted. "
"Please configure a different default before you try again!"
)
class UniqueConstraintViolation(APIException):
status_code = 400
default_detail = "Unique constraint violated"
class TestAdapterException(APIException):
status_code = 500
default_detail = "Error while testing adapter."
class TestAdapterInputException(UnstractBaseException):
status_code = 400
default_detail = "Connection test failed using the given configuration data"
class ErrorFetchingAdapterData(UnstractBaseException):
status_code = 400
default_detail = "Error while fetching adapter data."

View File

@@ -0,0 +1,109 @@
# Generated by Django 4.2.1 on 2024-01-23 11:18
import uuid
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name="AdapterInstance",
fields=[
("created_at", models.DateTimeField(auto_now_add=True)),
("modified_at", models.DateTimeField(auto_now=True)),
(
"id",
models.UUIDField(
db_comment="Unique identifier for the Adapter Instance",
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
(
"adapter_name",
models.TextField(
db_comment="Name of the Adapter Instance",
max_length=128,
),
),
(
"adapter_id",
models.CharField(
db_comment="Unique identifier of the Adapter",
default="",
max_length=128,
),
),
(
"adapter_metadata",
models.JSONField(
db_column="adapter_metadata",
db_comment="JSON adapter metadata submitted by the user",
default=dict,
),
),
(
"adapter_type",
models.CharField(
choices=[
("UNKNOWN", "UNKNOWN"),
("LLM", "LLM"),
("EMBEDDING", "EMBEDDING"),
("VECTOR_DB", "VECTOR_DB"),
],
db_comment="Type of adapter LLM/EMBEDDING/VECTOR_DB",
),
),
(
"is_active",
models.BooleanField(
db_comment="Is the adapter instance currently being used",
default=False,
),
),
(
"is_default",
models.BooleanField(
db_comment="Is the adapter instance default",
default=False,
),
),
(
"created_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="created_adapters",
to=settings.AUTH_USER_MODEL,
),
),
(
"modified_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="modified_adapters",
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"verbose_name": "adapter_adapterinstance",
"verbose_name_plural": "adapter_adapterinstance",
"db_table": "adapter_adapterinstance",
},
),
]

View File

@@ -0,0 +1,18 @@
# Generated by Django 4.2.1 on 2024-01-20 08:32
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("adapter_processor", "0001_initial"),
]
operations = [
migrations.AddConstraint(
model_name="adapterinstance",
constraint=models.UniqueConstraint(
fields=("adapter_name", "adapter_type"), name="unique_adapter"
),
),
]

View File

@@ -0,0 +1,41 @@
# Generated by Django 4.2.1 on 2024-02-13 13:09
import json
from typing import Any
from account.models import EncryptionSecret
from adapter_processor.models import AdapterInstance
from cryptography.fernet import Fernet
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("adapter_processor", "0002_adapterinstance_unique_adapter"),
("account", "0005_encryptionsecret"),
]
def EncryptCredentials(apps: Any, schema_editor: Any) -> None:
encryption_secret: EncryptionSecret = EncryptionSecret.objects.get()
f: Fernet = Fernet(encryption_secret.key.encode("utf-8"))
queryset = AdapterInstance.objects.all()
for obj in queryset: # type: ignore
# Access attributes of the object
print(f"Object ID: {obj.id}, Name: {obj.adapter_name}")
if hasattr(obj, "adapter_metadata"):
json_string: str = json.dumps(obj.adapter_metadata)
obj.adapter_metadata_b = f.encrypt(json_string.encode("utf-8"))
obj.save()
operations = [
migrations.AddField(
model_name="adapterinstance",
name="adapter_metadata_b",
field=models.BinaryField(null=True),
),
migrations.RunPython(
EncryptCredentials, reverse_code=migrations.RunPython.noop
),
]

View File

@@ -0,0 +1,26 @@
# Generated by Django 4.2.1 on 2024-02-23 09:29
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("adapter_processor", "0003_adapterinstance_adapter_metadata_b"),
]
operations = [
migrations.AlterField(
model_name="adapterinstance",
name="adapter_type",
field=models.CharField(
choices=[
("UNKNOWN", "UNKNOWN"),
("LLM", "LLM"),
("EMBEDDING", "EMBEDDING"),
("VECTOR_DB", "VECTOR_DB"),
("X2TEXT", "X2TEXT"),
],
db_comment="Type of adapter LLM/EMBEDDING/VECTOR_DB",
),
),
]

View File

@@ -0,0 +1,78 @@
import uuid
from account.models import User
from django.db import models
from unstract.adapters.enums import AdapterTypes
from utils.models.base_model import BaseModel
ADAPTER_NAME_SIZE = 128
VERSION_NAME_SIZE = 64
ADAPTER_ID_LENGTH = 128
class AdapterInstance(BaseModel):
id = models.UUIDField(
primary_key=True,
default=uuid.uuid4,
editable=False,
db_comment="Unique identifier for the Adapter Instance",
)
adapter_name = models.TextField(
max_length=ADAPTER_NAME_SIZE,
null=False,
blank=False,
db_comment="Name of the Adapter Instance",
)
adapter_id = models.CharField(
max_length=ADAPTER_ID_LENGTH,
default="",
db_comment="Unique identifier of the Adapter",
)
# TODO to be removed once the migration for encryption
adapter_metadata = models.JSONField(
db_column="adapter_metadata",
null=False,
blank=False,
default=dict,
db_comment="JSON adapter metadata submitted by the user",
)
adapter_metadata_b = models.BinaryField(null=True)
adapter_type = models.CharField(
choices=[(tag.value, tag.name) for tag in AdapterTypes],
db_comment="Type of adapter LLM/EMBEDDING/VECTOR_DB",
)
created_by = models.ForeignKey(
User,
on_delete=models.SET_NULL,
related_name="created_adapters",
null=True,
blank=True,
)
modified_by = models.ForeignKey(
User,
on_delete=models.SET_NULL,
related_name="modified_adapters",
null=True,
blank=True,
)
is_active = models.BooleanField(
default=False,
db_comment="Is the adapter instance currently being used",
)
is_default = models.BooleanField(
default=False,
db_comment="Is the adapter instance default",
)
class Meta:
verbose_name = "adapter_adapterinstance"
verbose_name_plural = "adapter_adapterinstance"
db_table = "adapter_adapterinstance"
constraints = [
models.UniqueConstraint(
fields=["adapter_name", "adapter_type"],
name="unique_adapter",
),
]

View File

@@ -0,0 +1,91 @@
import json
from typing import Any
from account.models import EncryptionSecret
from adapter_processor.adapter_processor import AdapterProcessor
from adapter_processor.constants import AdapterKeys
from cryptography.fernet import Fernet
from rest_framework import serializers
from unstract.adapters.constants import Common as common
from utils.serializer_utils import SerializerUtils
from backend.constants import FieldLengthConstants as FLC
from backend.serializers import AuditSerializer
from .models import AdapterInstance
class TestAdapterSerializer(serializers.Serializer):
adapter_id = serializers.CharField(max_length=FLC.ADAPTER_ID_LENGTH)
adapter_metadata = serializers.JSONField()
adapter_type = serializers.JSONField()
class BaseAdapterSerializer(AuditSerializer):
class Meta:
model = AdapterInstance
fields = "__all__"
class DefaultAdapterSerializer(serializers.Serializer):
llm_default = serializers.CharField(
max_length=FLC.UUID_LENGTH, required=False
)
embedding_default = serializers.CharField(
max_length=FLC.UUID_LENGTH, required=False
)
vector_db_default = serializers.CharField(
max_length=FLC.UUID_LENGTH, required=False
)
class AdapterInstanceSerializer(BaseAdapterSerializer):
"""Inherits BaseAdapterSerializer.
Used for GET/POST request for adapter
"""
class Meta(BaseAdapterSerializer.Meta):
pass
def to_internal_value(self, data: dict[str, Any]) -> dict[str, Any]:
encryption_secret: EncryptionSecret = EncryptionSecret.objects.get()
f: Fernet = Fernet(encryption_secret.key.encode("utf-8"))
json_string: str = json.dumps(data.pop(AdapterKeys.ADAPTER_METADATA))
data[AdapterKeys.ADAPTER_METADATA_B] = f.encrypt(
json_string.encode("utf-8")
)
return data
def to_representation(self, instance: AdapterInstance) -> dict[str, str]:
rep: dict[str, str] = super().to_representation(instance)
if SerializerUtils.check_context_for_GET_or_POST(context=self.context):
rep[common.ICON] = AdapterProcessor.get_adapter_data_with_key(
instance.adapter_id, common.ICON
)
rep.pop(AdapterKeys.ADAPTER_METADATA_B)
return rep
class AdapterDetailSerializer(BaseAdapterSerializer):
"""Inherits BaseAdapterSerializer.
Used for GET/UPDATE/DELETE request for adapter/<uuid:pk>
"""
def to_representation(self, instance: AdapterInstance) -> dict[str, str]:
rep: dict[str, str] = super().to_representation(instance)
encryption_secret: EncryptionSecret = EncryptionSecret.objects.get()
f: Fernet = Fernet(encryption_secret.key.encode("utf-8"))
rep.pop(AdapterKeys.ADAPTER_METADATA_B)
adapter_metadata = json.loads(
f.decrypt(bytes(instance.adapter_metadata_b).decode("utf-8"))
)
rep[AdapterKeys.ADAPTER_METADATA] = adapter_metadata
return rep

View File

@@ -0,0 +1,34 @@
from adapter_processor.views import (
AdapterDetailViewSet,
AdapterInstanceViewSet,
AdapterViewSet,
DefaultAdapterViewSet,
)
from django.urls import path
from rest_framework.urlpatterns import format_suffix_patterns
adapter = AdapterViewSet.as_view({"get": "list"})
default_triad = DefaultAdapterViewSet.as_view(
{"post": "configure_default_triad"}
)
adapter_schema = AdapterViewSet.as_view({"get": "get_adapter_schema"})
adapter_list = AdapterInstanceViewSet.as_view({"post": "create", "get": "list"})
adapter_detail = AdapterDetailViewSet.as_view(
{
"get": "retrieve",
"put": "update",
"patch": "partial_update",
"delete": "destroy",
}
)
adapter_test = AdapterViewSet.as_view({"post": "test"})
urlpatterns = format_suffix_patterns(
[
path("adapter_schema/", adapter_schema, name="get_adapter_schema"),
path("supported_adapters/", adapter, name="adapter-list"),
path("adapter/", adapter_list, name="adapter-list"),
path("adapter/default_triad/", default_triad, name="default_triad"),
path("adapter/<uuid:pk>/", adapter_detail, name="adapter_detail"),
path("test_adapters/", adapter_test, name="adapter-test"),
]
)

View File

@@ -0,0 +1,182 @@
import logging
from typing import Any, Optional
from account.models import User
from adapter_processor.adapter_processor import AdapterProcessor
from adapter_processor.constants import AdapterKeys
from adapter_processor.exceptions import (
CannotDeleteDefaultAdapter,
IdIsMandatory,
InValidType,
UniqueConstraintViolation,
)
from adapter_processor.serializers import (
AdapterDetailSerializer,
AdapterInstanceSerializer,
DefaultAdapterSerializer,
TestAdapterSerializer,
)
from django.db import IntegrityError
from django.db.models import QuerySet
from django.http.response import HttpResponse
from permissions.permission import IsOwner
from rest_framework import status
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.versioning import URLPathVersioning
from rest_framework.viewsets import GenericViewSet, ModelViewSet
from utils.filtering import FilterHelper
from .constants import AdapterKeys as constant
from .exceptions import InternalServiceError
from .models import AdapterInstance
logger = logging.getLogger(__name__)
class DefaultAdapterViewSet(ModelViewSet):
versioning_class = URLPathVersioning
serializer_class = DefaultAdapterSerializer
def configure_default_triad(
self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any]
) -> HttpResponse:
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
# Convert request data to json
default_triad = request.data
AdapterProcessor.set_default_triad(default_triad, request.user)
return Response(status=status.HTTP_200_OK)
class AdapterViewSet(GenericViewSet):
versioning_class = URLPathVersioning
serializer_class = TestAdapterSerializer
def list(
self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any]
) -> HttpResponse:
if request.method == "GET":
adapter_type = request.GET.get(AdapterKeys.ADAPTER_TYPE)
if (
adapter_type == AdapterKeys.LLM
or adapter_type == AdapterKeys.EMBEDDING
or adapter_type == AdapterKeys.VECTOR_DB
or adapter_type == AdapterKeys.X2TEXT
):
json_schema = AdapterProcessor.get_all_supported_adapters(
type=adapter_type
)
return Response(json_schema, status=status.HTTP_200_OK)
else:
raise InValidType
def get_adapter_schema(
self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any]
) -> HttpResponse:
if request.method == "GET":
adapter_name = request.GET.get(AdapterKeys.ID)
if adapter_name is None or adapter_name == "":
raise IdIsMandatory()
json_schema = AdapterProcessor.get_json_schema(
adapter_id=adapter_name
)
return Response(data=json_schema, status=status.HTTP_200_OK)
def test(self, request: Request) -> Response:
"""Tests the connector against the credentials passed."""
serializer: AdapterInstanceSerializer = self.get_serializer(
data=request.data
)
serializer.is_valid(raise_exception=True)
adapter_id = serializer.validated_data.get(AdapterKeys.ADAPTER_ID)
adapter_metadata = serializer.validated_data.get(
AdapterKeys.ADAPTER_METADATA
)
adapter_metadata[
AdapterKeys.ADAPTER_TYPE
] = serializer.validated_data.get(AdapterKeys.ADAPTER_TYPE)
test_result = AdapterProcessor.test_adapter(
adapter_id=adapter_id, adapter_metadata=adapter_metadata
)
return Response(
{AdapterKeys.IS_VALID: test_result},
status=status.HTTP_200_OK,
)
class AdapterInstanceViewSet(ModelViewSet):
queryset = AdapterInstance.objects.all()
serializer_class = AdapterInstanceSerializer
def get_queryset(self) -> Optional[QuerySet]:
if filter_args := FilterHelper.build_filter_args(
self.request,
constant.ADAPTER_TYPE,
):
queryset = AdapterInstance.objects.filter(
created_by=self.request.user, **filter_args
)
else:
queryset = AdapterInstance.objects.filter(
created_by=self.request.user
)
return queryset
def create(self, request: Any) -> Response:
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
try:
# Check to see if there is a default configured
# for this adapter_type and for the current user
existing_adapter_default = self.get_existing_defaults(
request.data, request.user
)
# If there is no default, then make this one as default
if existing_adapter_default is None:
# Update the adapter_instance to is_default=True
serializer.validated_data[AdapterKeys.IS_DEFAULT] = True
serializer.save()
except IntegrityError:
raise UniqueConstraintViolation(
f"{AdapterKeys.ADAPTER_NAME_EXISTS}"
)
except Exception as e:
logger.error(f"Error saving adapter to DB: {e}")
raise InternalServiceError
headers = self.get_success_headers(serializer.data)
return Response(
serializer.data, status=status.HTTP_201_CREATED, headers=headers
)
def get_existing_defaults(
self, adapter_config: dict[str, Any], user: User
) -> Optional[AdapterInstance]:
filter_params: dict[str, Any] = {}
adapter_type = adapter_config.get(AdapterKeys.ADAPTER_TYPE)
filter_params["adapter_type"] = adapter_type
filter_params["is_default"] = True
filter_params["created_by"] = user
existing_adapter_default: AdapterInstance = (
AdapterInstance.objects.filter(**filter_params).first()
)
return existing_adapter_default
class AdapterDetailViewSet(ModelViewSet):
queryset = AdapterInstance.objects.all()
serializer_class = AdapterDetailSerializer
permission_classes = [IsOwner]
def destroy(
self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any]
) -> Response:
adapter_instance: AdapterInstance = self.get_object()
if adapter_instance.is_default:
logger.error("Cannot delete a default adapter")
raise CannotDeleteDefaultAdapter
super().perform_destroy(adapter_instance)
return Response(status=status.HTTP_204_NO_CONTENT)

0
backend/api/__init__.py Normal file
View File

5
backend/api/admin.py Normal file
View File

@@ -0,0 +1,5 @@
from django.contrib import admin
from .models import APIDeployment, APIKey
admin.site.register([APIDeployment, APIKey])

View File

@@ -0,0 +1,113 @@
from typing import Any, Optional
from api.deployment_helper import DeploymentHelper
from api.exceptions import InvalidAPIRequest
from api.models import APIDeployment
from api.serializers import (
APIDeploymentListSerializer,
APIDeploymentSerializer,
DeploymentResponseSerializer,
ExecutionRequestSerializer,
)
from django.db.models import QuerySet
from permissions.permission import IsOwner
from rest_framework import serializers, status, views, viewsets
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import Serializer
from workflow_manager.workflow.dto import ExecutionResponse
class DeploymentExecution(views.APIView):
def initialize_request(
self, request: Request, *args: Any, **kwargs: Any
) -> Request:
"""To remove csrf request for public API.
Args:
request (Request): _description_
Returns:
Request: _description_
"""
setattr(request, "csrf_processing_done", True)
return super().initialize_request(request, *args, **kwargs)
@DeploymentHelper.validate_api_key
def post(
self, request: Request, org_name: str, api_name: str, api: APIDeployment
) -> Response:
file_objs = request.FILES.getlist("files")
serializer = ExecutionRequestSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
timeout = serializer.get_timeout(serializer.validated_data)
if not file_objs or len(file_objs) == 0:
raise InvalidAPIRequest("File shouldn't be empty")
response = DeploymentHelper.execute_workflow(
organization_name=org_name,
api=api,
file_objs=file_objs,
timeout=timeout,
)
return Response({"message": response}, status=status.HTTP_200_OK)
@DeploymentHelper.validate_api_key
def get(
self, request: Request, org_name: str, api_name: str, api: APIDeployment
) -> Response:
execution_id = request.query_params.get("execution_id")
if not execution_id:
raise InvalidAPIRequest("execution_id shouldn't be empty")
response: ExecutionResponse = DeploymentHelper.get_execution_status(
execution_id=execution_id
)
return Response(
{"status": response.execution_status, "message": response.result},
status=status.HTTP_200_OK,
)
class APIDeploymentViewSet(viewsets.ModelViewSet):
permission_classes = [IsOwner]
def get_queryset(self) -> Optional[QuerySet]:
return APIDeployment.objects.filter(created_by=self.request.user)
def get_serializer_class(self) -> serializers.Serializer:
if self.action in ["list"]:
return APIDeploymentListSerializer
return APIDeploymentSerializer
@action(detail=True, methods=["get"])
def fetch_one(self, request: Request, pk: Optional[str] = None) -> Response:
"""Custom action to fetch a single instance."""
instance = self.get_object()
serializer = self.get_serializer(instance)
return Response(serializer.data)
def create(
self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any]
) -> Response:
serializer: Serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
api_key = DeploymentHelper.create_api_key(serializer=serializer)
response_serializer = DeploymentResponseSerializer(
{"api_key": api_key.api_key, **serializer.data}
)
headers = self.get_success_headers(serializer.data)
return Response(
response_serializer.data,
status=status.HTTP_201_CREATED,
headers=headers,
)
def get_error_from_serializer(error_details: dict[str, Any]) -> Optional[str]:
error_key = next(iter(error_details))
# Get the first error message
error_message: str = f"{error_details[error_key][0]} : {error_key}"
return error_message

View File

@@ -0,0 +1,28 @@
from api.deployment_helper import DeploymentHelper
from api.exceptions import APINotFound
from api.key_helper import KeyHelper
from api.models import APIKey
from api.serializers import APIKeyListSerializer, APIKeySerializer
from rest_framework import serializers, viewsets
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response
class APIKeyViewSet(viewsets.ModelViewSet):
queryset = APIKey.objects.all()
def get_serializer_class(self) -> serializers.Serializer:
if self.action in ["api_keys"]:
return APIKeyListSerializer
return APIKeySerializer
@action(detail=True, methods=["get"])
def api_keys(self, request: Request, api_id: str) -> Response:
"""Custom action to fetch api keys of an api deployment."""
api = DeploymentHelper.get_api_by_id(api_id=api_id)
if not api:
raise APINotFound()
keys = KeyHelper.list_api_keys_of_api(api_instance=api)
serializer = self.get_serializer(keys, many=True)
return Response(serializer.data)

5
backend/api/apps.py Normal file
View File

@@ -0,0 +1,5 @@
from django.apps import AppConfig
class ApiConfig(AppConfig):
name = "api"

3
backend/api/constants.py Normal file
View File

@@ -0,0 +1,3 @@
class ApiExecution:
PATH: str = "deployment/api"
MAXIMUM_TIMEOUT_IN_SEC: int = 300 # 5 minutes

View File

@@ -0,0 +1,254 @@
import logging
import uuid
from functools import wraps
from typing import Any, Optional
from urllib.parse import urlencode
from api.constants import ApiExecution
from api.exceptions import (
ApiKeyCreateException,
APINotFound,
Forbidden,
InactiveAPI,
UnauthorizedKey,
)
from api.key_helper import KeyHelper
from api.models import APIDeployment, APIKey
from api.serializers import APIExecutionResponseSerializer
from django.core.files.uploadedfile import UploadedFile
from django.db import connection
from django_tenants.utils import get_tenant_model, tenant_context
from rest_framework import status
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import Serializer
from rest_framework.utils.serializer_helpers import ReturnDict
from workflow_manager.endpoint.destination import DestinationConnector
from workflow_manager.endpoint.source import SourceConnector
from workflow_manager.workflow.dto import ExecutionResponse
from workflow_manager.workflow.models.workflow import Workflow
from workflow_manager.workflow.workflow_helper import WorkflowHelper
logger = logging.getLogger(__name__)
class DeploymentHelper:
@staticmethod
def validate_api_key(func: Any) -> Any:
"""Decorator that validates the API key.
Sample header:
Authorization: Bearer 123e4567-e89b-12d3-a456-426614174001
Args:
func (Any): Function to wrap for validation
"""
@wraps(func)
def wrapper(
self: Any, request: Request, *args: Any, **kwargs: Any
) -> Any:
"""Wrapper to validate the inputs and key.
Args:
request (Request): Request context
Raises:
Forbidden: _description_
APINotFound: _description_
Returns:
Any: _description_
"""
try:
authorization_header = request.headers.get("Authorization")
api_key = None
if authorization_header and authorization_header.startswith(
"Bearer "
):
api_key = authorization_header.split(" ")[1]
if not api_key:
raise Forbidden("Missing api key")
org_name = kwargs.get("org_name") or request.data.get(
"org_name"
)
api_name = kwargs.get("api_name") or request.data.get(
"api_name"
)
if not api_name:
raise Forbidden("Missing api_name")
tenant = get_tenant_model().objects.get(schema_name=org_name)
with tenant_context(tenant):
api_deployment = (
DeploymentHelper.get_deployment_by_api_name(
api_name=api_name
)
)
DeploymentHelper.validate_api(
api_deployment=api_deployment, api_key=api_key
)
kwargs["api"] = api_deployment
return func(self, request, *args, **kwargs)
except (UnauthorizedKey, InactiveAPI, APINotFound):
raise
except Exception as exception:
logger.error(f"Exception: {exception}")
return Response(
{"error": str(exception)}, status=status.HTTP_403_FORBIDDEN
)
return wrapper
@staticmethod
def validate_api(
api_deployment: Optional[APIDeployment], api_key: str
) -> None:
"""Validating API and API key.
Args:
api_deployment (Optional[APIDeployment]): _description_
api_key (str): _description_
Raises:
APINotFound: _description_
InactiveAPI: _description_
"""
if not api_deployment:
raise APINotFound()
if not api_deployment.is_active:
raise InactiveAPI()
KeyHelper.validate_api_key(api_key=api_key, api_instance=api_deployment)
@staticmethod
def validate_and_get_workflow(workflow_id: str) -> Workflow:
"""Validate that the specified workflow_id exists in the Workflow
model."""
return WorkflowHelper.get_workflow_by_id(workflow_id)
@staticmethod
def get_api_by_id(api_id: str) -> Optional[APIDeployment]:
try:
api_deployment: APIDeployment = APIDeployment.objects.get(pk=api_id)
return api_deployment
except APIDeployment.DoesNotExist:
return None
@staticmethod
def construct_complete_endpoint(api_name: str) -> str:
"""Constructs the complete API endpoint by appending organization
schema, endpoint path, and Django app backend URL.
Parameters:
- endpoint (str): The endpoint path to be appended to the complete URL.
Returns:
- str: The complete API endpoint URL.
"""
org_schema = connection.get_tenant().schema_name
return f"{ApiExecution.PATH}/{org_schema}/{api_name}/"
@staticmethod
def construct_status_endpoint(api_endpoint: str, execution_id: str) -> str:
"""Construct a complete status endpoint URL by appending the
execution_id as a query parameter.
Args:
api_endpoint (str): The base API endpoint.
execution_id (str): The execution ID to be included as
a query parameter.
Returns:
str: The complete status endpoint URL.
"""
query_parameters = urlencode({"execution_id": execution_id})
complete_endpoint = f"{api_endpoint}?{query_parameters}"
return complete_endpoint
@staticmethod
def get_deployment_by_api_name(
api_name: str,
) -> Optional[APIDeployment]:
"""Get and return the APIDeployment object by api_name."""
try:
api: APIDeployment = APIDeployment.objects.get(api_name=api_name)
return api
except APIDeployment.DoesNotExist:
return None
@staticmethod
def create_api_key(serializer: Serializer) -> APIKey:
"""To make API key for an API.
Args:
serializer (Serializer): Request serializer
Raises:
ApiKeyCreateException: Exception
"""
api_deployment: APIDeployment = serializer.instance
try:
api_key: APIKey = KeyHelper.create_api_key(api_deployment)
return api_key
except Exception as error:
logger.error(f"Error while creating API key error: {str(error)}")
api_deployment.delete()
logger.info("Deleted the deployment instance")
raise ApiKeyCreateException()
@staticmethod
def execute_workflow(
organization_name: str,
api: APIDeployment,
file_objs: list[UploadedFile],
timeout: int,
) -> ReturnDict:
"""Execute workflow by api.
Args:
organization_name (str): organization name
api (APIDeployment): api model object
file_obj (UploadedFile): input file
Returns:
ReturnDict: execution status/ result
"""
workflow_id = api.workflow.id
pipeline_id = api.id
execution_id = str(uuid.uuid4())
hash_values_of_files = SourceConnector.add_input_file_to_api_storage(
workflow_id=workflow_id,
execution_id=execution_id,
file_objs=file_objs,
)
try:
result = WorkflowHelper.execute_workflow_async(
workflow_id=workflow_id,
pipeline_id=pipeline_id,
hash_values_of_files=hash_values_of_files,
timeout=timeout,
execution_id=execution_id,
)
result.status_api = DeploymentHelper.construct_status_endpoint(
api_endpoint=api.api_endpoint, execution_id=execution_id
)
except Exception:
DestinationConnector.delete_api_storage_dir(
workflow_id=workflow_id, execution_id=execution_id
)
raise
return APIExecutionResponseSerializer(result).data
@staticmethod
def get_execution_status(execution_id: str) -> ExecutionResponse:
"""Current status of api execution.
Args:
execution_id (str): execution id
Returns:
ReturnDict: status/result of execution
"""
execution_response: ExecutionResponse = (
WorkflowHelper.get_status_of_async_task(execution_id=execution_id)
)
return execution_response

94
backend/api/exceptions.py Normal file
View File

@@ -0,0 +1,94 @@
from typing import Optional
from rest_framework.exceptions import APIException
class MandatoryWorkflowId(APIException):
status_code = 400
default_detail = "Workflow ID is mandatory"
class ApiKeyCreateException(APIException):
status_code = 500
default_detail = "Exception while create API key"
def __init__(
self, detail: Optional[str] = None, code: Optional[int] = None
):
if detail is not None:
self.detail = detail
if code is not None:
self.code = code
super().__init__(detail, code)
class Forbidden(APIException):
status_code = 403
default_detail = (
"User is forbidden from performing this action. Please contact admin"
)
def __init__(
self, detail: Optional[str] = None, code: Optional[int] = None
):
if detail is not None:
self.detail = detail
if code is not None:
self.code = code
super().__init__(detail, code)
class APINotFound(APIException):
status_code = 404
default_detail = "Api not found"
def __init__(
self, detail: Optional[str] = None, code: Optional[int] = None
):
if detail is not None:
self.detail = detail
if code is not None:
self.code = code
super().__init__(detail, code)
class InvalidAPIRequest(APIException):
status_code = 400
default_detail = "Bad request"
def __init__(
self, detail: Optional[str] = None, code: Optional[int] = None
):
if detail is not None:
self.detail = detail
if code is not None:
self.code = code
super().__init__(detail, code)
class InactiveAPI(APIException):
status_code = 404
default_detail = "API not found or Inactive"
def __init__(
self, detail: Optional[str] = None, code: Optional[int] = None
):
if detail is not None:
self.detail = detail
if code is not None:
self.code = code
super().__init__(detail, code)
class UnauthorizedKey(APIException):
status_code = 401
default_detail = "Unauthorized"
def __init__(
self, detail: Optional[str] = None, code: Optional[int] = None
):
if detail is not None:
self.detail = detail
if code is not None:
self.code = code
super().__init__(detail, code)

72
backend/api/key_helper.py Normal file
View File

@@ -0,0 +1,72 @@
import logging
from api.exceptions import Forbidden, UnauthorizedKey
from api.models import APIDeployment, APIKey
from api.serializers import APIKeySerializer
from workflow_manager.workflow.workflow_helper import WorkflowHelper
logger = logging.getLogger(__name__)
class KeyHelper:
@staticmethod
def validate_api_key(api_key: str, api_instance: APIDeployment) -> None:
"""Validate api key.
Args:
api_key (str): api key from request
api_instance (APIDeployment): api deployment instance
Raises:
Forbidden: _description_
"""
try:
api_key_instance: APIKey = APIKey.objects.get(api_key=api_key)
if not KeyHelper.has_access(api_key_instance, api_instance):
raise UnauthorizedKey()
except APIKey.DoesNotExist:
raise UnauthorizedKey()
except APIDeployment.DoesNotExist:
raise Forbidden("API not found.")
@staticmethod
def list_api_keys_of_api(api_instance: APIDeployment) -> list[APIKey]:
api_keys: list[APIKey] = APIKey.objects.filter(api=api_instance).all()
return api_keys
@staticmethod
def has_access(api_key: APIKey, api_instance: APIDeployment) -> bool:
"""Check if the provided API key has access to the specified API
instance.
Args:
api_key (APIKey): api key associated with the api
api_instance (APIDeployment): api model
Returns:
bool: True if allowed to execute, False otherwise
"""
if not api_key.is_active:
return False
if isinstance(api_key.api, APIDeployment):
return api_key.api == api_instance
return False
@staticmethod
def validate_workflow_exists(workflow_id: str) -> None:
"""Validate that the specified workflow_id exists in the Workflow
model."""
WorkflowHelper.get_workflow_by_id(workflow_id)
@staticmethod
def create_api_key(deployment: APIDeployment) -> APIKey:
"""Create an APIKey entity with the data from the provided
APIDeployment instance."""
# Create an instance of the APIKey model
api_key_serializer = APIKeySerializer(
data={"api": deployment.id, "description": "Initial Access Key"},
context={"deployment": deployment},
)
api_key_serializer.is_valid(raise_exception=True)
api_key: APIKey = api_key_serializer.save()
return api_key

View File

@@ -0,0 +1,185 @@
# Generated by Django 4.2.1 on 2024-01-23 11:18
import uuid
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
("workflow", "0001_initial"),
]
operations = [
migrations.CreateModel(
name="APIDeployment",
fields=[
("created_at", models.DateTimeField(auto_now_add=True)),
("modified_at", models.DateTimeField(auto_now=True)),
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
(
"display_name",
models.CharField(
db_comment="User-given display name for the API.",
default="default api",
max_length=30,
unique=True,
),
),
(
"description",
models.CharField(
blank=True,
db_comment="User-given description for the API.",
default="",
max_length=255,
),
),
(
"is_active",
models.BooleanField(
db_comment="Flag indicating whether the API is active or not.",
default=True,
),
),
(
"api_endpoint",
models.CharField(
db_comment="URL endpoint for the API deployment.",
editable=False,
max_length=255,
unique=True,
),
),
(
"api_name",
models.CharField(
db_comment="Short name for the API deployment.",
default="default",
max_length=30,
unique=True,
),
),
(
"created_by",
models.ForeignKey(
blank=True,
editable=False,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="api_created_by",
to=settings.AUTH_USER_MODEL,
),
),
(
"modified_by",
models.ForeignKey(
blank=True,
editable=False,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="api_modified_by",
to=settings.AUTH_USER_MODEL,
),
),
(
"workflow",
models.ForeignKey(
db_comment="Foreign key reference to the Workflow model.",
on_delete=django.db.models.deletion.CASCADE,
to="workflow.workflow",
),
),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="APIKey",
fields=[
("created_at", models.DateTimeField(auto_now_add=True)),
("modified_at", models.DateTimeField(auto_now=True)),
(
"id",
models.UUIDField(
db_comment="Unique identifier for the API key.",
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
(
"api_key",
models.UUIDField(
db_comment="Actual key UUID.",
default=uuid.uuid4,
editable=False,
unique=True,
),
),
(
"description",
models.CharField(
db_comment="Description of the API key.",
max_length=255,
null=True,
),
),
(
"is_active",
models.BooleanField(
db_comment="Flag indicating whether the API key is active or not.",
default=True,
),
),
(
"api",
models.ForeignKey(
db_comment="Foreign key reference to the APIDeployment model.",
on_delete=django.db.models.deletion.CASCADE,
to="api.apideployment",
),
),
(
"created_by",
models.ForeignKey(
blank=True,
editable=False,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="api_key_created_by",
to=settings.AUTH_USER_MODEL,
),
),
(
"modified_by",
models.ForeignKey(
blank=True,
editable=False,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="api_key_modified_by",
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"abstract": False,
},
),
]

View File

141
backend/api/models.py Normal file
View File

@@ -0,0 +1,141 @@
import uuid
from typing import Any
from account.models import User
from api.constants import ApiExecution
from django.db import connection, models
from utils.models.base_model import BaseModel
from workflow_manager.workflow.models.workflow import Workflow
API_NAME_MAX_LENGTH = 30
DESCRIPTION_MAX_LENGTH = 255
API_ENDPOINT_MAX_LENGTH = 255
class APIDeployment(BaseModel):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
display_name = models.CharField(
max_length=API_NAME_MAX_LENGTH,
unique=True,
default="default api",
db_comment="User-given display name for the API.",
)
description = models.CharField(
max_length=DESCRIPTION_MAX_LENGTH,
blank=True,
default="",
db_comment="User-given description for the API.",
)
workflow = models.ForeignKey(
Workflow,
on_delete=models.CASCADE,
db_comment="Foreign key reference to the Workflow model.",
)
is_active = models.BooleanField(
default=True,
db_comment="Flag indicating whether the API is active or not.",
)
api_endpoint = models.CharField(
max_length=API_ENDPOINT_MAX_LENGTH,
unique=True,
editable=False,
db_comment="URL endpoint for the API deployment.",
)
api_name = models.CharField(
max_length=API_NAME_MAX_LENGTH,
unique=True,
default="default",
db_comment="Short name for the API deployment.",
)
created_by = models.ForeignKey(
User,
on_delete=models.SET_NULL,
related_name="api_created_by",
null=True,
blank=True,
editable=False,
)
modified_by = models.ForeignKey(
User,
on_delete=models.SET_NULL,
related_name="api_modified_by",
null=True,
blank=True,
editable=False,
)
def __str__(self) -> str:
return f"{self.id} - {self.display_name}"
def save(self, *args: Any, **kwargs: Any) -> None:
"""Save hook to update api_endpoint.
Custom save hook for updating the 'api_endpoint' based on
'api_name'. If the instance is being updated, it checks for
changes in 'api_name' and adjusts 'api_endpoint'
accordingly. If the instance is new, 'api_endpoint' is set
based on 'api_name' and the current database schema.
"""
if self.pk is not None:
try:
original = APIDeployment.objects.get(pk=self.pk)
if original.api_name != self.api_name:
org_schema = connection.get_tenant().schema_name
self.api_endpoint = (
f"{ApiExecution.PATH}/{org_schema}/{self.api_name}/"
)
except APIDeployment.DoesNotExist:
org_schema = connection.get_tenant().schema_name
self.api_endpoint = (
f"{ApiExecution.PATH}/{org_schema}/{self.api_name}/"
)
super().save(*args, **kwargs)
class APIKey(BaseModel):
id = models.UUIDField(
primary_key=True,
editable=False,
default=uuid.uuid4,
db_comment="Unique identifier for the API key.",
)
api_key = models.UUIDField(
default=uuid.uuid4,
editable=False,
unique=True,
db_comment="Actual key UUID.",
)
api = models.ForeignKey(
APIDeployment,
on_delete=models.CASCADE,
db_comment="Foreign key reference to the APIDeployment model.",
)
description = models.CharField(
max_length=DESCRIPTION_MAX_LENGTH,
null=True,
db_comment="Description of the API key.",
)
is_active = models.BooleanField(
default=True,
db_comment="Flag indicating whether the API key is active or not.",
)
created_by = models.ForeignKey(
User,
on_delete=models.SET_NULL,
related_name="api_key_created_by",
null=True,
blank=True,
editable=False,
)
modified_by = models.ForeignKey(
User,
on_delete=models.SET_NULL,
related_name="api_key_modified_by",
null=True,
blank=True,
editable=False,
)
def __str__(self) -> str:
return f"{self.api.api_name} - {self.id} - {self.api_key}"

124
backend/api/serializers.py Normal file
View File

@@ -0,0 +1,124 @@
from collections import OrderedDict
from typing import Any, Union
from api.constants import ApiExecution
from api.models import APIDeployment, APIKey
from backend.serializers import AuditSerializer
from django.core.validators import RegexValidator
from rest_framework.serializers import (
CharField,
IntegerField,
JSONField,
ModelSerializer,
Serializer,
ValidationError,
)
class APIDeploymentSerializer(AuditSerializer):
class Meta:
model = APIDeployment
fields = "__all__"
def validate_api_name(self, value: str) -> str:
api_name_validator = RegexValidator(
regex=r"^[a-zA-Z0-9_-]+$",
message="Only letters, numbers, hyphen and \
underscores are allowed.",
code="invalid_api_name",
)
api_name_validator(value)
return value
class APIKeySerializer(AuditSerializer):
class Meta:
model = APIKey
fields = "__all__"
def to_representation(self, instance: APIKey) -> OrderedDict[str, Any]:
"""Override the to_representation method to include additional
context."""
context = self.context.get("context", {})
deployment: APIDeployment = context.get("deployment")
representation: OrderedDict[str, Any] = super().to_representation(
instance
)
if deployment:
representation["api"] = deployment.id
representation["description"] = f"API Key for {deployment.name}"
representation["is_active"] = True
return representation
class ExecutionRequestSerializer(Serializer):
"""Execution request serializer
timeout: 0: maximum value of timeout, -1: async execution
"""
timeout = IntegerField(
min_value=-1, max_value=ApiExecution.MAXIMUM_TIMEOUT_IN_SEC, default=-1
)
def validate_timeout(self, value: Any) -> int:
if not isinstance(value, int):
raise ValidationError("timeout must be a integer.")
if value == 0:
value = ApiExecution.MAXIMUM_TIMEOUT_IN_SEC
return value
def get_timeout(self, validated_data: dict[str, Union[int, None]]) -> int:
value = validated_data.get("timeout", -1)
if not isinstance(value, int):
raise ValidationError("timeout must be a integer.")
return value
class APIDeploymentListSerializer(ModelSerializer):
workflow_name = CharField(source="workflow.workflow_name", read_only=True)
class Meta:
model = APIDeployment
fields = [
"id",
"workflow",
"workflow_name",
"display_name",
"description",
"is_active",
"api_endpoint",
"api_name",
"created_by",
]
class APIKeyListSerializer(ModelSerializer):
class Meta:
model = APIKey
fields = [
"id",
"created_at",
"modified_at",
"api_key",
"is_active",
"description",
"api",
]
class DeploymentResponseSerializer(Serializer):
is_active = CharField()
id = CharField()
api_key = CharField()
api_endpoint = CharField()
display_name = CharField()
description = CharField()
api_name = CharField()
class APIExecutionResponseSerializer(Serializer):
execution_status = CharField()
status_api = CharField()
error = CharField()
result = JSONField()

1
backend/api/tests.py Normal file
View File

@@ -0,0 +1 @@
# Create your tests here.

52
backend/api/urls.py Normal file
View File

@@ -0,0 +1,52 @@
from api.api_deployment_views import APIDeploymentViewSet, DeploymentExecution
from api.api_key_views import APIKeyViewSet
from django.urls import path, re_path
from rest_framework.urlpatterns import format_suffix_patterns
deployment = APIDeploymentViewSet.as_view(
{
"get": APIDeploymentViewSet.list.__name__,
"post": APIDeploymentViewSet.create.__name__,
}
)
deployment_details = APIDeploymentViewSet.as_view(
{
"get": APIDeploymentViewSet.retrieve.__name__,
"put": APIDeploymentViewSet.update.__name__,
"patch": APIDeploymentViewSet.partial_update.__name__,
"delete": APIDeploymentViewSet.destroy.__name__,
}
)
execute = DeploymentExecution.as_view()
key_details = APIKeyViewSet.as_view(
{
"get": APIKeyViewSet.retrieve.__name__,
"put": APIKeyViewSet.update.__name__,
"delete": APIKeyViewSet.destroy.__name__,
}
)
api_key = APIKeyViewSet.as_view(
{
"get": APIKeyViewSet.api_keys.__name__,
"post": APIKeyViewSet.create.__name__,
}
)
urlpatterns = format_suffix_patterns(
[
path("deployment/", deployment, name="api_deployment"),
path(
"deployment/<uuid:pk>/",
deployment_details,
name="api_deployment_details",
),
re_path(
r"^api/(?P<org_name>[\w-]+)/(?P<api_name>[\w-]+)/?$",
execute,
name="api_deployment_execution",
),
path("keys/<uuid:pk>/", key_details, name="key_details"),
path("keys/api/<str:api_id>/", api_key, name="api_key"),
]
)

0
backend/apps/__init__.py Normal file
View File

View File

@@ -0,0 +1,4 @@
class AppConstants:
"""Constants for Apps."""

View File

@@ -0,0 +1,6 @@
from rest_framework.exceptions import APIException
class FetchAppListFailed(APIException):
status_code = 400
default_detail = "Failed to fetch App list."

9
backend/apps/urls.py Normal file
View File

@@ -0,0 +1,9 @@
from django.urls import path
from apps import views
from rest_framework.urlpatterns import format_suffix_patterns
urlpatterns = format_suffix_patterns(
[
path("app/", views.get_app_list, name="app-list"),
]
)

22
backend/apps/views.py Normal file
View File

@@ -0,0 +1,22 @@
import logging
from apps.exceptions import FetchAppListFailed
from rest_framework import status
from rest_framework.decorators import api_view
from rest_framework.request import Request
from rest_framework.response import Response
logger = logging.getLogger(__name__)
@api_view(("GET",))
def get_app_list(request: Request) -> Response:
"""API to fetch List of Apps."""
if request.method == "GET":
try:
return Response(data=[], status=status.HTTP_200_OK)
# Refactored dated: 19/12/2023
# ( Removed -> backend/apps/app_processor.py )
except Exception as exe:
logger.error(f"Error occured while fetching app list {exe}")
raise FetchAppListFailed()

View File

@@ -0,0 +1,3 @@
from .celery import app as celery_app
__all__ = ["celery_app"]

20
backend/backend/asgi.py Normal file
View File

@@ -0,0 +1,20 @@
"""ASGI config for backend project.
It exposes the ASGI callable as a module-level variable named ``application``.
For more information on this file, see
https://docs.djangoproject.com/en/4.2/howto/deployment/asgi/
"""
import os
from django.core.asgi import get_asgi_application
from dotenv import load_dotenv
load_dotenv()
os.environ.setdefault(
"DJANGO_SETTINGS_MODULE",
os.environ.get("DJANGO_SETTINGS_MODULE", "backend.settings.dev"),
)
application = get_asgi_application()

29
backend/backend/celery.py Normal file
View File

@@ -0,0 +1,29 @@
"""This module contains the Celery configuration for the backend
project."""
import os
from celery import Celery
from django.conf import settings
# Set the default Django settings module for the 'celery' program.
os.environ.setdefault(
"DJANGO_SETTINGS_MODULE",
os.environ.get("DJANGO_SETTINGS_MODULE", "backend.settings.dev"),
)
# Create a Celery instance. Default time zone is UTC.
app = Celery("backend")
# Use Redis as the message broker.
app.conf.broker_url = settings.CELERY_BROKER_URL
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
# Load task modules from all registered Django app configs.
app.config_from_object("django.conf:settings", namespace="CELERY")
# Autodiscover tasks in all installed apps.
app.autodiscover_tasks()
# Use the Django-Celery-Beat scheduler.
app.conf.beat_scheduler = "django_celery_beat.schedulers:DatabaseScheduler"

View File

@@ -0,0 +1,30 @@
class RequestKey:
"""Commonly used keys in requests/repsonses."""
REQUEST = "request"
PROJECT = "project"
WORKFLOW = "workflow"
CREATED_BY = "created_by"
MODIFIED_BY = "modified_by"
MODIFIED_AT = "modified_at"
class FieldLengthConstants:
"""Used to determine length of fields in a model."""
ORG_NAME_SIZE = 64
CRON_LENGTH = 256
UUID_LENGTH = 36
# Not to be confused with a connector instance
CONNECTOR_ID_LENGTH = 128
ADAPTER_ID_LENGTH = 128
class RequestHeader:
"""Request header constants."""
X_API_KEY = "X-API-KEY"
class UrlPathConstants:
PROMPT_STUDIO = "prompt-studio/"

View File

@@ -0,0 +1,36 @@
from typing import Any, Optional
from rest_framework.exceptions import APIException
from rest_framework.response import Response
from rest_framework.views import exception_handler
from unstract.connectors.exceptions import ConnectorBaseException
class UnstractBaseException(APIException):
default_detail = "Error occurred"
def __init__(
self,
detail: Optional[str] = None,
core_err: Optional[ConnectorBaseException] = None,
**kwargs: Any,
) -> None:
if detail is None:
detail = self.default_detail
if core_err and core_err.user_message:
detail = core_err.user_message
super().__init__(detail=detail, **kwargs)
self._core_err = core_err
class LLMHelperError(Exception):
pass
def custom_exception_handler(exc, context) -> Response: # type: ignore
response = exception_handler(exc, context)
if response is not None:
response.data["status_code"] = response.status_code
return response

View File

@@ -0,0 +1,15 @@
# Flower is a real-time web based monitor and administration tool
# for Celery. Its under active development,
# but is already an essential tool.
from django.conf import settings
# Broker URL
BROKER_URL = settings.CELERY_BROKER_URL
# Flower web port
PORT = 5555
# Enable basic authentication (when required)
# basic_auth = {
# 'username': 'password'
# }

View File

@@ -0,0 +1,40 @@
"""URL configuration for backend project.
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/4.2/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: path('', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from account.admin import admin
from django.conf import settings
from django.conf.urls import * # noqa: F401, F403
from django.urls import include, path
path_prefix = settings.PATH_PREFIX
api_path_prefix = settings.API_DEPLOYMENT_PATH_PREFIX
urlpatterns = [
path(f"{path_prefix}/", include("account.urls")),
# Admin URLs
path(f"{path_prefix}/admin/doc/", include("django.contrib.admindocs.urls")),
path(f"{path_prefix}/admin/", admin.site.urls),
# Connector OAuth
path(f"{path_prefix}/", include("connector_auth.urls")),
# Docs
path(f"{path_prefix}/", include("docs.urls")),
# Socket.io
path(f"{path_prefix}/", include("log_events.urls")),
# API deployment
path(f"{api_path_prefix}/", include("api.urls")),
# Feature flags
path(f"{path_prefix}/flags/", include("feature_flag.urls")),
]

View File

@@ -0,0 +1,23 @@
from typing import Any
from backend.constants import RequestKey
from rest_framework.serializers import ModelSerializer
class AuditSerializer(ModelSerializer):
def create(self, validated_data: dict[str, Any]) -> Any:
if self.context.get(RequestKey.REQUEST):
validated_data[RequestKey.CREATED_BY] = self.context.get(
RequestKey.REQUEST
).user
validated_data[RequestKey.MODIFIED_BY] = self.context.get(
RequestKey.REQUEST
).user
return super().create(validated_data)
def update(self, instance: Any, validated_data: dict[str, Any]) -> Any:
if self.context.get(RequestKey.REQUEST):
validated_data[RequestKey.MODIFIED_BY] = self.context.get(
RequestKey.REQUEST
).user
return super().update(instance, validated_data)

View File

View File

@@ -0,0 +1,458 @@
"""Django settings for backend project.
Generated by 'django-admin startproject' using Django 4.2.1.
For more information on this file, see
https://docs.djangoproject.com/en/4.2/topics/settings/
For the full list of settings and their values, see
https://docs.djangoproject.com/en/4.2/ref/settings/
"""
import os
from pathlib import Path
from typing import Optional
from dotenv import find_dotenv, load_dotenv
missing_settings = []
def get_required_setting(
setting_key: str, default: Optional[str] = None
) -> Optional[str]:
"""Get the value of an environment variable specified by the given key. Add
missing keys to `missing_settings` so that exception can be raised at the
end.
Args:
key (str): The key of the environment variable
default (Optional[str], optional): Default value to return incase of
env not found. Defaults to None.
Returns:
Optional[str]: The value of the environment variable if found,
otherwise the default value.
"""
data = os.environ.get(setting_key, default)
if not data:
missing_settings.append(setting_key)
return data
# Build paths inside the project like this: BASE_DIR / 'subdir'.
BASE_DIR = Path(__file__).resolve().parent.parent
LOGGING = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"verbose": {
"format": "[%(asctime)s] %(levelname)s %(name)s: %(message)s",
"datefmt": "%d/%b/%Y %H:%M:%S",
},
"simple": {
"format": "{levelname} {message}",
"style": "{",
},
},
"handlers": {
"console": {
"level": "INFO", # Set the desired logging level here
"class": "logging.StreamHandler",
"formatter": "verbose",
},
},
"root": {
"handlers": ["console"],
"level": "INFO", # Set the desired logging level here as well
},
}
ENV_FILE = find_dotenv()
if ENV_FILE:
load_dotenv(ENV_FILE)
# Loading environment variables
WORKFLOW_ACTION_EXPIRATION_TIME_IN_SECOND = os.environ.get(
"WORKFLOW_ACTION_EXPIRATION_TIME_IN_SECOND", 10800
)
WEB_APP_ORIGIN_URL = os.environ.get(
"WEB_APP_ORIGIN_URL", "http://localhost:3000"
)
LOGIN_NEXT_URL = os.environ.get("LOGIN_NEXT_URL", "http://localhost:3000/org")
LANDING_URL = os.environ.get("LANDING_URL", "http://localhost:3000/landing")
ERROR_URL = os.environ.get("ERROR_URL", "http://localhost:3000/error")
DJANGO_APP_BACKEND_URL = os.environ.get(
"DJANGO_APP_BACKEND_URL", "http://localhost:8000"
)
INTERNAL_SERVICE_API_KEY = os.environ.get("INTERNAL_SERVICE_API_KEY")
GOOGLE_STORAGE_ACCESS_KEY_ID = os.environ.get("GOOGLE_STORAGE_ACCESS_KEY_ID")
GOOGLE_STORAGE_SECRET_ACCESS_KEY = os.environ.get(
"GOOGLE_STORAGE_SECRET_ACCESS_KEY"
)
UNSTRACT_FREE_STORAGE_BUCKET_NAME = os.environ.get(
"UNSTRACT_FREE_STORAGE_BUCKET_NAME", "pandora-user-storage"
)
GOOGLE_STORAGE_BASE_URL = os.environ.get("GOOGLE_STORAGE_BASE_URL")
REDIS_USER = os.environ.get("REDIS_USER", "default")
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD", "")
REDIS_HOST = os.environ.get("REDIS_HOST", "localhost")
REDIS_PORT = os.environ.get("REDIS_PORT", "6379")
REDIS_DB = os.environ.get("REDIS_DB", "")
SESSION_EXPIRATION_TIME_IN_SECOND = os.environ.get(
"SESSION_EXPIRATION_TIME_IN_SECOND", 3600
)
PATH_PREFIX = os.environ.get("PATH_PREFIX", "api/v1").strip("/")
API_DEPLOYMENT_PATH_PREFIX = os.environ.get(
"API_DEPLOYMENT_PATH_PREFIX", "deployment"
).strip("/")
DB_NAME = os.environ.get("DB_NAME", "unstract_db")
DB_USER = os.environ.get("DB_USER", "unstract_dev")
DB_HOST = os.environ.get("DB_HOST", "backend-db-1")
DB_PASSWORD = os.environ.get("DB_PASSWORD", "unstract_pass")
DB_PORT = os.environ.get("DB_PORT", 5432)
DEFAULT_ORGANIZATION = "default_org"
FLIPT_BASE_URL = os.environ.get("FLIPT_BASE_URL", "http://localhost:9005")
PLATFORM_HOST = os.environ.get("PLATFORM_SERVICE_HOST", "http://localhost")
PLATFORM_PORT = os.environ.get("PLATFORM_SERVICE_PORT", 3001)
PROMPT_HOST = os.environ.get("PROMPT_HOST", "http://localhost")
PROMPT_PORT = os.environ.get("PROMPT_PORT", 3003)
PROMPT_STUDIO_FILE_PATH = os.environ.get(
"PROMPT_STUDIO_FILE_PATH", "/app/prompt-studio-data"
)
X2TEXT_HOST = os.environ.get("X2TEXT_HOST", "http://localhost")
X2TEXT_PORT = os.environ.get("X2TEXT_PORT", 3004)
STRUCTURE_TOOL_IMAGE_URL = get_required_setting("STRUCTURE_TOOL_IMAGE_URL")
STRUCTURE_TOOL_IMAGE_NAME = get_required_setting("STRUCTURE_TOOL_IMAGE_NAME")
STRUCTURE_TOOL_IMAGE_TAG = get_required_setting("STRUCTURE_TOOL_IMAGE_TAG")
WORKFLOW_DATA_DIR = os.environ.get("WORKFLOW_DATA_DIR")
API_STORAGE_DIR = os.environ.get("API_STORAGE_DIR")
# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = get_required_setting("DJANGO_SECRET_KEY")
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True
ALLOWED_HOSTS = ["*"]
CSRF_TRUSTED_ORIGINS = [WEB_APP_ORIGIN_URL]
CORS_ALLOW_ALL_ORIGINS = False
SESSION_COOKIE_AGE = 86400
# Application definition
SHARED_APPS = (
# Multitenancy
"django_tenants",
"corsheaders",
# For the organization model
"account",
# Django apps should go below this line
"django.contrib.admin",
"django.contrib.auth",
"django.contrib.contenttypes",
"django.contrib.sessions",
"django.contrib.messages",
"django.contrib.staticfiles",
"django.contrib.admindocs",
# Third party apps should go below this line,
"rest_framework",
# Connector OAuth
"connector_auth",
"social_django",
# Doc generator
"drf_yasg",
"docs",
# Plugins
"plugins",
"log_events",
"feature_flag",
"django_celery_beat",
)
TENANT_APPS = (
# your tenant-specific apps
"django.contrib.admin",
"django.contrib.auth",
"django.contrib.contenttypes",
"django.contrib.sessions",
"django.contrib.messages",
"django.contrib.staticfiles",
"tenant_account",
"project",
"prompt",
"connector",
"adapter_processor",
"file_management",
"workflow_manager.endpoint",
"workflow_manager.workflow",
"tool_instance",
"pipeline",
"cron_expression_generator",
"platform_settings",
"api",
"prompt_studio.prompt_profile_manager",
"prompt_studio.prompt_studio",
"prompt_studio.prompt_studio_core",
"prompt_studio.prompt_studio_registry",
"prompt_studio.prompt_studio_output_manager",
)
INSTALLED_APPS = list(SHARED_APPS) + [
app for app in TENANT_APPS if app not in SHARED_APPS
]
DEFAULT_MODEL_BACKEND = "django.contrib.auth.backends.ModelBackend"
GOOGLE_MODEL_BACKEND = "social_core.backends.google.GoogleOAuth2"
AUTHENTICATION_BACKENDS = (
DEFAULT_MODEL_BACKEND,
GOOGLE_MODEL_BACKEND,
)
TENANT_MODEL = "account.Organization"
TENANT_DOMAIN_MODEL = "account.Domain"
AUTH_USER_MODEL = "account.User"
PUBLIC_ORG_ID = "public"
MIDDLEWARE = [
"corsheaders.middleware.CorsMiddleware",
"django_tenants.middleware.TenantSubfolderMiddleware",
"django.middleware.security.SecurityMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware",
"django.middleware.csrf.CsrfViewMiddleware",
"django.middleware.common.CommonMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
"django.middleware.clickjacking.XFrameOptionsMiddleware",
"account.custom_auth_middleware.CustomAuthMiddleware",
"middleware.exception.ExceptionLoggingMiddleware",
"social_django.middleware.SocialAuthExceptionMiddleware",
]
PUBLIC_SCHEMA_URLCONF = "backend.public_urls"
ROOT_URLCONF = "backend.urls"
TENANT_SUBFOLDER_PREFIX = f"/{PATH_PREFIX}/unstract"
SHOW_PUBLIC_IF_NO_TENANT_FOUND = True
TEMPLATES = [
{
"BACKEND": "django.template.backends.django.DjangoTemplates",
"DIRS": [],
"APP_DIRS": True,
"OPTIONS": {
"context_processors": [
"django.template.context_processors.debug",
"django.template.context_processors.request",
"django.contrib.auth.context_processors.auth",
"django.contrib.messages.context_processors.messages",
],
},
},
]
WSGI_APPLICATION = "backend.wsgi.application"
# Database
# https://docs.djangoproject.com/en/4.2/ref/settings/#databases
DATABASES = {
"default": {
"ENGINE": "django_tenants.postgresql_backend",
"NAME": f"{DB_NAME}",
"USER": f"{DB_USER}",
"HOST": f"{DB_HOST}",
"PASSWORD": f"{DB_PASSWORD}",
"PORT": f"{DB_PORT}",
"ATOMIC_REQUESTS": True,
}
}
DATABASE_ROUTERS = ("django_tenants.routers.TenantSyncRouter",)
CACHES = {
"default": {
"BACKEND": "django_redis.cache.RedisCache",
"LOCATION": f"redis://{REDIS_HOST}:{REDIS_PORT}",
"OPTIONS": {
"CLIENT_CLASS": "django_redis.client.DefaultClient",
"SERIALIZER": "django_redis.serializers.json.JSONSerializer",
"DB": REDIS_DB,
"USERNAME": REDIS_USER,
"PASSWORD": REDIS_PASSWORD,
},
"KEY_FUNCTION": "account.cache_service.custom_key_function",
}
}
RQ_QUEUES = {
"default": {"USE_REDIS_CACHE": "default"},
}
# Used for asynchronous/Queued execution
# Celery based scheduler
CELERY_BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}"
# CELERY_RESULT_BACKEND = f"redis://{REDIS_HOST}:{REDIS_PORT}/1"
# Postgres as result backend
CELERY_RESULT_BACKEND = (
f"db+postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
)
CELERY_ACCEPT_CONTENT = ["json"]
CELERY_TASK_SERIALIZER = "json"
CELERY_RESULT_SERIALIZER = "json"
CELERY_TIMEZONE = "UTC"
CELERY_TASK_MAX_RETRIES = 3
CELERY_TASK_RETRY_BACKOFF = 60 # Time in seconds before retrying the task
# Feature Flag
FEATURE_FLAG_SERVICE_URL = {
"evaluate": f"{FLIPT_BASE_URL}/api/v1/flags/evaluate/"
}
SCHEDULER_KWARGS = {
"coalesce": True,
"misfire_grace_time": 300,
"max_instances": 1,
"replace_existing": True,
}
# Password validation
# https://docs.djangoproject.com/en/4.2/ref/settings/#auth-password-validators
AUTH_PASSWORD_VALIDATORS = [
{
"NAME": "django.contrib.auth.password_validation."
"UserAttributeSimilarityValidator",
},
{
"NAME": "django.contrib.auth.password_validation."
"MinimumLengthValidator",
},
{
"NAME": "django.contrib.auth.password_validation."
"CommonPasswordValidator",
},
{
"NAME": "django.contrib.auth.password_validation."
"NumericPasswordValidator",
},
]
# Internationalization
# https://docs.djangoproject.com/en/4.2/topics/i18n/
LANGUAGE_CODE = "en-us"
TIME_ZONE = "UTC"
USE_I18N = True
USE_TZ = True
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/4.2/howto/static-files/
STATIC_URL = f"/{PATH_PREFIX}/static/"
# Default primary key field type
# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field
DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
REST_FRAMEWORK = {
"DEFAULT_PERMISSION_CLASSES": [], # TODO: Update once auth is figured
"TEST_REQUEST_DEFAULT_FORMAT": "json",
"EXCEPTION_HANDLER": "middleware.exception.drf_logging_exc_handler",
}
# These paths will work without authentication
WHITELISTED_PATHS_LIST = [
"/login",
"/home",
"/callback",
"/favicon.ico",
"/logout",
"/signup",
]
WHITELISTED_PATHS = [f"/{PATH_PREFIX}{PATH}" for PATH in WHITELISTED_PATHS_LIST]
# White lists workflow-api-deployment path
WHITELISTED_PATHS.append(f"/{API_DEPLOYMENT_PATH_PREFIX}")
# White list paths under tenant paths
TENANT_ACCESSIBLE_PUBLIC_PATHS_LIST = ["/oauth", "/organization", "/doc"]
TENANT_ACCESSIBLE_PUBLIC_PATHS = [
f"/{PATH_PREFIX}{PATH}" for PATH in TENANT_ACCESSIBLE_PUBLIC_PATHS_LIST
]
# API Doc Generator Settings
# https://drf-yasg.readthedocs.io/en/stable/settings.html
REDOC_SETTINGS = {
"PATH_IN_MIDDLE": True,
"REQUIRED_PROPS_FIRST": True,
}
# Social Auth Settings
SOCIAL_AUTH_LOGIN_REDIRECT_URL = (
f"{WEB_APP_ORIGIN_URL}/oauth-status/?status=success"
)
SOCIAL_AUTH_LOGIN_ERROR_URL = f"{WEB_APP_ORIGIN_URL}/oauth-status/?status=error"
SOCIAL_AUTH_EXTRA_DATA_EXPIRATION_TIME_IN_SECOND = os.environ.get(
"SOCIAL_AUTH_EXTRA_DATA_EXPIRATION_TIME_IN_SECOND", 3600
)
SOCIAL_AUTH_USER_MODEL = "account.User"
SOCIAL_AUTH_STORAGE = "connector_auth.models.ConnectorDjangoStorage"
SOCIAL_AUTH_JSONFIELD_ENABLED = True
SOCIAL_AUTH_URL_NAMESPACE = "social"
SOCIAL_AUTH_FIELDS_STORED_IN_SESSION = ["oauth-key", "connector-guid"]
SOCIAL_AUTH_TRAILING_SLASH = False
for key in [
"GOOGLE_OAUTH2_KEY",
"GOOGLE_OAUTH2_SECRET",
]:
exec("SOCIAL_AUTH_{key} = os.environ.get('{key}')".format(key=key))
SOCIAL_AUTH_PIPELINE = (
# Checks if user is authenticated
"connector_auth.pipeline.common.check_user_exists",
# Gets user details from provider
"social_core.pipeline.social_auth.social_details",
"social_core.pipeline.social_auth.social_uid",
# Cache secrets and fields in redis
"connector_auth.pipeline.common.cache_oauth_creds",
)
# Social Auth: Google OAuth2
# Default takes care of sign in flow which we don't need for connectors
SOCIAL_AUTH_GOOGLE_OAUTH2_IGNORE_DEFAULT_SCOPE = True
SOCIAL_AUTH_GOOGLE_OAUTH2_SCOPE = [
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/drive",
]
SOCIAL_AUTH_GOOGLE_OAUTH2_AUTH_EXTRA_ARGUMENTS = {
"access_type": "offline",
"include_granted_scopes": "true",
"prompt": "consent",
}
SOCIAL_AUTH_GOOGLE_OAUTH2_USE_UNIQUE_USER_ID = True
# Always keep this line at the bottom of the file.
if missing_settings:
ERROR_MESSAGE = "Below required settings are missing.\n" + ",\n".join(
missing_settings
)
raise ValueError(ERROR_MESSAGE)

View File

@@ -0,0 +1,27 @@
from backend.settings.base import * # noqa: F401, F403
DEBUG = True
X_FRAME_OPTIONS = "http://localhost:3000"
X_FRAME_OPTIONS = "ALLOW-FROM http://localhost:3000"
CORS_ALLOWED_ORIGINS = [
"http://localhost:3000",
"http://127.0.0.1:3000",
"https://dev-3xlzwou1raoituv0.us.auth0.com",
# Other allowed origins if needed
]
CORS_ORIGIN_WHITELIST = [
"http://localhost:3000",
"http://127.0.0.1:3000",
"https://dev-3xlzwou1raoituv0.us.auth0.com",
# Other allowed origins if needed
]
CORS_ALLOW_METHODS = ["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
CORS_ALLOW_HEADERS = [
"authorization",
"content-type",
]

View File

@@ -0,0 +1,3 @@
from backend.settings.base import * # noqa: F401, F403
DEBUG = True

53
backend/backend/urls.py Normal file
View File

@@ -0,0 +1,53 @@
"""URL configuration for backend project.
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/4.2/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: path('', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from backend.constants import UrlPathConstants
from django.conf.urls import * # noqa: F401, F403
from django.urls import include, path
urlpatterns = [
path("", include("tenant_account.urls")),
path("", include("prompt.urls")),
path("", include("project.urls")),
path("", include("connector.urls")),
path("", include("connector_processor.urls")),
path("", include("adapter_processor.urls")),
path("", include("file_management.urls")),
path("", include("tool_instance.urls")),
path("", include("cron_expression_generator.urls")),
path("", include("pipeline.urls")),
path("", include("apps.urls")),
path("workflow/", include("workflow_manager.urls")),
path("platform/", include("platform_settings.urls")),
path("api/", include("api.urls")),
path(
UrlPathConstants.PROMPT_STUDIO,
include("prompt_studio.prompt_profile_manager.urls"),
),
path(
UrlPathConstants.PROMPT_STUDIO,
include("prompt_studio.prompt_studio.urls"),
),
path("", include("prompt_studio.prompt_studio_core.urls")),
path(
UrlPathConstants.PROMPT_STUDIO,
include("prompt_studio.prompt_studio_registry.urls"),
),
path(
UrlPathConstants.PROMPT_STUDIO,
include("prompt_studio.prompt_studio_output_manager.urls"),
),
]

25
backend/backend/wsgi.py Normal file
View File

@@ -0,0 +1,25 @@
"""WSGI config for backend project.
It exposes the WSGI callable as a module-level variable named ``application``.
For more information on this file, see
https://docs.djangoproject.com/en/4.2/howto/deployment/wsgi/
"""
import os
import socketio
from django.conf import settings
from django.core.wsgi import get_wsgi_application
from dotenv import load_dotenv
from log_events.views import sio
load_dotenv()
path_prefix = settings.PATH_PREFIX
os.environ.setdefault(
"DJANGO_SETTINGS_MODULE",
os.environ.get("DJANGO_SETTINGS_MODULE", "backend.settings.dev"),
)
django_app = get_wsgi_application()
application = socketio.WSGIApp(sio, django_app,socketio_path=f"{path_prefix}/socket")

View File

View File

@@ -0,0 +1,5 @@
from django.contrib import admin
from .models import ConnectorInstance
admin.site.register(ConnectorInstance)

View File

@@ -0,0 +1,5 @@
from django.apps import AppConfig
class ConnectorConfig(AppConfig):
name = "connector"

View File

@@ -0,0 +1,316 @@
import logging
from typing import Any, Optional
from account.models import User
from connector.constants import ConnectorInstanceConstant
from connector.models import ConnectorInstance
from connector.unstract_account import UnstractAccount
from django.conf import settings
from django.db import connection
from unstract.connectors.filesystems.ucs import UnstractCloudStorage
from unstract.connectors.filesystems.ucs.constants import UCSKey
from workflow_manager.workflow.models.workflow import Workflow
logger = logging.getLogger(__name__)
class ConnectorInstanceHelper:
@staticmethod
def create_default_gcs_connector(workflow: Workflow, user: User) -> None:
"""Method to create default storage connector.
Args:
org_id (str)
workflow (Workflow)
user (User)
"""
org_schema = connection.tenant.schema_name
if not user.project_storage_created:
logger.info("Creating default storage")
account = UnstractAccount(org_schema, user.email)
account.provision_s3_storage()
account.upload_sample_files()
user.project_storage_created = True
user.save()
logger.info("default storage created successfully.")
logger.info("Adding connectors to Unstract")
connector_name = ConnectorInstanceConstant.USER_STORAGE
gcs_id = UnstractCloudStorage.get_id()
bucket_name = settings.UNSTRACT_FREE_STORAGE_BUCKET_NAME
base_path = f"{bucket_name}/{org_schema}/{user.email}"
connector_metadata = {
UCSKey.KEY: settings.GOOGLE_STORAGE_ACCESS_KEY_ID,
UCSKey.SECRET: settings.GOOGLE_STORAGE_SECRET_ACCESS_KEY,
UCSKey.BUCKET: bucket_name,
UCSKey.ENDPOINT_URL: settings.GOOGLE_STORAGE_BASE_URL,
}
connector_metadata__input = {
**connector_metadata,
UCSKey.PATH: base_path + "/input",
}
connector_metadata__output = {
**connector_metadata,
UCSKey.PATH: base_path + "/output",
}
ConnectorInstance.objects.create(
connector_name=connector_name,
workflow=workflow,
created_by=user,
connector_id=gcs_id,
connector_metadata=connector_metadata__input,
connector_type=ConnectorInstance.ConnectorType.INPUT,
connector_mode=ConnectorInstance.ConnectorMode.FILE_SYSTEM,
)
ConnectorInstance.objects.create(
connector_name=connector_name,
workflow=workflow,
created_by=user,
connector_id=gcs_id,
connector_metadata=connector_metadata__output,
connector_type=ConnectorInstance.ConnectorType.OUTPUT,
connector_mode=ConnectorInstance.ConnectorMode.FILE_SYSTEM,
)
logger.info("Connectors added successfully.")
@staticmethod
def get_connector_instances_by_workflow(
workflow_id: str,
connector_type: tuple[str, str],
connector_mode: Optional[tuple[int, str]] = None,
values: Optional[list[str]] = None,
connector_name: Optional[str] = None,
) -> list[ConnectorInstance]:
"""Method to get connector instances by workflow.
Args:
workflow_id (str)
connector_type (tuple[str, str]): Specifies input/output
connector_mode (Optional[tuple[int, str]], optional):
Specifies database/file
values (Optional[list[str]], optional): Defaults to None.
connector_name (Optional[str], optional): Defaults to None.
Returns:
list[ConnectorInstance]
"""
logger.info(f"Setting connector mode to {connector_mode}")
filter_params: dict[str, Any] = {
"workflow": workflow_id,
"connector_type": connector_type,
}
if connector_mode is not None:
filter_params["connector_mode"] = connector_mode
if connector_name is not None:
filter_params["connector_name"] = connector_name
connector_instances = ConnectorInstance.objects.filter(
**filter_params
).all()
logger.info(f"Retrived connector instance values {connector_instances}")
if values is not None:
filtered_connector_instances = connector_instances.values(*values)
logger.info(
f"Returning filtered \
connector instance value {filtered_connector_instances}"
)
return list(filtered_connector_instances)
logger.info(f"Returning connector instances {connector_instances}")
return list(connector_instances)
@staticmethod
def get_connector_instance_by_workflow(
workflow_id: str,
connector_type: tuple[str, str],
connector_mode: Optional[tuple[int, str]] = None,
connector_name: Optional[str] = None,
) -> Optional[ConnectorInstance]:
"""Get one connector instance.
Use this method if the connector instance is unique for \
filter_params
Args:
workflow_id (str): _description_
connector_type (tuple[str, str]): Specifies input/output
connector_mode (Optional[tuple[int, str]], optional).
Specifies database/filesystem
values (Optional[list[str]], optional).
connector_name (Optional[str], optional).
Returns:
list[ConnectorInstance]: _description_
"""
logger.info("Fetching connector instance by workflow")
filter_params: dict[str, Any] = {
"workflow": workflow_id,
"connector_type": connector_type,
}
if connector_mode is not None:
filter_params["connector_mode"] = connector_mode
if connector_name is not None:
filter_params["connector_name"] = connector_name
try:
connector_instance: ConnectorInstance = (
ConnectorInstance.objects.filter(**filter_params).first()
)
except Exception as exc:
logger.error(
f"Error occured while fetching connector instances {exc}"
)
raise exc
return connector_instance
@staticmethod
def get_input_connector_instance_by_name_for_workflow(
workflow_id: str,
connector_name: str,
) -> Optional[ConnectorInstance]:
"""Method to get Input connector instance name from the workflow.
Args:
workflow_id (str)
connector_name (str)
Returns:
Optional[ConnectorInstance]
"""
return ConnectorInstanceHelper.get_connector_instance_by_workflow(
workflow_id=workflow_id,
connector_type=ConnectorInstance.ConnectorType.INPUT,
connector_name=connector_name,
)
@staticmethod
def get_output_connector_instance_by_name_for_workflow(
workflow_id: str,
connector_name: str,
) -> Optional[ConnectorInstance]:
"""Method to get output connector name by Workflow.
Args:
workflow_id (str)
connector_name (str)
Returns:
Optional[ConnectorInstance]
"""
return ConnectorInstanceHelper.get_connector_instance_by_workflow(
workflow_id=workflow_id,
connector_type=ConnectorInstance.ConnectorType.OUTPUT,
connector_name=connector_name,
)
@staticmethod
def get_input_connector_instances_by_workflow(
workflow_id: str,
) -> list[ConnectorInstance]:
"""Method to get connector instances by workflow.
Args:
workflow_id (str)
Returns:
list[ConnectorInstance]
"""
return ConnectorInstanceHelper.get_connector_instances_by_workflow(
workflow_id, ConnectorInstance.ConnectorType.INPUT
)
@staticmethod
def get_output_connector_instances_by_workflow(
workflow_id: str,
) -> list[ConnectorInstance]:
"""Method to get output connector instances by workflow.
Args:
workflow_id (str): _description_
Returns:
list[ConnectorInstance]: _description_
"""
return ConnectorInstanceHelper.get_connector_instances_by_workflow(
workflow_id, ConnectorInstance.ConnectorType.OUTPUT
)
@staticmethod
def get_file_system_input_connector_instances_by_workflow(
workflow_id: str, values: Optional[list[str]] = None
) -> list[ConnectorInstance]:
"""Method to fetch file system connector by workflow.
Args:
workflow_id (str):
values (Optional[list[str]], optional)
Returns:
list[ConnectorInstance]
"""
return ConnectorInstanceHelper.get_connector_instances_by_workflow(
workflow_id,
ConnectorInstance.ConnectorType.INPUT,
ConnectorInstance.ConnectorMode.FILE_SYSTEM,
values,
)
@staticmethod
def get_file_system_output_connector_instances_by_workflow(
workflow_id: str, values: Optional[list[str]] = None
) -> list[ConnectorInstance]:
"""Method to get file system output connector by workflow.
Args:
workflow_id (str)
values (Optional[list[str]], optional)
Returns:
list[ConnectorInstance]
"""
return ConnectorInstanceHelper.get_connector_instances_by_workflow(
workflow_id,
ConnectorInstance.ConnectorType.OUTPUT,
ConnectorInstance.ConnectorMode.FILE_SYSTEM,
values,
)
@staticmethod
def get_database_input_connector_instances_by_workflow(
workflow_id: str, values: Optional[list[str]] = None
) -> list[ConnectorInstance]:
"""Method to fetch input database connectors by workflow.
Args:
workflow_id (str)
values (Optional[list[str]], optional)
Returns:
list[ConnectorInstance]
"""
return ConnectorInstanceHelper.get_connector_instances_by_workflow(
workflow_id,
ConnectorInstance.ConnectorType.INPUT,
ConnectorInstance.ConnectorMode.DATABASE,
values,
)
@staticmethod
def get_database_output_connector_instances_by_workflow(
workflow_id: str, values: Optional[list[str]] = None
) -> list[ConnectorInstance]:
"""Method to fetch output database connectors by workflow.
Args:
workflow_id (str)
values (Optional[list[str]], optional)
Returns:
list[ConnectorInstance]
"""
return ConnectorInstanceHelper.get_connector_instances_by_workflow(
workflow_id,
ConnectorInstance.ConnectorType.OUTPUT,
ConnectorInstance.ConnectorMode.DATABASE,
values,
)

View File

@@ -0,0 +1,17 @@
class ConnectorInstanceKey:
CONNECTOR_ID = "connector_id"
CONNECTOR_NAME = "connector_name"
CONNECTOR_TYPE = "connector_type"
CONNECTOR_MODE = "connector_mode"
CONNECTOR_VERSION = "connector_version"
CONNECTOR_AUTH = "connector_auth"
CONNECTOR_METADATA = "connector_metadata"
CONNECTOR_METADATA_B = "connector_metadata_b"
CONNECTOR_EXISTS = (
"Connector with this configuration already exists in this project."
)
DUPLICATE_API = "It appears that a duplicate call may have been made."
class ConnectorInstanceConstant:
USER_STORAGE = "User Storage"

View File

@@ -0,0 +1,38 @@
from datetime import datetime
from connector_auth.constants import SocialAuthConstants
from connector_auth.models import ConnectorAuth
from django.db import models
import logging
logger = logging.getLogger(__name__)
class ConnectorAuthJSONField(models.JSONField):
def from_db_value(self, value, expression, connection): # type: ignore
""" Overrding default function. """
metadata = super().from_db_value(value, expression, connection)
provider = metadata.get(SocialAuthConstants.PROVIDER)
uid = metadata.get(SocialAuthConstants.UID)
if provider and uid:
refresh_after_str = metadata.get(SocialAuthConstants.REFRESH_AFTER)
if refresh_after_str:
refresh_after = datetime.strptime(
refresh_after_str, SocialAuthConstants.REFRESH_AFTER_FORMAT
)
if datetime.now() > refresh_after:
metadata = self._refresh_tokens(provider, uid)
return metadata
def _refresh_tokens(self, provider: str, uid: str) -> dict[str, str]:
"""Retrieves PSA object and refreshes the token if necessary."""
connector_auth: ConnectorAuth = ConnectorAuth.get_social_auth(
provider=provider, uid=uid
)
tokens_refreshed = False
if connector_auth:
(
connector_metadata,
tokens_refreshed,
) = connector_auth.get_and_refresh_tokens()
return connector_metadata # type: ignore

View File

@@ -0,0 +1,122 @@
# Generated by Django 4.2.1 on 2024-01-23 11:18
import uuid
import connector.fields
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
("project", "0001_initial"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
("workflow", "0001_initial"),
("connector_auth", "0001_initial"),
]
operations = [
migrations.CreateModel(
name="ConnectorInstance",
fields=[
("created_at", models.DateTimeField(auto_now_add=True)),
("modified_at", models.DateTimeField(auto_now=True)),
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("connector_name", models.TextField(max_length=128)),
("connector_id", models.CharField(default="", max_length=128)),
(
"connector_metadata",
connector.fields.ConnectorAuthJSONField(
db_column="connector_metadata", default=dict
),
),
(
"connector_version",
models.CharField(default="", max_length=64),
),
(
"connector_type",
models.CharField(
choices=[("INPUT", "Input"), ("OUTPUT", "Output")]
),
),
(
"connector_mode",
models.CharField(
choices=[
(0, "UNKNOWN"),
(1, "FILE_SYSTEM"),
(2, "DATABASE"),
],
db_comment="0: UNKNOWN, 1: FILE_SYSTEM, 2: DATABASE",
default=0,
),
),
(
"connector_auth",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="connector_auth.connectorauth",
),
),
(
"created_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="created_connectors",
to=settings.AUTH_USER_MODEL,
),
),
(
"modified_by",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="modified_connectors",
to=settings.AUTH_USER_MODEL,
),
),
(
"project",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="project_connectors",
to="project.project",
),
),
(
"workflow",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="workflow_connectors",
to="workflow.workflow",
),
),
],
),
migrations.AddConstraint(
model_name="connectorinstance",
constraint=models.UniqueConstraint(
fields=("connector_name", "workflow", "connector_type"),
name="unique_connector",
),
),
]

View File

@@ -0,0 +1,42 @@
# Generated by Django 4.2.1 on 2024-02-16 06:50
import json
from typing import Any
from account.models import EncryptionSecret
from connector.models import ConnectorInstance
from cryptography.fernet import Fernet
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("connector", "0001_initial"),
("account", "0005_encryptionsecret"),
]
def EncryptCredentials(apps: Any, schema_editor: Any) -> None:
encryption_secret: EncryptionSecret = EncryptionSecret.objects.get()
f: Fernet = Fernet(encryption_secret.key.encode("utf-8"))
queryset = ConnectorInstance.objects.all()
for obj in queryset: # type: ignore
# Access attributes of the object
if hasattr(obj, "connector_metadata"):
json_string: str = json.dumps(obj.connector_metadata)
obj.connector_metadata_b = f.encrypt(
json_string.encode("utf-8")
)
obj.save()
operations = [
migrations.AddField(
model_name="connectorinstance",
name="connector_metadata_b",
field=models.BinaryField(null=True),
),
migrations.RunPython(
EncryptCredentials, reverse_code=migrations.RunPython.noop
),
]

View File

116
backend/connector/models.py Normal file
View File

@@ -0,0 +1,116 @@
import uuid
from account.models import User
from connector.fields import ConnectorAuthJSONField
from connector_auth.models import ConnectorAuth
from connector_processor.connector_processor import ConnectorProcessor
from connector_processor.constants import ConnectorKeys
from django.db import models
from project.models import Project
from utils.models.base_model import BaseModel
from workflow_manager.workflow.models import Workflow
from backend.constants import FieldLengthConstants as FLC
CONNECTOR_NAME_SIZE = 128
VERSION_NAME_SIZE = 64
class ConnectorInstance(BaseModel):
class ConnectorType(models.TextChoices):
INPUT = "INPUT", "Input"
OUTPUT = "OUTPUT", "Output"
class ConnectorMode(models.IntegerChoices):
UNKNOWN = 0, "UNKNOWN"
FILE_SYSTEM = 1, "FILE_SYSTEM"
DATABASE = 2, "DATABASE"
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
connector_name = models.TextField(
max_length=CONNECTOR_NAME_SIZE, null=False, blank=False
)
project = models.ForeignKey(
Project,
on_delete=models.CASCADE,
related_name="project_connectors",
null=True,
blank=True,
)
workflow = models.ForeignKey(
Workflow,
on_delete=models.CASCADE,
related_name="workflow_connectors",
null=False,
blank=False,
)
connector_id = models.CharField(
max_length=FLC.CONNECTOR_ID_LENGTH, default=""
)
# TODO Required to be removed
connector_metadata = ConnectorAuthJSONField(
db_column="connector_metadata", null=False, blank=False, default=dict
)
connector_metadata_b = models.BinaryField(null=True)
connector_version = models.CharField(
max_length=VERSION_NAME_SIZE, default=""
)
connector_type = models.CharField(choices=ConnectorType.choices)
connector_auth = models.ForeignKey(
ConnectorAuth, on_delete=models.SET_NULL, null=True, blank=True
)
connector_mode = models.CharField(
choices=ConnectorMode.choices,
default=ConnectorMode.UNKNOWN,
db_comment="0: UNKNOWN, 1: FILE_SYSTEM, 2: DATABASE",
)
created_by = models.ForeignKey(
User,
on_delete=models.SET_NULL,
related_name="created_connectors",
null=True,
blank=True,
)
modified_by = models.ForeignKey(
User,
on_delete=models.SET_NULL,
related_name="modified_connectors",
null=True,
blank=True,
)
def get_connector_metadata(self) -> dict[str, str]:
"""Gets connector metadata and refreshes the tokens if needed in case
of OAuth."""
tokens_refreshed = False
if self.connector_auth:
(
self.connector_metadata,
tokens_refreshed,
) = self.connector_auth.get_and_refresh_tokens()
if tokens_refreshed:
self.save()
return self.connector_metadata
@staticmethod
def supportsOAuth(connector_id: str) -> bool:
return bool(
ConnectorProcessor.get_connector_data_with_key(
connector_id, ConnectorKeys.OAUTH
)
)
def __str__(self) -> str:
return (
f"Connector({self.id}, type{self.connector_type},"
f" workflow: {self.workflow})"
)
class Meta:
constraints = [
models.UniqueConstraint(
fields=["connector_name", "workflow", "connector_type"],
name="unique_connector",
),
]

Some files were not shown because too many files have changed in this diff Show More